Files
udmin/docs/MIDDLEWARES.md
ayou a3f2f99a68 docs: 添加项目文档包括总览、架构、流程引擎和服务层
新增以下文档文件:
- PROJECT_OVERVIEW.md 项目总览文档
- BACKEND_ARCHITECTURE.md 后端架构文档
- FRONTEND_ARCHITECTURE.md 前端架构文档
- FLOW_ENGINE.md 流程引擎文档
- SERVICES.md 服务层文档
- ERROR_HANDLING.md 错误处理模块文档

文档内容涵盖项目整体介绍、技术架构、核心模块设计和实现细节
2025-09-24 20:21:45 +08:00

2227 lines
64 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# 中间件文档
## 概述
中间件层是 UdminAI 系统的核心组件,基于 Axum 框架实现提供了认证授权、请求日志、错误处理、CORS、WebSocket 支持、SSEServer-Sent Events等功能。中间件采用洋葱模型按顺序处理请求和响应。
## 架构设计
### 中间件模块结构
```
middlewares/
├── mod.rs # 中间件模块导出
├── auth.rs # 认证中间件
├── cors.rs # CORS 中间件
├── http_client.rs # HTTP 客户端中间件
├── logging.rs # 请求日志中间件
├── rate_limit.rs # 限流中间件
├── sse.rs # Server-Sent Events 中间件
└── ws.rs # WebSocket 中间件
```
### 设计原则
- **模块化**: 每个中间件独立实现,职责单一
- **可组合**: 中间件可以灵活组合使用
- **高性能**: 最小化性能开销
- **类型安全**: 利用 Rust 类型系统确保安全
- **可配置**: 支持灵活的配置选项
- **可测试**: 易于单元测试和集成测试
## 认证中间件 (auth.rs)
### 功能特性
- JWT Token 验证
- 用户身份识别
- 权限检查
- Token 刷新机制
- 多种认证策略
### 实现代码
```rust
use axum::{
extract::{Request, State},
http::{header, StatusCode},
middleware::Next,
response::Response,
};
use jsonwebtoken::{decode, DecodingKey, Validation};
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use tracing::{error, info, warn};
use crate::{
error::AppError,
models::user,
services::user_service,
AppState,
};
/// JWT Claims 结构
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Claims {
pub sub: String, // 用户ID
pub username: String, // 用户名
pub roles: Vec<String>, // 角色列表
pub permissions: Vec<String>, // 权限列表
pub exp: usize, // 过期时间
pub iat: usize, // 签发时间
pub iss: String, // 签发者
}
/// 认证上下文
#[derive(Debug, Clone)]
pub struct AuthContext {
pub user_id: String,
pub username: String,
pub roles: Vec<String>,
pub permissions: HashSet<String>,
}
/// JWT 认证中间件
pub async fn jwt_auth_middleware(
State(state): State<AppState>,
mut request: Request,
next: Next,
) -> Result<Response, AppError> {
let auth_header = request
.headers()
.get(header::AUTHORIZATION)
.and_then(|header| header.to_str().ok());
let token = match auth_header {
Some(header) if header.starts_with("Bearer ") => {
header.trim_start_matches("Bearer ")
}
_ => {
warn!(target = "udmin", "Missing or invalid authorization header");
return Err(AppError::Unauthorized("缺少认证令牌".to_string()));
}
};
// 验证 JWT Token
let claims = match decode::<Claims>(
token,
&DecodingKey::from_secret(state.config.jwt_secret.as_ref()),
&Validation::default(),
) {
Ok(token_data) => token_data.claims,
Err(err) => {
error!(target = "udmin", error = %err, "JWT token validation failed");
return Err(AppError::Unauthorized("无效的认证令牌".to_string()));
}
};
// 检查用户是否存在且活跃
let user = user_service::find_by_id(&state.db, &claims.sub)
.await
.map_err(|_| AppError::Unauthorized("用户不存在".to_string()))?;
if user.status != user::UserStatus::Active {
warn!(target = "udmin", user_id = %claims.sub, "Inactive user attempted access");
return Err(AppError::Unauthorized("用户账户已被禁用".to_string()));
}
// 创建认证上下文
let auth_context = AuthContext {
user_id: claims.sub.clone(),
username: claims.username.clone(),
roles: claims.roles.clone(),
permissions: claims.permissions.into_iter().collect(),
};
// 将认证上下文添加到请求扩展中
request.extensions_mut().insert(auth_context);
info!(
target = "udmin",
user_id = %claims.sub,
username = %claims.username,
"User authenticated successfully"
);
Ok(next.run(request).await)
}
/// 权限检查中间件
pub fn require_permission(required_permission: &'static str) -> impl Fn(Request, Next) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Response, AppError>> + Send>> + Clone {
move |request: Request, next: Next| {
Box::pin(async move {
let auth_context = request
.extensions()
.get::<AuthContext>()
.ok_or_else(|| AppError::Unauthorized("未认证的请求".to_string()))?;
if !auth_context.permissions.contains(required_permission) {
warn!(
target = "udmin",
user_id = %auth_context.user_id,
required_permission = %required_permission,
"Permission denied"
);
return Err(AppError::Forbidden("权限不足".to_string()));
}
info!(
target = "udmin",
user_id = %auth_context.user_id,
permission = %required_permission,
"Permission granted"
);
Ok(next.run(request).await)
})
}
}
/// 角色检查中间件
pub fn require_role(required_role: &'static str) -> impl Fn(Request, Next) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Response, AppError>> + Send>> + Clone {
move |request: Request, next: Next| {
Box::pin(async move {
let auth_context = request
.extensions()
.get::<AuthContext>()
.ok_or_else(|| AppError::Unauthorized("未认证的请求".to_string()))?;
if !auth_context.roles.contains(&required_role.to_string()) {
warn!(
target = "udmin",
user_id = %auth_context.user_id,
required_role = %required_role,
"Role check failed"
);
return Err(AppError::Forbidden("角色权限不足".to_string()));
}
info!(
target = "udmin",
user_id = %auth_context.user_id,
role = %required_role,
"Role check passed"
);
Ok(next.run(request).await)
})
}
}
/// 可选认证中间件(不强制要求认证)
pub async fn optional_auth_middleware(
State(state): State<AppState>,
mut request: Request,
next: Next,
) -> Response {
let auth_header = request
.headers()
.get(header::AUTHORIZATION)
.and_then(|header| header.to_str().ok());
if let Some(header) = auth_header {
if let Some(token) = header.strip_prefix("Bearer ") {
if let Ok(token_data) = decode::<Claims>(
token,
&DecodingKey::from_secret(state.config.jwt_secret.as_ref()),
&Validation::default(),
) {
let claims = token_data.claims;
// 创建认证上下文
let auth_context = AuthContext {
user_id: claims.sub.clone(),
username: claims.username.clone(),
roles: claims.roles.clone(),
permissions: claims.permissions.into_iter().collect(),
};
request.extensions_mut().insert(auth_context);
}
}
}
next.run(request).await
}
```
## CORS 中间件 (cors.rs)
### 功能特性
- 跨域资源共享支持
- 可配置的允许源
- 预检请求处理
- 安全头设置
### 实现代码
```rust
use axum::{
extract::Request,
http::{header, HeaderValue, Method, StatusCode},
middleware::Next,
response::Response,
};
use tower_http::cors::{Any, CorsLayer};
use tracing::info;
/// CORS 配置
#[derive(Debug, Clone)]
pub struct CorsConfig {
pub allowed_origins: Vec<String>,
pub allowed_methods: Vec<Method>,
pub allowed_headers: Vec<String>,
pub max_age: u64,
pub allow_credentials: bool,
}
impl Default for CorsConfig {
fn default() -> Self {
Self {
allowed_origins: vec![
"http://localhost:3000".to_string(),
"http://localhost:5173".to_string(),
"http://127.0.0.1:3000".to_string(),
"http://127.0.0.1:5173".to_string(),
],
allowed_methods: vec![
Method::GET,
Method::POST,
Method::PUT,
Method::DELETE,
Method::PATCH,
Method::OPTIONS,
],
allowed_headers: vec![
"content-type".to_string(),
"authorization".to_string(),
"x-requested-with".to_string(),
"x-api-key".to_string(),
],
max_age: 3600,
allow_credentials: true,
}
}
}
/// 创建 CORS 层
pub fn create_cors_layer(config: CorsConfig) -> CorsLayer {
let mut cors = CorsLayer::new()
.allow_methods(config.allowed_methods)
.allow_headers(
config
.allowed_headers
.iter()
.map(|h| h.parse().unwrap())
.collect::<Vec<_>>(),
)
.max_age(std::time::Duration::from_secs(config.max_age));
// 设置允许的源
if config.allowed_origins.contains(&"*".to_string()) {
cors = cors.allow_origin(Any);
} else {
cors = cors.allow_origin(
config
.allowed_origins
.iter()
.map(|origin| origin.parse().unwrap())
.collect::<Vec<HeaderValue>>(),
);
}
// 设置是否允许凭据
if config.allow_credentials {
cors = cors.allow_credentials(true);
}
info!(target = "udmin", "CORS middleware configured");
cors
}
/// 自定义 CORS 中间件
pub async fn cors_middleware(
request: Request,
next: Next,
) -> Response {
let origin = request
.headers()
.get(header::ORIGIN)
.and_then(|v| v.to_str().ok());
let method = request.method();
// 处理预检请求
if method == Method::OPTIONS {
let mut response = Response::builder()
.status(StatusCode::OK)
.body(axum::body::Body::empty())
.unwrap();
let headers = response.headers_mut();
if let Some(origin) = origin {
headers.insert(
header::ACCESS_CONTROL_ALLOW_ORIGIN,
HeaderValue::from_str(origin).unwrap(),
);
}
headers.insert(
header::ACCESS_CONTROL_ALLOW_METHODS,
HeaderValue::from_static("GET, POST, PUT, DELETE, PATCH, OPTIONS"),
);
headers.insert(
header::ACCESS_CONTROL_ALLOW_HEADERS,
HeaderValue::from_static("content-type, authorization, x-requested-with, x-api-key"),
);
headers.insert(
header::ACCESS_CONTROL_MAX_AGE,
HeaderValue::from_static("3600"),
);
headers.insert(
header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
HeaderValue::from_static("true"),
);
return response;
}
let mut response = next.run(request).await;
// 为实际请求添加 CORS 头
let headers = response.headers_mut();
if let Some(origin) = origin {
headers.insert(
header::ACCESS_CONTROL_ALLOW_ORIGIN,
HeaderValue::from_str(origin).unwrap(),
);
}
headers.insert(
header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
HeaderValue::from_static("true"),
);
response
}
```
## 请求日志中间件 (logging.rs)
### 功能特性
- 请求响应日志记录
- 性能监控
- 错误追踪
- 结构化日志
- 请求ID追踪
### 实现代码
```rust
use axum::{
extract::{MatchedPath, Request},
http::StatusCode,
middleware::Next,
response::Response,
};
use std::time::Instant;
use tracing::{error, info, warn};
use uuid::Uuid;
/// 请求日志中间件
pub async fn request_logging_middleware(
request: Request,
next: Next,
) -> Response {
let start_time = Instant::now();
let request_id = Uuid::new_v4().to_string();
// 提取请求信息
let method = request.method().clone();
let uri = request.uri().clone();
let version = request.version();
let user_agent = request
.headers()
.get("user-agent")
.and_then(|v| v.to_str().ok())
.unwrap_or("unknown");
let remote_addr = request
.headers()
.get("x-forwarded-for")
.and_then(|v| v.to_str().ok())
.or_else(|| {
request
.headers()
.get("x-real-ip")
.and_then(|v| v.to_str().ok())
})
.unwrap_or("unknown");
// 获取匹配的路径模式
let matched_path = request
.extensions()
.get::<MatchedPath>()
.map(|mp| mp.as_str())
.unwrap_or(uri.path());
info!(
target = "udmin",
request_id = %request_id,
method = %method,
uri = %uri,
path = %matched_path,
version = ?version,
user_agent = %user_agent,
remote_addr = %remote_addr,
"Request started"
);
// 执行请求
let response = next.run(request).await;
// 计算请求耗时
let duration = start_time.elapsed();
let status = response.status();
// 根据状态码选择日志级别
match status.as_u16() {
200..=299 => {
info!(
target = "udmin",
request_id = %request_id,
method = %method,
uri = %uri,
path = %matched_path,
status = %status.as_u16(),
duration_ms = %duration.as_millis(),
"Request completed successfully"
);
}
300..=399 => {
info!(
target = "udmin",
request_id = %request_id,
method = %method,
uri = %uri,
path = %matched_path,
status = %status.as_u16(),
duration_ms = %duration.as_millis(),
"Request redirected"
);
}
400..=499 => {
warn!(
target = "udmin",
request_id = %request_id,
method = %method,
uri = %uri,
path = %matched_path,
status = %status.as_u16(),
duration_ms = %duration.as_millis(),
"Client error"
);
}
500..=599 => {
error!(
target = "udmin",
request_id = %request_id,
method = %method,
uri = %uri,
path = %matched_path,
status = %status.as_u16(),
duration_ms = %duration.as_millis(),
"Server error"
);
}
_ => {
info!(
target = "udmin",
request_id = %request_id,
method = %method,
uri = %uri,
path = %matched_path,
status = %status.as_u16(),
duration_ms = %duration.as_millis(),
"Request completed"
);
}
}
response
}
/// 性能监控中间件
pub async fn performance_monitoring_middleware(
request: Request,
next: Next,
) -> Response {
let start_time = Instant::now();
let method = request.method().clone();
let uri = request.uri().clone();
let response = next.run(request).await;
let duration = start_time.elapsed();
// 记录慢请求
if duration.as_millis() > 1000 {
warn!(
target = "udmin",
method = %method,
uri = %uri,
duration_ms = %duration.as_millis(),
"Slow request detected"
);
}
// 记录性能指标
info!(
target = "udmin.performance",
method = %method,
uri = %uri,
status = %response.status().as_u16(),
duration_ms = %duration.as_millis(),
"Performance metrics"
);
response
}
```
## HTTP 客户端中间件 (http_client.rs)
### 功能特性
- HTTP 客户端封装
- 请求重试机制
- 超时控制
- 请求日志
- 错误处理
### 实现代码
```rust
use reqwest::{Client, ClientBuilder, Response};
use serde::{Deserialize, Serialize};
use std::time::Duration;
use tracing::{error, info, warn};
use url::Url;
use crate::error::AppError;
/// HTTP 客户端配置
#[derive(Debug, Clone)]
pub struct HttpClientConfig {
pub timeout: Duration,
pub connect_timeout: Duration,
pub max_retries: u32,
pub retry_delay: Duration,
pub user_agent: String,
}
impl Default for HttpClientConfig {
fn default() -> Self {
Self {
timeout: Duration::from_secs(30),
connect_timeout: Duration::from_secs(10),
max_retries: 3,
retry_delay: Duration::from_millis(1000),
user_agent: "UdminAI/1.0".to_string(),
}
}
}
/// HTTP 客户端包装器
#[derive(Debug, Clone)]
pub struct HttpClient {
client: Client,
config: HttpClientConfig,
}
impl HttpClient {
/// 创建新的 HTTP 客户端
pub fn new(config: HttpClientConfig) -> Result<Self, AppError> {
let client = ClientBuilder::new()
.timeout(config.timeout)
.connect_timeout(config.connect_timeout)
.user_agent(&config.user_agent)
.build()
.map_err(|e| AppError::InternalServerError(format!("创建HTTP客户端失败: {}", e)))?;
Ok(Self { client, config })
}
/// 发送 GET 请求
pub async fn get(&self, url: &str) -> Result<Response, AppError> {
self.request_with_retry("GET", url, None::<()>).await
}
/// 发送 POST 请求
pub async fn post<T: Serialize>(&self, url: &str, body: &T) -> Result<Response, AppError> {
self.request_with_retry("POST", url, Some(body)).await
}
/// 发送 PUT 请求
pub async fn put<T: Serialize>(&self, url: &str, body: &T) -> Result<Response, AppError> {
self.request_with_retry("PUT", url, Some(body)).await
}
/// 发送 DELETE 请求
pub async fn delete(&self, url: &str) -> Result<Response, AppError> {
self.request_with_retry("DELETE", url, None::<()>).await
}
/// 带重试的请求
async fn request_with_retry<T: Serialize>(
&self,
method: &str,
url: &str,
body: Option<&T>,
) -> Result<Response, AppError> {
let parsed_url = Url::parse(url)
.map_err(|e| AppError::BadRequest(format!("无效的URL: {}", e)))?;
let mut last_error = None;
for attempt in 0..=self.config.max_retries {
let start_time = std::time::Instant::now();
info!(
target = "udmin",
method = %method,
url = %url,
attempt = %attempt,
"HTTP request started"
);
let result = self.make_request(method, &parsed_url, body).await;
let duration = start_time.elapsed();
match result {
Ok(response) => {
let status = response.status();
info!(
target = "udmin",
method = %method,
url = %url,
status = %status.as_u16(),
duration_ms = %duration.as_millis(),
attempt = %attempt,
"HTTP request completed"
);
// 如果是服务器错误且还有重试次数,则重试
if status.is_server_error() && attempt < self.config.max_retries {
warn!(
target = "udmin",
method = %method,
url = %url,
status = %status.as_u16(),
attempt = %attempt,
"Server error, retrying"
);
tokio::time::sleep(self.config.retry_delay).await;
continue;
}
return Ok(response);
}
Err(e) => {
error!(
target = "udmin",
method = %method,
url = %url,
error = %e,
duration_ms = %duration.as_millis(),
attempt = %attempt,
"HTTP request failed"
);
last_error = Some(e);
// 如果还有重试次数,则等待后重试
if attempt < self.config.max_retries {
tokio::time::sleep(self.config.retry_delay).await;
}
}
}
}
Err(last_error.unwrap_or_else(|| {
AppError::InternalServerError("HTTP请求失败".to_string())
}))
}
/// 执行实际的 HTTP 请求
async fn make_request<T: Serialize>(
&self,
method: &str,
url: &Url,
body: Option<&T>,
) -> Result<Response, AppError> {
let mut request_builder = match method {
"GET" => self.client.get(url.clone()),
"POST" => self.client.post(url.clone()),
"PUT" => self.client.put(url.clone()),
"DELETE" => self.client.delete(url.clone()),
_ => {
return Err(AppError::BadRequest(format!(
"不支持的HTTP方法: {}",
method
)))
}
};
// 添加请求体
if let Some(body) = body {
request_builder = request_builder.json(body);
}
// 发送请求
let response = request_builder
.send()
.await
.map_err(|e| AppError::InternalServerError(format!("HTTP请求失败: {}", e)))?;
Ok(response)
}
/// 发送 JSON 请求并解析响应
pub async fn json_request<T, R>(
&self,
method: &str,
url: &str,
body: Option<&T>,
) -> Result<R, AppError>
where
T: Serialize,
R: for<'de> Deserialize<'de>,
{
let response = match method {
"GET" => self.get(url).await?,
"POST" => {
if let Some(body) = body {
self.post(url, body).await?
} else {
return Err(AppError::BadRequest("POST请求需要请求体".to_string()));
}
}
"PUT" => {
if let Some(body) = body {
self.put(url, body).await?
} else {
return Err(AppError::BadRequest("PUT请求需要请求体".to_string()));
}
}
"DELETE" => self.delete(url).await?,
_ => {
return Err(AppError::BadRequest(format!(
"不支持的HTTP方法: {}",
method
)))
}
};
// 检查响应状态
if !response.status().is_success() {
let status = response.status();
let error_text = response
.text()
.await
.unwrap_or_else(|_| "无法读取错误响应".to_string());
return Err(AppError::InternalServerError(format!(
"HTTP请求失败: {} - {}",
status, error_text
)));
}
// 解析 JSON 响应
let json_response = response
.json::<R>()
.await
.map_err(|e| AppError::InternalServerError(format!("解析JSON响应失败: {}", e)))?;
Ok(json_response)
}
}
```
## 限流中间件 (rate_limit.rs)
### 功能特性
- 基于令牌桶的限流
- 支持不同的限流策略
- IP 级别限流
- 用户级别限流
- 动态配置
### 实现代码
```rust
use axum::{
extract::{Request, State},
http::StatusCode,
middleware::Next,
response::Response,
};
use std::{
collections::HashMap,
sync::{Arc, Mutex},
time::{Duration, Instant},
};
use tracing::{info, warn};
use crate::{error::AppError, middlewares::auth::AuthContext};
/// 令牌桶
#[derive(Debug, Clone)]
struct TokenBucket {
capacity: u32,
tokens: u32,
last_refill: Instant,
refill_rate: u32, // tokens per second
}
impl TokenBucket {
fn new(capacity: u32, refill_rate: u32) -> Self {
Self {
capacity,
tokens: capacity,
last_refill: Instant::now(),
refill_rate,
}
}
fn try_consume(&mut self, tokens: u32) -> bool {
self.refill();
if self.tokens >= tokens {
self.tokens -= tokens;
true
} else {
false
}
}
fn refill(&mut self) {
let now = Instant::now();
let elapsed = now.duration_since(self.last_refill);
let tokens_to_add = (elapsed.as_secs_f64() * self.refill_rate as f64) as u32;
if tokens_to_add > 0 {
self.tokens = (self.tokens + tokens_to_add).min(self.capacity);
self.last_refill = now;
}
}
}
/// 限流器
#[derive(Debug)]
pub struct RateLimiter {
buckets: Arc<Mutex<HashMap<String, TokenBucket>>>,
default_capacity: u32,
default_refill_rate: u32,
}
impl RateLimiter {
pub fn new(default_capacity: u32, default_refill_rate: u32) -> Self {
Self {
buckets: Arc::new(Mutex::new(HashMap::new())),
default_capacity,
default_refill_rate,
}
}
pub fn check_rate_limit(&self, key: &str, tokens: u32) -> bool {
let mut buckets = self.buckets.lock().unwrap();
let bucket = buckets
.entry(key.to_string())
.or_insert_with(|| TokenBucket::new(self.default_capacity, self.default_refill_rate));
bucket.try_consume(tokens)
}
/// 清理过期的桶
pub fn cleanup_expired_buckets(&self, max_idle_duration: Duration) {
let mut buckets = self.buckets.lock().unwrap();
let now = Instant::now();
buckets.retain(|_, bucket| {
now.duration_since(bucket.last_refill) < max_idle_duration
});
}
}
/// 限流配置
#[derive(Debug, Clone)]
pub struct RateLimitConfig {
pub requests_per_second: u32,
pub burst_capacity: u32,
pub enable_ip_limit: bool,
pub enable_user_limit: bool,
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
requests_per_second: 100,
burst_capacity: 200,
enable_ip_limit: true,
enable_user_limit: true,
}
}
}
/// IP 限流中间件
pub async fn ip_rate_limit_middleware(
State(rate_limiter): State<Arc<RateLimiter>>,
request: Request,
next: Next,
) -> Result<Response, AppError> {
// 获取客户端 IP
let client_ip = request
.headers()
.get("x-forwarded-for")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.split(',').next())
.or_else(|| {
request
.headers()
.get("x-real-ip")
.and_then(|v| v.to_str().ok())
})
.unwrap_or("unknown")
.trim();
let rate_limit_key = format!("ip:{}", client_ip);
// 检查限流
if !rate_limiter.check_rate_limit(&rate_limit_key, 1) {
warn!(
target = "udmin",
client_ip = %client_ip,
"IP rate limit exceeded"
);
return Err(AppError::TooManyRequests(
"请求过于频繁,请稍后再试".to_string(),
));
}
info!(
target = "udmin",
client_ip = %client_ip,
"IP rate limit check passed"
);
Ok(next.run(request).await)
}
/// 用户限流中间件
pub async fn user_rate_limit_middleware(
State(rate_limiter): State<Arc<RateLimiter>>,
request: Request,
next: Next,
) -> Result<Response, AppError> {
// 获取用户认证信息
if let Some(auth_context) = request.extensions().get::<AuthContext>() {
let rate_limit_key = format!("user:{}", auth_context.user_id);
// 检查用户级别限流
if !rate_limiter.check_rate_limit(&rate_limit_key, 1) {
warn!(
target = "udmin",
user_id = %auth_context.user_id,
"User rate limit exceeded"
);
return Err(AppError::TooManyRequests(
"用户请求过于频繁,请稍后再试".to_string(),
));
}
info!(
target = "udmin",
user_id = %auth_context.user_id,
"User rate limit check passed"
);
}
Ok(next.run(request).await)
}
/// 创建限流中间件
pub fn create_rate_limit_middleware(
config: RateLimitConfig,
) -> Arc<RateLimiter> {
let rate_limiter = Arc::new(RateLimiter::new(
config.burst_capacity,
config.requests_per_second,
));
// 启动清理任务
let cleanup_limiter = rate_limiter.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(300)); // 5分钟清理一次
loop {
interval.tick().await;
cleanup_limiter.cleanup_expired_buckets(Duration::from_secs(3600)); // 清理1小时未使用的桶
}
});
info!(target = "udmin", "Rate limiter initialized");
rate_limiter
}
```
## WebSocket 中间件 (ws.rs)
### 功能特性
- WebSocket 连接管理
- 实时消息推送
- 连接认证
- 心跳检测
- 广播支持
### 实现代码
```rust
use axum::{
extract::{ws::WebSocket, Query, State, WebSocketUpgrade},
response::Response,
};
use futures_util::{sink::SinkExt, stream::StreamExt};
use serde::{Deserialize, Serialize};
use std::{
collections::HashMap,
sync::{Arc, RwLock},
time::{Duration, Instant},
};
use tokio::sync::mpsc;
use tracing::{error, info, warn};
use uuid::Uuid;
use crate::{
error::AppError,
services::user_service,
AppState,
};
/// WebSocket 连接信息
#[derive(Debug, Clone)]
pub struct ConnectionInfo {
pub id: String,
pub user_id: Option<String>,
pub connected_at: Instant,
pub last_ping: Instant,
pub sender: mpsc::UnboundedSender<WebSocketMessage>,
}
/// WebSocket 消息类型
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", content = "data")]
pub enum WebSocketMessage {
/// 认证消息
Auth { token: String },
/// 认证成功
AuthSuccess { user_id: String },
/// 认证失败
AuthError { message: String },
/// 心跳包
Ping,
/// 心跳响应
Pong,
/// 流程执行状态更新
FlowExecutionUpdate {
execution_id: String,
status: String,
progress: Option<f64>,
message: Option<String>,
},
/// 任务执行状态更新
JobExecutionUpdate {
job_id: String,
execution_id: String,
status: String,
message: Option<String>,
},
/// 系统通知
SystemNotification {
title: String,
message: String,
level: String,
},
/// 用户通知
UserNotification {
id: String,
title: String,
message: String,
created_at: String,
},
/// 错误消息
Error { message: String },
}
/// WebSocket 连接管理器
#[derive(Debug)]
pub struct WebSocketManager {
connections: Arc<RwLock<HashMap<String, ConnectionInfo>>>,
user_connections: Arc<RwLock<HashMap<String, Vec<String>>>>, // user_id -> connection_ids
}
impl WebSocketManager {
pub fn new() -> Self {
Self {
connections: Arc::new(RwLock::new(HashMap::new())),
user_connections: Arc::new(RwLock::new(HashMap::new())),
}
}
/// 添加连接
pub fn add_connection(&self, connection_info: ConnectionInfo) {
let connection_id = connection_info.id.clone();
let user_id = connection_info.user_id.clone();
// 添加到连接列表
self.connections.write().unwrap().insert(connection_id.clone(), connection_info);
// 如果有用户ID添加到用户连接映射
if let Some(user_id) = user_id {
self.user_connections
.write()
.unwrap()
.entry(user_id)
.or_insert_with(Vec::new)
.push(connection_id.clone());
}
info!(
target = "udmin",
connection_id = %connection_id,
"WebSocket connection added"
);
}
/// 移除连接
pub fn remove_connection(&self, connection_id: &str) {
let mut connections = self.connections.write().unwrap();
if let Some(connection_info) = connections.remove(connection_id) {
// 从用户连接映射中移除
if let Some(user_id) = &connection_info.user_id {
let mut user_connections = self.user_connections.write().unwrap();
if let Some(user_conn_list) = user_connections.get_mut(user_id) {
user_conn_list.retain(|id| id != connection_id);
if user_conn_list.is_empty() {
user_connections.remove(user_id);
}
}
}
info!(
target = "udmin",
connection_id = %connection_id,
user_id = ?connection_info.user_id,
"WebSocket connection removed"
);
}
}
/// 向指定连接发送消息
pub fn send_to_connection(&self, connection_id: &str, message: WebSocketMessage) -> bool {
let connections = self.connections.read().unwrap();
if let Some(connection_info) = connections.get(connection_id) {
match connection_info.sender.send(message) {
Ok(_) => true,
Err(e) => {
warn!(
target = "udmin",
connection_id = %connection_id,
error = %e,
"Failed to send message to connection"
);
false
}
}
} else {
false
}
}
/// 向指定用户的所有连接发送消息
pub fn send_to_user(&self, user_id: &str, message: WebSocketMessage) -> usize {
let user_connections = self.user_connections.read().unwrap();
let mut sent_count = 0;
if let Some(connection_ids) = user_connections.get(user_id) {
for connection_id in connection_ids {
if self.send_to_connection(connection_id, message.clone()) {
sent_count += 1;
}
}
}
info!(
target = "udmin",
user_id = %user_id,
sent_count = %sent_count,
"Message sent to user connections"
);
sent_count
}
/// 广播消息给所有连接
pub fn broadcast(&self, message: WebSocketMessage) -> usize {
let connections = self.connections.read().unwrap();
let mut sent_count = 0;
for (connection_id, connection_info) in connections.iter() {
match connection_info.sender.send(message.clone()) {
Ok(_) => sent_count += 1,
Err(e) => {
warn!(
target = "udmin",
connection_id = %connection_id,
error = %e,
"Failed to broadcast message"
);
}
}
}
info!(
target = "udmin",
sent_count = %sent_count,
total_connections = %connections.len(),
"Message broadcasted"
);
sent_count
}
/// 获取连接统计信息
pub fn get_stats(&self) -> WebSocketStats {
let connections = self.connections.read().unwrap();
let user_connections = self.user_connections.read().unwrap();
WebSocketStats {
total_connections: connections.len(),
authenticated_connections: connections.values().filter(|c| c.user_id.is_some()).count(),
unique_users: user_connections.len(),
}
}
/// 清理过期连接
pub fn cleanup_stale_connections(&self, timeout: Duration) {
let now = Instant::now();
let connections = self.connections.read().unwrap();
let stale_connections: Vec<String> = connections
.iter()
.filter(|(_, info)| now.duration_since(info.last_ping) > timeout)
.map(|(id, _)| id.clone())
.collect();
drop(connections);
for connection_id in stale_connections {
self.remove_connection(&connection_id);
warn!(
target = "udmin",
connection_id = %connection_id,
"Removed stale WebSocket connection"
);
}
}
}
/// WebSocket 统计信息
#[derive(Debug, Serialize)]
pub struct WebSocketStats {
pub total_connections: usize,
pub authenticated_connections: usize,
pub unique_users: usize,
}
/// WebSocket 查询参数
#[derive(Debug, Deserialize)]
pub struct WebSocketQuery {
pub token: Option<String>,
}
/// WebSocket 升级处理器
pub async fn websocket_handler(
ws: WebSocketUpgrade,
Query(params): Query<WebSocketQuery>,
State(state): State<AppState>,
) -> Response {
ws.on_upgrade(move |socket| handle_websocket(socket, params.token, state))
}
/// 处理 WebSocket 连接
async fn handle_websocket(
socket: WebSocket,
token: Option<String>,
state: AppState,
) {
let connection_id = Uuid::new_v4().to_string();
let (mut sender, mut receiver) = socket.split();
let (tx, mut rx) = mpsc::unbounded_channel::<WebSocketMessage>();
info!(
target = "udmin",
connection_id = %connection_id,
"WebSocket connection established"
);
// 创建连接信息
let connection_info = ConnectionInfo {
id: connection_id.clone(),
user_id: None,
connected_at: Instant::now(),
last_ping: Instant::now(),
sender: tx,
};
// 添加到连接管理器
state.ws_manager.add_connection(connection_info);
// 如果提供了token尝试认证
let mut authenticated_user_id = None;
if let Some(token) = token {
match authenticate_websocket_token(&token, &state).await {
Ok(user_id) => {
authenticated_user_id = Some(user_id.clone());
// 更新连接信息
if let Some(mut connection_info) = state.connections.write().unwrap().get_mut(&connection_id) {
connection_info.user_id = Some(user_id.clone());
}
// 发送认证成功消息
let _ = sender.send(axum::extract::ws::Message::Text(
serde_json::to_string(&WebSocketMessage::AuthSuccess { user_id }).unwrap()
)).await;
info!(
target = "udmin",
connection_id = %connection_id,
user_id = %authenticated_user_id.as_ref().unwrap(),
"WebSocket connection authenticated"
);
}
Err(e) => {
warn!(
target = "udmin",
connection_id = %connection_id,
error = %e,
"WebSocket authentication failed"
);
let _ = sender.send(axum::extract::ws::Message::Text(
serde_json::to_string(&WebSocketMessage::AuthError {
message: "认证失败".to_string()
}).unwrap()
)).await;
}
}
}
// 启动消息发送任务
let connection_id_clone = connection_id.clone();
let ws_manager_clone = state.ws_manager.clone();
let send_task = tokio::spawn(async move {
while let Some(message) = rx.recv().await {
let text = match serde_json::to_string(&message) {
Ok(text) => text,
Err(e) => {
error!(
target = "udmin",
connection_id = %connection_id_clone,
error = %e,
"Failed to serialize WebSocket message"
);
continue;
}
};
if sender.send(axum::extract::ws::Message::Text(text)).await.is_err() {
break;
}
}
// 移除连接
ws_manager_clone.remove_connection(&connection_id_clone);
});
// 启动心跳任务
let connection_id_clone = connection_id.clone();
let ws_manager_clone = state.ws_manager.clone();
let heartbeat_task = tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(30));
loop {
interval.tick().await;
if !ws_manager_clone.send_to_connection(
&connection_id_clone,
WebSocketMessage::Ping,
) {
break;
}
}
});
// 处理接收到的消息
while let Some(msg) = receiver.next().await {
match msg {
Ok(axum::extract::ws::Message::Text(text)) => {
match serde_json::from_str::<WebSocketMessage>(&text) {
Ok(WebSocketMessage::Auth { token }) => {
// 处理认证消息
match authenticate_websocket_token(&token, &state).await {
Ok(user_id) => {
authenticated_user_id = Some(user_id.clone());
let _ = state.ws_manager.send_to_connection(
&connection_id,
WebSocketMessage::AuthSuccess { user_id },
);
}
Err(_) => {
let _ = state.ws_manager.send_to_connection(
&connection_id,
WebSocketMessage::AuthError {
message: "认证失败".to_string(),
},
);
}
}
}
Ok(WebSocketMessage::Pong) => {
// 更新最后ping时间
// 这里可以更新连接的last_ping时间
}
Ok(_) => {
// 处理其他消息类型
}
Err(e) => {
warn!(
target = "udmin",
connection_id = %connection_id,
error = %e,
"Failed to parse WebSocket message"
);
}
}
}
Ok(axum::extract::ws::Message::Close(_)) => {
info!(
target = "udmin",
connection_id = %connection_id,
"WebSocket connection closed by client"
);
break;
}
Err(e) => {
error!(
target = "udmin",
connection_id = %connection_id,
error = %e,
"WebSocket error"
);
break;
}
_ => {}
}
}
// 清理任务
send_task.abort();
heartbeat_task.abort();
state.ws_manager.remove_connection(&connection_id);
info!(
target = "udmin",
connection_id = %connection_id,
"WebSocket connection closed"
);
}
/// 认证 WebSocket Token
async fn authenticate_websocket_token(
token: &str,
state: &AppState,
) -> Result<String, AppError> {
use jsonwebtoken::{decode, DecodingKey, Validation};
use crate::middlewares::auth::Claims;
// 验证 JWT Token
let claims = decode::<Claims>(
token,
&DecodingKey::from_secret(state.config.jwt_secret.as_ref()),
&Validation::default(),
)
.map_err(|_| AppError::Unauthorized("无效的认证令牌".to_string()))?
.claims;
// 检查用户是否存在且活跃
let user = user_service::find_by_id(&state.db, &claims.sub)
.await
.map_err(|_| AppError::Unauthorized("用户不存在".to_string()))?;
if user.status != crate::models::user::UserStatus::Active {
return Err(AppError::Unauthorized("用户账户已被禁用".to_string()));
}
Ok(claims.sub)
}
/// 启动 WebSocket 管理器清理任务
pub fn start_websocket_cleanup_task(ws_manager: Arc<WebSocketManager>) {
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(60));
loop {
interval.tick().await;
ws_manager.cleanup_stale_connections(Duration::from_secs(300)); // 5分钟超时
}
});
}
```
## SSE 中间件 (sse.rs)
### 功能特性
- Server-Sent Events 支持
- 实时事件推送
- 连接管理
- 事件过滤
- 重连支持
### 实现代码
```rust
use axum::{
extract::{Query, State},
http::{header, HeaderValue, StatusCode},
response::{sse::Event, Response, Sse},
};
use futures_util::stream::{self, Stream};
use serde::{Deserialize, Serialize};
use std::{
collections::HashMap,
convert::Infallible,
sync::{Arc, RwLock},
time::{Duration, SystemTime, UNIX_EPOCH},
};
use tokio::sync::mpsc;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{error, info, warn};
use uuid::Uuid;
use crate::{
error::AppError,
middlewares::auth::AuthContext,
AppState,
};
/// SSE 事件类型
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", content = "data")]
pub enum SseEvent {
/// 流程执行状态更新
FlowExecutionUpdate {
execution_id: String,
flow_id: String,
status: String,
progress: Option<f64>,
message: Option<String>,
timestamp: u64,
},
/// 任务执行状态更新
JobExecutionUpdate {
job_id: String,
execution_id: String,
status: String,
message: Option<String>,
timestamp: u64,
},
/// 系统通知
SystemNotification {
id: String,
title: String,
message: String,
level: String,
timestamp: u64,
},
/// 用户通知
UserNotification {
id: String,
title: String,
message: String,
timestamp: u64,
},
/// 系统状态更新
SystemStatus {
cpu_usage: f64,
memory_usage: f64,
disk_usage: f64,
active_connections: usize,
timestamp: u64,
},
/// 心跳事件
Heartbeat {
timestamp: u64,
},
}
/// SSE 连接信息
#[derive(Debug)]
struct SseConnection {
id: String,
user_id: Option<String>,
sender: mpsc::UnboundedSender<Result<Event, Infallible>>,
filters: Vec<String>,
created_at: SystemTime,
}
/// SSE 连接管理器
#[derive(Debug)]
pub struct SseManager {
connections: Arc<RwLock<HashMap<String, SseConnection>>>,
user_connections: Arc<RwLock<HashMap<String, Vec<String>>>>,
}
impl SseManager {
pub fn new() -> Self {
Self {
connections: Arc::new(RwLock::new(HashMap::new())),
user_connections: Arc<RwLock::new(HashMap::new())),
}
}
/// 添加 SSE 连接
pub fn add_connection(
&self,
connection_id: String,
user_id: Option<String>,
sender: mpsc::UnboundedSender<Result<Event, Infallible>>,
filters: Vec<String>,
) {
let connection = SseConnection {
id: connection_id.clone(),
user_id: user_id.clone(),
sender,
filters,
created_at: SystemTime::now(),
};
// 添加到连接列表
self.connections.write().unwrap().insert(connection_id.clone(), connection);
// 如果有用户ID添加到用户连接映射
if let Some(user_id) = user_id {
self.user_connections
.write()
.unwrap()
.entry(user_id)
.or_insert_with(Vec::new)
.push(connection_id.clone());
}
info!(
target = "udmin",
connection_id = %connection_id,
"SSE connection added"
);
}
/// 移除 SSE 连接
pub fn remove_connection(&self, connection_id: &str) {
let mut connections = self.connections.write().unwrap();
if let Some(connection) = connections.remove(connection_id) {
// 从用户连接映射中移除
if let Some(user_id) = &connection.user_id {
let mut user_connections = self.user_connections.write().unwrap();
if let Some(user_conn_list) = user_connections.get_mut(user_id) {
user_conn_list.retain(|id| id != connection_id);
if user_conn_list.is_empty() {
user_connections.remove(user_id);
}
}
}
info!(
target = "udmin",
connection_id = %connection_id,
user_id = ?connection.user_id,
"SSE connection removed"
);
}
}
/// 发送事件到指定连接
pub fn send_to_connection(&self, connection_id: &str, event: SseEvent) -> bool {
let connections = self.connections.read().unwrap();
if let Some(connection) = connections.get(connection_id) {
// 检查事件过滤器
if !connection.filters.is_empty() {
let event_type = match &event {
SseEvent::FlowExecutionUpdate { .. } => "flow_execution",
SseEvent::JobExecutionUpdate { .. } => "job_execution",
SseEvent::SystemNotification { .. } => "system_notification",
SseEvent::UserNotification { .. } => "user_notification",
SseEvent::SystemStatus { .. } => "system_status",
SseEvent::Heartbeat { .. } => "heartbeat",
};
if !connection.filters.contains(&event_type.to_string()) {
return true; // 过滤掉,但不算失败
}
}
let sse_event = match create_sse_event(&event) {
Ok(event) => event,
Err(e) => {
error!(
target = "udmin",
connection_id = %connection_id,
error = %e,
"Failed to create SSE event"
);
return false;
}
};
match connection.sender.send(Ok(sse_event)) {
Ok(_) => true,
Err(e) => {
warn!(
target = "udmin",
connection_id = %connection_id,
error = %e,
"Failed to send SSE event"
);
false
}
}
} else {
false
}
}
/// 发送事件到指定用户的所有连接
pub fn send_to_user(&self, user_id: &str, event: SseEvent) -> usize {
let user_connections = self.user_connections.read().unwrap();
let mut sent_count = 0;
if let Some(connection_ids) = user_connections.get(user_id) {
for connection_id in connection_ids {
if self.send_to_connection(connection_id, event.clone()) {
sent_count += 1;
}
}
}
info!(
target = "udmin",
user_id = %user_id,
sent_count = %sent_count,
"SSE event sent to user connections"
);
sent_count
}
/// 广播事件到所有连接
pub fn broadcast(&self, event: SseEvent) -> usize {
let connections = self.connections.read().unwrap();
let mut sent_count = 0;
for (connection_id, _) in connections.iter() {
if self.send_to_connection(connection_id, event.clone()) {
sent_count += 1;
}
}
info!(
target = "udmin",
sent_count = %sent_count,
total_connections = %connections.len(),
"SSE event broadcasted"
);
sent_count
}
/// 获取连接统计信息
pub fn get_stats(&self) -> SseStats {
let connections = self.connections.read().unwrap();
let user_connections = self.user_connections.read().unwrap();
SseStats {
total_connections: connections.len(),
authenticated_connections: connections.values().filter(|c| c.user_id.is_some()).count(),
unique_users: user_connections.len(),
}
}
}
/// SSE 统计信息
#[derive(Debug, Serialize)]
pub struct SseStats {
pub total_connections: usize,
pub authenticated_connections: usize,
pub unique_users: usize,
}
/// SSE 查询参数
#[derive(Debug, Deserialize)]
pub struct SseQuery {
pub filters: Option<String>, // 逗号分隔的事件类型过滤器
pub last_event_id: Option<String>,
}
/// SSE 处理器
pub async fn sse_handler(
Query(params): Query<SseQuery>,
State(state): State<AppState>,
auth_context: Option<AuthContext>,
) -> Result<Response, AppError> {
let connection_id = Uuid::new_v4().to_string();
let (tx, rx) = mpsc::unbounded_channel();
// 解析过滤器
let filters = params
.filters
.map(|f| f.split(',').map(|s| s.trim().to_string()).collect())
.unwrap_or_default();
// 添加连接到管理器
state.sse_manager.add_connection(
connection_id.clone(),
auth_context.as_ref().map(|ctx| ctx.user_id.clone()),
tx,
filters,
);
info!(
target = "udmin",
connection_id = %connection_id,
user_id = ?auth_context.as_ref().map(|ctx| &ctx.user_id),
"SSE connection established"
);
// 发送初始心跳
let initial_heartbeat = SseEvent::Heartbeat {
timestamp: SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs(),
};
state.sse_manager.send_to_connection(&connection_id, initial_heartbeat);
// 创建事件流
let stream = UnboundedReceiverStream::new(rx);
// 启动心跳任务
let connection_id_clone = connection_id.clone();
let sse_manager_clone = state.sse_manager.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(30));
loop {
interval.tick().await;
let heartbeat = SseEvent::Heartbeat {
timestamp: SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs(),
};
if !sse_manager_clone.send_to_connection(&connection_id_clone, heartbeat) {
break;
}
}
// 清理连接
sse_manager_clone.remove_connection(&connection_id_clone);
});
let sse = Sse::new(stream).keep_alive(
axum::response::sse::KeepAlive::new()
.interval(Duration::from_secs(15))
.text("keep-alive-text"),
);
Ok(sse.into_response())
}
/// 创建 SSE 事件
fn create_sse_event(event: &SseEvent) -> Result<Event, serde_json::Error> {
let event_type = match event {
SseEvent::FlowExecutionUpdate { .. } => "flow_execution_update",
SseEvent::JobExecutionUpdate { .. } => "job_execution_update",
SseEvent::SystemNotification { .. } => "system_notification",
SseEvent::UserNotification { .. } => "user_notification",
SseEvent::SystemStatus { .. } => "system_status",
SseEvent::Heartbeat { .. } => "heartbeat",
};
let data = serde_json::to_string(event)?;
Ok(Event::default()
.event(event_type)
.data(data))
}
/// 启动 SSE 管理器清理任务
pub fn start_sse_cleanup_task(sse_manager: Arc<SseManager>) {
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(300)); // 5分钟清理一次
loop {
interval.tick().await;
// 这里可以添加清理逻辑,比如移除长时间未活动的连接
let stats = sse_manager.get_stats();
info!(
target = "udmin",
total_connections = %stats.total_connections,
authenticated_connections = %stats.authenticated_connections,
unique_users = %stats.unique_users,
"SSE connection stats"
);
}
});
}
```
## 中间件组合和配置
### 中间件栈配置
```rust
use axum::{
middleware,
Router,
};
use std::sync::Arc;
use tower::ServiceBuilder;
use tower_http::{
compression::CompressionLayer,
timeout::TimeoutLayer,
trace::TraceLayer,
};
use crate::{
middlewares::{
auth::{jwt_auth_middleware, optional_auth_middleware},
cors::create_cors_layer,
logging::{request_logging_middleware, performance_monitoring_middleware},
rate_limit::{create_rate_limit_middleware, ip_rate_limit_middleware, user_rate_limit_middleware},
},
AppState,
};
/// 创建中间件栈
pub fn create_middleware_stack(state: AppState) -> ServiceBuilder<
impl tower::Layer<
axum::routing::Router,
Service = impl tower::Service<
http::Request<axum::body::Body>,
Response = axum::response::Response,
Error = std::convert::Infallible,
> + Clone + Send + 'static,
> + Clone,
> {
// 创建限流器
let rate_limiter = create_rate_limit_middleware(Default::default());
ServiceBuilder::new()
// 请求追踪
.layer(TraceLayer::new_for_http())
// 请求超时
.layer(TimeoutLayer::new(std::time::Duration::from_secs(30)))
// 响应压缩
.layer(CompressionLayer::new())
// CORS
.layer(create_cors_layer(Default::default()))
// 请求日志
.layer(middleware::from_fn(request_logging_middleware))
// 性能监控
.layer(middleware::from_fn(performance_monitoring_middleware))
// IP 限流
.layer(middleware::from_fn_with_state(
Arc::clone(&rate_limiter),
ip_rate_limit_middleware,
))
// 用户限流(需要在认证之后)
.layer(middleware::from_fn_with_state(
Arc::clone(&rate_limiter),
user_rate_limit_middleware,
))
}
/// 为需要认证的路由创建中间件栈
pub fn create_auth_middleware_stack(state: AppState) -> ServiceBuilder<
impl tower::Layer<
axum::routing::Router,
Service = impl tower::Service<
http::Request<axum::body::Body>,
Response = axum::response::Response,
Error = std::convert::Infallible,
> + Clone + Send + 'static,
> + Clone,
> {
ServiceBuilder::new()
.layer(middleware::from_fn_with_state(
state.clone(),
jwt_auth_middleware,
))
}
/// 为可选认证的路由创建中间件栈
pub fn create_optional_auth_middleware_stack(state: AppState) -> ServiceBuilder<
impl tower::Layer<
axum::routing::Router,
Service = impl tower::Service<
http::Request<axum::body::Body>,
Response = axum::response::Response,
Error = std::convert::Infallible,
> + Clone + Send + 'static,
> + Clone,
> {
ServiceBuilder::new()
.layer(middleware::from_fn_with_state(
state.clone(),
optional_auth_middleware,
))
}
```
## 测试支持
### 中间件测试
```rust
#[cfg(test)]
mod tests {
use super::*;
use axum::{
body::Body,
http::{Request, StatusCode},
};
use tower::ServiceExt;
#[tokio::test]
async fn test_cors_middleware() {
let app = Router::new()
.route("/test", axum::routing::get(|| async { "OK" }))
.layer(create_cors_layer(Default::default()));
let request = Request::builder()
.method("OPTIONS")
.uri("/test")
.header("Origin", "http://localhost:3000")
.body(Body::empty())
.unwrap();
let response = app.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert!(response.headers().contains_key("access-control-allow-origin"));
}
#[tokio::test]
async fn test_rate_limit_middleware() {
let rate_limiter = Arc::new(RateLimiter::new(1, 1)); // 1 token, 1 per second
let app = Router::new()
.route("/test", axum::routing::get(|| async { "OK" }))
.layer(middleware::from_fn_with_state(
rate_limiter,
ip_rate_limit_middleware,
));
// 第一个请求应该成功
let request1 = Request::builder()
.uri("/test")
.body(Body::empty())
.unwrap();
let response1 = app.clone().oneshot(request1).await.unwrap();
assert_eq!(response1.status(), StatusCode::OK);
// 第二个请求应该被限流
let request2 = Request::builder()
.uri("/test")
.body(Body::empty())
.unwrap();
let response2 = app.oneshot(request2).await.unwrap();
assert_eq!(response2.status(), StatusCode::TOO_MANY_REQUESTS);
}
}
```
## 最佳实践
### 中间件设计
- **单一职责**: 每个中间件只负责一个特定功能
- **可组合性**: 中间件可以灵活组合使用
- **性能优化**: 最小化中间件的性能开销
- **错误处理**: 优雅地处理中间件中的错误
- **日志记录**: 记录关键操作和错误信息
### 安全考虑
- **认证验证**: 严格验证用户身份
- **权限检查**: 确保用户有足够的权限
- **输入验证**: 验证所有输入数据
- **限流保护**: 防止恶意请求和DDoS攻击
- **CORS配置**: 正确配置跨域资源共享
### 性能优化
- **缓存策略**: 缓存认证结果和权限信息
- **连接池**: 合理配置数据库连接池
- **异步处理**: 使用异步操作提高并发性能
- **资源清理**: 及时清理过期的连接和资源
## 总结
中间件层是 UdminAI 系统的重要组成部分提供了认证授权、请求日志、错误处理、CORS、WebSocket、SSE、限流等核心功能。通过模块化设计和灵活的组合机制中间件层为系统提供了强大的横切关注点支持确保了系统的安全性、可靠性和性能。
主要特点:
- **模块化设计**: 每个中间件独立实现,职责单一
- **灵活组合**: 支持灵活的中间件组合和配置
- **高性能**: 优化的实现确保最小的性能开销
- **类型安全**: 利用 Rust 类型系统确保安全性
- **实时通信**: 支持 WebSocket 和 SSE 实时通信
- **安全防护**: 提供认证、授权、限流等安全机制
通过这套完整的中间件系统UdminAI 能够提供安全、高效、可靠的 Web 服务,满足企业级应用的各种需求。
```