重构流程引擎核心组件,引入执行器接口Executor替代原有TaskComponent,优化节点配置映射逻辑: 1. 新增mappers模块集中处理节点配置提取 2. 为存储层添加Storage trait抽象 3. 移除对ctx魔法字段的依赖,直接传递节点信息 4. 增加构建器模式支持引擎创建 5. 完善DSL解析的输入校验 同时标记部分未使用代码为allow(dead_code)
258 lines
12 KiB
Rust
258 lines
12 KiB
Rust
use async_trait::async_trait;
|
||
use serde_json::{json, Value};
|
||
use tracing::info;
|
||
|
||
use crate::flow::task::Executor;
|
||
use crate::flow::domain::{NodeDef, NodeId};
|
||
|
||
#[derive(Default)]
|
||
pub struct DbTask;
|
||
|
||
#[async_trait]
|
||
impl Executor for DbTask {
|
||
async fn execute(&self, node_id: &NodeId, _node: &NodeDef, ctx: &mut Value) -> anyhow::Result<()> {
|
||
// 1) 读取 db 配置:仅节点级 db,不再回退到全局 ctx.db,避免误用项目数据库
|
||
let node_id_opt = Some(node_id.0.clone());
|
||
let cfg = match (&node_id_opt, ctx.get("nodes")) {
|
||
(Some(node_id), Some(nodes)) => nodes.get(&node_id).and_then(|n| n.get("db")).cloned(),
|
||
_ => None,
|
||
};
|
||
|
||
let Some(cfg) = cfg else {
|
||
info!(target = "udmin.flow", "db task: no config found, skip");
|
||
return Ok(());
|
||
};
|
||
|
||
// 3) 解析配置(包含可选连接信息)
|
||
let (sql, params, output_key, conn, mode_from_db) = parse_db_config(cfg)?;
|
||
// 提前读取结果模式,优先 connection.mode,其次 db.output.mode/db.outputMode/db.mode
|
||
let result_mode = get_result_mode_from_conn(&conn).or(mode_from_db);
|
||
info!(target = "udmin.flow", "db task: exec sql: {}", sql);
|
||
|
||
// 4) 获取连接:必须显式声明 db.connection,禁止回退到项目全局数据库,避免安全风险
|
||
let db: std::borrow::Cow<'_, crate::db::Db>;
|
||
let tmp_conn; // 用于在本作用域内持有临时连接
|
||
use sea_orm::{Statement, ConnectionTrait};
|
||
|
||
let conn_cfg = conn.ok_or_else(|| anyhow::anyhow!("db task: connection config is required (db.connection)"))?;
|
||
// 构造 URL 并建立临时连接
|
||
let url = extract_connection_url(conn_cfg)?;
|
||
use sea_orm::{ConnectOptions, Database};
|
||
use std::time::Duration;
|
||
let mut opt = ConnectOptions::new(url);
|
||
opt.max_connections(20)
|
||
.min_connections(1)
|
||
.connect_timeout(Duration::from_secs(8))
|
||
.idle_timeout(Duration::from_secs(120))
|
||
.sqlx_logging(true);
|
||
tmp_conn = Database::connect(opt).await?;
|
||
db = std::borrow::Cow::Owned(tmp_conn);
|
||
|
||
// 判定是否为 SELECT:简单判断前缀,允许前导空白与括号
|
||
let is_select = {
|
||
let s = sql.trim_start();
|
||
let s = s.trim_start_matches('(');
|
||
s.to_uppercase().starts_with("SELECT")
|
||
};
|
||
|
||
// 构建参数列表(支持位置和命名两种形式)
|
||
let params_vec: Vec<sea_orm::Value> = match params {
|
||
None => vec![],
|
||
Some(Value::Array(arr)) => arr.into_iter().map(json_to_db_value).collect::<anyhow::Result<_>>()?,
|
||
Some(Value::Object(obj)) => {
|
||
// 对命名参数对象,保持插入顺序不可控,这里仅将值收集为位置绑定,建议 SQL 使用 `?` 占位
|
||
obj.into_iter().map(|(_, v)| json_to_db_value(v)).collect::<anyhow::Result<_>>()?
|
||
}
|
||
Some(v) => {
|
||
// 其它类型:当作单个位置参数
|
||
vec![json_to_db_value(v)?]
|
||
}
|
||
};
|
||
|
||
let stmt = Statement::from_sql_and_values(db.get_database_backend(), &sql, params_vec);
|
||
|
||
let result = if is_select {
|
||
let rows = db.query_all(stmt).await?;
|
||
// 将 QueryResult 转换为 JSON 数组
|
||
let mut out = Vec::with_capacity(rows.len());
|
||
for row in rows {
|
||
let mut obj = serde_json::Map::new();
|
||
// 读取列名列表
|
||
let cols = row.column_names();
|
||
for col_name in cols.iter() {
|
||
let key = col_name.to_string();
|
||
// 尝试以通用 JSON 值提取(优先字符串、数值、布尔、二进制、null)
|
||
let val = try_get_as_json(&row, &key);
|
||
obj.insert(key, val);
|
||
}
|
||
out.push(Value::Object(obj));
|
||
}
|
||
// 默认 rows 模式:直接返回数组
|
||
match result_mode.as_deref() {
|
||
// 返回首行字段对象(无则 Null)
|
||
Some("fields") | Some("first") => {
|
||
if let Some(Value::Object(m)) = out.get(0) { Value::Object(m.clone()) } else { Value::Null }
|
||
}
|
||
// 默认与显式 rows 都返回数组
|
||
_ => Value::Array(out),
|
||
}
|
||
} else {
|
||
let exec = db.execute(stmt).await?;
|
||
// 非 SELECT 默认返回受影响行数
|
||
match result_mode.as_deref() {
|
||
// 如显式要求 rows,则返回空数组
|
||
Some("rows") => json!([]),
|
||
_ => json!(exec.rows_affected()),
|
||
}
|
||
};
|
||
|
||
// 5) 写回 ctx(并对敏感信息脱敏)
|
||
let write_key = output_key.unwrap_or_else(|| "db_response".to_string());
|
||
if let (Some(node_id), Some(obj)) = (node_id_opt, ctx.as_object_mut()) {
|
||
if let Some(nodes) = obj.get_mut("nodes").and_then(|v| v.as_object_mut()) {
|
||
if let Some(target) = nodes.get_mut(&node_id).and_then(|v| v.as_object_mut()) {
|
||
// 写入结果
|
||
target.insert(write_key, result);
|
||
// 对密码字段脱敏(保留其它配置不变)
|
||
if let Some(dbv) = target.get_mut("db") {
|
||
if let Some(dbo) = dbv.as_object_mut() {
|
||
if let Some(connv) = dbo.get_mut("connection") {
|
||
match connv {
|
||
Value::Object(m) => {
|
||
if let Some(pw) = m.get_mut("password") {
|
||
*pw = Value::String("***".to_string());
|
||
}
|
||
if let Some(Value::String(url)) = m.get_mut("url") {
|
||
*url = "***".to_string();
|
||
}
|
||
}
|
||
Value::String(s) => { *s = "***".to_string(); }
|
||
_ => {}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
return Ok(());
|
||
}
|
||
}
|
||
}
|
||
if let Value::Object(map) = ctx { map.insert(write_key, result); }
|
||
Ok(())
|
||
}
|
||
}
|
||
|
||
fn parse_db_config(cfg: Value) -> anyhow::Result<(String, Option<Value>, Option<String>, Option<Value>, Option<String>)> {
|
||
match cfg {
|
||
Value::String(sql) => Ok((sql, None, None, None, None)),
|
||
Value::Object(mut m) => {
|
||
let sql = m
|
||
.remove("sql")
|
||
.and_then(|v| v.as_str().map(|s| s.to_string()))
|
||
.ok_or_else(|| anyhow::anyhow!("db config missing sql"))?;
|
||
let params = m.remove("params");
|
||
let output_key = m.remove("outputKey").and_then(|v| v.as_str().map(|s| s.to_string()));
|
||
// 在移除 connection 前,从 db 层读取可能的输出模式
|
||
let mode_from_db = {
|
||
// db.output.mode
|
||
let from_output = m.get("output").and_then(|v| v.as_object()).and_then(|o| o.get("mode")).and_then(|v| v.as_str()).map(|s| s.to_string());
|
||
// db.outputMode 或 db.mode
|
||
let from_flat = m.get("outputMode").and_then(|v| v.as_str()).map(|s| s.to_string())
|
||
.or_else(|| m.get("mode").and_then(|v| v.as_str()).map(|s| s.to_string()));
|
||
from_output.or(from_flat)
|
||
};
|
||
let conn = m.remove("connection");
|
||
// 安全策略:必须显式声明连接,禁止默认落到全局数据库
|
||
if conn.is_none() {
|
||
return Err(anyhow::anyhow!("db config missing connection (db.connection is required)"));
|
||
}
|
||
Ok((sql, params, output_key, conn, mode_from_db))
|
||
}
|
||
_ => Err(anyhow::anyhow!("invalid db config")),
|
||
}
|
||
}
|
||
|
||
fn extract_connection_url(cfg: Value) -> anyhow::Result<String> {
|
||
match cfg {
|
||
Value::String(url) => Ok(url),
|
||
Value::Object(mut m) => {
|
||
if let Some(url) = m.remove("url").and_then(|v| v.as_str().map(|s| s.to_string())) {
|
||
return Ok(url);
|
||
}
|
||
let driver = m
|
||
.remove("driver")
|
||
.and_then(|v| v.as_str().map(|s| s.to_string()))
|
||
.unwrap_or_else(|| "mysql".to_string());
|
||
// sqlite 特殊处理:仅需要 database(文件路径或 :memory:)
|
||
if driver == "sqlite" {
|
||
let database = m.remove("database").and_then(|v| v.as_str().map(|s| s.to_string())).ok_or_else(|| anyhow::anyhow!("connection.database is required for sqlite unless url provided"))?;
|
||
return Ok(format!("sqlite://{}", database));
|
||
}
|
||
|
||
let host = m.remove("host").and_then(|v| v.as_str().map(|s| s.to_string())).unwrap_or_else(|| "localhost".to_string());
|
||
let port = m.remove("port").map(|v| match v { Value::Number(n) => n.to_string(), Value::String(s) => s, _ => String::new() });
|
||
let database = m.remove("database").and_then(|v| v.as_str().map(|s| s.to_string())).ok_or_else(|| anyhow::anyhow!("connection.database is required unless url provided"))?;
|
||
let username = m.remove("username").and_then(|v| v.as_str().map(|s| s.to_string())).ok_or_else(|| anyhow::anyhow!("connection.username is required unless url provided"))?;
|
||
let password = m.remove("password").and_then(|v| v.as_str().map(|s| s.to_string())).unwrap_or_default();
|
||
let port_part = port.filter(|s| !s.is_empty()).map(|s| format!(":{}", s)).unwrap_or_default();
|
||
let url = format!(
|
||
"{}://{}:{}@{}{}{}",
|
||
driver,
|
||
percent_encoding::utf8_percent_encode(&username, percent_encoding::NON_ALPHANUMERIC),
|
||
percent_encoding::utf8_percent_encode(&password, percent_encoding::NON_ALPHANUMERIC),
|
||
host,
|
||
port_part,
|
||
format!("/{}", database)
|
||
);
|
||
Ok(url)
|
||
}
|
||
_ => Err(anyhow::anyhow!("invalid connection config")),
|
||
}
|
||
}
|
||
|
||
fn get_result_mode_from_conn(conn: &Option<Value>) -> Option<String> {
|
||
match conn {
|
||
Some(Value::Object(m)) => m.get("mode").and_then(|v| v.as_str()).map(|s| s.to_string()),
|
||
_ => None,
|
||
}
|
||
}
|
||
|
||
fn json_to_db_value(v: Value) -> anyhow::Result<sea_orm::Value> {
|
||
use sea_orm::Value as DbValue;
|
||
let dv = match v {
|
||
Value::Null => DbValue::String(None),
|
||
Value::Bool(b) => DbValue::Bool(Some(b)),
|
||
Value::Number(n) => {
|
||
if let Some(i) = n.as_i64() { DbValue::BigInt(Some(i)) }
|
||
else if let Some(u) = n.as_u64() { DbValue::BigUnsigned(Some(u)) }
|
||
else if let Some(f) = n.as_f64() { DbValue::Double(Some(f)) }
|
||
else { DbValue::String(None) }
|
||
}
|
||
Value::String(s) => DbValue::String(Some(Box::new(s))),
|
||
Value::Array(arr) => {
|
||
// 无通用跨库数组类型:存为 JSON 字符串
|
||
let s = serde_json::to_string(&Value::Array(arr))?;
|
||
DbValue::String(Some(Box::new(s)))
|
||
}
|
||
Value::Object(obj) => {
|
||
let s = serde_json::to_string(&Value::Object(obj))?;
|
||
DbValue::String(Some(Box::new(s)))
|
||
}
|
||
};
|
||
Ok(dv)
|
||
}
|
||
|
||
fn try_get_as_json(row: &sea_orm::QueryResult, col_name: &str) -> Value {
|
||
// 该函数在原文件其余部分定义,保持不变
|
||
#[allow(unused)]
|
||
fn guess_text(bytes: &[u8]) -> Option<String> {
|
||
String::from_utf8(bytes.to_vec()).ok()
|
||
}
|
||
row.try_get::<String>("", col_name)
|
||
.map(Value::String)
|
||
.or_else(|_| row.try_get::<i64>("", col_name).map(|v| Value::Number(v.into())))
|
||
.or_else(|_| row.try_get::<u64>("", col_name).map(|v| Value::Number(v.into())))
|
||
.or_else(|_| row.try_get::<f64>("", col_name).map(|v| serde_json::Number::from_f64(v).map(Value::Number).unwrap_or(Value::Null)))
|
||
.or_else(|_| row.try_get::<bool>("", col_name).map(Value::Bool))
|
||
.or_else(|_| row.try_get::<Vec<u8>>("", col_name).map(|v| guess_text(&v).map(Value::String).unwrap_or(Value::Null)))
|
||
.unwrap_or_else(|_| Value::Null)
|
||
} |