feat(ws): 新增WebSocket实时通信支持与SSE独立服务
重构中间件结构,新增ws模块实现WebSocket流程执行实时推送 将SSE服务拆分为独立端口监听,默认8866 优化前端流式模式切换,支持WS/SSE协议选择 统一流式事件处理逻辑,完善错误处理与取消机制 更新Cargo.toml依赖,添加WebSocket相关库 调整代码组织结构,规范导入分组与注释
This commit is contained in:
@ -1,4 +1,6 @@
|
||||
pub mod jwt;
|
||||
pub mod logging;
|
||||
pub mod sse;
|
||||
pub mod http_client;
|
||||
pub mod http_client;
|
||||
pub mod ws;
|
||||
// removed: pub mod sse_server;
|
||||
@ -72,4 +72,133 @@ pub async fn emit_error(
|
||||
message: msg,
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
// 追加所需 imports(避免重复导入 tracing::info 和 chrono::Utc)
|
||||
use axum::{Router, middleware, routing::post, extract::{State, Path, Query}, Json};
|
||||
use axum::http::HeaderMap;
|
||||
use tokio::net::TcpListener;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::db::Db;
|
||||
use crate::error::AppError;
|
||||
use crate::middlewares;
|
||||
use crate::middlewares::jwt::decode_token;
|
||||
use crate::services::flow_service;
|
||||
use tower_http::cors::{CorsLayer, Any, AllowOrigin};
|
||||
use axum::http::{HeaderValue, Method};
|
||||
|
||||
// 构建仅包含 SSE 路由与通用日志中间件的 Router(参照 ws::build_ws_app)
|
||||
pub fn build_sse_app(db: Db) -> Router {
|
||||
// 组装 CORS,与主服务保持一致的允许来源与头方法
|
||||
let allow_origins = std::env::var("CORS_ALLOW_ORIGINS").unwrap_or_else(|_| "http://localhost:5173".into());
|
||||
let origin_values: Vec<HeaderValue> = allow_origins
|
||||
.split(',')
|
||||
.filter_map(|s| HeaderValue::from_str(s.trim()).ok())
|
||||
.collect();
|
||||
|
||||
let allowed_methods = [
|
||||
Method::GET,
|
||||
Method::POST,
|
||||
Method::PUT,
|
||||
Method::PATCH,
|
||||
Method::DELETE,
|
||||
Method::OPTIONS,
|
||||
];
|
||||
|
||||
let cors = if origin_values.is_empty() {
|
||||
CorsLayer::new()
|
||||
.allow_origin(Any)
|
||||
.allow_methods(allowed_methods.clone())
|
||||
.allow_headers([
|
||||
axum::http::header::ACCEPT,
|
||||
axum::http::header::CONTENT_TYPE,
|
||||
axum::http::header::AUTHORIZATION,
|
||||
])
|
||||
.allow_credentials(false)
|
||||
} else {
|
||||
CorsLayer::new()
|
||||
.allow_origin(AllowOrigin::list(origin_values))
|
||||
.allow_methods(allowed_methods)
|
||||
.allow_headers([
|
||||
axum::http::header::ACCEPT,
|
||||
axum::http::header::CONTENT_TYPE,
|
||||
axum::http::header::AUTHORIZATION,
|
||||
])
|
||||
.allow_credentials(true)
|
||||
};
|
||||
|
||||
Router::new()
|
||||
.nest(
|
||||
"/api",
|
||||
Router::new().route("/flows/{id}/run/stream", post(run_sse)),
|
||||
)
|
||||
.with_state(db.clone())
|
||||
.layer(cors)
|
||||
.layer(middleware::from_fn_with_state(db, middlewares::logging::request_logger))
|
||||
}
|
||||
|
||||
// 启动 SSE 服务,默认端口 8866,可通过 SSE_HOST/SSE_PORT 覆盖(SSE_HOST 回退 APP_HOST)
|
||||
pub async fn serve(db: Db) -> Result<(), std::io::Error> {
|
||||
let host = std::env::var("SSE_HOST")
|
||||
.ok()
|
||||
.or_else(|| std::env::var("APP_HOST").ok())
|
||||
.unwrap_or_else(|| "0.0.0.0".into());
|
||||
let port = std::env::var("SSE_PORT").unwrap_or_else(|_| "8866".into());
|
||||
let addr = format!("{}:{}", host, port);
|
||||
tracing::info!("sse listening on {}", addr);
|
||||
|
||||
let listener = TcpListener::bind(addr).await?;
|
||||
axum::serve(listener, build_sse_app(db)).await
|
||||
}
|
||||
|
||||
// SSE 路由处理:参考 WebSocket 的鉴权与调用方式
|
||||
#[derive(serde::Deserialize)]
|
||||
struct RunReq { #[serde(default)] input: serde_json::Value }
|
||||
|
||||
async fn run_sse(
|
||||
State(db): State<Db>,
|
||||
Path(id): Path<String>,
|
||||
Query(q): Query<HashMap<String, String>>,
|
||||
headers: HeaderMap,
|
||||
Json(req): Json<RunReq>,
|
||||
) -> Result<axum::response::sse::Sse<impl futures::Stream<Item = Result<axum::response::sse::Event, std::convert::Infallible>>>, AppError> {
|
||||
use axum::http::header::AUTHORIZATION;
|
||||
|
||||
// 1) 认证:优先 Authorization,其次查询参数 access_token
|
||||
let token_opt = headers
|
||||
.get(AUTHORIZATION)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|s| s.strip_prefix("Bearer "))
|
||||
.map(|s| s.to_string())
|
||||
.or_else(|| q.get("access_token").cloned());
|
||||
|
||||
let token = token_opt.ok_or(AppError::Unauthorized)?;
|
||||
let secret = std::env::var("JWT_SECRET").map_err(|_| AppError::Unauthorized)?;
|
||||
let claims = decode_token(&token, &secret)?;
|
||||
if claims.typ != "access" { return Err(AppError::Unauthorized); }
|
||||
|
||||
// 可选:Redis 二次校验(与 WS 一致)
|
||||
let redis_validation_enabled = std::env::var("REDIS_TOKEN_VALIDATION")
|
||||
.unwrap_or_else(|_| "true".to_string())
|
||||
.parse::<bool>().unwrap_or(true);
|
||||
if redis_validation_enabled {
|
||||
let is_valid = crate::redis::TokenRedis::validate_access_token(&token, claims.uid).await.unwrap_or(false);
|
||||
if !is_valid { return Err(AppError::Unauthorized); }
|
||||
}
|
||||
|
||||
// 建立 mpsc 通道用于接收引擎的流式事件
|
||||
let (tx, rx) = tokio::sync::mpsc::channel::<crate::flow::context::StreamEvent>(16);
|
||||
|
||||
// 启动后台任务运行流程,将事件通过 tx 发送
|
||||
let db_clone = db.clone();
|
||||
let id_clone = id.clone();
|
||||
let input = req.input.clone();
|
||||
let user_info = Some((claims.uid, claims.sub));
|
||||
tokio::spawn(async move {
|
||||
let _ = flow_service::run_with_stream(db_clone, &id_clone, flow_service::RunReq { input }, user_info, tx).await;
|
||||
});
|
||||
|
||||
// 由通用组件把 Receiver 包装为 SSE 响应
|
||||
Ok(from_mpsc(rx))
|
||||
}
|
||||
141
backend/src/middlewares/ws.rs
Normal file
141
backend/src/middlewares/ws.rs
Normal file
@ -0,0 +1,141 @@
|
||||
use axum::{Router, middleware};
|
||||
use crate::db::Db;
|
||||
use crate::routes;
|
||||
use crate::middlewares;
|
||||
use tokio::net::TcpListener;
|
||||
use tracing::info;
|
||||
|
||||
// 新增:错误类型与鉴权/Redis 校验、Flow 运行
|
||||
use crate::error::AppError;
|
||||
use crate::middlewares::jwt::decode_token;
|
||||
use crate::services::flow_service;
|
||||
use crate::flow::context::StreamEvent;
|
||||
|
||||
// 封装 WS 服务构建:返回仅包含 WS 路由与通用日志中间件的 Router
|
||||
pub fn build_ws_app(db: Db) -> Router {
|
||||
Router::new()
|
||||
.nest("/api", routes::flows::ws_router())
|
||||
.with_state(db.clone())
|
||||
.layer(middleware::from_fn_with_state(db, middlewares::logging::request_logger))
|
||||
}
|
||||
|
||||
// 启动 WS 服务,读取 WS_HOST/WS_PORT(回退到 APP_HOST/默认端口),并启动监听
|
||||
pub async fn serve(db: Db) -> Result<(), std::io::Error> {
|
||||
let host = std::env::var("WS_HOST")
|
||||
.ok()
|
||||
.or_else(|| std::env::var("APP_HOST").ok())
|
||||
.unwrap_or_else(|| "0.0.0.0".into());
|
||||
let port = std::env::var("WS_PORT").unwrap_or_else(|_| "8855".into());
|
||||
let addr = format!("{}:{}", host, port);
|
||||
info!("ws listening on {}", addr);
|
||||
|
||||
let listener = TcpListener::bind(addr).await?;
|
||||
axum::serve(listener, build_ws_app(db)).await
|
||||
}
|
||||
|
||||
// ================= 路由处理:仅从路由层转发调用 =================
|
||||
use std::collections::HashMap;
|
||||
use axum::http::{HeaderMap, header::AUTHORIZATION};
|
||||
use axum::response::Response;
|
||||
use axum::extract::{State, Path, Query};
|
||||
use axum::extract::ws::{WebSocketUpgrade, WebSocket, Message, Utf8Bytes};
|
||||
|
||||
pub async fn run_ws(
|
||||
State(db): State<Db>,
|
||||
Path(id): Path<String>,
|
||||
Query(q): Query<HashMap<String, String>>,
|
||||
headers: HeaderMap,
|
||||
ws: WebSocketUpgrade,
|
||||
) -> Result<Response, AppError> {
|
||||
// 1) 认证:优先 Authorization,其次查询参数 access_token
|
||||
let token_opt = headers
|
||||
.get(AUTHORIZATION)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|s| s.strip_prefix("Bearer "))
|
||||
.map(|s| s.to_string())
|
||||
.or_else(|| q.get("access_token").cloned());
|
||||
|
||||
let token = token_opt.ok_or(AppError::Unauthorized)?;
|
||||
let secret = std::env::var("JWT_SECRET").map_err(|_| AppError::Unauthorized)?;
|
||||
let claims = decode_token(&token, &secret)?;
|
||||
if claims.typ != "access" { return Err(AppError::Unauthorized); }
|
||||
|
||||
// 可选:Redis 二次校验(与 AuthUser 提取逻辑一致)
|
||||
let redis_validation_enabled = std::env::var("REDIS_TOKEN_VALIDATION")
|
||||
.unwrap_or_else(|_| "true".to_string())
|
||||
.parse::<bool>().unwrap_or(true);
|
||||
if redis_validation_enabled {
|
||||
let is_valid = crate::redis::TokenRedis::validate_access_token(&token, claims.uid).await.unwrap_or(false);
|
||||
if !is_valid { return Err(AppError::Unauthorized); }
|
||||
}
|
||||
|
||||
let db_clone = db.clone();
|
||||
let id_clone = id.clone();
|
||||
let user_info = Some((claims.uid, claims.sub));
|
||||
|
||||
Ok(ws.on_upgrade(move |socket| async move {
|
||||
handle_ws_flow(socket, db_clone, id_clone, user_info).await;
|
||||
}))
|
||||
}
|
||||
|
||||
pub async fn handle_ws_flow(mut socket: WebSocket, db: Db, id: String, user_info: Option<(i64, String)>) {
|
||||
use tokio::time::{timeout, Duration};
|
||||
use tokio::select;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
// 读取首条消息作为输入(5s 超时),格式:{ "input": any }
|
||||
let mut input_value = serde_json::Value::Object(serde_json::Map::new());
|
||||
match timeout(Duration::from_secs(5), socket.recv()).await {
|
||||
Ok(Some(Ok(Message::Text(s)))) => {
|
||||
if let Ok(v) = serde_json::from_str::<serde_json::Value>(s.as_str()) {
|
||||
if let Some(inp) = v.get("input") { input_value = inp.clone(); }
|
||||
}
|
||||
}
|
||||
Ok(Some(Ok(Message::Binary(b)))) => {
|
||||
if let Ok(s) = std::str::from_utf8(&b) {
|
||||
if let Ok(v) = serde_json::from_str::<serde_json::Value>(s) {
|
||||
if let Some(inp) = v.get("input") { input_value = inp.clone(); }
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
// mpsc 管道:引擎事件 -> 此任务
|
||||
let (tx, mut rx) = mpsc::channel::<StreamEvent>(16);
|
||||
|
||||
// 后台运行流程
|
||||
let db2 = db.clone();
|
||||
let id2 = id.clone();
|
||||
tokio::spawn(async move {
|
||||
let _ = flow_service::run_with_stream(db2, &id2, flow_service::RunReq { input: input_value }, user_info, tx).await;
|
||||
});
|
||||
|
||||
// 转发事件到 WebSocket
|
||||
loop {
|
||||
select! {
|
||||
maybe_evt = rx.recv() => {
|
||||
match maybe_evt {
|
||||
None => { let _ = socket.send(Message::Close(None)).await; break; }
|
||||
Some(evt) => {
|
||||
let json = match evt {
|
||||
StreamEvent::Node { node_id, logs, ctx } => serde_json::json!({"type":"node","node_id": node_id, "logs": logs, "ctx": ctx}),
|
||||
StreamEvent::Done { ok, ctx, logs } => serde_json::json!({"type":"done","ok": ok, "ctx": ctx, "logs": logs}),
|
||||
StreamEvent::Error { message } => serde_json::json!({"type":"error","message": message}),
|
||||
};
|
||||
let _ = socket.send(Message::Text(Utf8Bytes::from(json.to_string()))).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
// 客户端主动关闭或发送其它指令(例如取消)
|
||||
maybe_msg = socket.recv() => {
|
||||
match maybe_msg {
|
||||
None => break,
|
||||
Some(Ok(Message::Close(_))) => break,
|
||||
Some(Ok(_other)) => { /* 当前不处理取消等指令 */ }
|
||||
Some(Err(_)) => break,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user