Files
udmin/backend/src/flow/executors/db.rs
ayou 81757eecf5 feat(flow): 重构流程引擎与任务执行器架构
重构流程引擎核心组件,引入执行器接口Executor替代原有TaskComponent,优化节点配置映射逻辑:
1. 新增mappers模块集中处理节点配置提取
2. 为存储层添加Storage trait抽象
3. 移除对ctx魔法字段的依赖,直接传递节点信息
4. 增加构建器模式支持引擎创建
5. 完善DSL解析的输入校验

同时标记部分未使用代码为allow(dead_code)
2025-09-16 23:58:28 +08:00

258 lines
12 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use async_trait::async_trait;
use serde_json::{json, Value};
use tracing::info;
use crate::flow::task::Executor;
use crate::flow::domain::{NodeDef, NodeId};
#[derive(Default)]
pub struct DbTask;
#[async_trait]
impl Executor for DbTask {
async fn execute(&self, node_id: &NodeId, _node: &NodeDef, ctx: &mut Value) -> anyhow::Result<()> {
// 1) 读取 db 配置:仅节点级 db不再回退到全局 ctx.db避免误用项目数据库
let node_id_opt = Some(node_id.0.clone());
let cfg = match (&node_id_opt, ctx.get("nodes")) {
(Some(node_id), Some(nodes)) => nodes.get(&node_id).and_then(|n| n.get("db")).cloned(),
_ => None,
};
let Some(cfg) = cfg else {
info!(target = "udmin.flow", "db task: no config found, skip");
return Ok(());
};
// 3) 解析配置(包含可选连接信息)
let (sql, params, output_key, conn, mode_from_db) = parse_db_config(cfg)?;
// 提前读取结果模式,优先 connection.mode其次 db.output.mode/db.outputMode/db.mode
let result_mode = get_result_mode_from_conn(&conn).or(mode_from_db);
info!(target = "udmin.flow", "db task: exec sql: {}", sql);
// 4) 获取连接:必须显式声明 db.connection禁止回退到项目全局数据库避免安全风险
let db: std::borrow::Cow<'_, crate::db::Db>;
let tmp_conn; // 用于在本作用域内持有临时连接
use sea_orm::{Statement, ConnectionTrait};
let conn_cfg = conn.ok_or_else(|| anyhow::anyhow!("db task: connection config is required (db.connection)"))?;
// 构造 URL 并建立临时连接
let url = extract_connection_url(conn_cfg)?;
use sea_orm::{ConnectOptions, Database};
use std::time::Duration;
let mut opt = ConnectOptions::new(url);
opt.max_connections(20)
.min_connections(1)
.connect_timeout(Duration::from_secs(8))
.idle_timeout(Duration::from_secs(120))
.sqlx_logging(true);
tmp_conn = Database::connect(opt).await?;
db = std::borrow::Cow::Owned(tmp_conn);
// 判定是否为 SELECT简单判断前缀允许前导空白与括号
let is_select = {
let s = sql.trim_start();
let s = s.trim_start_matches('(');
s.to_uppercase().starts_with("SELECT")
};
// 构建参数列表(支持位置和命名两种形式)
let params_vec: Vec<sea_orm::Value> = match params {
None => vec![],
Some(Value::Array(arr)) => arr.into_iter().map(json_to_db_value).collect::<anyhow::Result<_>>()?,
Some(Value::Object(obj)) => {
// 对命名参数对象,保持插入顺序不可控,这里仅将值收集为位置绑定,建议 SQL 使用 `?` 占位
obj.into_iter().map(|(_, v)| json_to_db_value(v)).collect::<anyhow::Result<_>>()?
}
Some(v) => {
// 其它类型:当作单个位置参数
vec![json_to_db_value(v)?]
}
};
let stmt = Statement::from_sql_and_values(db.get_database_backend(), &sql, params_vec);
let result = if is_select {
let rows = db.query_all(stmt).await?;
// 将 QueryResult 转换为 JSON 数组
let mut out = Vec::with_capacity(rows.len());
for row in rows {
let mut obj = serde_json::Map::new();
// 读取列名列表
let cols = row.column_names();
for col_name in cols.iter() {
let key = col_name.to_string();
// 尝试以通用 JSON 值提取优先字符串、数值、布尔、二进制、null
let val = try_get_as_json(&row, &key);
obj.insert(key, val);
}
out.push(Value::Object(obj));
}
// 默认 rows 模式:直接返回数组
match result_mode.as_deref() {
// 返回首行字段对象(无则 Null
Some("fields") | Some("first") => {
if let Some(Value::Object(m)) = out.get(0) { Value::Object(m.clone()) } else { Value::Null }
}
// 默认与显式 rows 都返回数组
_ => Value::Array(out),
}
} else {
let exec = db.execute(stmt).await?;
// 非 SELECT 默认返回受影响行数
match result_mode.as_deref() {
// 如显式要求 rows则返回空数组
Some("rows") => json!([]),
_ => json!(exec.rows_affected()),
}
};
// 5) 写回 ctx并对敏感信息脱敏
let write_key = output_key.unwrap_or_else(|| "db_response".to_string());
if let (Some(node_id), Some(obj)) = (node_id_opt, ctx.as_object_mut()) {
if let Some(nodes) = obj.get_mut("nodes").and_then(|v| v.as_object_mut()) {
if let Some(target) = nodes.get_mut(&node_id).and_then(|v| v.as_object_mut()) {
// 写入结果
target.insert(write_key, result);
// 对密码字段脱敏(保留其它配置不变)
if let Some(dbv) = target.get_mut("db") {
if let Some(dbo) = dbv.as_object_mut() {
if let Some(connv) = dbo.get_mut("connection") {
match connv {
Value::Object(m) => {
if let Some(pw) = m.get_mut("password") {
*pw = Value::String("***".to_string());
}
if let Some(Value::String(url)) = m.get_mut("url") {
*url = "***".to_string();
}
}
Value::String(s) => { *s = "***".to_string(); }
_ => {}
}
}
}
}
return Ok(());
}
}
}
if let Value::Object(map) = ctx { map.insert(write_key, result); }
Ok(())
}
}
fn parse_db_config(cfg: Value) -> anyhow::Result<(String, Option<Value>, Option<String>, Option<Value>, Option<String>)> {
match cfg {
Value::String(sql) => Ok((sql, None, None, None, None)),
Value::Object(mut m) => {
let sql = m
.remove("sql")
.and_then(|v| v.as_str().map(|s| s.to_string()))
.ok_or_else(|| anyhow::anyhow!("db config missing sql"))?;
let params = m.remove("params");
let output_key = m.remove("outputKey").and_then(|v| v.as_str().map(|s| s.to_string()));
// 在移除 connection 前,从 db 层读取可能的输出模式
let mode_from_db = {
// db.output.mode
let from_output = m.get("output").and_then(|v| v.as_object()).and_then(|o| o.get("mode")).and_then(|v| v.as_str()).map(|s| s.to_string());
// db.outputMode 或 db.mode
let from_flat = m.get("outputMode").and_then(|v| v.as_str()).map(|s| s.to_string())
.or_else(|| m.get("mode").and_then(|v| v.as_str()).map(|s| s.to_string()));
from_output.or(from_flat)
};
let conn = m.remove("connection");
// 安全策略:必须显式声明连接,禁止默认落到全局数据库
if conn.is_none() {
return Err(anyhow::anyhow!("db config missing connection (db.connection is required)"));
}
Ok((sql, params, output_key, conn, mode_from_db))
}
_ => Err(anyhow::anyhow!("invalid db config")),
}
}
fn extract_connection_url(cfg: Value) -> anyhow::Result<String> {
match cfg {
Value::String(url) => Ok(url),
Value::Object(mut m) => {
if let Some(url) = m.remove("url").and_then(|v| v.as_str().map(|s| s.to_string())) {
return Ok(url);
}
let driver = m
.remove("driver")
.and_then(|v| v.as_str().map(|s| s.to_string()))
.unwrap_or_else(|| "mysql".to_string());
// sqlite 特殊处理:仅需要 database文件路径或 :memory:
if driver == "sqlite" {
let database = m.remove("database").and_then(|v| v.as_str().map(|s| s.to_string())).ok_or_else(|| anyhow::anyhow!("connection.database is required for sqlite unless url provided"))?;
return Ok(format!("sqlite://{}", database));
}
let host = m.remove("host").and_then(|v| v.as_str().map(|s| s.to_string())).unwrap_or_else(|| "localhost".to_string());
let port = m.remove("port").map(|v| match v { Value::Number(n) => n.to_string(), Value::String(s) => s, _ => String::new() });
let database = m.remove("database").and_then(|v| v.as_str().map(|s| s.to_string())).ok_or_else(|| anyhow::anyhow!("connection.database is required unless url provided"))?;
let username = m.remove("username").and_then(|v| v.as_str().map(|s| s.to_string())).ok_or_else(|| anyhow::anyhow!("connection.username is required unless url provided"))?;
let password = m.remove("password").and_then(|v| v.as_str().map(|s| s.to_string())).unwrap_or_default();
let port_part = port.filter(|s| !s.is_empty()).map(|s| format!(":{}", s)).unwrap_or_default();
let url = format!(
"{}://{}:{}@{}{}{}",
driver,
percent_encoding::utf8_percent_encode(&username, percent_encoding::NON_ALPHANUMERIC),
percent_encoding::utf8_percent_encode(&password, percent_encoding::NON_ALPHANUMERIC),
host,
port_part,
format!("/{}", database)
);
Ok(url)
}
_ => Err(anyhow::anyhow!("invalid connection config")),
}
}
fn get_result_mode_from_conn(conn: &Option<Value>) -> Option<String> {
match conn {
Some(Value::Object(m)) => m.get("mode").and_then(|v| v.as_str()).map(|s| s.to_string()),
_ => None,
}
}
fn json_to_db_value(v: Value) -> anyhow::Result<sea_orm::Value> {
use sea_orm::Value as DbValue;
let dv = match v {
Value::Null => DbValue::String(None),
Value::Bool(b) => DbValue::Bool(Some(b)),
Value::Number(n) => {
if let Some(i) = n.as_i64() { DbValue::BigInt(Some(i)) }
else if let Some(u) = n.as_u64() { DbValue::BigUnsigned(Some(u)) }
else if let Some(f) = n.as_f64() { DbValue::Double(Some(f)) }
else { DbValue::String(None) }
}
Value::String(s) => DbValue::String(Some(Box::new(s))),
Value::Array(arr) => {
// 无通用跨库数组类型:存为 JSON 字符串
let s = serde_json::to_string(&Value::Array(arr))?;
DbValue::String(Some(Box::new(s)))
}
Value::Object(obj) => {
let s = serde_json::to_string(&Value::Object(obj))?;
DbValue::String(Some(Box::new(s)))
}
};
Ok(dv)
}
fn try_get_as_json(row: &sea_orm::QueryResult, col_name: &str) -> Value {
// 该函数在原文件其余部分定义,保持不变
#[allow(unused)]
fn guess_text(bytes: &[u8]) -> Option<String> {
String::from_utf8(bytes.to_vec()).ok()
}
row.try_get::<String>("", col_name)
.map(Value::String)
.or_else(|_| row.try_get::<i64>("", col_name).map(|v| Value::Number(v.into())))
.or_else(|_| row.try_get::<u64>("", col_name).map(|v| Value::Number(v.into())))
.or_else(|_| row.try_get::<f64>("", col_name).map(|v| serde_json::Number::from_f64(v).map(Value::Number).unwrap_or(Value::Null)))
.or_else(|_| row.try_get::<bool>("", col_name).map(Value::Bool))
.or_else(|_| row.try_get::<Vec<u8>>("", col_name).map(|v| guess_text(&v).map(Value::String).unwrap_or(Value::Null)))
.unwrap_or_else(|_| Value::Null)
}