feat(ws): 新增WebSocket实时通信支持与SSE独立服务

重构中间件结构,新增ws模块实现WebSocket流程执行实时推送
将SSE服务拆分为独立端口监听,默认8866
优化前端流式模式切换,支持WS/SSE协议选择
统一流式事件处理逻辑,完善错误处理与取消机制
更新Cargo.toml依赖,添加WebSocket相关库
调整代码组织结构,规范导入分组与注释
This commit is contained in:
2025-09-21 22:15:33 +08:00
parent dd7857940f
commit 30716686ed
23 changed files with 805 additions and 101 deletions

View File

@ -7,7 +7,7 @@ JWT_SECRET=dev_secret_change_me
JWT_ISS=udmin
JWT_ACCESS_EXP_SECS=1800
JWT_REFRESH_EXP_SECS=1209600
CORS_ALLOW_ORIGINS=http://localhost:5173,http://localhost:5174,http://localhost:5175
CORS_ALLOW_ORIGINS=http://localhost:5173,http://localhost:5174,http://localhost:5175,http://127.0.0.1:5173,http://127.0.0.1:5174,http://127.0.0.1:5175,http://localhost:8888,http://127.0.0.1:8888
# Redis配置
REDIS_URL=redis://:123456@127.0.0.1:6379/9

44
backend/Cargo.lock generated
View File

@ -244,6 +244,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "021e862c184ae977658b36c4500f7feac3221ca5da43e3f25bd04ab6c79a29b5"
dependencies = [
"axum-core",
"base64 0.22.1",
"bytes",
"form_urlencoded",
"futures-util",
@ -263,8 +264,10 @@ dependencies = [
"serde_json",
"serde_path_to_error",
"serde_urlencoded",
"sha1",
"sync_wrapper",
"tokio",
"tokio-tungstenite",
"tower",
"tower-layer",
"tower-service",
@ -724,6 +727,12 @@ dependencies = [
"syn 2.0.106",
]
[[package]]
name = "data-encoding"
version = "2.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2a2330da5de22e8a3cb63252ce2abb30116bf5265e89c0e01bc17015ce30a476"
[[package]]
name = "deadpool"
version = "0.12.3"
@ -3654,6 +3663,18 @@ dependencies = [
"tokio",
]
[[package]]
name = "tokio-tungstenite"
version = "0.26.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a9daff607c6d2bf6c16fd681ccb7eecc83e4e2cdc1ca067ffaadfca5de7f084"
dependencies = [
"futures-util",
"log",
"tokio",
"tungstenite",
]
[[package]]
name = "tokio-util"
version = "0.7.16"
@ -3830,6 +3851,23 @@ version = "0.2.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b"
[[package]]
name = "tungstenite"
version = "0.26.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4793cb5e56680ecbb1d843515b23b6de9a75eb04b66643e256a396d43be33c13"
dependencies = [
"bytes",
"data-encoding",
"http",
"httparse",
"log",
"rand 0.9.2",
"sha1",
"thiserror",
"utf-8",
]
[[package]]
name = "typeid"
version = "1.0.3"
@ -3954,6 +3992,12 @@ dependencies = [
"percent-encoding",
]
[[package]]
name = "utf-8"
version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9"
[[package]]
name = "utf8_iter"
version = "1.0.4"

View File

@ -5,7 +5,7 @@ edition = "2024"
default-run = "udmin"
[dependencies]
axum = "0.8.4"
axum = { version = "0.8.4", features = ["ws"] }
tokio = { version = "1.47.1", features = ["full"] }
tower = "0.5.2"
tower-http = { version = "0.6.6", features = ["cors", "trace"] }

View File

@ -1,15 +1,40 @@
// std
use std::cell::RefCell;
use std::collections::HashMap;
use tokio::sync::{RwLock, Mutex};
use futures::future::join_all;
use rhai::Engine;
use tracing::info;
use std::time::Instant;
// === 表达式评估支持thread_local 引擎与 AST 缓存,避免全局 Sync/Send 限制 ===
use std::cell::RefCell;
use rhai::AST;
// third-party
use futures::future::join_all;
use regex::Regex;
use rhai::{AST, Engine};
use tokio::sync::{Mutex, RwLock};
use tracing::info;
// crate
use crate::flow::executors::condition::eval_condition_json;
// super
use super::{
context::{DriveOptions, ExecutionMode},
domain::{ChainDef, NodeKind},
task::TaskRegistry,
};
// 结构体:紧随 use
pub struct FlowEngine {
pub tasks: TaskRegistry,
}
#[derive(Debug, Clone)]
pub struct DriveError {
pub node_id: String,
pub ctx: serde_json::Value,
pub logs: Vec<String>,
pub message: String,
}
// === 表达式评估支持thread_local 引擎与 AST 缓存,避免全局 Sync/Send 限制 ===
// 模块流程执行引擎engine.rs
// 作用:驱动 ChainDef 流程图,支持:
@ -132,12 +157,6 @@ pub(crate) fn eval_rhai_expr_json(expr: &str, ctx: &serde_json::Value) -> Option
Err(_) => None,
}
}
use super::{context::{DriveOptions, ExecutionMode}, domain::{ChainDef, NodeKind}, task::TaskRegistry};
use crate::flow::executors::condition::eval_condition_json;
pub struct FlowEngine {
pub tasks: TaskRegistry,
}
impl FlowEngine {
pub fn new(tasks: TaskRegistry) -> Self { Self { tasks } }
@ -491,14 +510,6 @@ impl Default for FlowEngine {
}
#[derive(Debug, Clone)]
pub struct DriveError {
pub node_id: String,
pub ctx: serde_json::Value,
pub logs: Vec<String>,
pub message: String,
}
impl std::fmt::Display for DriveError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.message)

