新增以下文档文件: - PROJECT_OVERVIEW.md 项目总览文档 - BACKEND_ARCHITECTURE.md 后端架构文档 - FRONTEND_ARCHITECTURE.md 前端架构文档 - FLOW_ENGINE.md 流程引擎文档 - SERVICES.md 服务层文档 - ERROR_HANDLING.md 错误处理模块文档 文档内容涵盖项目整体介绍、技术架构、核心模块设计和实现细节
64 KiB
64 KiB
中间件文档
概述
中间件层是 UdminAI 系统的核心组件,基于 Axum 框架实现,提供了认证授权、请求日志、错误处理、CORS、WebSocket 支持、SSE(Server-Sent Events)等功能。中间件采用洋葱模型,按顺序处理请求和响应。
架构设计
中间件模块结构
middlewares/
├── mod.rs # 中间件模块导出
├── auth.rs # 认证中间件
├── cors.rs # CORS 中间件
├── http_client.rs # HTTP 客户端中间件
├── logging.rs # 请求日志中间件
├── rate_limit.rs # 限流中间件
├── sse.rs # Server-Sent Events 中间件
└── ws.rs # WebSocket 中间件
设计原则
- 模块化: 每个中间件独立实现,职责单一
- 可组合: 中间件可以灵活组合使用
- 高性能: 最小化性能开销
- 类型安全: 利用 Rust 类型系统确保安全
- 可配置: 支持灵活的配置选项
- 可测试: 易于单元测试和集成测试
认证中间件 (auth.rs)
功能特性
- JWT Token 验证
- 用户身份识别
- 权限检查
- Token 刷新机制
- 多种认证策略
实现代码
use axum::{
extract::{Request, State},
http::{header, StatusCode},
middleware::Next,
response::Response,
};
use jsonwebtoken::{decode, DecodingKey, Validation};
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use tracing::{error, info, warn};
use crate::{
error::AppError,
models::user,
services::user_service,
AppState,
};
/// JWT Claims 结构
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Claims {
pub sub: String, // 用户ID
pub username: String, // 用户名
pub roles: Vec<String>, // 角色列表
pub permissions: Vec<String>, // 权限列表
pub exp: usize, // 过期时间
pub iat: usize, // 签发时间
pub iss: String, // 签发者
}
/// 认证上下文
#[derive(Debug, Clone)]
pub struct AuthContext {
pub user_id: String,
pub username: String,
pub roles: Vec<String>,
pub permissions: HashSet<String>,
}
/// JWT 认证中间件
pub async fn jwt_auth_middleware(
State(state): State<AppState>,
mut request: Request,
next: Next,
) -> Result<Response, AppError> {
let auth_header = request
.headers()
.get(header::AUTHORIZATION)
.and_then(|header| header.to_str().ok());
let token = match auth_header {
Some(header) if header.starts_with("Bearer ") => {
header.trim_start_matches("Bearer ")
}
_ => {
warn!(target = "udmin", "Missing or invalid authorization header");
return Err(AppError::Unauthorized("缺少认证令牌".to_string()));
}
};
// 验证 JWT Token
let claims = match decode::<Claims>(
token,
&DecodingKey::from_secret(state.config.jwt_secret.as_ref()),
&Validation::default(),
) {
Ok(token_data) => token_data.claims,
Err(err) => {
error!(target = "udmin", error = %err, "JWT token validation failed");
return Err(AppError::Unauthorized("无效的认证令牌".to_string()));
}
};
// 检查用户是否存在且活跃
let user = user_service::find_by_id(&state.db, &claims.sub)
.await
.map_err(|_| AppError::Unauthorized("用户不存在".to_string()))?;
if user.status != user::UserStatus::Active {
warn!(target = "udmin", user_id = %claims.sub, "Inactive user attempted access");
return Err(AppError::Unauthorized("用户账户已被禁用".to_string()));
}
// 创建认证上下文
let auth_context = AuthContext {
user_id: claims.sub.clone(),
username: claims.username.clone(),
roles: claims.roles.clone(),
permissions: claims.permissions.into_iter().collect(),
};
// 将认证上下文添加到请求扩展中
request.extensions_mut().insert(auth_context);
info!(
target = "udmin",
user_id = %claims.sub,
username = %claims.username,
"User authenticated successfully"
);
Ok(next.run(request).await)
}
/// 权限检查中间件
pub fn require_permission(required_permission: &'static str) -> impl Fn(Request, Next) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Response, AppError>> + Send>> + Clone {
move |request: Request, next: Next| {
Box::pin(async move {
let auth_context = request
.extensions()
.get::<AuthContext>()
.ok_or_else(|| AppError::Unauthorized("未认证的请求".to_string()))?;
if !auth_context.permissions.contains(required_permission) {
warn!(
target = "udmin",
user_id = %auth_context.user_id,
required_permission = %required_permission,
"Permission denied"
);
return Err(AppError::Forbidden("权限不足".to_string()));
}
info!(
target = "udmin",
user_id = %auth_context.user_id,
permission = %required_permission,
"Permission granted"
);
Ok(next.run(request).await)
})
}
}
/// 角色检查中间件
pub fn require_role(required_role: &'static str) -> impl Fn(Request, Next) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Response, AppError>> + Send>> + Clone {
move |request: Request, next: Next| {
Box::pin(async move {
let auth_context = request
.extensions()
.get::<AuthContext>()
.ok_or_else(|| AppError::Unauthorized("未认证的请求".to_string()))?;
if !auth_context.roles.contains(&required_role.to_string()) {
warn!(
target = "udmin",
user_id = %auth_context.user_id,
required_role = %required_role,
"Role check failed"
);
return Err(AppError::Forbidden("角色权限不足".to_string()));
}
info!(
target = "udmin",
user_id = %auth_context.user_id,
role = %required_role,
"Role check passed"
);
Ok(next.run(request).await)
})
}
}
/// 可选认证中间件(不强制要求认证)
pub async fn optional_auth_middleware(
State(state): State<AppState>,
mut request: Request,
next: Next,
) -> Response {
let auth_header = request
.headers()
.get(header::AUTHORIZATION)
.and_then(|header| header.to_str().ok());
if let Some(header) = auth_header {
if let Some(token) = header.strip_prefix("Bearer ") {
if let Ok(token_data) = decode::<Claims>(
token,
&DecodingKey::from_secret(state.config.jwt_secret.as_ref()),
&Validation::default(),
) {
let claims = token_data.claims;
// 创建认证上下文
let auth_context = AuthContext {
user_id: claims.sub.clone(),
username: claims.username.clone(),
roles: claims.roles.clone(),
permissions: claims.permissions.into_iter().collect(),
};
request.extensions_mut().insert(auth_context);
}
}
}
next.run(request).await
}
CORS 中间件 (cors.rs)
功能特性
- 跨域资源共享支持
- 可配置的允许源
- 预检请求处理
- 安全头设置
实现代码
use axum::{
extract::Request,
http::{header, HeaderValue, Method, StatusCode},
middleware::Next,
response::Response,
};
use tower_http::cors::{Any, CorsLayer};
use tracing::info;
/// CORS 配置
#[derive(Debug, Clone)]
pub struct CorsConfig {
pub allowed_origins: Vec<String>,
pub allowed_methods: Vec<Method>,
pub allowed_headers: Vec<String>,
pub max_age: u64,
pub allow_credentials: bool,
}
impl Default for CorsConfig {
fn default() -> Self {
Self {
allowed_origins: vec![
"http://localhost:3000".to_string(),
"http://localhost:5173".to_string(),
"http://127.0.0.1:3000".to_string(),
"http://127.0.0.1:5173".to_string(),
],
allowed_methods: vec![
Method::GET,
Method::POST,
Method::PUT,
Method::DELETE,
Method::PATCH,
Method::OPTIONS,
],
allowed_headers: vec![
"content-type".to_string(),
"authorization".to_string(),
"x-requested-with".to_string(),
"x-api-key".to_string(),
],
max_age: 3600,
allow_credentials: true,
}
}
}
/// 创建 CORS 层
pub fn create_cors_layer(config: CorsConfig) -> CorsLayer {
let mut cors = CorsLayer::new()
.allow_methods(config.allowed_methods)
.allow_headers(
config
.allowed_headers
.iter()
.map(|h| h.parse().unwrap())
.collect::<Vec<_>>(),
)
.max_age(std::time::Duration::from_secs(config.max_age));
// 设置允许的源
if config.allowed_origins.contains(&"*".to_string()) {
cors = cors.allow_origin(Any);
} else {
cors = cors.allow_origin(
config
.allowed_origins
.iter()
.map(|origin| origin.parse().unwrap())
.collect::<Vec<HeaderValue>>(),
);
}
// 设置是否允许凭据
if config.allow_credentials {
cors = cors.allow_credentials(true);
}
info!(target = "udmin", "CORS middleware configured");
cors
}
/// 自定义 CORS 中间件
pub async fn cors_middleware(
request: Request,
next: Next,
) -> Response {
let origin = request
.headers()
.get(header::ORIGIN)
.and_then(|v| v.to_str().ok());
let method = request.method();
// 处理预检请求
if method == Method::OPTIONS {
let mut response = Response::builder()
.status(StatusCode::OK)
.body(axum::body::Body::empty())
.unwrap();
let headers = response.headers_mut();
if let Some(origin) = origin {
headers.insert(
header::ACCESS_CONTROL_ALLOW_ORIGIN,
HeaderValue::from_str(origin).unwrap(),
);
}
headers.insert(
header::ACCESS_CONTROL_ALLOW_METHODS,
HeaderValue::from_static("GET, POST, PUT, DELETE, PATCH, OPTIONS"),
);
headers.insert(
header::ACCESS_CONTROL_ALLOW_HEADERS,
HeaderValue::from_static("content-type, authorization, x-requested-with, x-api-key"),
);
headers.insert(
header::ACCESS_CONTROL_MAX_AGE,
HeaderValue::from_static("3600"),
);
headers.insert(
header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
HeaderValue::from_static("true"),
);
return response;
}
let mut response = next.run(request).await;
// 为实际请求添加 CORS 头
let headers = response.headers_mut();
if let Some(origin) = origin {
headers.insert(
header::ACCESS_CONTROL_ALLOW_ORIGIN,
HeaderValue::from_str(origin).unwrap(),
);
}
headers.insert(
header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
HeaderValue::from_static("true"),
);
response
}
请求日志中间件 (logging.rs)
功能特性
- 请求响应日志记录
- 性能监控
- 错误追踪
- 结构化日志
- 请求ID追踪
实现代码
use axum::{
extract::{MatchedPath, Request},
http::StatusCode,
middleware::Next,
response::Response,
};
use std::time::Instant;
use tracing::{error, info, warn};
use uuid::Uuid;
/// 请求日志中间件
pub async fn request_logging_middleware(
request: Request,
next: Next,
) -> Response {
let start_time = Instant::now();
let request_id = Uuid::new_v4().to_string();
// 提取请求信息
let method = request.method().clone();
let uri = request.uri().clone();
let version = request.version();
let user_agent = request
.headers()
.get("user-agent")
.and_then(|v| v.to_str().ok())
.unwrap_or("unknown");
let remote_addr = request
.headers()
.get("x-forwarded-for")
.and_then(|v| v.to_str().ok())
.or_else(|| {
request
.headers()
.get("x-real-ip")
.and_then(|v| v.to_str().ok())
})
.unwrap_or("unknown");
// 获取匹配的路径模式
let matched_path = request
.extensions()
.get::<MatchedPath>()
.map(|mp| mp.as_str())
.unwrap_or(uri.path());
info!(
target = "udmin",
request_id = %request_id,
method = %method,
uri = %uri,
path = %matched_path,
version = ?version,
user_agent = %user_agent,
remote_addr = %remote_addr,
"Request started"
);
// 执行请求
let response = next.run(request).await;
// 计算请求耗时
let duration = start_time.elapsed();
let status = response.status();
// 根据状态码选择日志级别
match status.as_u16() {
200..=299 => {
info!(
target = "udmin",
request_id = %request_id,
method = %method,
uri = %uri,
path = %matched_path,
status = %status.as_u16(),
duration_ms = %duration.as_millis(),
"Request completed successfully"
);
}
300..=399 => {
info!(
target = "udmin",
request_id = %request_id,
method = %method,
uri = %uri,
path = %matched_path,
status = %status.as_u16(),
duration_ms = %duration.as_millis(),
"Request redirected"
);
}
400..=499 => {
warn!(
target = "udmin",
request_id = %request_id,
method = %method,
uri = %uri,
path = %matched_path,
status = %status.as_u16(),
duration_ms = %duration.as_millis(),
"Client error"
);
}
500..=599 => {
error!(
target = "udmin",
request_id = %request_id,
method = %method,
uri = %uri,
path = %matched_path,
status = %status.as_u16(),
duration_ms = %duration.as_millis(),
"Server error"
);
}
_ => {
info!(
target = "udmin",
request_id = %request_id,
method = %method,
uri = %uri,
path = %matched_path,
status = %status.as_u16(),
duration_ms = %duration.as_millis(),
"Request completed"
);
}
}
response
}
/// 性能监控中间件
pub async fn performance_monitoring_middleware(
request: Request,
next: Next,
) -> Response {
let start_time = Instant::now();
let method = request.method().clone();
let uri = request.uri().clone();
let response = next.run(request).await;
let duration = start_time.elapsed();
// 记录慢请求
if duration.as_millis() > 1000 {
warn!(
target = "udmin",
method = %method,
uri = %uri,
duration_ms = %duration.as_millis(),
"Slow request detected"
);
}
// 记录性能指标
info!(
target = "udmin.performance",
method = %method,
uri = %uri,
status = %response.status().as_u16(),
duration_ms = %duration.as_millis(),
"Performance metrics"
);
response
}
HTTP 客户端中间件 (http_client.rs)
功能特性
- HTTP 客户端封装
- 请求重试机制
- 超时控制
- 请求日志
- 错误处理
实现代码
use reqwest::{Client, ClientBuilder, Response};
use serde::{Deserialize, Serialize};
use std::time::Duration;
use tracing::{error, info, warn};
use url::Url;
use crate::error::AppError;
/// HTTP 客户端配置
#[derive(Debug, Clone)]
pub struct HttpClientConfig {
pub timeout: Duration,
pub connect_timeout: Duration,
pub max_retries: u32,
pub retry_delay: Duration,
pub user_agent: String,
}
impl Default for HttpClientConfig {
fn default() -> Self {
Self {
timeout: Duration::from_secs(30),
connect_timeout: Duration::from_secs(10),
max_retries: 3,
retry_delay: Duration::from_millis(1000),
user_agent: "UdminAI/1.0".to_string(),
}
}
}
/// HTTP 客户端包装器
#[derive(Debug, Clone)]
pub struct HttpClient {
client: Client,
config: HttpClientConfig,
}
impl HttpClient {
/// 创建新的 HTTP 客户端
pub fn new(config: HttpClientConfig) -> Result<Self, AppError> {
let client = ClientBuilder::new()
.timeout(config.timeout)
.connect_timeout(config.connect_timeout)
.user_agent(&config.user_agent)
.build()
.map_err(|e| AppError::InternalServerError(format!("创建HTTP客户端失败: {}", e)))?;
Ok(Self { client, config })
}
/// 发送 GET 请求
pub async fn get(&self, url: &str) -> Result<Response, AppError> {
self.request_with_retry("GET", url, None::<()>).await
}
/// 发送 POST 请求
pub async fn post<T: Serialize>(&self, url: &str, body: &T) -> Result<Response, AppError> {
self.request_with_retry("POST", url, Some(body)).await
}
/// 发送 PUT 请求
pub async fn put<T: Serialize>(&self, url: &str, body: &T) -> Result<Response, AppError> {
self.request_with_retry("PUT", url, Some(body)).await
}
/// 发送 DELETE 请求
pub async fn delete(&self, url: &str) -> Result<Response, AppError> {
self.request_with_retry("DELETE", url, None::<()>).await
}
/// 带重试的请求
async fn request_with_retry<T: Serialize>(
&self,
method: &str,
url: &str,
body: Option<&T>,
) -> Result<Response, AppError> {
let parsed_url = Url::parse(url)
.map_err(|e| AppError::BadRequest(format!("无效的URL: {}", e)))?;
let mut last_error = None;
for attempt in 0..=self.config.max_retries {
let start_time = std::time::Instant::now();
info!(
target = "udmin",
method = %method,
url = %url,
attempt = %attempt,
"HTTP request started"
);
let result = self.make_request(method, &parsed_url, body).await;
let duration = start_time.elapsed();
match result {
Ok(response) => {
let status = response.status();
info!(
target = "udmin",
method = %method,
url = %url,
status = %status.as_u16(),
duration_ms = %duration.as_millis(),
attempt = %attempt,
"HTTP request completed"
);
// 如果是服务器错误且还有重试次数,则重试
if status.is_server_error() && attempt < self.config.max_retries {
warn!(
target = "udmin",
method = %method,
url = %url,
status = %status.as_u16(),
attempt = %attempt,
"Server error, retrying"
);
tokio::time::sleep(self.config.retry_delay).await;
continue;
}
return Ok(response);
}
Err(e) => {
error!(
target = "udmin",
method = %method,
url = %url,
error = %e,
duration_ms = %duration.as_millis(),
attempt = %attempt,
"HTTP request failed"
);
last_error = Some(e);
// 如果还有重试次数,则等待后重试
if attempt < self.config.max_retries {
tokio::time::sleep(self.config.retry_delay).await;
}
}
}
}
Err(last_error.unwrap_or_else(|| {
AppError::InternalServerError("HTTP请求失败".to_string())
}))
}
/// 执行实际的 HTTP 请求
async fn make_request<T: Serialize>(
&self,
method: &str,
url: &Url,
body: Option<&T>,
) -> Result<Response, AppError> {
let mut request_builder = match method {
"GET" => self.client.get(url.clone()),
"POST" => self.client.post(url.clone()),
"PUT" => self.client.put(url.clone()),
"DELETE" => self.client.delete(url.clone()),
_ => {
return Err(AppError::BadRequest(format!(
"不支持的HTTP方法: {}",
method
)))
}
};
// 添加请求体
if let Some(body) = body {
request_builder = request_builder.json(body);
}
// 发送请求
let response = request_builder
.send()
.await
.map_err(|e| AppError::InternalServerError(format!("HTTP请求失败: {}", e)))?;
Ok(response)
}
/// 发送 JSON 请求并解析响应
pub async fn json_request<T, R>(
&self,
method: &str,
url: &str,
body: Option<&T>,
) -> Result<R, AppError>
where
T: Serialize,
R: for<'de> Deserialize<'de>,
{
let response = match method {
"GET" => self.get(url).await?,
"POST" => {
if let Some(body) = body {
self.post(url, body).await?
} else {
return Err(AppError::BadRequest("POST请求需要请求体".to_string()));
}
}
"PUT" => {
if let Some(body) = body {
self.put(url, body).await?
} else {
return Err(AppError::BadRequest("PUT请求需要请求体".to_string()));
}
}
"DELETE" => self.delete(url).await?,
_ => {
return Err(AppError::BadRequest(format!(
"不支持的HTTP方法: {}",
method
)))
}
};
// 检查响应状态
if !response.status().is_success() {
let status = response.status();
let error_text = response
.text()
.await
.unwrap_or_else(|_| "无法读取错误响应".to_string());
return Err(AppError::InternalServerError(format!(
"HTTP请求失败: {} - {}",
status, error_text
)));
}
// 解析 JSON 响应
let json_response = response
.json::<R>()
.await
.map_err(|e| AppError::InternalServerError(format!("解析JSON响应失败: {}", e)))?;
Ok(json_response)
}
}
限流中间件 (rate_limit.rs)
功能特性
- 基于令牌桶的限流
- 支持不同的限流策略
- IP 级别限流
- 用户级别限流
- 动态配置
实现代码
use axum::{
extract::{Request, State},
http::StatusCode,
middleware::Next,
response::Response,
};
use std::{
collections::HashMap,
sync::{Arc, Mutex},
time::{Duration, Instant},
};
use tracing::{info, warn};
use crate::{error::AppError, middlewares::auth::AuthContext};
/// 令牌桶
#[derive(Debug, Clone)]
struct TokenBucket {
capacity: u32,
tokens: u32,
last_refill: Instant,
refill_rate: u32, // tokens per second
}
impl TokenBucket {
fn new(capacity: u32, refill_rate: u32) -> Self {
Self {
capacity,
tokens: capacity,
last_refill: Instant::now(),
refill_rate,
}
}
fn try_consume(&mut self, tokens: u32) -> bool {
self.refill();
if self.tokens >= tokens {
self.tokens -= tokens;
true
} else {
false
}
}
fn refill(&mut self) {
let now = Instant::now();
let elapsed = now.duration_since(self.last_refill);
let tokens_to_add = (elapsed.as_secs_f64() * self.refill_rate as f64) as u32;
if tokens_to_add > 0 {
self.tokens = (self.tokens + tokens_to_add).min(self.capacity);
self.last_refill = now;
}
}
}
/// 限流器
#[derive(Debug)]
pub struct RateLimiter {
buckets: Arc<Mutex<HashMap<String, TokenBucket>>>,
default_capacity: u32,
default_refill_rate: u32,
}
impl RateLimiter {
pub fn new(default_capacity: u32, default_refill_rate: u32) -> Self {
Self {
buckets: Arc::new(Mutex::new(HashMap::new())),
default_capacity,
default_refill_rate,
}
}
pub fn check_rate_limit(&self, key: &str, tokens: u32) -> bool {
let mut buckets = self.buckets.lock().unwrap();
let bucket = buckets
.entry(key.to_string())
.or_insert_with(|| TokenBucket::new(self.default_capacity, self.default_refill_rate));
bucket.try_consume(tokens)
}
/// 清理过期的桶
pub fn cleanup_expired_buckets(&self, max_idle_duration: Duration) {
let mut buckets = self.buckets.lock().unwrap();
let now = Instant::now();
buckets.retain(|_, bucket| {
now.duration_since(bucket.last_refill) < max_idle_duration
});
}
}
/// 限流配置
#[derive(Debug, Clone)]
pub struct RateLimitConfig {
pub requests_per_second: u32,
pub burst_capacity: u32,
pub enable_ip_limit: bool,
pub enable_user_limit: bool,
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
requests_per_second: 100,
burst_capacity: 200,
enable_ip_limit: true,
enable_user_limit: true,
}
}
}
/// IP 限流中间件
pub async fn ip_rate_limit_middleware(
State(rate_limiter): State<Arc<RateLimiter>>,
request: Request,
next: Next,
) -> Result<Response, AppError> {
// 获取客户端 IP
let client_ip = request
.headers()
.get("x-forwarded-for")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.split(',').next())
.or_else(|| {
request
.headers()
.get("x-real-ip")
.and_then(|v| v.to_str().ok())
})
.unwrap_or("unknown")
.trim();
let rate_limit_key = format!("ip:{}", client_ip);
// 检查限流
if !rate_limiter.check_rate_limit(&rate_limit_key, 1) {
warn!(
target = "udmin",
client_ip = %client_ip,
"IP rate limit exceeded"
);
return Err(AppError::TooManyRequests(
"请求过于频繁,请稍后再试".to_string(),
));
}
info!(
target = "udmin",
client_ip = %client_ip,
"IP rate limit check passed"
);
Ok(next.run(request).await)
}
/// 用户限流中间件
pub async fn user_rate_limit_middleware(
State(rate_limiter): State<Arc<RateLimiter>>,
request: Request,
next: Next,
) -> Result<Response, AppError> {
// 获取用户认证信息
if let Some(auth_context) = request.extensions().get::<AuthContext>() {
let rate_limit_key = format!("user:{}", auth_context.user_id);
// 检查用户级别限流
if !rate_limiter.check_rate_limit(&rate_limit_key, 1) {
warn!(
target = "udmin",
user_id = %auth_context.user_id,
"User rate limit exceeded"
);
return Err(AppError::TooManyRequests(
"用户请求过于频繁,请稍后再试".to_string(),
));
}
info!(
target = "udmin",
user_id = %auth_context.user_id,
"User rate limit check passed"
);
}
Ok(next.run(request).await)
}
/// 创建限流中间件
pub fn create_rate_limit_middleware(
config: RateLimitConfig,
) -> Arc<RateLimiter> {
let rate_limiter = Arc::new(RateLimiter::new(
config.burst_capacity,
config.requests_per_second,
));
// 启动清理任务
let cleanup_limiter = rate_limiter.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(300)); // 5分钟清理一次
loop {
interval.tick().await;
cleanup_limiter.cleanup_expired_buckets(Duration::from_secs(3600)); // 清理1小时未使用的桶
}
});
info!(target = "udmin", "Rate limiter initialized");
rate_limiter
}
WebSocket 中间件 (ws.rs)
功能特性
- WebSocket 连接管理
- 实时消息推送
- 连接认证
- 心跳检测
- 广播支持
实现代码
use axum::{
extract::{ws::WebSocket, Query, State, WebSocketUpgrade},
response::Response,
};
use futures_util::{sink::SinkExt, stream::StreamExt};
use serde::{Deserialize, Serialize};
use std::{
collections::HashMap,
sync::{Arc, RwLock},
time::{Duration, Instant},
};
use tokio::sync::mpsc;
use tracing::{error, info, warn};
use uuid::Uuid;
use crate::{
error::AppError,
services::user_service,
AppState,
};
/// WebSocket 连接信息
#[derive(Debug, Clone)]
pub struct ConnectionInfo {
pub id: String,
pub user_id: Option<String>,
pub connected_at: Instant,
pub last_ping: Instant,
pub sender: mpsc::UnboundedSender<WebSocketMessage>,
}
/// WebSocket 消息类型
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", content = "data")]
pub enum WebSocketMessage {
/// 认证消息
Auth { token: String },
/// 认证成功
AuthSuccess { user_id: String },
/// 认证失败
AuthError { message: String },
/// 心跳包
Ping,
/// 心跳响应
Pong,
/// 流程执行状态更新
FlowExecutionUpdate {
execution_id: String,
status: String,
progress: Option<f64>,
message: Option<String>,
},
/// 任务执行状态更新
JobExecutionUpdate {
job_id: String,
execution_id: String,
status: String,
message: Option<String>,
},
/// 系统通知
SystemNotification {
title: String,
message: String,
level: String,
},
/// 用户通知
UserNotification {
id: String,
title: String,
message: String,
created_at: String,
},
/// 错误消息
Error { message: String },
}
/// WebSocket 连接管理器
#[derive(Debug)]
pub struct WebSocketManager {
connections: Arc<RwLock<HashMap<String, ConnectionInfo>>>,
user_connections: Arc<RwLock<HashMap<String, Vec<String>>>>, // user_id -> connection_ids
}
impl WebSocketManager {
pub fn new() -> Self {
Self {
connections: Arc::new(RwLock::new(HashMap::new())),
user_connections: Arc::new(RwLock::new(HashMap::new())),
}
}
/// 添加连接
pub fn add_connection(&self, connection_info: ConnectionInfo) {
let connection_id = connection_info.id.clone();
let user_id = connection_info.user_id.clone();
// 添加到连接列表
self.connections.write().unwrap().insert(connection_id.clone(), connection_info);
// 如果有用户ID,添加到用户连接映射
if let Some(user_id) = user_id {
self.user_connections
.write()
.unwrap()
.entry(user_id)
.or_insert_with(Vec::new)
.push(connection_id.clone());
}
info!(
target = "udmin",
connection_id = %connection_id,
"WebSocket connection added"
);
}
/// 移除连接
pub fn remove_connection(&self, connection_id: &str) {
let mut connections = self.connections.write().unwrap();
if let Some(connection_info) = connections.remove(connection_id) {
// 从用户连接映射中移除
if let Some(user_id) = &connection_info.user_id {
let mut user_connections = self.user_connections.write().unwrap();
if let Some(user_conn_list) = user_connections.get_mut(user_id) {
user_conn_list.retain(|id| id != connection_id);
if user_conn_list.is_empty() {
user_connections.remove(user_id);
}
}
}
info!(
target = "udmin",
connection_id = %connection_id,
user_id = ?connection_info.user_id,
"WebSocket connection removed"
);
}
}
/// 向指定连接发送消息
pub fn send_to_connection(&self, connection_id: &str, message: WebSocketMessage) -> bool {
let connections = self.connections.read().unwrap();
if let Some(connection_info) = connections.get(connection_id) {
match connection_info.sender.send(message) {
Ok(_) => true,
Err(e) => {
warn!(
target = "udmin",
connection_id = %connection_id,
error = %e,
"Failed to send message to connection"
);
false
}
}
} else {
false
}
}
/// 向指定用户的所有连接发送消息
pub fn send_to_user(&self, user_id: &str, message: WebSocketMessage) -> usize {
let user_connections = self.user_connections.read().unwrap();
let mut sent_count = 0;
if let Some(connection_ids) = user_connections.get(user_id) {
for connection_id in connection_ids {
if self.send_to_connection(connection_id, message.clone()) {
sent_count += 1;
}
}
}
info!(
target = "udmin",
user_id = %user_id,
sent_count = %sent_count,
"Message sent to user connections"
);
sent_count
}
/// 广播消息给所有连接
pub fn broadcast(&self, message: WebSocketMessage) -> usize {
let connections = self.connections.read().unwrap();
let mut sent_count = 0;
for (connection_id, connection_info) in connections.iter() {
match connection_info.sender.send(message.clone()) {
Ok(_) => sent_count += 1,
Err(e) => {
warn!(
target = "udmin",
connection_id = %connection_id,
error = %e,
"Failed to broadcast message"
);
}
}
}
info!(
target = "udmin",
sent_count = %sent_count,
total_connections = %connections.len(),
"Message broadcasted"
);
sent_count
}
/// 获取连接统计信息
pub fn get_stats(&self) -> WebSocketStats {
let connections = self.connections.read().unwrap();
let user_connections = self.user_connections.read().unwrap();
WebSocketStats {
total_connections: connections.len(),
authenticated_connections: connections.values().filter(|c| c.user_id.is_some()).count(),
unique_users: user_connections.len(),
}
}
/// 清理过期连接
pub fn cleanup_stale_connections(&self, timeout: Duration) {
let now = Instant::now();
let connections = self.connections.read().unwrap();
let stale_connections: Vec<String> = connections
.iter()
.filter(|(_, info)| now.duration_since(info.last_ping) > timeout)
.map(|(id, _)| id.clone())
.collect();
drop(connections);
for connection_id in stale_connections {
self.remove_connection(&connection_id);
warn!(
target = "udmin",
connection_id = %connection_id,
"Removed stale WebSocket connection"
);
}
}
}
/// WebSocket 统计信息
#[derive(Debug, Serialize)]
pub struct WebSocketStats {
pub total_connections: usize,
pub authenticated_connections: usize,
pub unique_users: usize,
}
/// WebSocket 查询参数
#[derive(Debug, Deserialize)]
pub struct WebSocketQuery {
pub token: Option<String>,
}
/// WebSocket 升级处理器
pub async fn websocket_handler(
ws: WebSocketUpgrade,
Query(params): Query<WebSocketQuery>,
State(state): State<AppState>,
) -> Response {
ws.on_upgrade(move |socket| handle_websocket(socket, params.token, state))
}
/// 处理 WebSocket 连接
async fn handle_websocket(
socket: WebSocket,
token: Option<String>,
state: AppState,
) {
let connection_id = Uuid::new_v4().to_string();
let (mut sender, mut receiver) = socket.split();
let (tx, mut rx) = mpsc::unbounded_channel::<WebSocketMessage>();
info!(
target = "udmin",
connection_id = %connection_id,
"WebSocket connection established"
);
// 创建连接信息
let connection_info = ConnectionInfo {
id: connection_id.clone(),
user_id: None,
connected_at: Instant::now(),
last_ping: Instant::now(),
sender: tx,
};
// 添加到连接管理器
state.ws_manager.add_connection(connection_info);
// 如果提供了token,尝试认证
let mut authenticated_user_id = None;
if let Some(token) = token {
match authenticate_websocket_token(&token, &state).await {
Ok(user_id) => {
authenticated_user_id = Some(user_id.clone());
// 更新连接信息
if let Some(mut connection_info) = state.connections.write().unwrap().get_mut(&connection_id) {
connection_info.user_id = Some(user_id.clone());
}
// 发送认证成功消息
let _ = sender.send(axum::extract::ws::Message::Text(
serde_json::to_string(&WebSocketMessage::AuthSuccess { user_id }).unwrap()
)).await;
info!(
target = "udmin",
connection_id = %connection_id,
user_id = %authenticated_user_id.as_ref().unwrap(),
"WebSocket connection authenticated"
);
}
Err(e) => {
warn!(
target = "udmin",
connection_id = %connection_id,
error = %e,
"WebSocket authentication failed"
);
let _ = sender.send(axum::extract::ws::Message::Text(
serde_json::to_string(&WebSocketMessage::AuthError {
message: "认证失败".to_string()
}).unwrap()
)).await;
}
}
}
// 启动消息发送任务
let connection_id_clone = connection_id.clone();
let ws_manager_clone = state.ws_manager.clone();
let send_task = tokio::spawn(async move {
while let Some(message) = rx.recv().await {
let text = match serde_json::to_string(&message) {
Ok(text) => text,
Err(e) => {
error!(
target = "udmin",
connection_id = %connection_id_clone,
error = %e,
"Failed to serialize WebSocket message"
);
continue;
}
};
if sender.send(axum::extract::ws::Message::Text(text)).await.is_err() {
break;
}
}
// 移除连接
ws_manager_clone.remove_connection(&connection_id_clone);
});
// 启动心跳任务
let connection_id_clone = connection_id.clone();
let ws_manager_clone = state.ws_manager.clone();
let heartbeat_task = tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(30));
loop {
interval.tick().await;
if !ws_manager_clone.send_to_connection(
&connection_id_clone,
WebSocketMessage::Ping,
) {
break;
}
}
});
// 处理接收到的消息
while let Some(msg) = receiver.next().await {
match msg {
Ok(axum::extract::ws::Message::Text(text)) => {
match serde_json::from_str::<WebSocketMessage>(&text) {
Ok(WebSocketMessage::Auth { token }) => {
// 处理认证消息
match authenticate_websocket_token(&token, &state).await {
Ok(user_id) => {
authenticated_user_id = Some(user_id.clone());
let _ = state.ws_manager.send_to_connection(
&connection_id,
WebSocketMessage::AuthSuccess { user_id },
);
}
Err(_) => {
let _ = state.ws_manager.send_to_connection(
&connection_id,
WebSocketMessage::AuthError {
message: "认证失败".to_string(),
},
);
}
}
}
Ok(WebSocketMessage::Pong) => {
// 更新最后ping时间
// 这里可以更新连接的last_ping时间
}
Ok(_) => {
// 处理其他消息类型
}
Err(e) => {
warn!(
target = "udmin",
connection_id = %connection_id,
error = %e,
"Failed to parse WebSocket message"
);
}
}
}
Ok(axum::extract::ws::Message::Close(_)) => {
info!(
target = "udmin",
connection_id = %connection_id,
"WebSocket connection closed by client"
);
break;
}
Err(e) => {
error!(
target = "udmin",
connection_id = %connection_id,
error = %e,
"WebSocket error"
);
break;
}
_ => {}
}
}
// 清理任务
send_task.abort();
heartbeat_task.abort();
state.ws_manager.remove_connection(&connection_id);
info!(
target = "udmin",
connection_id = %connection_id,
"WebSocket connection closed"
);
}
/// 认证 WebSocket Token
async fn authenticate_websocket_token(
token: &str,
state: &AppState,
) -> Result<String, AppError> {
use jsonwebtoken::{decode, DecodingKey, Validation};
use crate::middlewares::auth::Claims;
// 验证 JWT Token
let claims = decode::<Claims>(
token,
&DecodingKey::from_secret(state.config.jwt_secret.as_ref()),
&Validation::default(),
)
.map_err(|_| AppError::Unauthorized("无效的认证令牌".to_string()))?
.claims;
// 检查用户是否存在且活跃
let user = user_service::find_by_id(&state.db, &claims.sub)
.await
.map_err(|_| AppError::Unauthorized("用户不存在".to_string()))?;
if user.status != crate::models::user::UserStatus::Active {
return Err(AppError::Unauthorized("用户账户已被禁用".to_string()));
}
Ok(claims.sub)
}
/// 启动 WebSocket 管理器清理任务
pub fn start_websocket_cleanup_task(ws_manager: Arc<WebSocketManager>) {
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(60));
loop {
interval.tick().await;
ws_manager.cleanup_stale_connections(Duration::from_secs(300)); // 5分钟超时
}
});
}
SSE 中间件 (sse.rs)
功能特性
- Server-Sent Events 支持
- 实时事件推送
- 连接管理
- 事件过滤
- 重连支持
实现代码
use axum::{
extract::{Query, State},
http::{header, HeaderValue, StatusCode},
response::{sse::Event, Response, Sse},
};
use futures_util::stream::{self, Stream};
use serde::{Deserialize, Serialize};
use std::{
collections::HashMap,
convert::Infallible,
sync::{Arc, RwLock},
time::{Duration, SystemTime, UNIX_EPOCH},
};
use tokio::sync::mpsc;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{error, info, warn};
use uuid::Uuid;
use crate::{
error::AppError,
middlewares::auth::AuthContext,
AppState,
};
/// SSE 事件类型
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", content = "data")]
pub enum SseEvent {
/// 流程执行状态更新
FlowExecutionUpdate {
execution_id: String,
flow_id: String,
status: String,
progress: Option<f64>,
message: Option<String>,
timestamp: u64,
},
/// 任务执行状态更新
JobExecutionUpdate {
job_id: String,
execution_id: String,
status: String,
message: Option<String>,
timestamp: u64,
},
/// 系统通知
SystemNotification {
id: String,
title: String,
message: String,
level: String,
timestamp: u64,
},
/// 用户通知
UserNotification {
id: String,
title: String,
message: String,
timestamp: u64,
},
/// 系统状态更新
SystemStatus {
cpu_usage: f64,
memory_usage: f64,
disk_usage: f64,
active_connections: usize,
timestamp: u64,
},
/// 心跳事件
Heartbeat {
timestamp: u64,
},
}
/// SSE 连接信息
#[derive(Debug)]
struct SseConnection {
id: String,
user_id: Option<String>,
sender: mpsc::UnboundedSender<Result<Event, Infallible>>,
filters: Vec<String>,
created_at: SystemTime,
}
/// SSE 连接管理器
#[derive(Debug)]
pub struct SseManager {
connections: Arc<RwLock<HashMap<String, SseConnection>>>,
user_connections: Arc<RwLock<HashMap<String, Vec<String>>>>,
}
impl SseManager {
pub fn new() -> Self {
Self {
connections: Arc::new(RwLock::new(HashMap::new())),
user_connections: Arc<RwLock::new(HashMap::new())),
}
}
/// 添加 SSE 连接
pub fn add_connection(
&self,
connection_id: String,
user_id: Option<String>,
sender: mpsc::UnboundedSender<Result<Event, Infallible>>,
filters: Vec<String>,
) {
let connection = SseConnection {
id: connection_id.clone(),
user_id: user_id.clone(),
sender,
filters,
created_at: SystemTime::now(),
};
// 添加到连接列表
self.connections.write().unwrap().insert(connection_id.clone(), connection);
// 如果有用户ID,添加到用户连接映射
if let Some(user_id) = user_id {
self.user_connections
.write()
.unwrap()
.entry(user_id)
.or_insert_with(Vec::new)
.push(connection_id.clone());
}
info!(
target = "udmin",
connection_id = %connection_id,
"SSE connection added"
);
}
/// 移除 SSE 连接
pub fn remove_connection(&self, connection_id: &str) {
let mut connections = self.connections.write().unwrap();
if let Some(connection) = connections.remove(connection_id) {
// 从用户连接映射中移除
if let Some(user_id) = &connection.user_id {
let mut user_connections = self.user_connections.write().unwrap();
if let Some(user_conn_list) = user_connections.get_mut(user_id) {
user_conn_list.retain(|id| id != connection_id);
if user_conn_list.is_empty() {
user_connections.remove(user_id);
}
}
}
info!(
target = "udmin",
connection_id = %connection_id,
user_id = ?connection.user_id,
"SSE connection removed"
);
}
}
/// 发送事件到指定连接
pub fn send_to_connection(&self, connection_id: &str, event: SseEvent) -> bool {
let connections = self.connections.read().unwrap();
if let Some(connection) = connections.get(connection_id) {
// 检查事件过滤器
if !connection.filters.is_empty() {
let event_type = match &event {
SseEvent::FlowExecutionUpdate { .. } => "flow_execution",
SseEvent::JobExecutionUpdate { .. } => "job_execution",
SseEvent::SystemNotification { .. } => "system_notification",
SseEvent::UserNotification { .. } => "user_notification",
SseEvent::SystemStatus { .. } => "system_status",
SseEvent::Heartbeat { .. } => "heartbeat",
};
if !connection.filters.contains(&event_type.to_string()) {
return true; // 过滤掉,但不算失败
}
}
let sse_event = match create_sse_event(&event) {
Ok(event) => event,
Err(e) => {
error!(
target = "udmin",
connection_id = %connection_id,
error = %e,
"Failed to create SSE event"
);
return false;
}
};
match connection.sender.send(Ok(sse_event)) {
Ok(_) => true,
Err(e) => {
warn!(
target = "udmin",
connection_id = %connection_id,
error = %e,
"Failed to send SSE event"
);
false
}
}
} else {
false
}
}
/// 发送事件到指定用户的所有连接
pub fn send_to_user(&self, user_id: &str, event: SseEvent) -> usize {
let user_connections = self.user_connections.read().unwrap();
let mut sent_count = 0;
if let Some(connection_ids) = user_connections.get(user_id) {
for connection_id in connection_ids {
if self.send_to_connection(connection_id, event.clone()) {
sent_count += 1;
}
}
}
info!(
target = "udmin",
user_id = %user_id,
sent_count = %sent_count,
"SSE event sent to user connections"
);
sent_count
}
/// 广播事件到所有连接
pub fn broadcast(&self, event: SseEvent) -> usize {
let connections = self.connections.read().unwrap();
let mut sent_count = 0;
for (connection_id, _) in connections.iter() {
if self.send_to_connection(connection_id, event.clone()) {
sent_count += 1;
}
}
info!(
target = "udmin",
sent_count = %sent_count,
total_connections = %connections.len(),
"SSE event broadcasted"
);
sent_count
}
/// 获取连接统计信息
pub fn get_stats(&self) -> SseStats {
let connections = self.connections.read().unwrap();
let user_connections = self.user_connections.read().unwrap();
SseStats {
total_connections: connections.len(),
authenticated_connections: connections.values().filter(|c| c.user_id.is_some()).count(),
unique_users: user_connections.len(),
}
}
}
/// SSE 统计信息
#[derive(Debug, Serialize)]
pub struct SseStats {
pub total_connections: usize,
pub authenticated_connections: usize,
pub unique_users: usize,
}
/// SSE 查询参数
#[derive(Debug, Deserialize)]
pub struct SseQuery {
pub filters: Option<String>, // 逗号分隔的事件类型过滤器
pub last_event_id: Option<String>,
}
/// SSE 处理器
pub async fn sse_handler(
Query(params): Query<SseQuery>,
State(state): State<AppState>,
auth_context: Option<AuthContext>,
) -> Result<Response, AppError> {
let connection_id = Uuid::new_v4().to_string();
let (tx, rx) = mpsc::unbounded_channel();
// 解析过滤器
let filters = params
.filters
.map(|f| f.split(',').map(|s| s.trim().to_string()).collect())
.unwrap_or_default();
// 添加连接到管理器
state.sse_manager.add_connection(
connection_id.clone(),
auth_context.as_ref().map(|ctx| ctx.user_id.clone()),
tx,
filters,
);
info!(
target = "udmin",
connection_id = %connection_id,
user_id = ?auth_context.as_ref().map(|ctx| &ctx.user_id),
"SSE connection established"
);
// 发送初始心跳
let initial_heartbeat = SseEvent::Heartbeat {
timestamp: SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs(),
};
state.sse_manager.send_to_connection(&connection_id, initial_heartbeat);
// 创建事件流
let stream = UnboundedReceiverStream::new(rx);
// 启动心跳任务
let connection_id_clone = connection_id.clone();
let sse_manager_clone = state.sse_manager.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(30));
loop {
interval.tick().await;
let heartbeat = SseEvent::Heartbeat {
timestamp: SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs(),
};
if !sse_manager_clone.send_to_connection(&connection_id_clone, heartbeat) {
break;
}
}
// 清理连接
sse_manager_clone.remove_connection(&connection_id_clone);
});
let sse = Sse::new(stream).keep_alive(
axum::response::sse::KeepAlive::new()
.interval(Duration::from_secs(15))
.text("keep-alive-text"),
);
Ok(sse.into_response())
}
/// 创建 SSE 事件
fn create_sse_event(event: &SseEvent) -> Result<Event, serde_json::Error> {
let event_type = match event {
SseEvent::FlowExecutionUpdate { .. } => "flow_execution_update",
SseEvent::JobExecutionUpdate { .. } => "job_execution_update",
SseEvent::SystemNotification { .. } => "system_notification",
SseEvent::UserNotification { .. } => "user_notification",
SseEvent::SystemStatus { .. } => "system_status",
SseEvent::Heartbeat { .. } => "heartbeat",
};
let data = serde_json::to_string(event)?;
Ok(Event::default()
.event(event_type)
.data(data))
}
/// 启动 SSE 管理器清理任务
pub fn start_sse_cleanup_task(sse_manager: Arc<SseManager>) {
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(300)); // 5分钟清理一次
loop {
interval.tick().await;
// 这里可以添加清理逻辑,比如移除长时间未活动的连接
let stats = sse_manager.get_stats();
info!(
target = "udmin",
total_connections = %stats.total_connections,
authenticated_connections = %stats.authenticated_connections,
unique_users = %stats.unique_users,
"SSE connection stats"
);
}
});
}
中间件组合和配置
中间件栈配置
use axum::{
middleware,
Router,
};
use std::sync::Arc;
use tower::ServiceBuilder;
use tower_http::{
compression::CompressionLayer,
timeout::TimeoutLayer,
trace::TraceLayer,
};
use crate::{
middlewares::{
auth::{jwt_auth_middleware, optional_auth_middleware},
cors::create_cors_layer,
logging::{request_logging_middleware, performance_monitoring_middleware},
rate_limit::{create_rate_limit_middleware, ip_rate_limit_middleware, user_rate_limit_middleware},
},
AppState,
};
/// 创建中间件栈
pub fn create_middleware_stack(state: AppState) -> ServiceBuilder<
impl tower::Layer<
axum::routing::Router,
Service = impl tower::Service<
http::Request<axum::body::Body>,
Response = axum::response::Response,
Error = std::convert::Infallible,
> + Clone + Send + 'static,
> + Clone,
> {
// 创建限流器
let rate_limiter = create_rate_limit_middleware(Default::default());
ServiceBuilder::new()
// 请求追踪
.layer(TraceLayer::new_for_http())
// 请求超时
.layer(TimeoutLayer::new(std::time::Duration::from_secs(30)))
// 响应压缩
.layer(CompressionLayer::new())
// CORS
.layer(create_cors_layer(Default::default()))
// 请求日志
.layer(middleware::from_fn(request_logging_middleware))
// 性能监控
.layer(middleware::from_fn(performance_monitoring_middleware))
// IP 限流
.layer(middleware::from_fn_with_state(
Arc::clone(&rate_limiter),
ip_rate_limit_middleware,
))
// 用户限流(需要在认证之后)
.layer(middleware::from_fn_with_state(
Arc::clone(&rate_limiter),
user_rate_limit_middleware,
))
}
/// 为需要认证的路由创建中间件栈
pub fn create_auth_middleware_stack(state: AppState) -> ServiceBuilder<
impl tower::Layer<
axum::routing::Router,
Service = impl tower::Service<
http::Request<axum::body::Body>,
Response = axum::response::Response,
Error = std::convert::Infallible,
> + Clone + Send + 'static,
> + Clone,
> {
ServiceBuilder::new()
.layer(middleware::from_fn_with_state(
state.clone(),
jwt_auth_middleware,
))
}
/// 为可选认证的路由创建中间件栈
pub fn create_optional_auth_middleware_stack(state: AppState) -> ServiceBuilder<
impl tower::Layer<
axum::routing::Router,
Service = impl tower::Service<
http::Request<axum::body::Body>,
Response = axum::response::Response,
Error = std::convert::Infallible,
> + Clone + Send + 'static,
> + Clone,
> {
ServiceBuilder::new()
.layer(middleware::from_fn_with_state(
state.clone(),
optional_auth_middleware,
))
}
测试支持
中间件测试
#[cfg(test)]
mod tests {
use super::*;
use axum::{
body::Body,
http::{Request, StatusCode},
};
use tower::ServiceExt;
#[tokio::test]
async fn test_cors_middleware() {
let app = Router::new()
.route("/test", axum::routing::get(|| async { "OK" }))
.layer(create_cors_layer(Default::default()));
let request = Request::builder()
.method("OPTIONS")
.uri("/test")
.header("Origin", "http://localhost:3000")
.body(Body::empty())
.unwrap();
let response = app.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert!(response.headers().contains_key("access-control-allow-origin"));
}
#[tokio::test]
async fn test_rate_limit_middleware() {
let rate_limiter = Arc::new(RateLimiter::new(1, 1)); // 1 token, 1 per second
let app = Router::new()
.route("/test", axum::routing::get(|| async { "OK" }))
.layer(middleware::from_fn_with_state(
rate_limiter,
ip_rate_limit_middleware,
));
// 第一个请求应该成功
let request1 = Request::builder()
.uri("/test")
.body(Body::empty())
.unwrap();
let response1 = app.clone().oneshot(request1).await.unwrap();
assert_eq!(response1.status(), StatusCode::OK);
// 第二个请求应该被限流
let request2 = Request::builder()
.uri("/test")
.body(Body::empty())
.unwrap();
let response2 = app.oneshot(request2).await.unwrap();
assert_eq!(response2.status(), StatusCode::TOO_MANY_REQUESTS);
}
}
最佳实践
中间件设计
- 单一职责: 每个中间件只负责一个特定功能
- 可组合性: 中间件可以灵活组合使用
- 性能优化: 最小化中间件的性能开销
- 错误处理: 优雅地处理中间件中的错误
- 日志记录: 记录关键操作和错误信息
安全考虑
- 认证验证: 严格验证用户身份
- 权限检查: 确保用户有足够的权限
- 输入验证: 验证所有输入数据
- 限流保护: 防止恶意请求和DDoS攻击
- CORS配置: 正确配置跨域资源共享
性能优化
- 缓存策略: 缓存认证结果和权限信息
- 连接池: 合理配置数据库连接池
- 异步处理: 使用异步操作提高并发性能
- 资源清理: 及时清理过期的连接和资源
总结
中间件层是 UdminAI 系统的重要组成部分,提供了认证授权、请求日志、错误处理、CORS、WebSocket、SSE、限流等核心功能。通过模块化设计和灵活的组合机制,中间件层为系统提供了强大的横切关注点支持,确保了系统的安全性、可靠性和性能。
主要特点:
- 模块化设计: 每个中间件独立实现,职责单一
- 灵活组合: 支持灵活的中间件组合和配置
- 高性能: 优化的实现确保最小的性能开销
- 类型安全: 利用 Rust 类型系统确保安全性
- 实时通信: 支持 WebSocket 和 SSE 实时通信
- 安全防护: 提供认证、授权、限流等安全机制
通过这套完整的中间件系统,UdminAI 能够提供安全、高效、可靠的 Web 服务,满足企业级应用的各种需求。