use async_trait::async_trait; use serde_json::{json, Value}; use tracing::info; use crate::flow::task::Executor; use crate::flow::domain::{NodeDef, NodeId}; #[derive(Default)] pub struct DbTask; #[async_trait] impl Executor for DbTask { async fn execute(&self, node_id: &NodeId, _node: &NodeDef, ctx: &mut Value) -> anyhow::Result<()> { // 1) 读取 db 配置:仅节点级 db,不再回退到全局 ctx.db,避免误用项目数据库 let node_id_opt = Some(node_id.0.clone()); let cfg = match (&node_id_opt, ctx.get("nodes")) { (Some(node_id), Some(nodes)) => nodes.get(&node_id).and_then(|n| n.get("db")).cloned(), _ => None, }; let Some(cfg) = cfg else { info!(target = "udmin.flow", "db task: no config found, skip"); return Ok(()); }; // 3) 解析配置(包含可选连接信息) let (sql, params, output_key, conn, mode_from_db) = parse_db_config(cfg)?; // 提前读取结果模式,优先 connection.mode,其次 db.output.mode/db.outputMode/db.mode let result_mode = get_result_mode_from_conn(&conn).or(mode_from_db); info!(target = "udmin.flow", "db task: exec sql: {}", sql); // 4) 获取连接:必须显式声明 db.connection,禁止回退到项目全局数据库,避免安全风险 let db: std::borrow::Cow<'_, crate::db::Db>; let tmp_conn; // 用于在本作用域内持有临时连接 use sea_orm::{Statement, ConnectionTrait}; let conn_cfg = conn.ok_or_else(|| anyhow::anyhow!("db task: connection config is required (db.connection)"))?; // 构造 URL 并建立临时连接 let url = extract_connection_url(conn_cfg)?; use sea_orm::{ConnectOptions, Database}; use std::time::Duration; let mut opt = ConnectOptions::new(url); opt.max_connections(20) .min_connections(1) .connect_timeout(Duration::from_secs(8)) .idle_timeout(Duration::from_secs(120)) .sqlx_logging(true); tmp_conn = Database::connect(opt).await?; db = std::borrow::Cow::Owned(tmp_conn); // 判定是否为 SELECT:简单判断前缀,允许前导空白与括号 let is_select = { let s = sql.trim_start(); let s = s.trim_start_matches('('); s.to_uppercase().starts_with("SELECT") }; // 构建参数列表(支持位置和命名两种形式) let params_vec: Vec = match params { None => vec![], Some(Value::Array(arr)) => arr.into_iter().map(json_to_db_value).collect::>()?, Some(Value::Object(obj)) => { // 对命名参数对象,保持插入顺序不可控,这里仅将值收集为位置绑定,建议 SQL 使用 `?` 占位 obj.into_iter().map(|(_, v)| json_to_db_value(v)).collect::>()? } Some(v) => { // 其它类型:当作单个位置参数 vec![json_to_db_value(v)?] } }; let stmt = Statement::from_sql_and_values(db.get_database_backend(), &sql, params_vec); let result = if is_select { let rows = db.query_all(stmt).await?; // 将 QueryResult 转换为 JSON 数组 let mut out = Vec::with_capacity(rows.len()); for row in rows { let mut obj = serde_json::Map::new(); // 读取列名列表 let cols = row.column_names(); for col_name in cols.iter() { let key = col_name.to_string(); // 尝试以通用 JSON 值提取(优先字符串、数值、布尔、二进制、null) let val = try_get_as_json(&row, &key); obj.insert(key, val); } out.push(Value::Object(obj)); } // 默认 rows 模式:直接返回数组 match result_mode.as_deref() { // 返回首行字段对象(无则 Null) Some("fields") | Some("first") => { if let Some(Value::Object(m)) = out.get(0) { Value::Object(m.clone()) } else { Value::Null } } // 默认与显式 rows 都返回数组 _ => Value::Array(out), } } else { let exec = db.execute(stmt).await?; // 非 SELECT 默认返回受影响行数 match result_mode.as_deref() { // 如显式要求 rows,则返回空数组 Some("rows") => json!([]), _ => json!(exec.rows_affected()), } }; // 5) 写回 ctx(并对敏感信息脱敏) let write_key = output_key.unwrap_or_else(|| "db_response".to_string()); if let (Some(node_id), Some(obj)) = (node_id_opt, ctx.as_object_mut()) { if let Some(nodes) = obj.get_mut("nodes").and_then(|v| v.as_object_mut()) { if let Some(target) = nodes.get_mut(&node_id).and_then(|v| v.as_object_mut()) { // 写入结果 target.insert(write_key, result); // 对密码字段脱敏(保留其它配置不变) if let Some(dbv) = target.get_mut("db") { if let Some(dbo) = dbv.as_object_mut() { if let Some(connv) = dbo.get_mut("connection") { match connv { Value::Object(m) => { if let Some(pw) = m.get_mut("password") { *pw = Value::String("***".to_string()); } if let Some(Value::String(url)) = m.get_mut("url") { *url = "***".to_string(); } } Value::String(s) => { *s = "***".to_string(); } _ => {} } } } } return Ok(()); } } } if let Value::Object(map) = ctx { map.insert(write_key, result); } Ok(()) } } fn parse_db_config(cfg: Value) -> anyhow::Result<(String, Option, Option, Option, Option)> { match cfg { Value::String(sql) => Ok((sql, None, None, None, None)), Value::Object(mut m) => { let sql = m .remove("sql") .and_then(|v| v.as_str().map(|s| s.to_string())) .ok_or_else(|| anyhow::anyhow!("db config missing sql"))?; let params = m.remove("params"); let output_key = m.remove("outputKey").and_then(|v| v.as_str().map(|s| s.to_string())); // 在移除 connection 前,从 db 层读取可能的输出模式 let mode_from_db = { // db.output.mode let from_output = m.get("output").and_then(|v| v.as_object()).and_then(|o| o.get("mode")).and_then(|v| v.as_str()).map(|s| s.to_string()); // db.outputMode 或 db.mode let from_flat = m.get("outputMode").and_then(|v| v.as_str()).map(|s| s.to_string()) .or_else(|| m.get("mode").and_then(|v| v.as_str()).map(|s| s.to_string())); from_output.or(from_flat) }; let conn = m.remove("connection"); // 安全策略:必须显式声明连接,禁止默认落到全局数据库 if conn.is_none() { return Err(anyhow::anyhow!("db config missing connection (db.connection is required)")); } Ok((sql, params, output_key, conn, mode_from_db)) } _ => Err(anyhow::anyhow!("invalid db config")), } } fn extract_connection_url(cfg: Value) -> anyhow::Result { match cfg { Value::String(url) => Ok(url), Value::Object(mut m) => { if let Some(url) = m.remove("url").and_then(|v| v.as_str().map(|s| s.to_string())) { return Ok(url); } let driver = m .remove("driver") .and_then(|v| v.as_str().map(|s| s.to_string())) .unwrap_or_else(|| "mysql".to_string()); // sqlite 特殊处理:仅需要 database(文件路径或 :memory:) if driver == "sqlite" { let database = m.remove("database").and_then(|v| v.as_str().map(|s| s.to_string())).ok_or_else(|| anyhow::anyhow!("connection.database is required for sqlite unless url provided"))?; return Ok(format!("sqlite://{}", database)); } let host = m.remove("host").and_then(|v| v.as_str().map(|s| s.to_string())).unwrap_or_else(|| "localhost".to_string()); let port = m.remove("port").map(|v| match v { Value::Number(n) => n.to_string(), Value::String(s) => s, _ => String::new() }); let database = m.remove("database").and_then(|v| v.as_str().map(|s| s.to_string())).ok_or_else(|| anyhow::anyhow!("connection.database is required unless url provided"))?; let username = m.remove("username").and_then(|v| v.as_str().map(|s| s.to_string())).ok_or_else(|| anyhow::anyhow!("connection.username is required unless url provided"))?; let password = m.remove("password").and_then(|v| v.as_str().map(|s| s.to_string())).unwrap_or_default(); let port_part = port.filter(|s| !s.is_empty()).map(|s| format!(":{}", s)).unwrap_or_default(); let url = format!( "{}://{}:{}@{}{}{}", driver, percent_encoding::utf8_percent_encode(&username, percent_encoding::NON_ALPHANUMERIC), percent_encoding::utf8_percent_encode(&password, percent_encoding::NON_ALPHANUMERIC), host, port_part, format!("/{}", database) ); Ok(url) } _ => Err(anyhow::anyhow!("invalid connection config")), } } fn get_result_mode_from_conn(conn: &Option) -> Option { match conn { Some(Value::Object(m)) => m.get("mode").and_then(|v| v.as_str()).map(|s| s.to_string()), _ => None, } } fn json_to_db_value(v: Value) -> anyhow::Result { use sea_orm::Value as DbValue; let dv = match v { Value::Null => DbValue::String(None), Value::Bool(b) => DbValue::Bool(Some(b)), Value::Number(n) => { if let Some(i) = n.as_i64() { DbValue::BigInt(Some(i)) } else if let Some(u) = n.as_u64() { DbValue::BigUnsigned(Some(u)) } else if let Some(f) = n.as_f64() { DbValue::Double(Some(f)) } else { DbValue::String(None) } } Value::String(s) => DbValue::String(Some(Box::new(s))), Value::Array(arr) => { // 无通用跨库数组类型:存为 JSON 字符串 let s = serde_json::to_string(&Value::Array(arr))?; DbValue::String(Some(Box::new(s))) } Value::Object(obj) => { let s = serde_json::to_string(&Value::Object(obj))?; DbValue::String(Some(Box::new(s))) } }; Ok(dv) } fn try_get_as_json(row: &sea_orm::QueryResult, col_name: &str) -> Value { // 该函数在原文件其余部分定义,保持不变 #[allow(unused)] fn guess_text(bytes: &[u8]) -> Option { String::from_utf8(bytes.to_vec()).ok() } row.try_get::("", col_name) .map(Value::String) .or_else(|_| row.try_get::("", col_name).map(|v| Value::Number(v.into()))) .or_else(|_| row.try_get::("", col_name).map(|v| Value::Number(v.into()))) .or_else(|_| row.try_get::("", col_name).map(|v| serde_json::Number::from_f64(v).map(Value::Number).unwrap_or(Value::Null))) .or_else(|_| row.try_get::("", col_name).map(Value::Bool)) .or_else(|_| row.try_get::>("", col_name).map(|v| guess_text(&v).map(Value::String).unwrap_or(Value::Null))) .unwrap_or_else(|_| Value::Null) }