feat(ws): 新增WebSocket实时通信支持与SSE独立服务
重构中间件结构,新增ws模块实现WebSocket流程执行实时推送 将SSE服务拆分为独立端口监听,默认8866 优化前端流式模式切换,支持WS/SSE协议选择 统一流式事件处理逻辑,完善错误处理与取消机制 更新Cargo.toml依赖,添加WebSocket相关库 调整代码组织结构,规范导入分组与注释
This commit is contained in:
@ -7,7 +7,7 @@ JWT_SECRET=dev_secret_change_me
|
||||
JWT_ISS=udmin
|
||||
JWT_ACCESS_EXP_SECS=1800
|
||||
JWT_REFRESH_EXP_SECS=1209600
|
||||
CORS_ALLOW_ORIGINS=http://localhost:5173,http://localhost:5174,http://localhost:5175
|
||||
CORS_ALLOW_ORIGINS=http://localhost:5173,http://localhost:5174,http://localhost:5175,http://127.0.0.1:5173,http://127.0.0.1:5174,http://127.0.0.1:5175,http://localhost:8888,http://127.0.0.1:8888
|
||||
|
||||
# Redis配置
|
||||
REDIS_URL=redis://:123456@127.0.0.1:6379/9
|
||||
|
||||
44
backend/Cargo.lock
generated
44
backend/Cargo.lock
generated
@ -244,6 +244,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "021e862c184ae977658b36c4500f7feac3221ca5da43e3f25bd04ab6c79a29b5"
|
||||
dependencies = [
|
||||
"axum-core",
|
||||
"base64 0.22.1",
|
||||
"bytes",
|
||||
"form_urlencoded",
|
||||
"futures-util",
|
||||
@ -263,8 +264,10 @@ dependencies = [
|
||||
"serde_json",
|
||||
"serde_path_to_error",
|
||||
"serde_urlencoded",
|
||||
"sha1",
|
||||
"sync_wrapper",
|
||||
"tokio",
|
||||
"tokio-tungstenite",
|
||||
"tower",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
@ -724,6 +727,12 @@ dependencies = [
|
||||
"syn 2.0.106",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "data-encoding"
|
||||
version = "2.9.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2a2330da5de22e8a3cb63252ce2abb30116bf5265e89c0e01bc17015ce30a476"
|
||||
|
||||
[[package]]
|
||||
name = "deadpool"
|
||||
version = "0.12.3"
|
||||
@ -3654,6 +3663,18 @@ dependencies = [
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-tungstenite"
|
||||
version = "0.26.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7a9daff607c6d2bf6c16fd681ccb7eecc83e4e2cdc1ca067ffaadfca5de7f084"
|
||||
dependencies = [
|
||||
"futures-util",
|
||||
"log",
|
||||
"tokio",
|
||||
"tungstenite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-util"
|
||||
version = "0.7.16"
|
||||
@ -3830,6 +3851,23 @@ version = "0.2.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b"
|
||||
|
||||
[[package]]
|
||||
name = "tungstenite"
|
||||
version = "0.26.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4793cb5e56680ecbb1d843515b23b6de9a75eb04b66643e256a396d43be33c13"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"data-encoding",
|
||||
"http",
|
||||
"httparse",
|
||||
"log",
|
||||
"rand 0.9.2",
|
||||
"sha1",
|
||||
"thiserror",
|
||||
"utf-8",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "typeid"
|
||||
version = "1.0.3"
|
||||
@ -3954,6 +3992,12 @@ dependencies = [
|
||||
"percent-encoding",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "utf-8"
|
||||
version = "0.7.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9"
|
||||
|
||||
[[package]]
|
||||
name = "utf8_iter"
|
||||
version = "1.0.4"
|
||||
|
||||
@ -5,7 +5,7 @@ edition = "2024"
|
||||
default-run = "udmin"
|
||||
|
||||
[dependencies]
|
||||
axum = "0.8.4"
|
||||
axum = { version = "0.8.4", features = ["ws"] }
|
||||
tokio = { version = "1.47.1", features = ["full"] }
|
||||
tower = "0.5.2"
|
||||
tower-http = { version = "0.6.6", features = ["cors", "trace"] }
|
||||
|
||||
@ -1,15 +1,40 @@
|
||||
// std
|
||||
use std::cell::RefCell;
|
||||
use std::collections::HashMap;
|
||||
use tokio::sync::{RwLock, Mutex};
|
||||
use futures::future::join_all;
|
||||
|
||||
use rhai::Engine;
|
||||
use tracing::info;
|
||||
use std::time::Instant;
|
||||
|
||||
// === 表达式评估支持:thread_local 引擎与 AST 缓存,避免全局 Sync/Send 限制 ===
|
||||
use std::cell::RefCell;
|
||||
use rhai::AST;
|
||||
// third-party
|
||||
use futures::future::join_all;
|
||||
use regex::Regex;
|
||||
use rhai::{AST, Engine};
|
||||
use tokio::sync::{Mutex, RwLock};
|
||||
use tracing::info;
|
||||
|
||||
// crate
|
||||
use crate::flow::executors::condition::eval_condition_json;
|
||||
|
||||
// super
|
||||
use super::{
|
||||
context::{DriveOptions, ExecutionMode},
|
||||
domain::{ChainDef, NodeKind},
|
||||
task::TaskRegistry,
|
||||
};
|
||||
|
||||
// 结构体:紧随 use
|
||||
pub struct FlowEngine {
|
||||
pub tasks: TaskRegistry,
|
||||
}
|
||||
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DriveError {
|
||||
pub node_id: String,
|
||||
pub ctx: serde_json::Value,
|
||||
pub logs: Vec<String>,
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
// === 表达式评估支持:thread_local 引擎与 AST 缓存,避免全局 Sync/Send 限制 ===
|
||||
|
||||
// 模块:流程执行引擎(engine.rs)
|
||||
// 作用:驱动 ChainDef 流程图,支持:
|
||||
@ -132,12 +157,6 @@ pub(crate) fn eval_rhai_expr_json(expr: &str, ctx: &serde_json::Value) -> Option
|
||||
Err(_) => None,
|
||||
}
|
||||
}
|
||||
use super::{context::{DriveOptions, ExecutionMode}, domain::{ChainDef, NodeKind}, task::TaskRegistry};
|
||||
use crate::flow::executors::condition::eval_condition_json;
|
||||
|
||||
pub struct FlowEngine {
|
||||
pub tasks: TaskRegistry,
|
||||
}
|
||||
|
||||
impl FlowEngine {
|
||||
pub fn new(tasks: TaskRegistry) -> Self { Self { tasks } }
|
||||
@ -491,14 +510,6 @@ impl Default for FlowEngine {
|
||||
}
|
||||
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DriveError {
|
||||
pub node_id: String,
|
||||
pub ctx: serde_json::Value,
|
||||
pub logs: Vec<String>,
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for DriveError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.message)
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
// third-party
|
||||
use anyhow::Result;
|
||||
use serde_json::Value as V;
|
||||
use tracing::info;
|
||||
|
||||
// 业务函数
|
||||
pub(crate) fn eval_condition_json(ctx: &serde_json::Value, cond: &serde_json::Value) -> Result<bool> {
|
||||
// 新增:若 cond 为数组,按 AND 语义评估(全部为 true 才为 true)
|
||||
if let Some(arr) = cond.as_array() {
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
// third-party
|
||||
use async_trait::async_trait;
|
||||
use serde_json::{json, Value};
|
||||
use tracing::info;
|
||||
|
||||
use crate::flow::task::Executor;
|
||||
// crate
|
||||
use crate::flow::domain::{NodeDef, NodeId};
|
||||
use crate::flow::task::Executor;
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct DbTask;
|
||||
|
||||
@ -1,12 +1,17 @@
|
||||
use async_trait::async_trait;
|
||||
use serde_json::{Value, json, Map};
|
||||
use tracing::info;
|
||||
// std
|
||||
use std::collections::HashMap;
|
||||
|
||||
// third-party
|
||||
use async_trait::async_trait;
|
||||
use serde_json::{json, Map, Value};
|
||||
use tracing::info;
|
||||
|
||||
// crate
|
||||
use crate::flow::domain::{NodeDef, NodeId};
|
||||
use crate::flow::task::Executor;
|
||||
use crate::middlewares::http_client::{execute_http, HttpClientOptions, HttpRequest};
|
||||
|
||||
use crate::flow::task::Executor;
|
||||
use crate::flow::domain::{NodeDef, NodeId};
|
||||
|
||||
// 结构体:紧随 use
|
||||
#[derive(Default)]
|
||||
pub struct HttpTask;
|
||||
|
||||
@ -18,6 +23,7 @@ struct HttpOpts {
|
||||
http1_only: bool,
|
||||
}
|
||||
|
||||
// 业务实现与函数:置于最后
|
||||
#[async_trait]
|
||||
impl Executor for HttpTask {
|
||||
async fn execute(&self, node_id: &NodeId, _node: &NodeDef, ctx: &mut Value) -> anyhow::Result<()> {
|
||||
|
||||
@ -1,12 +1,18 @@
|
||||
use async_trait::async_trait
|
||||
;
|
||||
// std
|
||||
use std::fs;
|
||||
use std::time::Instant;
|
||||
|
||||
// third-party
|
||||
use async_trait::async_trait;
|
||||
use serde_json::Value;
|
||||
use tracing::{debug, info};
|
||||
use std::time::Instant;
|
||||
use std::fs;
|
||||
|
||||
use crate::flow::task::Executor;
|
||||
// crate
|
||||
use crate::flow::domain::{NodeDef, NodeId};
|
||||
use crate::flow::task::Executor;
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct ScriptJsTask;
|
||||
|
||||
fn read_node_script_file(ctx: &Value, node_id: &str, lang_key: &str) -> Option<String> {
|
||||
if let Some(nodes) = ctx.get("nodes").and_then(|v| v.as_object()) {
|
||||
@ -123,9 +129,6 @@ fn exec_js_file(node_id: &NodeId, path: &str, ctx: &mut Value) -> anyhow::Result
|
||||
exec_js_script(node_id, &code, ctx)
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct ScriptJsTask;
|
||||
|
||||
#[async_trait]
|
||||
impl Executor for ScriptJsTask {
|
||||
async fn execute(&self, node_id: &NodeId, _node: &NodeDef, ctx: &mut Value) -> anyhow::Result<()> {
|
||||
|
||||
@ -1,10 +1,17 @@
|
||||
// std
|
||||
use std::time::Instant;
|
||||
|
||||
// third-party
|
||||
use async_trait::async_trait;
|
||||
use serde_json::Value;
|
||||
use tracing::{debug, info};
|
||||
use std::time::Instant;
|
||||
|
||||
use crate::flow::task::Executor;
|
||||
// crate
|
||||
use crate::flow::domain::{NodeDef, NodeId};
|
||||
use crate::flow::task::Executor;
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct ScriptPythonTask;
|
||||
|
||||
fn read_node_script_file(ctx: &Value, node_id: &str, lang_key: &str) -> Option<String> {
|
||||
if let Some(nodes) = ctx.get("nodes").and_then(|v| v.as_object()) {
|
||||
@ -20,9 +27,6 @@ fn truncate_str(s: &str, max: usize) -> String {
|
||||
if s.len() <= max { s } else { format!("{}…", &s[..max]) }
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct ScriptPythonTask;
|
||||
|
||||
#[async_trait]
|
||||
impl Executor for ScriptPythonTask {
|
||||
async fn execute(&self, node_id: &NodeId, _node: &NodeDef, ctx: &mut Value) -> anyhow::Result<()> {
|
||||
|
||||
@ -1,13 +1,19 @@
|
||||
use serde_json::Value;
|
||||
use tracing::{debug, info};
|
||||
// std
|
||||
use std::fs;
|
||||
use std::time::Instant;
|
||||
|
||||
use crate::flow::domain::NodeId;
|
||||
// third-party
|
||||
use async_trait::async_trait;
|
||||
use serde_json::Value;
|
||||
use tracing::{debug, info};
|
||||
|
||||
// crate
|
||||
use crate::flow::domain::{NodeDef, NodeId};
|
||||
use crate::flow::engine::eval_rhai_expr_json;
|
||||
use crate::flow::task::Executor;
|
||||
use crate::flow::domain::NodeDef;
|
||||
use async_trait::async_trait;
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct ScriptRhaiTask;
|
||||
|
||||
fn truncate_str(s: &str, max: usize) -> String {
|
||||
let s = s.replace('\n', " ").replace('\r', " ");
|
||||
@ -80,9 +86,6 @@ fn read_node_script_file(ctx: &Value, node_id: &str) -> Option<String> {
|
||||
None
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct ScriptRhaiTask;
|
||||
|
||||
#[async_trait]
|
||||
impl Executor for ScriptRhaiTask {
|
||||
async fn execute(&self, node_id: &NodeId, _node: &NodeDef, ctx: &mut Value) -> anyhow::Result<()> {
|
||||
|
||||
@ -1,10 +1,12 @@
|
||||
// third-party
|
||||
use async_trait::async_trait;
|
||||
use serde_json::{Value, json};
|
||||
use tracing::info;
|
||||
|
||||
use crate::flow::task::Executor;
|
||||
// crate
|
||||
use crate::flow::domain::{NodeDef, NodeId};
|
||||
use crate::flow::engine::eval_rhai_expr_json;
|
||||
use crate::flow::task::Executor;
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct VariableTask;
|
||||
|
||||
@ -17,9 +17,9 @@ pub trait FlowLogHandler: Send + Sync {
|
||||
async fn log_error(&self, flow_id: &str, flow_code: Option<&str>, input: &Value, error_msg: &str, operator: Option<(i64, String)>, started_at: DateTime<FixedOffset>, duration_ms: i64) -> anyhow::Result<()>;
|
||||
|
||||
/// 记录流程执行失败(包含部分输出与累计日志)
|
||||
async fn log_error_detail(&self, flow_id: &str, flow_code: Option<&str>, input: &Value, output: &Value, logs: &[String], error_msg: &str, operator: Option<(i64, String)>, started_at: DateTime<FixedOffset>, duration_ms: i64) -> anyhow::Result<()> {
|
||||
async fn log_error_detail(&self, _flow_id: &str, _flow_code: Option<&str>, _input: &Value, _output: &Value, _logs: &[String], error_msg: &str, _operator: Option<(i64, String)>, _started_at: DateTime<FixedOffset>, _duration_ms: i64) -> anyhow::Result<()> {
|
||||
// 默认实现:退化为仅错误信息
|
||||
self.log_error(flow_id, flow_code, input, error_msg, operator, started_at, duration_ms).await
|
||||
self.log_error(_flow_id, _flow_code, _input, error_msg, _operator, _started_at, _duration_ms).await
|
||||
}
|
||||
|
||||
/// 记录流程执行成功
|
||||
|
||||
@ -118,8 +118,18 @@ async fn main() -> anyhow::Result<()> {
|
||||
if let Some(f) = &env_file_used { tracing::info!("env file loaded: {}", f); } else { tracing::info!("env file loaded: <none>"); }
|
||||
tracing::info!("resolved APP_HOST={} APP_PORT={}", app_host, app_port);
|
||||
|
||||
let addr = format!("{}:{}", app_host, app_port);
|
||||
tracing::info!("listening on {}", addr);
|
||||
axum::serve(tokio::net::TcpListener::bind(addr).await?, app).await?;
|
||||
let http_addr = format!("{}:{}", app_host, app_port);
|
||||
tracing::info!("listening on {}", http_addr);
|
||||
|
||||
// HTTP 服务监听
|
||||
let http_listener = tokio::net::TcpListener::bind(http_addr.clone()).await?;
|
||||
let http_server = axum::serve(http_listener, app);
|
||||
|
||||
// WS 服务下沉到中间件
|
||||
let ws_server = middlewares::ws::serve(db.clone());
|
||||
|
||||
// 新增:SSE 服务独立端口监听(默认 8866,可配 SSE_HOST/SSE_PORT)
|
||||
let sse_server = middlewares::sse::serve(db.clone());
|
||||
tokio::try_join!(http_server, ws_server, sse_server)?;
|
||||
Ok(())
|
||||
}
|
||||
@ -2,3 +2,5 @@ pub mod jwt;
|
||||
pub mod logging;
|
||||
pub mod sse;
|
||||
pub mod http_client;
|
||||
pub mod ws;
|
||||
// removed: pub mod sse_server;
|
||||
@ -73,3 +73,132 @@ pub async fn emit_error(
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
// 追加所需 imports(避免重复导入 tracing::info 和 chrono::Utc)
|
||||
use axum::{Router, middleware, routing::post, extract::{State, Path, Query}, Json};
|
||||
use axum::http::HeaderMap;
|
||||
use tokio::net::TcpListener;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::db::Db;
|
||||
use crate::error::AppError;
|
||||
use crate::middlewares;
|
||||
use crate::middlewares::jwt::decode_token;
|
||||
use crate::services::flow_service;
|
||||
use tower_http::cors::{CorsLayer, Any, AllowOrigin};
|
||||
use axum::http::{HeaderValue, Method};
|
||||
|
||||
// 构建仅包含 SSE 路由与通用日志中间件的 Router(参照 ws::build_ws_app)
|
||||
pub fn build_sse_app(db: Db) -> Router {
|
||||
// 组装 CORS,与主服务保持一致的允许来源与头方法
|
||||
let allow_origins = std::env::var("CORS_ALLOW_ORIGINS").unwrap_or_else(|_| "http://localhost:5173".into());
|
||||
let origin_values: Vec<HeaderValue> = allow_origins
|
||||
.split(',')
|
||||
.filter_map(|s| HeaderValue::from_str(s.trim()).ok())
|
||||
.collect();
|
||||
|
||||
let allowed_methods = [
|
||||
Method::GET,
|
||||
Method::POST,
|
||||
Method::PUT,
|
||||
Method::PATCH,
|
||||
Method::DELETE,
|
||||
Method::OPTIONS,
|
||||
];
|
||||
|
||||
let cors = if origin_values.is_empty() {
|
||||
CorsLayer::new()
|
||||
.allow_origin(Any)
|
||||
.allow_methods(allowed_methods.clone())
|
||||
.allow_headers([
|
||||
axum::http::header::ACCEPT,
|
||||
axum::http::header::CONTENT_TYPE,
|
||||
axum::http::header::AUTHORIZATION,
|
||||
])
|
||||
.allow_credentials(false)
|
||||
} else {
|
||||
CorsLayer::new()
|
||||
.allow_origin(AllowOrigin::list(origin_values))
|
||||
.allow_methods(allowed_methods)
|
||||
.allow_headers([
|
||||
axum::http::header::ACCEPT,
|
||||
axum::http::header::CONTENT_TYPE,
|
||||
axum::http::header::AUTHORIZATION,
|
||||
])
|
||||
.allow_credentials(true)
|
||||
};
|
||||
|
||||
Router::new()
|
||||
.nest(
|
||||
"/api",
|
||||
Router::new().route("/flows/{id}/run/stream", post(run_sse)),
|
||||
)
|
||||
.with_state(db.clone())
|
||||
.layer(cors)
|
||||
.layer(middleware::from_fn_with_state(db, middlewares::logging::request_logger))
|
||||
}
|
||||
|
||||
// 启动 SSE 服务,默认端口 8866,可通过 SSE_HOST/SSE_PORT 覆盖(SSE_HOST 回退 APP_HOST)
|
||||
pub async fn serve(db: Db) -> Result<(), std::io::Error> {
|
||||
let host = std::env::var("SSE_HOST")
|
||||
.ok()
|
||||
.or_else(|| std::env::var("APP_HOST").ok())
|
||||
.unwrap_or_else(|| "0.0.0.0".into());
|
||||
let port = std::env::var("SSE_PORT").unwrap_or_else(|_| "8866".into());
|
||||
let addr = format!("{}:{}", host, port);
|
||||
tracing::info!("sse listening on {}", addr);
|
||||
|
||||
let listener = TcpListener::bind(addr).await?;
|
||||
axum::serve(listener, build_sse_app(db)).await
|
||||
}
|
||||
|
||||
// SSE 路由处理:参考 WebSocket 的鉴权与调用方式
|
||||
#[derive(serde::Deserialize)]
|
||||
struct RunReq { #[serde(default)] input: serde_json::Value }
|
||||
|
||||
async fn run_sse(
|
||||
State(db): State<Db>,
|
||||
Path(id): Path<String>,
|
||||
Query(q): Query<HashMap<String, String>>,
|
||||
headers: HeaderMap,
|
||||
Json(req): Json<RunReq>,
|
||||
) -> Result<axum::response::sse::Sse<impl futures::Stream<Item = Result<axum::response::sse::Event, std::convert::Infallible>>>, AppError> {
|
||||
use axum::http::header::AUTHORIZATION;
|
||||
|
||||
// 1) 认证:优先 Authorization,其次查询参数 access_token
|
||||
let token_opt = headers
|
||||
.get(AUTHORIZATION)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|s| s.strip_prefix("Bearer "))
|
||||
.map(|s| s.to_string())
|
||||
.or_else(|| q.get("access_token").cloned());
|
||||
|
||||
let token = token_opt.ok_or(AppError::Unauthorized)?;
|
||||
let secret = std::env::var("JWT_SECRET").map_err(|_| AppError::Unauthorized)?;
|
||||
let claims = decode_token(&token, &secret)?;
|
||||
if claims.typ != "access" { return Err(AppError::Unauthorized); }
|
||||
|
||||
// 可选:Redis 二次校验(与 WS 一致)
|
||||
let redis_validation_enabled = std::env::var("REDIS_TOKEN_VALIDATION")
|
||||
.unwrap_or_else(|_| "true".to_string())
|
||||
.parse::<bool>().unwrap_or(true);
|
||||
if redis_validation_enabled {
|
||||
let is_valid = crate::redis::TokenRedis::validate_access_token(&token, claims.uid).await.unwrap_or(false);
|
||||
if !is_valid { return Err(AppError::Unauthorized); }
|
||||
}
|
||||
|
||||
// 建立 mpsc 通道用于接收引擎的流式事件
|
||||
let (tx, rx) = tokio::sync::mpsc::channel::<crate::flow::context::StreamEvent>(16);
|
||||
|
||||
// 启动后台任务运行流程,将事件通过 tx 发送
|
||||
let db_clone = db.clone();
|
||||
let id_clone = id.clone();
|
||||
let input = req.input.clone();
|
||||
let user_info = Some((claims.uid, claims.sub));
|
||||
tokio::spawn(async move {
|
||||
let _ = flow_service::run_with_stream(db_clone, &id_clone, flow_service::RunReq { input }, user_info, tx).await;
|
||||
});
|
||||
|
||||
// 由通用组件把 Receiver 包装为 SSE 响应
|
||||
Ok(from_mpsc(rx))
|
||||
}
|
||||
141
backend/src/middlewares/ws.rs
Normal file
141
backend/src/middlewares/ws.rs
Normal file
@ -0,0 +1,141 @@
|
||||
use axum::{Router, middleware};
|
||||
use crate::db::Db;
|
||||
use crate::routes;
|
||||
use crate::middlewares;
|
||||
use tokio::net::TcpListener;
|
||||
use tracing::info;
|
||||
|
||||
// 新增:错误类型与鉴权/Redis 校验、Flow 运行
|
||||
use crate::error::AppError;
|
||||
use crate::middlewares::jwt::decode_token;
|
||||
use crate::services::flow_service;
|
||||
use crate::flow::context::StreamEvent;
|
||||
|
||||
// 封装 WS 服务构建:返回仅包含 WS 路由与通用日志中间件的 Router
|
||||
pub fn build_ws_app(db: Db) -> Router {
|
||||
Router::new()
|
||||
.nest("/api", routes::flows::ws_router())
|
||||
.with_state(db.clone())
|
||||
.layer(middleware::from_fn_with_state(db, middlewares::logging::request_logger))
|
||||
}
|
||||
|
||||
// 启动 WS 服务,读取 WS_HOST/WS_PORT(回退到 APP_HOST/默认端口),并启动监听
|
||||
pub async fn serve(db: Db) -> Result<(), std::io::Error> {
|
||||
let host = std::env::var("WS_HOST")
|
||||
.ok()
|
||||
.or_else(|| std::env::var("APP_HOST").ok())
|
||||
.unwrap_or_else(|| "0.0.0.0".into());
|
||||
let port = std::env::var("WS_PORT").unwrap_or_else(|_| "8855".into());
|
||||
let addr = format!("{}:{}", host, port);
|
||||
info!("ws listening on {}", addr);
|
||||
|
||||
let listener = TcpListener::bind(addr).await?;
|
||||
axum::serve(listener, build_ws_app(db)).await
|
||||
}
|
||||
|
||||
// ================= 路由处理:仅从路由层转发调用 =================
|
||||
use std::collections::HashMap;
|
||||
use axum::http::{HeaderMap, header::AUTHORIZATION};
|
||||
use axum::response::Response;
|
||||
use axum::extract::{State, Path, Query};
|
||||
use axum::extract::ws::{WebSocketUpgrade, WebSocket, Message, Utf8Bytes};
|
||||
|
||||
pub async fn run_ws(
|
||||
State(db): State<Db>,
|
||||
Path(id): Path<String>,
|
||||
Query(q): Query<HashMap<String, String>>,
|
||||
headers: HeaderMap,
|
||||
ws: WebSocketUpgrade,
|
||||
) -> Result<Response, AppError> {
|
||||
// 1) 认证:优先 Authorization,其次查询参数 access_token
|
||||
let token_opt = headers
|
||||
.get(AUTHORIZATION)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|s| s.strip_prefix("Bearer "))
|
||||
.map(|s| s.to_string())
|
||||
.or_else(|| q.get("access_token").cloned());
|
||||
|
||||
let token = token_opt.ok_or(AppError::Unauthorized)?;
|
||||
let secret = std::env::var("JWT_SECRET").map_err(|_| AppError::Unauthorized)?;
|
||||
let claims = decode_token(&token, &secret)?;
|
||||
if claims.typ != "access" { return Err(AppError::Unauthorized); }
|
||||
|
||||
// 可选:Redis 二次校验(与 AuthUser 提取逻辑一致)
|
||||
let redis_validation_enabled = std::env::var("REDIS_TOKEN_VALIDATION")
|
||||
.unwrap_or_else(|_| "true".to_string())
|
||||
.parse::<bool>().unwrap_or(true);
|
||||
if redis_validation_enabled {
|
||||
let is_valid = crate::redis::TokenRedis::validate_access_token(&token, claims.uid).await.unwrap_or(false);
|
||||
if !is_valid { return Err(AppError::Unauthorized); }
|
||||
}
|
||||
|
||||
let db_clone = db.clone();
|
||||
let id_clone = id.clone();
|
||||
let user_info = Some((claims.uid, claims.sub));
|
||||
|
||||
Ok(ws.on_upgrade(move |socket| async move {
|
||||
handle_ws_flow(socket, db_clone, id_clone, user_info).await;
|
||||
}))
|
||||
}
|
||||
|
||||
pub async fn handle_ws_flow(mut socket: WebSocket, db: Db, id: String, user_info: Option<(i64, String)>) {
|
||||
use tokio::time::{timeout, Duration};
|
||||
use tokio::select;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
// 读取首条消息作为输入(5s 超时),格式:{ "input": any }
|
||||
let mut input_value = serde_json::Value::Object(serde_json::Map::new());
|
||||
match timeout(Duration::from_secs(5), socket.recv()).await {
|
||||
Ok(Some(Ok(Message::Text(s)))) => {
|
||||
if let Ok(v) = serde_json::from_str::<serde_json::Value>(s.as_str()) {
|
||||
if let Some(inp) = v.get("input") { input_value = inp.clone(); }
|
||||
}
|
||||
}
|
||||
Ok(Some(Ok(Message::Binary(b)))) => {
|
||||
if let Ok(s) = std::str::from_utf8(&b) {
|
||||
if let Ok(v) = serde_json::from_str::<serde_json::Value>(s) {
|
||||
if let Some(inp) = v.get("input") { input_value = inp.clone(); }
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
// mpsc 管道:引擎事件 -> 此任务
|
||||
let (tx, mut rx) = mpsc::channel::<StreamEvent>(16);
|
||||
|
||||
// 后台运行流程
|
||||
let db2 = db.clone();
|
||||
let id2 = id.clone();
|
||||
tokio::spawn(async move {
|
||||
let _ = flow_service::run_with_stream(db2, &id2, flow_service::RunReq { input: input_value }, user_info, tx).await;
|
||||
});
|
||||
|
||||
// 转发事件到 WebSocket
|
||||
loop {
|
||||
select! {
|
||||
maybe_evt = rx.recv() => {
|
||||
match maybe_evt {
|
||||
None => { let _ = socket.send(Message::Close(None)).await; break; }
|
||||
Some(evt) => {
|
||||
let json = match evt {
|
||||
StreamEvent::Node { node_id, logs, ctx } => serde_json::json!({"type":"node","node_id": node_id, "logs": logs, "ctx": ctx}),
|
||||
StreamEvent::Done { ok, ctx, logs } => serde_json::json!({"type":"done","ok": ok, "ctx": ctx, "logs": logs}),
|
||||
StreamEvent::Error { message } => serde_json::json!({"type":"error","message": message}),
|
||||
};
|
||||
let _ = socket.send(Message::Text(Utf8Bytes::from(json.to_string()))).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
// 客户端主动关闭或发送其它指令(例如取消)
|
||||
maybe_msg = socket.recv() => {
|
||||
match maybe_msg {
|
||||
None => break,
|
||||
Some(Ok(Message::Close(_))) => break,
|
||||
Some(Ok(_other)) => { /* 当前不处理取消等指令 */ }
|
||||
Some(Err(_)) => break,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,11 +1,27 @@
|
||||
use axum::{Router, routing::{post, get}, extract::{State, Path, Query}, Json};
|
||||
use crate::{db::Db, response::ApiResponse, services::{flow_service, log_service}, error::AppError};
|
||||
use serde::Deserialize;
|
||||
use tracing::{info, error};
|
||||
use crate::middlewares::jwt::AuthUser;
|
||||
// axum
|
||||
use axum::extract::{Path, Query, State, ws::WebSocketUpgrade};
|
||||
use axum::http::HeaderMap;
|
||||
use axum::response::Response;
|
||||
use axum::routing::{get, post};
|
||||
use axum::{Json, Router};
|
||||
|
||||
// std
|
||||
use std::collections::HashMap;
|
||||
|
||||
// third-party
|
||||
use serde::Deserialize;
|
||||
use tracing::{error, info};
|
||||
|
||||
// crate
|
||||
use crate::middlewares::jwt::AuthUser;
|
||||
// 新增:引入通用 SSE 组件
|
||||
use crate::middlewares::sse;
|
||||
use crate::{
|
||||
db::Db,
|
||||
error::AppError,
|
||||
response::ApiResponse,
|
||||
services::{flow_service, log_service},
|
||||
};
|
||||
|
||||
pub fn router() -> Router<Db> {
|
||||
Router::new()
|
||||
@ -14,6 +30,14 @@ pub fn router() -> Router<Db> {
|
||||
.route("/flows/{id}/run", post(run))
|
||||
// 新增:流式运行(SSE)端点
|
||||
.route("/flows/{id}/run/stream", post(run_stream))
|
||||
// 新增:WebSocket 实时输出端点(GET 握手)
|
||||
.route("/flows/{id}/run/ws", get(run_ws))
|
||||
}
|
||||
|
||||
// 新增:仅包含 WS 路由的精简 router,便于在单独端口挂载
|
||||
pub fn ws_router() -> Router<Db> {
|
||||
Router::new()
|
||||
.route("/flows/{id}/run/ws", get(run_ws))
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
@ -107,3 +131,15 @@ async fn run_stream(State(db): State<Db>, user: AuthUser, Path(id): Path<String>
|
||||
// 由通用组件把 Receiver 包装为 SSE 响应
|
||||
Ok(sse::from_mpsc(rx))
|
||||
}
|
||||
|
||||
// ================= WebSocket 模式:路由仅做转发 =================
|
||||
|
||||
async fn run_ws(
|
||||
State(db): State<Db>,
|
||||
Path(id): Path<String>,
|
||||
Query(q): Query<HashMap<String, String>>,
|
||||
headers: HeaderMap,
|
||||
ws: WebSocketUpgrade,
|
||||
) -> Result<Response, AppError> {
|
||||
crate::middlewares::ws::run_ws(State(db), Path(id), Query(q), headers, ws).await
|
||||
}
|
||||
@ -1,5 +1,3 @@
|
||||
// removed unused: use std::collections::HashMap;
|
||||
// removed unused: use std::sync::Mutex;
|
||||
use anyhow::Context as _;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
@ -8,8 +6,6 @@ use crate::flow::{self, dsl::FlowDSL, engine::FlowEngine, context::{DriveOptions
|
||||
use crate::db::Db;
|
||||
use crate::models::flow as db_flow;
|
||||
use crate::models::request_log; // 新增:查询最近修改人
|
||||
use crate::services::flow_run_log_service;
|
||||
use crate::services::flow_run_log_service::CreateRunLogInput;
|
||||
use sea_orm::{EntityTrait, ActiveModelTrait, Set, DbErr, ColumnTrait, QueryFilter, PaginatorTrait, QueryOrder};
|
||||
use sea_orm::entity::prelude::DateTimeWithTimeZone; // 新增:时间类型
|
||||
use chrono::{Utc, FixedOffset};
|
||||
|
||||
@ -11,6 +11,7 @@
|
||||
justify-content: space-between;
|
||||
gap: 8px;
|
||||
margin: 0 12px 8px 0;
|
||||
flex-wrap: wrap; // 允许在小屏时换行,但尽量保持在一行
|
||||
|
||||
.title {
|
||||
font-size: 15px;
|
||||
@ -18,6 +19,13 @@
|
||||
color: #333;
|
||||
flex: 1;
|
||||
}
|
||||
|
||||
.toggle {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
gap: 6px;
|
||||
white-space: nowrap;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -50,6 +50,15 @@ export const TestRunSidePanel: FC<TestRunSidePanelProps> = ({ visible, onCancel
|
||||
_setStreamMode(checked);
|
||||
localStorage.setItem('testrun-stream-mode', JSON.stringify(checked));
|
||||
};
|
||||
// 当启用流式时,选择 WS 或 SSE(默认 SSE)
|
||||
const [useWS, _setUseWS] = useState<boolean>(() => {
|
||||
const saved = localStorage.getItem('testrun-ws-mode');
|
||||
return saved ? JSON.parse(saved) : false;
|
||||
});
|
||||
const setUseWS = (checked: boolean) => {
|
||||
_setUseWS(checked);
|
||||
localStorage.setItem('testrun-ws-mode', JSON.stringify(checked));
|
||||
};
|
||||
|
||||
// 流式渲染:实时上下文与日志
|
||||
const [streamCtx, setStreamCtx] = useState<any | undefined>();
|
||||
@ -95,7 +104,25 @@ export const TestRunSidePanel: FC<TestRunSidePanelProps> = ({ visible, onCancel
|
||||
}
|
||||
|
||||
if (streamMode) {
|
||||
const { cancel, done } = customService.runStream(values, {
|
||||
const startStream = () => useWS
|
||||
? customService.runStreamWS(values, {
|
||||
onNode: (evt) => {
|
||||
if (evt.ctx) setStreamCtx((prev: any) => ({ ...(prev || {}), ...(evt.ctx || {}) }));
|
||||
if (evt.logs && evt.logs.length) setStreamLogs((prev: string[]) => [...prev, ...evt.logs!]);
|
||||
},
|
||||
onError: (evt) => {
|
||||
const msg = evt.message || I18n.t('Run failed');
|
||||
setErrors((prev) => [...(prev || []), msg]);
|
||||
},
|
||||
onDone: (evt) => {
|
||||
setResult({ ok: evt.ok, ctx: evt.ctx, logs: evt.logs });
|
||||
},
|
||||
onFatal: (err) => {
|
||||
setErrors((prev) => [...(prev || []), err.message || String(err)]);
|
||||
setRunning(false);
|
||||
},
|
||||
})
|
||||
: customService.runStream(values, {
|
||||
onNode: (evt) => {
|
||||
if (evt.ctx) setStreamCtx((prev: any) => ({ ...(prev || {}), ...(evt.ctx || {}) }));
|
||||
if (evt.logs && evt.logs.length) setStreamLogs((prev: string[]) => [...prev, ...evt.logs!]);
|
||||
@ -113,6 +140,8 @@ export const TestRunSidePanel: FC<TestRunSidePanelProps> = ({ visible, onCancel
|
||||
},
|
||||
});
|
||||
|
||||
const { cancel, done } = startStream();
|
||||
|
||||
cancelRef.current = cancel;
|
||||
|
||||
const finished = await done;
|
||||
@ -212,12 +241,15 @@ export const TestRunSidePanel: FC<TestRunSidePanelProps> = ({ visible, onCancel
|
||||
<div className={styles['testrun-panel-form']}>
|
||||
<div className={styles['testrun-panel-input']}>
|
||||
<div className={styles.title}>{I18n.t('Input Form')}</div>
|
||||
<div className={styles.toggle}>
|
||||
<div>{I18n.t('JSON Mode')}</div>
|
||||
<Switch
|
||||
checked={inputJSONMode}
|
||||
onChange={(checked: boolean) => setInputJSONMode(checked)}
|
||||
size="small"
|
||||
/>
|
||||
</div>
|
||||
<div className={styles.toggle}>
|
||||
<div>{I18n.t('Streaming Mode')}</div>
|
||||
<Switch
|
||||
checked={streamMode}
|
||||
@ -225,6 +257,17 @@ export const TestRunSidePanel: FC<TestRunSidePanelProps> = ({ visible, onCancel
|
||||
size="small"
|
||||
/>
|
||||
</div>
|
||||
{streamMode && (
|
||||
<div className={styles.toggle}>
|
||||
<div>WS</div>
|
||||
<Switch
|
||||
checked={useWS}
|
||||
onChange={(checked: boolean) => setUseWS(checked)}
|
||||
size="small"
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
{renderStatus}
|
||||
{errors?.map((e) => (
|
||||
<div className={styles.error} key={e}>
|
||||
|
||||
@ -15,6 +15,7 @@ import { I18n } from '@flowgram.ai/free-layout-editor';
|
||||
import api, { type ApiResp } from '../../utils/axios';
|
||||
import { stringifyFlowDoc } from '../utils/yaml';
|
||||
import { postSSE } from '../../utils/sse';
|
||||
import { getToken } from '../../utils/token'
|
||||
|
||||
interface RunResult { ok: boolean; ctx: any; logs: string[] }
|
||||
|
||||
@ -124,7 +125,126 @@ export class CustomService {
|
||||
}
|
||||
}
|
||||
|
||||
// 新增:SSE 流式运行,返回取消函数与完成 Promise
|
||||
// 新增:WebSocket 流式运行
|
||||
runStreamWS(
|
||||
input: any = {},
|
||||
handlers?: {
|
||||
onNode?: (e: StreamEvent & { type: 'node' }) => void;
|
||||
onDone?: (e: StreamEvent & { type: 'done' }) => void;
|
||||
onError?: (e: StreamEvent & { type: 'error' }) => void;
|
||||
onFatal?: (err: Error) => void;
|
||||
}
|
||||
) {
|
||||
const id = getFlowIdFromUrl();
|
||||
if (!id) {
|
||||
const err = new Error(I18n.t('Flow ID is missing, cannot run'));
|
||||
handlers?.onFatal?.(err);
|
||||
return { cancel: () => {}, done: Promise.resolve<RunResult | null>(null) } as const;
|
||||
}
|
||||
|
||||
// 构造 WS URL
|
||||
const base = (api.defaults.baseURL || '') as string; // 可能是 /api 或 http(s)://host/api
|
||||
function toWsUrl(httpUrl: string) {
|
||||
if (httpUrl.startsWith('https://')) return 'wss://' + httpUrl.slice('https://'.length);
|
||||
if (httpUrl.startsWith('http://')) return 'ws://' + httpUrl.slice('http://'.length);
|
||||
// 相对路径:拼 window.location
|
||||
const origin = window.location.origin; // http(s)://host:port
|
||||
const full = origin.replace(/^http/, 'ws') + (httpUrl.startsWith('/') ? httpUrl : '/' + httpUrl);
|
||||
return full;
|
||||
}
|
||||
const path = `/flows/${id}/run/ws`;
|
||||
// 取 token 放到查询参数(WS 握手无法自定义 Authorization 头部)
|
||||
const token = getToken();
|
||||
|
||||
// 新增:WS 使用独立端口,默认 8855,可通过 VITE_WS_PORT 覆盖
|
||||
const wsPort = (import.meta as any).env?.VITE_WS_PORT || '8855';
|
||||
let wsBase: string;
|
||||
if (base.startsWith('http://') || base.startsWith('https://')) {
|
||||
try {
|
||||
const u = new URL(base);
|
||||
const proto = u.protocol === 'https:' ? 'wss:' : 'ws:';
|
||||
u.protocol = proto;
|
||||
u.port = wsPort; // 改为 WS 端口
|
||||
wsBase = `${u.protocol}//${u.host}${u.pathname.replace(/\/$/, '')}`;
|
||||
} catch {
|
||||
const loc = window.location;
|
||||
const proto = loc.protocol === 'https:' ? 'wss:' : 'ws:';
|
||||
wsBase = `${proto}//${loc.hostname}:${wsPort}${base.startsWith('/') ? base : '/' + base}`;
|
||||
}
|
||||
} else {
|
||||
const loc = window.location;
|
||||
const proto = loc.protocol === 'https:' ? 'wss:' : 'ws:';
|
||||
wsBase = `${proto}//${loc.hostname}:${wsPort}${base.startsWith('/') ? base : '/' + base}`;
|
||||
}
|
||||
|
||||
const wsUrl = wsBase + path + (token ? (wsBase.includes('?') ? `&access_token=${encodeURIComponent(token)}` : `?access_token=${encodeURIComponent(token)}`) : '');
|
||||
|
||||
let ws: WebSocket | null = null;
|
||||
let resolveDone: (v: RunResult | null) => void;
|
||||
let rejectDone: (e: any) => void;
|
||||
const done = new Promise<RunResult | null>((resolve, reject) => { resolveDone = resolve; rejectDone = reject; });
|
||||
let finished = false;
|
||||
|
||||
try {
|
||||
ws = new WebSocket(wsUrl);
|
||||
} catch (e: any) {
|
||||
handlers?.onFatal?.(e);
|
||||
return { cancel: () => {}, done: Promise.resolve<RunResult | null>(null) } as const;
|
||||
}
|
||||
|
||||
ws.onopen = () => {
|
||||
try {
|
||||
ws?.send(JSON.stringify({ input }));
|
||||
} catch (e: any) {
|
||||
handlers?.onFatal?.(e);
|
||||
}
|
||||
};
|
||||
|
||||
ws.onmessage = (ev: MessageEvent) => {
|
||||
try {
|
||||
const data = typeof ev.data === 'string' ? ev.data : '' + ev.data;
|
||||
const evt = JSON.parse(data) as StreamEvent;
|
||||
if (evt.type === 'node') {
|
||||
handlers?.onNode?.(evt as any);
|
||||
return;
|
||||
}
|
||||
if (evt.type === 'error') {
|
||||
handlers?.onError?.(evt as any);
|
||||
return;
|
||||
}
|
||||
if (evt.type === 'done') {
|
||||
finished = true;
|
||||
handlers?.onDone?.(evt as any);
|
||||
resolveDone({ ok: evt.ok, ctx: (evt as any).ctx, logs: (evt as any).logs });
|
||||
ws?.close();
|
||||
return;
|
||||
}
|
||||
} catch (e: any) {
|
||||
// 忽略解析错误为致命,仅记录
|
||||
console.warn('WS message parse error', e);
|
||||
}
|
||||
};
|
||||
|
||||
ws.onerror = (ev: Event) => {
|
||||
if (!finished) {
|
||||
handlers?.onFatal?.(new Error('WebSocket error'));
|
||||
}
|
||||
};
|
||||
|
||||
ws.onclose = () => {
|
||||
if (!finished) {
|
||||
resolveDone(null);
|
||||
}
|
||||
};
|
||||
|
||||
const cancel = () => {
|
||||
try { finished = true; ws?.close(); } catch {}
|
||||
};
|
||||
|
||||
return { cancel, done } as const;
|
||||
}
|
||||
|
||||
// 现有:SSE 流式运行
|
||||
runStream(input: any = {}, handlers?: { onNode?: (e: StreamEvent & { type: 'node' }) => void; onDone?: (e: StreamEvent & { type: 'done' }) => void; onError?: (e: StreamEvent & { type: 'error' }) => void; onFatal?: (err: Error) => void; }) {
|
||||
const id = getFlowIdFromUrl();
|
||||
if (!id) {
|
||||
@ -134,7 +254,25 @@ export class CustomService {
|
||||
}
|
||||
|
||||
const base = (api.defaults.baseURL || '') as string;
|
||||
const url = base ? `${base}/flows/${id}/run/stream` : `/flows/${id}/run/stream`;
|
||||
// 参照 WS:SSE 使用独立端口,默认 8866,可通过 VITE_SSE_PORT 覆盖
|
||||
const ssePort = (import.meta as any).env?.VITE_SSE_PORT || '8866';
|
||||
let sseBase: string;
|
||||
if (base.startsWith('http://') || base.startsWith('https://')) {
|
||||
try {
|
||||
const u = new URL(base);
|
||||
// 协议保持与 base 一致,仅替换端口
|
||||
u.port = ssePort;
|
||||
sseBase = `${u.protocol}//${u.host}${u.pathname.replace(/\/$/, '')}`;
|
||||
} catch {
|
||||
const loc = window.location;
|
||||
sseBase = `${loc.protocol}//${loc.hostname}:${ssePort}${base.startsWith('/') ? base : '/' + base}`;
|
||||
}
|
||||
} else {
|
||||
const loc = window.location;
|
||||
sseBase = `${loc.protocol}//${loc.hostname}:${ssePort}${base.startsWith('/') ? base : '/' + base}`;
|
||||
}
|
||||
|
||||
const url = sseBase + `/flows/${id}/run/stream`;
|
||||
|
||||
const { cancel, done } = postSSE<RunResult | null>(url, { input }, {
|
||||
onMessage: (json: any) => {
|
||||
@ -150,12 +288,12 @@ export class CustomService {
|
||||
}
|
||||
if (evt.type === 'done') {
|
||||
handlers?.onDone?.(evt as any)
|
||||
return { ok: evt.ok, ctx: evt.ctx, logs: evt.logs }
|
||||
return { ok: (evt as any).ok, ctx: (evt as any).ctx, logs: (evt as any).logs }
|
||||
}
|
||||
} catch (_) {}
|
||||
return undefined
|
||||
},
|
||||
onFatal: (e) => handlers?.onFatal?.(e),
|
||||
onFatal: (e: any) => handlers?.onFatal?.(e),
|
||||
})
|
||||
return { cancel, done } as const;
|
||||
}
|
||||
|
||||
123
frontend/src/utils/__tests__/sse.test.ts
Normal file
123
frontend/src/utils/__tests__/sse.test.ts
Normal file
@ -0,0 +1,123 @@
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'
|
||||
import { postSSE } from '../sse'
|
||||
|
||||
// Mocks for axios instance and token utils used inside postSSE
|
||||
vi.mock('../axios', () => {
|
||||
return {
|
||||
default: {
|
||||
get: vi.fn(async () => ({ data: { code: 0, data: { access_token: 'tok2' } } })),
|
||||
defaults: { baseURL: '/api' },
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
let tokenSeq: string[] = ['tok0']
|
||||
vi.mock('../token', () => {
|
||||
return { getToken: vi.fn(() => (tokenSeq.length ? tokenSeq.shift() : undefined)) }
|
||||
})
|
||||
|
||||
const encoder = new TextEncoder()
|
||||
function makeSSEStream(chunks: string[]): ReadableStream<Uint8Array> {
|
||||
return new ReadableStream<Uint8Array>({
|
||||
start(controller) {
|
||||
for (const c of chunks) controller.enqueue(encoder.encode(c))
|
||||
controller.close()
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
describe('postSSE', () => {
|
||||
const originalFetch = global.fetch as any
|
||||
|
||||
beforeEach(() => {
|
||||
vi.useFakeTimers()
|
||||
tokenSeq = ['tok0']
|
||||
})
|
||||
afterEach(() => {
|
||||
vi.useRealTimers()
|
||||
vi.restoreAllMocks()
|
||||
// restore fetch if we replaced it
|
||||
if (originalFetch) {
|
||||
// @ts-ignore
|
||||
global.fetch = originalFetch
|
||||
}
|
||||
})
|
||||
|
||||
it('streams node and done events (CRLF) and resolves with done payload', async () => {
|
||||
const bodyStream = makeSSEStream([
|
||||
'data: {"type":"node","node_id":"n1","logs":["l1"],"server_ts":"2025-01-01T00:00:00.000Z"}\r\n\r\n',
|
||||
'data: {"type":"done","ok":true,"ctx":{},"logs":[],"server_ts":"2025-01-01T00:00:01.000Z"}\r\n\r\n',
|
||||
])
|
||||
|
||||
const resp = new Response(bodyStream, { status: 200 })
|
||||
const fetchMock = vi.fn(async () => resp)
|
||||
// replace global fetch
|
||||
// @ts-ignore
|
||||
global.fetch = fetchMock
|
||||
|
||||
const events: any[] = []
|
||||
const { done } = postSSE<{ ok: boolean; ctx: any; logs: string[] }>('http://example.com/sse', { a: 1 }, {
|
||||
onMessage: (json) => {
|
||||
events.push(json)
|
||||
if (json.type === 'done') return { ok: json.ok, ctx: json.ctx, logs: json.logs }
|
||||
},
|
||||
})
|
||||
|
||||
const result = await done
|
||||
expect(result).toEqual({ ok: true, ctx: {}, logs: [] })
|
||||
expect(events.some(e => e.type === 'node')).toBe(true)
|
||||
expect(events.some(e => e.type === 'done')).toBe(true)
|
||||
expect(fetchMock).toHaveBeenCalledTimes(1)
|
||||
const call = fetchMock.mock.calls[0] as any[]
|
||||
const init = call?.[1]
|
||||
expect((init as any).headers.Authorization).toBe('Bearer tok0')
|
||||
})
|
||||
|
||||
it('refreshes on 401 then retries with new token and resolves', async () => {
|
||||
tokenSeq = ['tok1', 'tok2']
|
||||
const first = new Response(null, { status: 401 })
|
||||
const sse = makeSSEStream(['data: {"type":"done","ok":true,"ctx":{},"logs":[]}\n\n'])
|
||||
const second = new Response(sse, { status: 200 })
|
||||
const fetchMock = vi.fn()
|
||||
.mockResolvedValueOnce(first)
|
||||
.mockResolvedValueOnce(second)
|
||||
// replace global fetch
|
||||
// @ts-ignore
|
||||
global.fetch = fetchMock
|
||||
|
||||
const { done } = postSSE('http://example.com/sse', {}, {
|
||||
onMessage: (j) => { if (j.type === 'done') return { ok: j.ok, ctx: j.ctx, logs: j.logs } },
|
||||
})
|
||||
const res = await done
|
||||
expect(res).toEqual({ ok: true, ctx: {}, logs: [] })
|
||||
expect(fetchMock).toHaveBeenCalledTimes(2)
|
||||
const call1 = fetchMock.mock.calls[0] as any[]
|
||||
const call2 = fetchMock.mock.calls[1] as any[]
|
||||
const init1 = call1?.[1]
|
||||
const init2 = call2?.[1]
|
||||
expect((init1 as any).headers.Authorization).toBe('Bearer tok1')
|
||||
expect((init2 as any).headers.Authorization).toBe('Bearer tok2')
|
||||
})
|
||||
|
||||
it('cancel aborts the stream and resolves to null', async () => {
|
||||
// a long stream without done
|
||||
const longStream = new ReadableStream<Uint8Array>({
|
||||
start(controller) {
|
||||
controller.enqueue(encoder.encode('data: {"type":"node","logs":["l1"]}\n\n'))
|
||||
// do not close to simulate long running
|
||||
},
|
||||
pull() {},
|
||||
cancel() {},
|
||||
})
|
||||
const fetchMock = vi.fn(async () => new Response(longStream, { status: 200 }))
|
||||
// replace global fetch
|
||||
// @ts-ignore
|
||||
global.fetch = fetchMock
|
||||
|
||||
const { cancel, done } = postSSE('http://example.com/sse', {}, { onMessage: () => {} })
|
||||
// cancel immediately
|
||||
cancel()
|
||||
const out = await done
|
||||
expect(out).toBeNull()
|
||||
})
|
||||
})
|
||||
@ -31,6 +31,7 @@ export default defineConfig(({ mode }) => {
|
||||
'/api': {
|
||||
target: proxyTarget,
|
||||
changeOrigin: true,
|
||||
ws: true,
|
||||
// 为 SSE 透传加固:禁用超时并保持连接
|
||||
proxyTimeout: 0,
|
||||
timeout: 0,
|
||||
|
||||
Reference in New Issue
Block a user