View File

@ -1,7 +1,9 @@
// third-party
use anyhow::Result;
use serde_json::Value as V;
use tracing::info;
// 业务函数
pub(crate) fn eval_condition_json(ctx: &serde_json::Value, cond: &serde_json::Value) -> Result<bool> {
// 新增:若 cond 为数组,按 AND 语义评估(全部为 true 才为 true
if let Some(arr) = cond.as_array() {

View File

@ -1,9 +1,11 @@
// third-party
use async_trait::async_trait;
use serde_json::{json, Value};
use tracing::info;
use crate::flow::task::Executor;
// crate
use crate::flow::domain::{NodeDef, NodeId};
use crate::flow::task::Executor;
#[derive(Default)]
pub struct DbTask;

View File

@ -1,12 +1,17 @@
use async_trait::async_trait;
use serde_json::{Value, json, Map};
use tracing::info;
// std
use std::collections::HashMap;
// third-party
use async_trait::async_trait;
use serde_json::{json, Map, Value};
use tracing::info;
// crate
use crate::flow::domain::{NodeDef, NodeId};
use crate::flow::task::Executor;
use crate::middlewares::http_client::{execute_http, HttpClientOptions, HttpRequest};
use crate::flow::task::Executor;
use crate::flow::domain::{NodeDef, NodeId};
// 结构体:紧随 use
#[derive(Default)]
pub struct HttpTask;
@ -18,6 +23,7 @@ struct HttpOpts {
http1_only: bool,
}
// 业务实现与函数:置于最后
#[async_trait]
impl Executor for HttpTask {
async fn execute(&self, node_id: &NodeId, _node: &NodeDef, ctx: &mut Value) -> anyhow::Result<()> {

View File

@ -1,12 +1,18 @@
use async_trait::async_trait
;
// std
use std::fs;
use std::time::Instant;
// third-party
use async_trait::async_trait;
use serde_json::Value;
use tracing::{debug, info};
use std::time::Instant;
use std::fs;
use crate::flow::task::Executor;
// crate
use crate::flow::domain::{NodeDef, NodeId};
use crate::flow::task::Executor;
#[derive(Default)]
pub struct ScriptJsTask;
fn read_node_script_file(ctx: &Value, node_id: &str, lang_key: &str) -> Option<String> {
if let Some(nodes) = ctx.get("nodes").and_then(|v| v.as_object()) {
@ -123,9 +129,6 @@ fn exec_js_file(node_id: &NodeId, path: &str, ctx: &mut Value) -> anyhow::Result
exec_js_script(node_id, &code, ctx)
}
#[derive(Default)]
pub struct ScriptJsTask;
#[async_trait]
impl Executor for ScriptJsTask {
async fn execute(&self, node_id: &NodeId, _node: &NodeDef, ctx: &mut Value) -> anyhow::Result<()> {

View File

@ -1,10 +1,17 @@
// std
use std::time::Instant;
// third-party
use async_trait::async_trait;
use serde_json::Value;
use tracing::{debug, info};
use std::time::Instant;
use crate::flow::task::Executor;
// crate
use crate::flow::domain::{NodeDef, NodeId};
use crate::flow::task::Executor;
#[derive(Default)]
pub struct ScriptPythonTask;
fn read_node_script_file(ctx: &Value, node_id: &str, lang_key: &str) -> Option<String> {
if let Some(nodes) = ctx.get("nodes").and_then(|v| v.as_object()) {
@ -20,9 +27,6 @@ fn truncate_str(s: &str, max: usize) -> String {
if s.len() <= max { s } else { format!("{}", &s[..max]) }
}
#[derive(Default)]
pub struct ScriptPythonTask;
#[async_trait]
impl Executor for ScriptPythonTask {
async fn execute(&self, node_id: &NodeId, _node: &NodeDef, ctx: &mut Value) -> anyhow::Result<()> {

View File

@ -1,13 +1,19 @@
use serde_json::Value;
use tracing::{debug, info};
// std
use std::fs;
use std::time::Instant;
use crate::flow::domain::NodeId;
// third-party
use async_trait::async_trait;
use serde_json::Value;
use tracing::{debug, info};
// crate
use crate::flow::domain::{NodeDef, NodeId};
use crate::flow::engine::eval_rhai_expr_json;
use crate::flow::task::Executor;
use crate::flow::domain::NodeDef;
use async_trait::async_trait;
#[derive(Default)]
pub struct ScriptRhaiTask;
fn truncate_str(s: &str, max: usize) -> String {
let s = s.replace('\n', " ").replace('\r', " ");
@ -80,9 +86,6 @@ fn read_node_script_file(ctx: &Value, node_id: &str) -> Option<String> {
None
}
#[derive(Default)]
pub struct ScriptRhaiTask;
#[async_trait]
impl Executor for ScriptRhaiTask {
async fn execute(&self, node_id: &NodeId, _node: &NodeDef, ctx: &mut Value) -> anyhow::Result<()> {

View File

@ -1,10 +1,12 @@
// third-party
use async_trait::async_trait;
use serde_json::{Value, json};
use tracing::info;
use crate::flow::task::Executor;
// crate
use crate::flow::domain::{NodeDef, NodeId};
use crate::flow::engine::eval_rhai_expr_json;
use crate::flow::task::Executor;
#[derive(Default)]
pub struct VariableTask;

View File

@ -17,9 +17,9 @@ pub trait FlowLogHandler: Send + Sync {
async fn log_error(&self, flow_id: &str, flow_code: Option<&str>, input: &Value, error_msg: &str, operator: Option<(i64, String)>, started_at: DateTime<FixedOffset>, duration_ms: i64) -> anyhow::Result<()>;
/// 记录流程执行失败(包含部分输出与累计日志)
async fn log_error_detail(&self, flow_id: &str, flow_code: Option<&str>, input: &Value, output: &Value, logs: &[String], error_msg: &str, operator: Option<(i64, String)>, started_at: DateTime<FixedOffset>, duration_ms: i64) -> anyhow::Result<()> {
async fn log_error_detail(&self, _flow_id: &str, _flow_code: Option<&str>, _input: &Value, _output: &Value, _logs: &[String], error_msg: &str, _operator: Option<(i64, String)>, _started_at: DateTime<FixedOffset>, _duration_ms: i64) -> anyhow::Result<()> {
// 默认实现:退化为仅错误信息
self.log_error(flow_id, flow_code, input, error_msg, operator, started_at, duration_ms).await
self.log_error(_flow_id, _flow_code, _input, error_msg, _operator, _started_at, _duration_ms).await
}
/// 记录流程执行成功

View File

@ -118,8 +118,18 @@ async fn main() -> anyhow::Result<()> {
if let Some(f) = &env_file_used { tracing::info!("env file loaded: {}", f); } else { tracing::info!("env file loaded: <none>"); }
tracing::info!("resolved APP_HOST={} APP_PORT={}", app_host, app_port);
let addr = format!("{}:{}", app_host, app_port);
tracing::info!("listening on {}", addr);
axum::serve(tokio::net::TcpListener::bind(addr).await?, app).await?;
let http_addr = format!("{}:{}", app_host, app_port);
tracing::info!("listening on {}", http_addr);
// HTTP 服务监听
let http_listener = tokio::net::TcpListener::bind(http_addr.clone()).await?;
let http_server = axum::serve(http_listener, app);
// WS 服务下沉到中间件
let ws_server = middlewares::ws::serve(db.clone());
// 新增SSE 服务独立端口监听(默认 8866可配 SSE_HOST/SSE_PORT
let sse_server = middlewares::sse::serve(db.clone());
tokio::try_join!(http_server, ws_server, sse_server)?;
Ok(())
}

View File

@ -1,4 +1,6 @@
pub mod jwt;
pub mod logging;
pub mod sse;
pub mod http_client;
pub mod http_client;
pub mod ws;
// removed: pub mod sse_server;

View File

@ -72,4 +72,133 @@ pub async fn emit_error(
message: msg,
})
.await;
}
// 追加所需 imports避免重复导入 tracing::info 和 chrono::Utc
use axum::{Router, middleware, routing::post, extract::{State, Path, Query}, Json};
use axum::http::HeaderMap;
use tokio::net::TcpListener;
use std::collections::HashMap;
use crate::db::Db;
use crate::error::AppError;
use crate::middlewares;
use crate::middlewares::jwt::decode_token;
use crate::services::flow_service;
use tower_http::cors::{CorsLayer, Any, AllowOrigin};
use axum::http::{HeaderValue, Method};
// 构建仅包含 SSE 路由与通用日志中间件的 Router参照 ws::build_ws_app
pub fn build_sse_app(db: Db) -> Router {
// 组装 CORS与主服务保持一致的允许来源与头方法
let allow_origins = std::env::var("CORS_ALLOW_ORIGINS").unwrap_or_else(|_| "http://localhost:5173".into());
let origin_values: Vec<HeaderValue> = allow_origins
.split(',')
.filter_map(|s| HeaderValue::from_str(s.trim()).ok())
.collect();
let allowed_methods = [
Method::GET,
Method::POST,
Method::PUT,
Method::PATCH,
Method::DELETE,
Method::OPTIONS,
];
let cors = if origin_values.is_empty() {
CorsLayer::new()
.allow_origin(Any)
.allow_methods(allowed_methods.clone())
.allow_headers([
axum::http::header::ACCEPT,
axum::http::header::CONTENT_TYPE,
axum::http::header::AUTHORIZATION,
])
.allow_credentials(false)
} else {
CorsLayer::new()
.allow_origin(AllowOrigin::list(origin_values))
.allow_methods(allowed_methods)
.allow_headers([
axum::http::header::ACCEPT,
axum::http::header::CONTENT_TYPE,
axum::http::header::AUTHORIZATION,
])
.allow_credentials(true)
};
Router::new()
.nest(
"/api",
Router::new().route("/flows/{id}/run/stream", post(run_sse)),
)
.with_state(db.clone())
.layer(cors)
.layer(middleware::from_fn_with_state(db, middlewares::logging::request_logger))
}
// 启动 SSE 服务,默认端口 8866可通过 SSE_HOST/SSE_PORT 覆盖SSE_HOST 回退 APP_HOST
pub async fn serve(db: Db) -> Result<(), std::io::Error> {
let host = std::env::var("SSE_HOST")
.ok()
.or_else(|| std::env::var("APP_HOST").ok())
.unwrap_or_else(|| "0.0.0.0".into());
let port = std::env::var("SSE_PORT").unwrap_or_else(|_| "8866".into());
let addr = format!("{}:{}", host, port);
tracing::info!("sse listening on {}", addr);
let listener = TcpListener::bind(addr).await?;
axum::serve(listener, build_sse_app(db)).await
}
// SSE 路由处理:参考 WebSocket 的鉴权与调用方式
#[derive(serde::Deserialize)]
struct RunReq { #[serde(default)] input: serde_json::Value }
async fn run_sse(
State(db): State<Db>,
Path(id): Path<String>,
Query(q): Query<HashMap<String, String>>,
headers: HeaderMap,
Json(req): Json<RunReq>,
) -> Result<axum::response::sse::Sse<impl futures::Stream<Item = Result<axum::response::sse::Event, std::convert::Infallible>>>, AppError> {
use axum::http::header::AUTHORIZATION;
// 1) 认证:优先 Authorization其次查询参数 access_token
let token_opt = headers
.get(AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.strip_prefix("Bearer "))
.map(|s| s.to_string())
.or_else(|| q.get("access_token").cloned());
let token = token_opt.ok_or(AppError::Unauthorized)?;
let secret = std::env::var("JWT_SECRET").map_err(|_| AppError::Unauthorized)?;
let claims = decode_token(&token, &secret)?;
if claims.typ != "access" { return Err(AppError::Unauthorized); }
// 可选Redis 二次校验(与 WS 一致)
let redis_validation_enabled = std::env::var("REDIS_TOKEN_VALIDATION")
.unwrap_or_else(|_| "true".to_string())
.parse::<bool>().unwrap_or(true);
if redis_validation_enabled {
let is_valid = crate::redis::TokenRedis::validate_access_token(&token, claims.uid).await.unwrap_or(false);
if !is_valid { return Err(AppError::Unauthorized); }
}
// 建立 mpsc 通道用于接收引擎的流式事件
let (tx, rx) = tokio::sync::mpsc::channel::<crate::flow::context::StreamEvent>(16);
// 启动后台任务运行流程,将事件通过 tx 发送
let db_clone = db.clone();
let id_clone = id.clone();
let input = req.input.clone();
let user_info = Some((claims.uid, claims.sub));
tokio::spawn(async move {
let _ = flow_service::run_with_stream(db_clone, &id_clone, flow_service::RunReq { input }, user_info, tx).await;
});
// 由通用组件把 Receiver 包装为 SSE 响应
Ok(from_mpsc(rx))
}

View File

@ -0,0 +1,141 @@
use axum::{Router, middleware};
use crate::db::Db;
use crate::routes;
use crate::middlewares;
use tokio::net::TcpListener;
use tracing::info;
// 新增:错误类型与鉴权/Redis 校验、Flow 运行
use crate::error::AppError;
use crate::middlewares::jwt::decode_token;
use crate::services::flow_service;
use crate::flow::context::StreamEvent;
// 封装 WS 服务构建:返回仅包含 WS 路由与通用日志中间件的 Router
pub fn build_ws_app(db: Db) -> Router {
Router::new()
.nest("/api", routes::flows::ws_router())
.with_state(db.clone())
.layer(middleware::from_fn_with_state(db, middlewares::logging::request_logger))
}
// 启动 WS 服务,读取 WS_HOST/WS_PORT回退到 APP_HOST/默认端口),并启动监听
pub async fn serve(db: Db) -> Result<(), std::io::Error> {
let host = std::env::var("WS_HOST")
.ok()
.or_else(|| std::env::var("APP_HOST").ok())
.unwrap_or_else(|| "0.0.0.0".into());
let port = std::env::var("WS_PORT").unwrap_or_else(|_| "8855".into());
let addr = format!("{}:{}", host, port);
info!("ws listening on {}", addr);
let listener = TcpListener::bind(addr).await?;
axum::serve(listener, build_ws_app(db)).await
}
// ================= 路由处理:仅从路由层转发调用 =================
use std::collections::HashMap;
use axum::http::{HeaderMap, header::AUTHORIZATION};
use axum::response::Response;
use axum::extract::{State, Path, Query};
use axum::extract::ws::{WebSocketUpgrade, WebSocket, Message, Utf8Bytes};
pub async fn run_ws(
State(db): State<Db>,
Path(id): Path<String>,
Query(q): Query<HashMap<String, String>>,
headers: HeaderMap,
ws: WebSocketUpgrade,
) -> Result<Response, AppError> {
// 1) 认证:优先 Authorization其次查询参数 access_token
let token_opt = headers
.get(AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.strip_prefix("Bearer "))
.map(|s| s.to_string())
.or_else(|| q.get("access_token").cloned());
let token = token_opt.ok_or(AppError::Unauthorized)?;
let secret = std::env::var("JWT_SECRET").map_err(|_| AppError::Unauthorized)?;
let claims = decode_token(&token, &secret)?;
if claims.typ != "access" { return Err(AppError::Unauthorized); }
// 可选Redis 二次校验(与 AuthUser 提取逻辑一致)
let redis_validation_enabled = std::env::var("REDIS_TOKEN_VALIDATION")
.unwrap_or_else(|_| "true".to_string())
.parse::<bool>().unwrap_or(true);
if redis_validation_enabled {
let is_valid = crate::redis::TokenRedis::validate_access_token(&token, claims.uid).await.unwrap_or(false);
if !is_valid { return Err(AppError::Unauthorized); }
}
let db_clone = db.clone();
let id_clone = id.clone();
let user_info = Some((claims.uid, claims.sub));
Ok(ws.on_upgrade(move |socket| async move {
handle_ws_flow(socket, db_clone, id_clone, user_info).await;
}))
}
pub async fn handle_ws_flow(mut socket: WebSocket, db: Db, id: String, user_info: Option<(i64, String)>) {
use tokio::time::{timeout, Duration};
use tokio::select;
use tokio::sync::mpsc;
// 读取首条消息作为输入5s 超时),格式:{ "input": any }
let mut input_value = serde_json::Value::Object(serde_json::Map::new());
match timeout(Duration::from_secs(5), socket.recv()).await {
Ok(Some(Ok(Message::Text(s)))) => {
if let Ok(v) = serde_json::from_str::<serde_json::Value>(s.as_str()) {
if let Some(inp) = v.get("input") { input_value = inp.clone(); }
}
}
Ok(Some(Ok(Message::Binary(b)))) => {
if let Ok(s) = std::str::from_utf8(&b) {
if let Ok(v) = serde_json::from_str::<serde_json::Value>(s) {
if let Some(inp) = v.get("input") { input_value = inp.clone(); }
}
}
}
_ => {}
}
// mpsc 管道:引擎事件 -> 此任务
let (tx, mut rx) = mpsc::channel::<StreamEvent>(16);
// 后台运行流程
let db2 = db.clone();
let id2 = id.clone();
tokio::spawn(async move {
let _ = flow_service::run_with_stream(db2, &id2, flow_service::RunReq { input: input_value }, user_info, tx).await;
});
// 转发事件到 WebSocket
loop {
select! {
maybe_evt = rx.recv() => {
match maybe_evt {
None => { let _ = socket.send(Message::Close(None)).await; break; }
Some(evt) => {
let json = match evt {
StreamEvent::Node { node_id, logs, ctx } => serde_json::json!({"type":"node","node_id": node_id, "logs": logs, "ctx": ctx}),
StreamEvent::Done { ok, ctx, logs } => serde_json::json!({"type":"done","ok": ok, "ctx": ctx, "logs": logs}),
StreamEvent::Error { message } => serde_json::json!({"type":"error","message": message}),
};
let _ = socket.send(Message::Text(Utf8Bytes::from(json.to_string()))).await;
}
}
}
// 客户端主动关闭或发送其它指令(例如取消)
maybe_msg = socket.recv() => {
match maybe_msg {
None => break,
Some(Ok(Message::Close(_))) => break,
Some(Ok(_other)) => { /* 当前不处理取消等指令 */ }
Some(Err(_)) => break,
}
}
}
}
}

View File

@ -1,11 +1,27 @@
use axum::{Router, routing::{post, get}, extract::{State, Path, Query}, Json};
use crate::{db::Db, response::ApiResponse, services::{flow_service, log_service}, error::AppError};
use serde::Deserialize;
use tracing::{info, error};
use crate::middlewares::jwt::AuthUser;
// axum
use axum::extract::{Path, Query, State, ws::WebSocketUpgrade};
use axum::http::HeaderMap;
use axum::response::Response;
use axum::routing::{get, post};
use axum::{Json, Router};
// std
use std::collections::HashMap;
// third-party
use serde::Deserialize;
use tracing::{error, info};
// crate
use crate::middlewares::jwt::AuthUser;
// 新增:引入通用 SSE 组件
use crate::middlewares::sse;
use crate::{
db::Db,
error::AppError,
response::ApiResponse,
services::{flow_service, log_service},
};
pub fn router() -> Router<Db> {
Router::new()
@ -14,6 +30,14 @@ pub fn router() -> Router<Db> {
.route("/flows/{id}/run", post(run))
// 新增流式运行SSE端点
.route("/flows/{id}/run/stream", post(run_stream))
// 新增WebSocket 实时输出端点GET 握手)
.route("/flows/{id}/run/ws", get(run_ws))
}
// 新增:仅包含 WS 路由的精简 router便于在单独端口挂载
pub fn ws_router() -> Router<Db> {
Router::new()
.route("/flows/{id}/run/ws", get(run_ws))
}
#[derive(Deserialize)]
@ -106,4 +130,16 @@ async fn run_stream(State(db): State<Db>, user: AuthUser, Path(id): Path<String>
// 由通用组件把 Receiver 包装为 SSE 响应
Ok(sse::from_mpsc(rx))
}
// ================= WebSocket 模式:路由仅做转发 =================
async fn run_ws(
State(db): State<Db>,
Path(id): Path<String>,
Query(q): Query<HashMap<String, String>>,
headers: HeaderMap,
ws: WebSocketUpgrade,
) -> Result<Response, AppError> {
crate::middlewares::ws::run_ws(State(db), Path(id), Query(q), headers, ws).await
}

View File

@ -1,5 +1,3 @@
// removed unused: use std::collections::HashMap;
// removed unused: use std::sync::Mutex;
use anyhow::Context as _;
use serde::{Deserialize, Serialize};
@ -8,8 +6,6 @@ use crate::flow::{self, dsl::FlowDSL, engine::FlowEngine, context::{DriveOptions
use crate::db::Db;
use crate::models::flow as db_flow;
use crate::models::request_log; // 新增:查询最近修改人
use crate::services::flow_run_log_service;
use crate::services::flow_run_log_service::CreateRunLogInput;
use sea_orm::{EntityTrait, ActiveModelTrait, Set, DbErr, ColumnTrait, QueryFilter, PaginatorTrait, QueryOrder};
use sea_orm::entity::prelude::DateTimeWithTimeZone; // 新增:时间类型
use chrono::{Utc, FixedOffset};