# 中间件文档 ## 概述 中间件层是 UdminAI 系统的核心组件,基于 Axum 框架实现,提供了认证授权、请求日志、错误处理、CORS、WebSocket 支持、SSE(Server-Sent Events)等功能。中间件采用洋葱模型,按顺序处理请求和响应。 ## 架构设计 ### 中间件模块结构 ``` middlewares/ ├── mod.rs # 中间件模块导出 ├── auth.rs # 认证中间件 ├── cors.rs # CORS 中间件 ├── http_client.rs # HTTP 客户端中间件 ├── logging.rs # 请求日志中间件 ├── rate_limit.rs # 限流中间件 ├── sse.rs # Server-Sent Events 中间件 └── ws.rs # WebSocket 中间件 ``` ### 设计原则 - **模块化**: 每个中间件独立实现,职责单一 - **可组合**: 中间件可以灵活组合使用 - **高性能**: 最小化性能开销 - **类型安全**: 利用 Rust 类型系统确保安全 - **可配置**: 支持灵活的配置选项 - **可测试**: 易于单元测试和集成测试 ## 认证中间件 (auth.rs) ### 功能特性 - JWT Token 验证 - 用户身份识别 - 权限检查 - Token 刷新机制 - 多种认证策略 ### 实现代码 ```rust use axum::{ extract::{Request, State}, http::{header, StatusCode}, middleware::Next, response::Response, }; use jsonwebtoken::{decode, DecodingKey, Validation}; use serde::{Deserialize, Serialize}; use std::collections::HashSet; use tracing::{error, info, warn}; use crate::{ error::AppError, models::user, services::user_service, AppState, }; /// JWT Claims 结构 #[derive(Debug, Serialize, Deserialize, Clone)] pub struct Claims { pub sub: String, // 用户ID pub username: String, // 用户名 pub roles: Vec, // 角色列表 pub permissions: Vec, // 权限列表 pub exp: usize, // 过期时间 pub iat: usize, // 签发时间 pub iss: String, // 签发者 } /// 认证上下文 #[derive(Debug, Clone)] pub struct AuthContext { pub user_id: String, pub username: String, pub roles: Vec, pub permissions: HashSet, } /// JWT 认证中间件 pub async fn jwt_auth_middleware( State(state): State, mut request: Request, next: Next, ) -> Result { let auth_header = request .headers() .get(header::AUTHORIZATION) .and_then(|header| header.to_str().ok()); let token = match auth_header { Some(header) if header.starts_with("Bearer ") => { header.trim_start_matches("Bearer ") } _ => { warn!(target = "udmin", "Missing or invalid authorization header"); return Err(AppError::Unauthorized("缺少认证令牌".to_string())); } }; // 验证 JWT Token let claims = match decode::( token, &DecodingKey::from_secret(state.config.jwt_secret.as_ref()), &Validation::default(), ) { Ok(token_data) => token_data.claims, Err(err) => { error!(target = "udmin", error = %err, "JWT token validation failed"); return Err(AppError::Unauthorized("无效的认证令牌".to_string())); } }; // 检查用户是否存在且活跃 let user = user_service::find_by_id(&state.db, &claims.sub) .await .map_err(|_| AppError::Unauthorized("用户不存在".to_string()))?; if user.status != user::UserStatus::Active { warn!(target = "udmin", user_id = %claims.sub, "Inactive user attempted access"); return Err(AppError::Unauthorized("用户账户已被禁用".to_string())); } // 创建认证上下文 let auth_context = AuthContext { user_id: claims.sub.clone(), username: claims.username.clone(), roles: claims.roles.clone(), permissions: claims.permissions.into_iter().collect(), }; // 将认证上下文添加到请求扩展中 request.extensions_mut().insert(auth_context); info!( target = "udmin", user_id = %claims.sub, username = %claims.username, "User authenticated successfully" ); Ok(next.run(request).await) } /// 权限检查中间件 pub fn require_permission(required_permission: &'static str) -> impl Fn(Request, Next) -> std::pin::Pin> + Send>> + Clone { move |request: Request, next: Next| { Box::pin(async move { let auth_context = request .extensions() .get::() .ok_or_else(|| AppError::Unauthorized("未认证的请求".to_string()))?; if !auth_context.permissions.contains(required_permission) { warn!( target = "udmin", user_id = %auth_context.user_id, required_permission = %required_permission, "Permission denied" ); return Err(AppError::Forbidden("权限不足".to_string())); } info!( target = "udmin", user_id = %auth_context.user_id, permission = %required_permission, "Permission granted" ); Ok(next.run(request).await) }) } } /// 角色检查中间件 pub fn require_role(required_role: &'static str) -> impl Fn(Request, Next) -> std::pin::Pin> + Send>> + Clone { move |request: Request, next: Next| { Box::pin(async move { let auth_context = request .extensions() .get::() .ok_or_else(|| AppError::Unauthorized("未认证的请求".to_string()))?; if !auth_context.roles.contains(&required_role.to_string()) { warn!( target = "udmin", user_id = %auth_context.user_id, required_role = %required_role, "Role check failed" ); return Err(AppError::Forbidden("角色权限不足".to_string())); } info!( target = "udmin", user_id = %auth_context.user_id, role = %required_role, "Role check passed" ); Ok(next.run(request).await) }) } } /// 可选认证中间件(不强制要求认证) pub async fn optional_auth_middleware( State(state): State, mut request: Request, next: Next, ) -> Response { let auth_header = request .headers() .get(header::AUTHORIZATION) .and_then(|header| header.to_str().ok()); if let Some(header) = auth_header { if let Some(token) = header.strip_prefix("Bearer ") { if let Ok(token_data) = decode::( token, &DecodingKey::from_secret(state.config.jwt_secret.as_ref()), &Validation::default(), ) { let claims = token_data.claims; // 创建认证上下文 let auth_context = AuthContext { user_id: claims.sub.clone(), username: claims.username.clone(), roles: claims.roles.clone(), permissions: claims.permissions.into_iter().collect(), }; request.extensions_mut().insert(auth_context); } } } next.run(request).await } ``` ## CORS 中间件 (cors.rs) ### 功能特性 - 跨域资源共享支持 - 可配置的允许源 - 预检请求处理 - 安全头设置 ### 实现代码 ```rust use axum::{ extract::Request, http::{header, HeaderValue, Method, StatusCode}, middleware::Next, response::Response, }; use tower_http::cors::{Any, CorsLayer}; use tracing::info; /// CORS 配置 #[derive(Debug, Clone)] pub struct CorsConfig { pub allowed_origins: Vec, pub allowed_methods: Vec, pub allowed_headers: Vec, pub max_age: u64, pub allow_credentials: bool, } impl Default for CorsConfig { fn default() -> Self { Self { allowed_origins: vec![ "http://localhost:3000".to_string(), "http://localhost:5173".to_string(), "http://127.0.0.1:3000".to_string(), "http://127.0.0.1:5173".to_string(), ], allowed_methods: vec![ Method::GET, Method::POST, Method::PUT, Method::DELETE, Method::PATCH, Method::OPTIONS, ], allowed_headers: vec![ "content-type".to_string(), "authorization".to_string(), "x-requested-with".to_string(), "x-api-key".to_string(), ], max_age: 3600, allow_credentials: true, } } } /// 创建 CORS 层 pub fn create_cors_layer(config: CorsConfig) -> CorsLayer { let mut cors = CorsLayer::new() .allow_methods(config.allowed_methods) .allow_headers( config .allowed_headers .iter() .map(|h| h.parse().unwrap()) .collect::>(), ) .max_age(std::time::Duration::from_secs(config.max_age)); // 设置允许的源 if config.allowed_origins.contains(&"*".to_string()) { cors = cors.allow_origin(Any); } else { cors = cors.allow_origin( config .allowed_origins .iter() .map(|origin| origin.parse().unwrap()) .collect::>(), ); } // 设置是否允许凭据 if config.allow_credentials { cors = cors.allow_credentials(true); } info!(target = "udmin", "CORS middleware configured"); cors } /// 自定义 CORS 中间件 pub async fn cors_middleware( request: Request, next: Next, ) -> Response { let origin = request .headers() .get(header::ORIGIN) .and_then(|v| v.to_str().ok()); let method = request.method(); // 处理预检请求 if method == Method::OPTIONS { let mut response = Response::builder() .status(StatusCode::OK) .body(axum::body::Body::empty()) .unwrap(); let headers = response.headers_mut(); if let Some(origin) = origin { headers.insert( header::ACCESS_CONTROL_ALLOW_ORIGIN, HeaderValue::from_str(origin).unwrap(), ); } headers.insert( header::ACCESS_CONTROL_ALLOW_METHODS, HeaderValue::from_static("GET, POST, PUT, DELETE, PATCH, OPTIONS"), ); headers.insert( header::ACCESS_CONTROL_ALLOW_HEADERS, HeaderValue::from_static("content-type, authorization, x-requested-with, x-api-key"), ); headers.insert( header::ACCESS_CONTROL_MAX_AGE, HeaderValue::from_static("3600"), ); headers.insert( header::ACCESS_CONTROL_ALLOW_CREDENTIALS, HeaderValue::from_static("true"), ); return response; } let mut response = next.run(request).await; // 为实际请求添加 CORS 头 let headers = response.headers_mut(); if let Some(origin) = origin { headers.insert( header::ACCESS_CONTROL_ALLOW_ORIGIN, HeaderValue::from_str(origin).unwrap(), ); } headers.insert( header::ACCESS_CONTROL_ALLOW_CREDENTIALS, HeaderValue::from_static("true"), ); response } ``` ## 请求日志中间件 (logging.rs) ### 功能特性 - 请求响应日志记录 - 性能监控 - 错误追踪 - 结构化日志 - 请求ID追踪 ### 实现代码 ```rust use axum::{ extract::{MatchedPath, Request}, http::StatusCode, middleware::Next, response::Response, }; use std::time::Instant; use tracing::{error, info, warn}; use uuid::Uuid; /// 请求日志中间件 pub async fn request_logging_middleware( request: Request, next: Next, ) -> Response { let start_time = Instant::now(); let request_id = Uuid::new_v4().to_string(); // 提取请求信息 let method = request.method().clone(); let uri = request.uri().clone(); let version = request.version(); let user_agent = request .headers() .get("user-agent") .and_then(|v| v.to_str().ok()) .unwrap_or("unknown"); let remote_addr = request .headers() .get("x-forwarded-for") .and_then(|v| v.to_str().ok()) .or_else(|| { request .headers() .get("x-real-ip") .and_then(|v| v.to_str().ok()) }) .unwrap_or("unknown"); // 获取匹配的路径模式 let matched_path = request .extensions() .get::() .map(|mp| mp.as_str()) .unwrap_or(uri.path()); info!( target = "udmin", request_id = %request_id, method = %method, uri = %uri, path = %matched_path, version = ?version, user_agent = %user_agent, remote_addr = %remote_addr, "Request started" ); // 执行请求 let response = next.run(request).await; // 计算请求耗时 let duration = start_time.elapsed(); let status = response.status(); // 根据状态码选择日志级别 match status.as_u16() { 200..=299 => { info!( target = "udmin", request_id = %request_id, method = %method, uri = %uri, path = %matched_path, status = %status.as_u16(), duration_ms = %duration.as_millis(), "Request completed successfully" ); } 300..=399 => { info!( target = "udmin", request_id = %request_id, method = %method, uri = %uri, path = %matched_path, status = %status.as_u16(), duration_ms = %duration.as_millis(), "Request redirected" ); } 400..=499 => { warn!( target = "udmin", request_id = %request_id, method = %method, uri = %uri, path = %matched_path, status = %status.as_u16(), duration_ms = %duration.as_millis(), "Client error" ); } 500..=599 => { error!( target = "udmin", request_id = %request_id, method = %method, uri = %uri, path = %matched_path, status = %status.as_u16(), duration_ms = %duration.as_millis(), "Server error" ); } _ => { info!( target = "udmin", request_id = %request_id, method = %method, uri = %uri, path = %matched_path, status = %status.as_u16(), duration_ms = %duration.as_millis(), "Request completed" ); } } response } /// 性能监控中间件 pub async fn performance_monitoring_middleware( request: Request, next: Next, ) -> Response { let start_time = Instant::now(); let method = request.method().clone(); let uri = request.uri().clone(); let response = next.run(request).await; let duration = start_time.elapsed(); // 记录慢请求 if duration.as_millis() > 1000 { warn!( target = "udmin", method = %method, uri = %uri, duration_ms = %duration.as_millis(), "Slow request detected" ); } // 记录性能指标 info!( target = "udmin.performance", method = %method, uri = %uri, status = %response.status().as_u16(), duration_ms = %duration.as_millis(), "Performance metrics" ); response } ``` ## HTTP 客户端中间件 (http_client.rs) ### 功能特性 - HTTP 客户端封装 - 请求重试机制 - 超时控制 - 请求日志 - 错误处理 ### 实现代码 ```rust use reqwest::{Client, ClientBuilder, Response}; use serde::{Deserialize, Serialize}; use std::time::Duration; use tracing::{error, info, warn}; use url::Url; use crate::error::AppError; /// HTTP 客户端配置 #[derive(Debug, Clone)] pub struct HttpClientConfig { pub timeout: Duration, pub connect_timeout: Duration, pub max_retries: u32, pub retry_delay: Duration, pub user_agent: String, } impl Default for HttpClientConfig { fn default() -> Self { Self { timeout: Duration::from_secs(30), connect_timeout: Duration::from_secs(10), max_retries: 3, retry_delay: Duration::from_millis(1000), user_agent: "UdminAI/1.0".to_string(), } } } /// HTTP 客户端包装器 #[derive(Debug, Clone)] pub struct HttpClient { client: Client, config: HttpClientConfig, } impl HttpClient { /// 创建新的 HTTP 客户端 pub fn new(config: HttpClientConfig) -> Result { let client = ClientBuilder::new() .timeout(config.timeout) .connect_timeout(config.connect_timeout) .user_agent(&config.user_agent) .build() .map_err(|e| AppError::InternalServerError(format!("创建HTTP客户端失败: {}", e)))?; Ok(Self { client, config }) } /// 发送 GET 请求 pub async fn get(&self, url: &str) -> Result { self.request_with_retry("GET", url, None::<()>).await } /// 发送 POST 请求 pub async fn post(&self, url: &str, body: &T) -> Result { self.request_with_retry("POST", url, Some(body)).await } /// 发送 PUT 请求 pub async fn put(&self, url: &str, body: &T) -> Result { self.request_with_retry("PUT", url, Some(body)).await } /// 发送 DELETE 请求 pub async fn delete(&self, url: &str) -> Result { self.request_with_retry("DELETE", url, None::<()>).await } /// 带重试的请求 async fn request_with_retry( &self, method: &str, url: &str, body: Option<&T>, ) -> Result { let parsed_url = Url::parse(url) .map_err(|e| AppError::BadRequest(format!("无效的URL: {}", e)))?; let mut last_error = None; for attempt in 0..=self.config.max_retries { let start_time = std::time::Instant::now(); info!( target = "udmin", method = %method, url = %url, attempt = %attempt, "HTTP request started" ); let result = self.make_request(method, &parsed_url, body).await; let duration = start_time.elapsed(); match result { Ok(response) => { let status = response.status(); info!( target = "udmin", method = %method, url = %url, status = %status.as_u16(), duration_ms = %duration.as_millis(), attempt = %attempt, "HTTP request completed" ); // 如果是服务器错误且还有重试次数,则重试 if status.is_server_error() && attempt < self.config.max_retries { warn!( target = "udmin", method = %method, url = %url, status = %status.as_u16(), attempt = %attempt, "Server error, retrying" ); tokio::time::sleep(self.config.retry_delay).await; continue; } return Ok(response); } Err(e) => { error!( target = "udmin", method = %method, url = %url, error = %e, duration_ms = %duration.as_millis(), attempt = %attempt, "HTTP request failed" ); last_error = Some(e); // 如果还有重试次数,则等待后重试 if attempt < self.config.max_retries { tokio::time::sleep(self.config.retry_delay).await; } } } } Err(last_error.unwrap_or_else(|| { AppError::InternalServerError("HTTP请求失败".to_string()) })) } /// 执行实际的 HTTP 请求 async fn make_request( &self, method: &str, url: &Url, body: Option<&T>, ) -> Result { let mut request_builder = match method { "GET" => self.client.get(url.clone()), "POST" => self.client.post(url.clone()), "PUT" => self.client.put(url.clone()), "DELETE" => self.client.delete(url.clone()), _ => { return Err(AppError::BadRequest(format!( "不支持的HTTP方法: {}", method ))) } }; // 添加请求体 if let Some(body) = body { request_builder = request_builder.json(body); } // 发送请求 let response = request_builder .send() .await .map_err(|e| AppError::InternalServerError(format!("HTTP请求失败: {}", e)))?; Ok(response) } /// 发送 JSON 请求并解析响应 pub async fn json_request( &self, method: &str, url: &str, body: Option<&T>, ) -> Result where T: Serialize, R: for<'de> Deserialize<'de>, { let response = match method { "GET" => self.get(url).await?, "POST" => { if let Some(body) = body { self.post(url, body).await? } else { return Err(AppError::BadRequest("POST请求需要请求体".to_string())); } } "PUT" => { if let Some(body) = body { self.put(url, body).await? } else { return Err(AppError::BadRequest("PUT请求需要请求体".to_string())); } } "DELETE" => self.delete(url).await?, _ => { return Err(AppError::BadRequest(format!( "不支持的HTTP方法: {}", method ))) } }; // 检查响应状态 if !response.status().is_success() { let status = response.status(); let error_text = response .text() .await .unwrap_or_else(|_| "无法读取错误响应".to_string()); return Err(AppError::InternalServerError(format!( "HTTP请求失败: {} - {}", status, error_text ))); } // 解析 JSON 响应 let json_response = response .json::() .await .map_err(|e| AppError::InternalServerError(format!("解析JSON响应失败: {}", e)))?; Ok(json_response) } } ``` ## 限流中间件 (rate_limit.rs) ### 功能特性 - 基于令牌桶的限流 - 支持不同的限流策略 - IP 级别限流 - 用户级别限流 - 动态配置 ### 实现代码 ```rust use axum::{ extract::{Request, State}, http::StatusCode, middleware::Next, response::Response, }; use std::{ collections::HashMap, sync::{Arc, Mutex}, time::{Duration, Instant}, }; use tracing::{info, warn}; use crate::{error::AppError, middlewares::auth::AuthContext}; /// 令牌桶 #[derive(Debug, Clone)] struct TokenBucket { capacity: u32, tokens: u32, last_refill: Instant, refill_rate: u32, // tokens per second } impl TokenBucket { fn new(capacity: u32, refill_rate: u32) -> Self { Self { capacity, tokens: capacity, last_refill: Instant::now(), refill_rate, } } fn try_consume(&mut self, tokens: u32) -> bool { self.refill(); if self.tokens >= tokens { self.tokens -= tokens; true } else { false } } fn refill(&mut self) { let now = Instant::now(); let elapsed = now.duration_since(self.last_refill); let tokens_to_add = (elapsed.as_secs_f64() * self.refill_rate as f64) as u32; if tokens_to_add > 0 { self.tokens = (self.tokens + tokens_to_add).min(self.capacity); self.last_refill = now; } } } /// 限流器 #[derive(Debug)] pub struct RateLimiter { buckets: Arc>>, default_capacity: u32, default_refill_rate: u32, } impl RateLimiter { pub fn new(default_capacity: u32, default_refill_rate: u32) -> Self { Self { buckets: Arc::new(Mutex::new(HashMap::new())), default_capacity, default_refill_rate, } } pub fn check_rate_limit(&self, key: &str, tokens: u32) -> bool { let mut buckets = self.buckets.lock().unwrap(); let bucket = buckets .entry(key.to_string()) .or_insert_with(|| TokenBucket::new(self.default_capacity, self.default_refill_rate)); bucket.try_consume(tokens) } /// 清理过期的桶 pub fn cleanup_expired_buckets(&self, max_idle_duration: Duration) { let mut buckets = self.buckets.lock().unwrap(); let now = Instant::now(); buckets.retain(|_, bucket| { now.duration_since(bucket.last_refill) < max_idle_duration }); } } /// 限流配置 #[derive(Debug, Clone)] pub struct RateLimitConfig { pub requests_per_second: u32, pub burst_capacity: u32, pub enable_ip_limit: bool, pub enable_user_limit: bool, } impl Default for RateLimitConfig { fn default() -> Self { Self { requests_per_second: 100, burst_capacity: 200, enable_ip_limit: true, enable_user_limit: true, } } } /// IP 限流中间件 pub async fn ip_rate_limit_middleware( State(rate_limiter): State>, request: Request, next: Next, ) -> Result { // 获取客户端 IP let client_ip = request .headers() .get("x-forwarded-for") .and_then(|v| v.to_str().ok()) .and_then(|v| v.split(',').next()) .or_else(|| { request .headers() .get("x-real-ip") .and_then(|v| v.to_str().ok()) }) .unwrap_or("unknown") .trim(); let rate_limit_key = format!("ip:{}", client_ip); // 检查限流 if !rate_limiter.check_rate_limit(&rate_limit_key, 1) { warn!( target = "udmin", client_ip = %client_ip, "IP rate limit exceeded" ); return Err(AppError::TooManyRequests( "请求过于频繁,请稍后再试".to_string(), )); } info!( target = "udmin", client_ip = %client_ip, "IP rate limit check passed" ); Ok(next.run(request).await) } /// 用户限流中间件 pub async fn user_rate_limit_middleware( State(rate_limiter): State>, request: Request, next: Next, ) -> Result { // 获取用户认证信息 if let Some(auth_context) = request.extensions().get::() { let rate_limit_key = format!("user:{}", auth_context.user_id); // 检查用户级别限流 if !rate_limiter.check_rate_limit(&rate_limit_key, 1) { warn!( target = "udmin", user_id = %auth_context.user_id, "User rate limit exceeded" ); return Err(AppError::TooManyRequests( "用户请求过于频繁,请稍后再试".to_string(), )); } info!( target = "udmin", user_id = %auth_context.user_id, "User rate limit check passed" ); } Ok(next.run(request).await) } /// 创建限流中间件 pub fn create_rate_limit_middleware( config: RateLimitConfig, ) -> Arc { let rate_limiter = Arc::new(RateLimiter::new( config.burst_capacity, config.requests_per_second, )); // 启动清理任务 let cleanup_limiter = rate_limiter.clone(); tokio::spawn(async move { let mut interval = tokio::time::interval(Duration::from_secs(300)); // 5分钟清理一次 loop { interval.tick().await; cleanup_limiter.cleanup_expired_buckets(Duration::from_secs(3600)); // 清理1小时未使用的桶 } }); info!(target = "udmin", "Rate limiter initialized"); rate_limiter } ``` ## WebSocket 中间件 (ws.rs) ### 功能特性 - WebSocket 连接管理 - 实时消息推送 - 连接认证 - 心跳检测 - 广播支持 ### 实现代码 ```rust use axum::{ extract::{ws::WebSocket, Query, State, WebSocketUpgrade}, response::Response, }; use futures_util::{sink::SinkExt, stream::StreamExt}; use serde::{Deserialize, Serialize}; use std::{ collections::HashMap, sync::{Arc, RwLock}, time::{Duration, Instant}, }; use tokio::sync::mpsc; use tracing::{error, info, warn}; use uuid::Uuid; use crate::{ error::AppError, services::user_service, AppState, }; /// WebSocket 连接信息 #[derive(Debug, Clone)] pub struct ConnectionInfo { pub id: String, pub user_id: Option, pub connected_at: Instant, pub last_ping: Instant, pub sender: mpsc::UnboundedSender, } /// WebSocket 消息类型 #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "type", content = "data")] pub enum WebSocketMessage { /// 认证消息 Auth { token: String }, /// 认证成功 AuthSuccess { user_id: String }, /// 认证失败 AuthError { message: String }, /// 心跳包 Ping, /// 心跳响应 Pong, /// 流程执行状态更新 FlowExecutionUpdate { execution_id: String, status: String, progress: Option, message: Option, }, /// 任务执行状态更新 JobExecutionUpdate { job_id: String, execution_id: String, status: String, message: Option, }, /// 系统通知 SystemNotification { title: String, message: String, level: String, }, /// 用户通知 UserNotification { id: String, title: String, message: String, created_at: String, }, /// 错误消息 Error { message: String }, } /// WebSocket 连接管理器 #[derive(Debug)] pub struct WebSocketManager { connections: Arc>>, user_connections: Arc>>>, // user_id -> connection_ids } impl WebSocketManager { pub fn new() -> Self { Self { connections: Arc::new(RwLock::new(HashMap::new())), user_connections: Arc::new(RwLock::new(HashMap::new())), } } /// 添加连接 pub fn add_connection(&self, connection_info: ConnectionInfo) { let connection_id = connection_info.id.clone(); let user_id = connection_info.user_id.clone(); // 添加到连接列表 self.connections.write().unwrap().insert(connection_id.clone(), connection_info); // 如果有用户ID,添加到用户连接映射 if let Some(user_id) = user_id { self.user_connections .write() .unwrap() .entry(user_id) .or_insert_with(Vec::new) .push(connection_id.clone()); } info!( target = "udmin", connection_id = %connection_id, "WebSocket connection added" ); } /// 移除连接 pub fn remove_connection(&self, connection_id: &str) { let mut connections = self.connections.write().unwrap(); if let Some(connection_info) = connections.remove(connection_id) { // 从用户连接映射中移除 if let Some(user_id) = &connection_info.user_id { let mut user_connections = self.user_connections.write().unwrap(); if let Some(user_conn_list) = user_connections.get_mut(user_id) { user_conn_list.retain(|id| id != connection_id); if user_conn_list.is_empty() { user_connections.remove(user_id); } } } info!( target = "udmin", connection_id = %connection_id, user_id = ?connection_info.user_id, "WebSocket connection removed" ); } } /// 向指定连接发送消息 pub fn send_to_connection(&self, connection_id: &str, message: WebSocketMessage) -> bool { let connections = self.connections.read().unwrap(); if let Some(connection_info) = connections.get(connection_id) { match connection_info.sender.send(message) { Ok(_) => true, Err(e) => { warn!( target = "udmin", connection_id = %connection_id, error = %e, "Failed to send message to connection" ); false } } } else { false } } /// 向指定用户的所有连接发送消息 pub fn send_to_user(&self, user_id: &str, message: WebSocketMessage) -> usize { let user_connections = self.user_connections.read().unwrap(); let mut sent_count = 0; if let Some(connection_ids) = user_connections.get(user_id) { for connection_id in connection_ids { if self.send_to_connection(connection_id, message.clone()) { sent_count += 1; } } } info!( target = "udmin", user_id = %user_id, sent_count = %sent_count, "Message sent to user connections" ); sent_count } /// 广播消息给所有连接 pub fn broadcast(&self, message: WebSocketMessage) -> usize { let connections = self.connections.read().unwrap(); let mut sent_count = 0; for (connection_id, connection_info) in connections.iter() { match connection_info.sender.send(message.clone()) { Ok(_) => sent_count += 1, Err(e) => { warn!( target = "udmin", connection_id = %connection_id, error = %e, "Failed to broadcast message" ); } } } info!( target = "udmin", sent_count = %sent_count, total_connections = %connections.len(), "Message broadcasted" ); sent_count } /// 获取连接统计信息 pub fn get_stats(&self) -> WebSocketStats { let connections = self.connections.read().unwrap(); let user_connections = self.user_connections.read().unwrap(); WebSocketStats { total_connections: connections.len(), authenticated_connections: connections.values().filter(|c| c.user_id.is_some()).count(), unique_users: user_connections.len(), } } /// 清理过期连接 pub fn cleanup_stale_connections(&self, timeout: Duration) { let now = Instant::now(); let connections = self.connections.read().unwrap(); let stale_connections: Vec = connections .iter() .filter(|(_, info)| now.duration_since(info.last_ping) > timeout) .map(|(id, _)| id.clone()) .collect(); drop(connections); for connection_id in stale_connections { self.remove_connection(&connection_id); warn!( target = "udmin", connection_id = %connection_id, "Removed stale WebSocket connection" ); } } } /// WebSocket 统计信息 #[derive(Debug, Serialize)] pub struct WebSocketStats { pub total_connections: usize, pub authenticated_connections: usize, pub unique_users: usize, } /// WebSocket 查询参数 #[derive(Debug, Deserialize)] pub struct WebSocketQuery { pub token: Option, } /// WebSocket 升级处理器 pub async fn websocket_handler( ws: WebSocketUpgrade, Query(params): Query, State(state): State, ) -> Response { ws.on_upgrade(move |socket| handle_websocket(socket, params.token, state)) } /// 处理 WebSocket 连接 async fn handle_websocket( socket: WebSocket, token: Option, state: AppState, ) { let connection_id = Uuid::new_v4().to_string(); let (mut sender, mut receiver) = socket.split(); let (tx, mut rx) = mpsc::unbounded_channel::(); info!( target = "udmin", connection_id = %connection_id, "WebSocket connection established" ); // 创建连接信息 let connection_info = ConnectionInfo { id: connection_id.clone(), user_id: None, connected_at: Instant::now(), last_ping: Instant::now(), sender: tx, }; // 添加到连接管理器 state.ws_manager.add_connection(connection_info); // 如果提供了token,尝试认证 let mut authenticated_user_id = None; if let Some(token) = token { match authenticate_websocket_token(&token, &state).await { Ok(user_id) => { authenticated_user_id = Some(user_id.clone()); // 更新连接信息 if let Some(mut connection_info) = state.connections.write().unwrap().get_mut(&connection_id) { connection_info.user_id = Some(user_id.clone()); } // 发送认证成功消息 let _ = sender.send(axum::extract::ws::Message::Text( serde_json::to_string(&WebSocketMessage::AuthSuccess { user_id }).unwrap() )).await; info!( target = "udmin", connection_id = %connection_id, user_id = %authenticated_user_id.as_ref().unwrap(), "WebSocket connection authenticated" ); } Err(e) => { warn!( target = "udmin", connection_id = %connection_id, error = %e, "WebSocket authentication failed" ); let _ = sender.send(axum::extract::ws::Message::Text( serde_json::to_string(&WebSocketMessage::AuthError { message: "认证失败".to_string() }).unwrap() )).await; } } } // 启动消息发送任务 let connection_id_clone = connection_id.clone(); let ws_manager_clone = state.ws_manager.clone(); let send_task = tokio::spawn(async move { while let Some(message) = rx.recv().await { let text = match serde_json::to_string(&message) { Ok(text) => text, Err(e) => { error!( target = "udmin", connection_id = %connection_id_clone, error = %e, "Failed to serialize WebSocket message" ); continue; } }; if sender.send(axum::extract::ws::Message::Text(text)).await.is_err() { break; } } // 移除连接 ws_manager_clone.remove_connection(&connection_id_clone); }); // 启动心跳任务 let connection_id_clone = connection_id.clone(); let ws_manager_clone = state.ws_manager.clone(); let heartbeat_task = tokio::spawn(async move { let mut interval = tokio::time::interval(Duration::from_secs(30)); loop { interval.tick().await; if !ws_manager_clone.send_to_connection( &connection_id_clone, WebSocketMessage::Ping, ) { break; } } }); // 处理接收到的消息 while let Some(msg) = receiver.next().await { match msg { Ok(axum::extract::ws::Message::Text(text)) => { match serde_json::from_str::(&text) { Ok(WebSocketMessage::Auth { token }) => { // 处理认证消息 match authenticate_websocket_token(&token, &state).await { Ok(user_id) => { authenticated_user_id = Some(user_id.clone()); let _ = state.ws_manager.send_to_connection( &connection_id, WebSocketMessage::AuthSuccess { user_id }, ); } Err(_) => { let _ = state.ws_manager.send_to_connection( &connection_id, WebSocketMessage::AuthError { message: "认证失败".to_string(), }, ); } } } Ok(WebSocketMessage::Pong) => { // 更新最后ping时间 // 这里可以更新连接的last_ping时间 } Ok(_) => { // 处理其他消息类型 } Err(e) => { warn!( target = "udmin", connection_id = %connection_id, error = %e, "Failed to parse WebSocket message" ); } } } Ok(axum::extract::ws::Message::Close(_)) => { info!( target = "udmin", connection_id = %connection_id, "WebSocket connection closed by client" ); break; } Err(e) => { error!( target = "udmin", connection_id = %connection_id, error = %e, "WebSocket error" ); break; } _ => {} } } // 清理任务 send_task.abort(); heartbeat_task.abort(); state.ws_manager.remove_connection(&connection_id); info!( target = "udmin", connection_id = %connection_id, "WebSocket connection closed" ); } /// 认证 WebSocket Token async fn authenticate_websocket_token( token: &str, state: &AppState, ) -> Result { use jsonwebtoken::{decode, DecodingKey, Validation}; use crate::middlewares::auth::Claims; // 验证 JWT Token let claims = decode::( token, &DecodingKey::from_secret(state.config.jwt_secret.as_ref()), &Validation::default(), ) .map_err(|_| AppError::Unauthorized("无效的认证令牌".to_string()))? .claims; // 检查用户是否存在且活跃 let user = user_service::find_by_id(&state.db, &claims.sub) .await .map_err(|_| AppError::Unauthorized("用户不存在".to_string()))?; if user.status != crate::models::user::UserStatus::Active { return Err(AppError::Unauthorized("用户账户已被禁用".to_string())); } Ok(claims.sub) } /// 启动 WebSocket 管理器清理任务 pub fn start_websocket_cleanup_task(ws_manager: Arc) { tokio::spawn(async move { let mut interval = tokio::time::interval(Duration::from_secs(60)); loop { interval.tick().await; ws_manager.cleanup_stale_connections(Duration::from_secs(300)); // 5分钟超时 } }); } ``` ## SSE 中间件 (sse.rs) ### 功能特性 - Server-Sent Events 支持 - 实时事件推送 - 连接管理 - 事件过滤 - 重连支持 ### 实现代码 ```rust use axum::{ extract::{Query, State}, http::{header, HeaderValue, StatusCode}, response::{sse::Event, Response, Sse}, }; use futures_util::stream::{self, Stream}; use serde::{Deserialize, Serialize}; use std::{ collections::HashMap, convert::Infallible, sync::{Arc, RwLock}, time::{Duration, SystemTime, UNIX_EPOCH}, }; use tokio::sync::mpsc; use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::{error, info, warn}; use uuid::Uuid; use crate::{ error::AppError, middlewares::auth::AuthContext, AppState, }; /// SSE 事件类型 #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "type", content = "data")] pub enum SseEvent { /// 流程执行状态更新 FlowExecutionUpdate { execution_id: String, flow_id: String, status: String, progress: Option, message: Option, timestamp: u64, }, /// 任务执行状态更新 JobExecutionUpdate { job_id: String, execution_id: String, status: String, message: Option, timestamp: u64, }, /// 系统通知 SystemNotification { id: String, title: String, message: String, level: String, timestamp: u64, }, /// 用户通知 UserNotification { id: String, title: String, message: String, timestamp: u64, }, /// 系统状态更新 SystemStatus { cpu_usage: f64, memory_usage: f64, disk_usage: f64, active_connections: usize, timestamp: u64, }, /// 心跳事件 Heartbeat { timestamp: u64, }, } /// SSE 连接信息 #[derive(Debug)] struct SseConnection { id: String, user_id: Option, sender: mpsc::UnboundedSender>, filters: Vec, created_at: SystemTime, } /// SSE 连接管理器 #[derive(Debug)] pub struct SseManager { connections: Arc>>, user_connections: Arc>>>, } impl SseManager { pub fn new() -> Self { Self { connections: Arc::new(RwLock::new(HashMap::new())), user_connections: Arc, sender: mpsc::UnboundedSender>, filters: Vec, ) { let connection = SseConnection { id: connection_id.clone(), user_id: user_id.clone(), sender, filters, created_at: SystemTime::now(), }; // 添加到连接列表 self.connections.write().unwrap().insert(connection_id.clone(), connection); // 如果有用户ID,添加到用户连接映射 if let Some(user_id) = user_id { self.user_connections .write() .unwrap() .entry(user_id) .or_insert_with(Vec::new) .push(connection_id.clone()); } info!( target = "udmin", connection_id = %connection_id, "SSE connection added" ); } /// 移除 SSE 连接 pub fn remove_connection(&self, connection_id: &str) { let mut connections = self.connections.write().unwrap(); if let Some(connection) = connections.remove(connection_id) { // 从用户连接映射中移除 if let Some(user_id) = &connection.user_id { let mut user_connections = self.user_connections.write().unwrap(); if let Some(user_conn_list) = user_connections.get_mut(user_id) { user_conn_list.retain(|id| id != connection_id); if user_conn_list.is_empty() { user_connections.remove(user_id); } } } info!( target = "udmin", connection_id = %connection_id, user_id = ?connection.user_id, "SSE connection removed" ); } } /// 发送事件到指定连接 pub fn send_to_connection(&self, connection_id: &str, event: SseEvent) -> bool { let connections = self.connections.read().unwrap(); if let Some(connection) = connections.get(connection_id) { // 检查事件过滤器 if !connection.filters.is_empty() { let event_type = match &event { SseEvent::FlowExecutionUpdate { .. } => "flow_execution", SseEvent::JobExecutionUpdate { .. } => "job_execution", SseEvent::SystemNotification { .. } => "system_notification", SseEvent::UserNotification { .. } => "user_notification", SseEvent::SystemStatus { .. } => "system_status", SseEvent::Heartbeat { .. } => "heartbeat", }; if !connection.filters.contains(&event_type.to_string()) { return true; // 过滤掉,但不算失败 } } let sse_event = match create_sse_event(&event) { Ok(event) => event, Err(e) => { error!( target = "udmin", connection_id = %connection_id, error = %e, "Failed to create SSE event" ); return false; } }; match connection.sender.send(Ok(sse_event)) { Ok(_) => true, Err(e) => { warn!( target = "udmin", connection_id = %connection_id, error = %e, "Failed to send SSE event" ); false } } } else { false } } /// 发送事件到指定用户的所有连接 pub fn send_to_user(&self, user_id: &str, event: SseEvent) -> usize { let user_connections = self.user_connections.read().unwrap(); let mut sent_count = 0; if let Some(connection_ids) = user_connections.get(user_id) { for connection_id in connection_ids { if self.send_to_connection(connection_id, event.clone()) { sent_count += 1; } } } info!( target = "udmin", user_id = %user_id, sent_count = %sent_count, "SSE event sent to user connections" ); sent_count } /// 广播事件到所有连接 pub fn broadcast(&self, event: SseEvent) -> usize { let connections = self.connections.read().unwrap(); let mut sent_count = 0; for (connection_id, _) in connections.iter() { if self.send_to_connection(connection_id, event.clone()) { sent_count += 1; } } info!( target = "udmin", sent_count = %sent_count, total_connections = %connections.len(), "SSE event broadcasted" ); sent_count } /// 获取连接统计信息 pub fn get_stats(&self) -> SseStats { let connections = self.connections.read().unwrap(); let user_connections = self.user_connections.read().unwrap(); SseStats { total_connections: connections.len(), authenticated_connections: connections.values().filter(|c| c.user_id.is_some()).count(), unique_users: user_connections.len(), } } } /// SSE 统计信息 #[derive(Debug, Serialize)] pub struct SseStats { pub total_connections: usize, pub authenticated_connections: usize, pub unique_users: usize, } /// SSE 查询参数 #[derive(Debug, Deserialize)] pub struct SseQuery { pub filters: Option, // 逗号分隔的事件类型过滤器 pub last_event_id: Option, } /// SSE 处理器 pub async fn sse_handler( Query(params): Query, State(state): State, auth_context: Option, ) -> Result { let connection_id = Uuid::new_v4().to_string(); let (tx, rx) = mpsc::unbounded_channel(); // 解析过滤器 let filters = params .filters .map(|f| f.split(',').map(|s| s.trim().to_string()).collect()) .unwrap_or_default(); // 添加连接到管理器 state.sse_manager.add_connection( connection_id.clone(), auth_context.as_ref().map(|ctx| ctx.user_id.clone()), tx, filters, ); info!( target = "udmin", connection_id = %connection_id, user_id = ?auth_context.as_ref().map(|ctx| &ctx.user_id), "SSE connection established" ); // 发送初始心跳 let initial_heartbeat = SseEvent::Heartbeat { timestamp: SystemTime::now() .duration_since(UNIX_EPOCH) .unwrap() .as_secs(), }; state.sse_manager.send_to_connection(&connection_id, initial_heartbeat); // 创建事件流 let stream = UnboundedReceiverStream::new(rx); // 启动心跳任务 let connection_id_clone = connection_id.clone(); let sse_manager_clone = state.sse_manager.clone(); tokio::spawn(async move { let mut interval = tokio::time::interval(Duration::from_secs(30)); loop { interval.tick().await; let heartbeat = SseEvent::Heartbeat { timestamp: SystemTime::now() .duration_since(UNIX_EPOCH) .unwrap() .as_secs(), }; if !sse_manager_clone.send_to_connection(&connection_id_clone, heartbeat) { break; } } // 清理连接 sse_manager_clone.remove_connection(&connection_id_clone); }); let sse = Sse::new(stream).keep_alive( axum::response::sse::KeepAlive::new() .interval(Duration::from_secs(15)) .text("keep-alive-text"), ); Ok(sse.into_response()) } /// 创建 SSE 事件 fn create_sse_event(event: &SseEvent) -> Result { let event_type = match event { SseEvent::FlowExecutionUpdate { .. } => "flow_execution_update", SseEvent::JobExecutionUpdate { .. } => "job_execution_update", SseEvent::SystemNotification { .. } => "system_notification", SseEvent::UserNotification { .. } => "user_notification", SseEvent::SystemStatus { .. } => "system_status", SseEvent::Heartbeat { .. } => "heartbeat", }; let data = serde_json::to_string(event)?; Ok(Event::default() .event(event_type) .data(data)) } /// 启动 SSE 管理器清理任务 pub fn start_sse_cleanup_task(sse_manager: Arc) { tokio::spawn(async move { let mut interval = tokio::time::interval(Duration::from_secs(300)); // 5分钟清理一次 loop { interval.tick().await; // 这里可以添加清理逻辑,比如移除长时间未活动的连接 let stats = sse_manager.get_stats(); info!( target = "udmin", total_connections = %stats.total_connections, authenticated_connections = %stats.authenticated_connections, unique_users = %stats.unique_users, "SSE connection stats" ); } }); } ``` ## 中间件组合和配置 ### 中间件栈配置 ```rust use axum::{ middleware, Router, }; use std::sync::Arc; use tower::ServiceBuilder; use tower_http::{ compression::CompressionLayer, timeout::TimeoutLayer, trace::TraceLayer, }; use crate::{ middlewares::{ auth::{jwt_auth_middleware, optional_auth_middleware}, cors::create_cors_layer, logging::{request_logging_middleware, performance_monitoring_middleware}, rate_limit::{create_rate_limit_middleware, ip_rate_limit_middleware, user_rate_limit_middleware}, }, AppState, }; /// 创建中间件栈 pub fn create_middleware_stack(state: AppState) -> ServiceBuilder< impl tower::Layer< axum::routing::Router, Service = impl tower::Service< http::Request, Response = axum::response::Response, Error = std::convert::Infallible, > + Clone + Send + 'static, > + Clone, > { // 创建限流器 let rate_limiter = create_rate_limit_middleware(Default::default()); ServiceBuilder::new() // 请求追踪 .layer(TraceLayer::new_for_http()) // 请求超时 .layer(TimeoutLayer::new(std::time::Duration::from_secs(30))) // 响应压缩 .layer(CompressionLayer::new()) // CORS .layer(create_cors_layer(Default::default())) // 请求日志 .layer(middleware::from_fn(request_logging_middleware)) // 性能监控 .layer(middleware::from_fn(performance_monitoring_middleware)) // IP 限流 .layer(middleware::from_fn_with_state( Arc::clone(&rate_limiter), ip_rate_limit_middleware, )) // 用户限流(需要在认证之后) .layer(middleware::from_fn_with_state( Arc::clone(&rate_limiter), user_rate_limit_middleware, )) } /// 为需要认证的路由创建中间件栈 pub fn create_auth_middleware_stack(state: AppState) -> ServiceBuilder< impl tower::Layer< axum::routing::Router, Service = impl tower::Service< http::Request, Response = axum::response::Response, Error = std::convert::Infallible, > + Clone + Send + 'static, > + Clone, > { ServiceBuilder::new() .layer(middleware::from_fn_with_state( state.clone(), jwt_auth_middleware, )) } /// 为可选认证的路由创建中间件栈 pub fn create_optional_auth_middleware_stack(state: AppState) -> ServiceBuilder< impl tower::Layer< axum::routing::Router, Service = impl tower::Service< http::Request, Response = axum::response::Response, Error = std::convert::Infallible, > + Clone + Send + 'static, > + Clone, > { ServiceBuilder::new() .layer(middleware::from_fn_with_state( state.clone(), optional_auth_middleware, )) } ``` ## 测试支持 ### 中间件测试 ```rust #[cfg(test)] mod tests { use super::*; use axum::{ body::Body, http::{Request, StatusCode}, }; use tower::ServiceExt; #[tokio::test] async fn test_cors_middleware() { let app = Router::new() .route("/test", axum::routing::get(|| async { "OK" })) .layer(create_cors_layer(Default::default())); let request = Request::builder() .method("OPTIONS") .uri("/test") .header("Origin", "http://localhost:3000") .body(Body::empty()) .unwrap(); let response = app.oneshot(request).await.unwrap(); assert_eq!(response.status(), StatusCode::OK); assert!(response.headers().contains_key("access-control-allow-origin")); } #[tokio::test] async fn test_rate_limit_middleware() { let rate_limiter = Arc::new(RateLimiter::new(1, 1)); // 1 token, 1 per second let app = Router::new() .route("/test", axum::routing::get(|| async { "OK" })) .layer(middleware::from_fn_with_state( rate_limiter, ip_rate_limit_middleware, )); // 第一个请求应该成功 let request1 = Request::builder() .uri("/test") .body(Body::empty()) .unwrap(); let response1 = app.clone().oneshot(request1).await.unwrap(); assert_eq!(response1.status(), StatusCode::OK); // 第二个请求应该被限流 let request2 = Request::builder() .uri("/test") .body(Body::empty()) .unwrap(); let response2 = app.oneshot(request2).await.unwrap(); assert_eq!(response2.status(), StatusCode::TOO_MANY_REQUESTS); } } ``` ## 最佳实践 ### 中间件设计 - **单一职责**: 每个中间件只负责一个特定功能 - **可组合性**: 中间件可以灵活组合使用 - **性能优化**: 最小化中间件的性能开销 - **错误处理**: 优雅地处理中间件中的错误 - **日志记录**: 记录关键操作和错误信息 ### 安全考虑 - **认证验证**: 严格验证用户身份 - **权限检查**: 确保用户有足够的权限 - **输入验证**: 验证所有输入数据 - **限流保护**: 防止恶意请求和DDoS攻击 - **CORS配置**: 正确配置跨域资源共享 ### 性能优化 - **缓存策略**: 缓存认证结果和权限信息 - **连接池**: 合理配置数据库连接池 - **异步处理**: 使用异步操作提高并发性能 - **资源清理**: 及时清理过期的连接和资源 ## 总结 中间件层是 UdminAI 系统的重要组成部分,提供了认证授权、请求日志、错误处理、CORS、WebSocket、SSE、限流等核心功能。通过模块化设计和灵活的组合机制,中间件层为系统提供了强大的横切关注点支持,确保了系统的安全性、可靠性和性能。 主要特点: - **模块化设计**: 每个中间件独立实现,职责单一 - **灵活组合**: 支持灵活的中间件组合和配置 - **高性能**: 优化的实现确保最小的性能开销 - **类型安全**: 利用 Rust 类型系统确保安全性 - **实时通信**: 支持 WebSocket 和 SSE 实时通信 - **安全防护**: 提供认证、授权、限流等安全机制 通过这套完整的中间件系统,UdminAI 能够提供安全、高效、可靠的 Web 服务,满足企业级应用的各种需求。 ```