From a3f2f99a68800d823b5e737014e47ac82f0ff436 Mon Sep 17 00:00:00 2001 From: ayou <550244300@qq.com> Date: Wed, 24 Sep 2025 20:21:45 +0800 Subject: [PATCH] =?UTF-8?q?docs:=20=E6=B7=BB=E5=8A=A0=E9=A1=B9=E7=9B=AE?= =?UTF-8?q?=E6=96=87=E6=A1=A3=E5=8C=85=E6=8B=AC=E6=80=BB=E8=A7=88=E3=80=81?= =?UTF-8?q?=E6=9E=B6=E6=9E=84=E3=80=81=E6=B5=81=E7=A8=8B=E5=BC=95=E6=93=8E?= =?UTF-8?q?=E5=92=8C=E6=9C=8D=E5=8A=A1=E5=B1=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 新增以下文档文件: - PROJECT_OVERVIEW.md 项目总览文档 - BACKEND_ARCHITECTURE.md 后端架构文档 - FRONTEND_ARCHITECTURE.md 前端架构文档 - FLOW_ENGINE.md 流程引擎文档 - SERVICES.md 服务层文档 - ERROR_HANDLING.md 错误处理模块文档 文档内容涵盖项目整体介绍、技术架构、核心模块设计和实现细节 --- README.md | 138 ++ docs/BACKEND_ARCHITECTURE.md | 308 +++++ docs/DATABASE.md | 1056 ++++++++++++++++ docs/ERROR_HANDLING.md | 878 +++++++++++++ docs/FLOW_ENGINE.md | 484 +++++++ docs/FRONTEND_ARCHITECTURE.md | 439 +++++++ docs/MIDDLEWARES.md | 2227 +++++++++++++++++++++++++++++++++ docs/MODELS.md | 1650 ++++++++++++++++++++++++ docs/PROJECT_OVERVIEW.md | 137 ++ docs/RESPONSE.md | 1161 +++++++++++++++++ docs/ROUTES.md | 1201 ++++++++++++++++++ docs/SERVICES.md | 853 +++++++++++++ docs/UTILS.md | 1271 +++++++++++++++++++ 13 files changed, 11803 insertions(+) create mode 100644 README.md create mode 100644 docs/BACKEND_ARCHITECTURE.md create mode 100644 docs/DATABASE.md create mode 100644 docs/ERROR_HANDLING.md create mode 100644 docs/FLOW_ENGINE.md create mode 100644 docs/FRONTEND_ARCHITECTURE.md create mode 100644 docs/MIDDLEWARES.md create mode 100644 docs/MODELS.md create mode 100644 docs/PROJECT_OVERVIEW.md create mode 100644 docs/RESPONSE.md create mode 100644 docs/ROUTES.md create mode 100644 docs/SERVICES.md create mode 100644 docs/UTILS.md diff --git a/README.md b/README.md new file mode 100644 index 0000000..5457f5a --- /dev/null +++ b/README.md @@ -0,0 +1,138 @@ +# UdminAI 项目文档 + +欢迎来到 UdminAI 项目文档中心。本文档集合提供了项目的完整技术文档,涵盖了架构设计、模块说明、API 文档和最佳实践等内容。 + +## 📚 文档导航 + +### 🏗️ 架构文档 + +- **[项目概览](PROJECT_OVERVIEW.md)** - UdminAI 项目的整体介绍、技术架构和核心功能 +- **[后端架构](BACKEND_ARCHITECTURE.md)** - 后端系统的详细架构设计和模块组织 +- **[前端架构](FRONTEND_ARCHITECTURE.md)** - 前端应用的架构设计、技术栈和组件体系 + +### 🔧 核心模块 + +- **[流程引擎](FLOW_ENGINE.md)** - 流程编排引擎的设计原理、执行机制和扩展能力 +- **[服务层](SERVICES.md)** - 业务服务层的设计模式、核心服务和集成方式 +- **[路由层](ROUTES.md)** - API 路由的设计规范、接口定义和中间件集成 +- **[数据模型](MODELS.md)** - 数据模型的设计原则、实体定义和关系映射 +- **[中间件](MIDDLEWARES.md)** - 中间件系统的设计理念、核心组件和使用方式 + +### 🛠️ 基础设施 + +- **[数据库](DATABASE.md)** - 数据库连接管理、查询优化和性能调优 +- **[错误处理](ERROR_HANDLING.md)** - 统一错误处理机制、错误类型定义和处理策略 +- **[响应格式](RESPONSE.md)** - API 响应格式规范、数据结构和最佳实践 +- **[工具函数](UTILS.md)** - 通用工具函数库、辅助工具和实用程序 + +### 📋 专项文档 + +- **[Redis 集成](REDIS_INTEGRATION.md)** - Redis 缓存系统的集成方案和使用指南 +- **[ID 生成分析](ID_GENERATION_ANALYSIS.md)** - 分布式 ID 生成策略和实现分析 + +### 🎯 流程编辑器 + +- **[变量节点使用](variable-node-usage.md)** - 流程变量节点的使用方法和最佳实践 +- **[固定布局演示](flow-fixed-layout-demo.md)** - 固定布局流程编辑器的演示和说明 +- **[自由布局演示](flow-free-layout-demo.md)** - 自由布局流程编辑器的功能展示 +- **[基础自由布局](flow-free-layout-base-demo.md)** - 自由布局的基础功能和操作指南 +- **[简单自由布局](flow-free-layout-simple-demo.md)** - 简化版自由布局编辑器的使用说明 +- **[自由布局 JSON](flow-free-layout-json.md)** - 自由布局的 JSON 数据结构定义 +- **[SJ 自由布局演示](flow-free-layout-sj-demo.md)** - SJ 版本自由布局编辑器的特性说明 + +## 🚀 快速开始 + +### 新手指南 + +如果你是第一次接触 UdminAI 项目,建议按以下顺序阅读文档: + +1. **[项目概览](PROJECT_OVERVIEW.md)** - 了解项目整体架构和核心概念 +2. **[后端架构](BACKEND_ARCHITECTURE.md)** - 理解后端系统的设计思路 +3. **[前端架构](FRONTEND_ARCHITECTURE.md)** - 掌握前端应用的组织结构 +4. **[流程引擎](FLOW_ENGINE.md)** - 深入了解核心的流程编排能力 + +### 开发者指南 + +对于参与开发的团队成员,重点关注以下文档: + +- **[服务层](SERVICES.md)** - 业务逻辑的实现规范 +- **[路由层](ROUTES.md)** - API 接口的设计标准 +- **[数据模型](MODELS.md)** - 数据结构的定义规则 +- **[错误处理](ERROR_HANDLING.md)** - 错误处理的统一方案 + +### 运维指南 + +对于系统运维和部署,参考以下文档: + +- **[数据库](DATABASE.md)** - 数据库的配置和优化 +- **[Redis 集成](REDIS_INTEGRATION.md)** - 缓存系统的部署和管理 +- **[中间件](MIDDLEWARES.md)** - 中间件的配置和监控 + +## 📖 文档约定 + +### 文档结构 + +每个模块文档都遵循统一的结构: + +- **概述** - 模块的基本介绍和设计目标 +- **设计原则** - 核心的设计理念和约束条件 +- **核心功能** - 主要功能特性和实现方式 +- **使用示例** - 具体的代码示例和使用方法 +- **最佳实践** - 推荐的使用模式和注意事项 +- **总结** - 模块特点和价值总结 + +### 代码示例 + +文档中的代码示例都经过验证,可以直接在项目中使用。代码遵循项目的编码规范: + +- **Rust 代码** - 遵循 Rust 2021 edition 和项目编码规范 +- **TypeScript 代码** - 遵循 TypeScript 严格模式和 ESLint 规则 +- **配置文件** - 使用 YAML/TOML 格式,保持简洁清晰 + +### 更新机制 + +文档与代码同步更新,确保文档的时效性和准确性: + +- **版本控制** - 文档与代码使用相同的版本管理 +- **持续集成** - 代码变更时自动检查文档的一致性 +- **定期审查** - 定期审查和更新文档内容 + +## 🤝 贡献指南 + +### 文档贡献 + +欢迎为项目文档做出贡献: + +1. **发现问题** - 如果发现文档中的错误或不准确之处,请提交 Issue +2. **改进建议** - 对文档结构或内容有改进建议,欢迎讨论 +3. **新增内容** - 可以补充缺失的文档或添加新的使用案例 +4. **翻译工作** - 可以帮助将文档翻译成其他语言 + +### 文档规范 + +贡献文档时请遵循以下规范: + +- **Markdown 格式** - 使用标准的 Markdown 语法 +- **中文写作** - 使用简洁明了的中文表达 +- **代码高亮** - 为代码块指定正确的语言类型 +- **链接检查** - 确保所有链接都是有效的 + +## 📞 获取帮助 + +如果在使用过程中遇到问题,可以通过以下方式获取帮助: + +- **查阅文档** - 首先查看相关的模块文档 +- **搜索 Issue** - 在项目 Issue 中搜索类似问题 +- **提交 Issue** - 如果问题未解决,请提交新的 Issue +- **社区讨论** - 参与项目的社区讨论 + +## 📄 许可证 + +本文档遵循与项目相同的许可证。详细信息请参考项目根目录的 LICENSE 文件。 + +--- + +**UdminAI 团队** +*构建智能化的流程管理平台* + +> 💡 **提示**: 建议将本文档加入浏览器书签,方便随时查阅。文档会持续更新,请关注最新版本。 \ No newline at end of file diff --git a/docs/BACKEND_ARCHITECTURE.md b/docs/BACKEND_ARCHITECTURE.md new file mode 100644 index 0000000..ff6603d --- /dev/null +++ b/docs/BACKEND_ARCHITECTURE.md @@ -0,0 +1,308 @@ +# 后端架构文档 + +## 概述 + +后端采用 Rust + Axum 构建,遵循分层架构设计,包含数据访问层、业务逻辑层、路由层和中间件层。 + +## 核心模块 + +### 1. 应用入口 (main.rs) + +**职责**: 应用启动、服务配置、中间件注册 + +**主要功能**: +- 数据库连接初始化 +- Redis 连接配置 +- CORS 跨域设置 +- 多端口服务启动 (HTTP/WebSocket/SSE) +- 日志中间件注册 + +**服务端口**: +- HTTP API: 9898 (默认) +- WebSocket: 8877 (默认) +- SSE: 8866 (默认) + +### 2. 数据库层 (db.rs) + +**职责**: 数据库连接管理和配置 + +**特性**: +- 支持多种数据库 (MySQL/PostgreSQL/SQLite) +- 连接池管理 +- 事务支持 +- 自动迁移 + +### 3. 错误处理 (error.rs) + +**职责**: 统一错误类型定义和处理 + +**错误类型**: +- `DatabaseError`: 数据库操作错误 +- `ValidationError`: 数据验证错误 +- `AuthenticationError`: 认证错误 +- `AuthorizationError`: 授权错误 +- `NotFoundError`: 资源不存在 +- `InternalServerError`: 内部服务器错误 + +### 4. 响应格式 (response.rs) + +**职责**: 统一 API 响应格式 + +**响应结构**: +```rust +pub struct ApiResponse { + pub code: i32, + pub message: String, + pub data: Option, +} + +pub struct PageResponse { + pub items: Vec, + pub total: u64, + pub page: u64, + pub page_size: u64, +} +``` + +## 业务模块 + +### Models (数据模型层) + +**位置**: `src/models/` + +**模型列表**: +- `user.rs`: 用户模型 +- `role.rs`: 角色模型 +- `menu.rs`: 菜单模型 +- `department.rs`: 部门模型 +- `position.rs`: 职位模型 +- `flow.rs`: 流程模型 +- `schedule_job.rs`: 定时任务模型 +- `flow_run_log.rs`: 流程运行日志 +- `request_log.rs`: 请求日志 +- `refresh_token.rs`: 刷新令牌 + +**特性**: +- 使用 SeaORM 宏自动生成 +- 支持关联查询 +- 自动时间戳管理 +- 软删除支持 + +### Services (业务逻辑层) + +**位置**: `src/services/` + +**服务列表**: +- `auth_service.rs`: 认证服务 +- `user_service.rs`: 用户管理服务 +- `role_service.rs`: 角色管理服务 +- `menu_service.rs`: 菜单管理服务 +- `department_service.rs`: 部门管理服务 +- `position_service.rs`: 职位管理服务 +- `flow_service.rs`: 流程管理服务 +- `schedule_job_service.rs`: 定时任务服务 +- `flow_run_log_service.rs`: 流程日志服务 +- `log_service.rs`: 系统日志服务 + +**设计原则**: +- 单一职责原则 +- 依赖注入 +- 异步处理 +- 事务管理 + +### Routes (路由层) + +**位置**: `src/routes/` + +**路由模块**: +- `auth.rs`: 认证相关路由 +- `users.rs`: 用户管理路由 +- `roles.rs`: 角色管理路由 +- `menus.rs`: 菜单管理路由 +- `departments.rs`: 部门管理路由 +- `positions.rs`: 职位管理路由 +- `flows.rs`: 流程管理路由 +- `schedule_jobs.rs`: 定时任务路由 +- `flow_run_logs.rs`: 流程日志路由 +- `logs.rs`: 系统日志路由 +- `dynamic_api.rs`: 动态 API 路由 + +**路由特性**: +- RESTful API 设计 +- 参数验证 +- 权限检查 +- 分页支持 +- 错误处理 + +### Middlewares (中间件层) + +**位置**: `src/middlewares/` + +**中间件列表**: +- `jwt.rs`: JWT 认证中间件 +- `logging.rs`: 请求日志中间件 +- `http_client.rs`: HTTP 客户端中间件 +- `ws.rs`: WebSocket 服务中间件 +- `sse.rs`: SSE 服务中间件 + +**功能特性**: +- 请求/响应拦截 +- 认证授权 +- 日志记录 +- 跨域处理 +- 实时通信 + +### Utils (工具模块) + +**位置**: `src/utils/` + +**工具列表**: +- `ids.rs`: ID 生成器 (Snowflake 算法) +- `password.rs`: 密码哈希工具 +- `scheduler.rs`: 任务调度器 + +**特性**: +- 分布式 ID 生成 +- 安全密码处理 +- 定时任务管理 + +## 流程引擎 + +**位置**: `src/flow/` + +### 核心组件 + +#### 1. 领域模型 (domain.rs) +- `ChainDef`: 流程链定义 +- `NodeDef`: 节点定义 +- `LinkDef`: 连接定义 +- `NodeKind`: 节点类型枚举 + +#### 2. DSL 解析 (dsl.rs) +- `FlowDSL`: 流程 DSL 结构 +- `NodeDSL`: 节点 DSL 结构 +- `DesignSyntax`: 设计语法结构 +- 校验和构建函数 + +#### 3. 执行引擎 (engine.rs) +- `FlowEngine`: 流程执行引擎 +- `TaskRegistry`: 任务注册表 +- `DriveOptions`: 执行选项 +- 并发执行支持 + +#### 4. 执行器 (executors/) +- `http.rs`: HTTP 请求执行器 +- `db.rs`: 数据库操作执行器 +- `condition.rs`: 条件判断执行器 +- `script_js.rs`: JavaScript 脚本执行器 +- `script_python.rs`: Python 脚本执行器 +- `script_rhai.rs`: Rhai 脚本执行器 +- `variable.rs`: 变量操作执行器 + +#### 5. 上下文管理 (context.rs) +- `StreamEvent`: 流事件定义 +- 执行上下文管理 +- 事件流处理 + +#### 6. 日志处理 (log_handler.rs) +- 流程执行日志 +- 错误日志记录 +- 性能监控 + +## 数据库设计 + +### 核心表结构 + +#### 用户权限相关 +- `users`: 用户表 +- `roles`: 角色表 +- `menus`: 菜单表 +- `departments`: 部门表 +- `positions`: 职位表 +- `user_roles`: 用户角色关联表 +- `role_menus`: 角色菜单关联表 +- `user_departments`: 用户部门关联表 +- `user_positions`: 用户职位关联表 + +#### 流程相关 +- `flows`: 流程表 +- `flow_run_logs`: 流程运行日志表 +- `schedule_jobs`: 定时任务表 + +#### 系统相关 +- `request_logs`: 请求日志表 +- `refresh_tokens`: 刷新令牌表 + +### 索引策略 +- 主键索引 +- 外键索引 +- 查询优化索引 +- 复合索引 + +## 安全机制 + +### 认证授权 +- JWT 令牌机制 +- 刷新令牌支持 +- 权限中间件验证 +- 角色基础访问控制 (RBAC) + +### 数据安全 +- Argon2 密码哈希 +- SQL 注入防护 +- XSS 防护 +- CSRF 防护 + +### 通信安全 +- HTTPS 支持 +- CORS 配置 +- 请求限流 +- 日志审计 + +## 性能优化 + +### 数据库优化 +- 连接池管理 +- 查询优化 +- 索引策略 +- 分页查询 + +### 缓存策略 +- Redis 缓存 +- 查询结果缓存 +- 会话缓存 + +### 并发处理 +- 异步 I/O +- 任务队列 +- 连接复用 +- 资源池管理 + +## 监控和日志 + +### 日志系统 +- 结构化日志 (tracing) +- 分级日志记录 +- 请求链路追踪 +- 错误堆栈记录 + +### 监控指标 +- 请求响应时间 +- 数据库连接状态 +- 内存使用情况 +- 错误率统计 + +## 部署配置 + +### 环境变量 +- 数据库连接配置 +- Redis 连接配置 +- JWT 密钥配置 +- 服务端口配置 +- CORS 配置 + +### 容器化 +- Docker 支持 +- 多阶段构建 +- 健康检查 +- 资源限制 \ No newline at end of file diff --git a/docs/DATABASE.md b/docs/DATABASE.md new file mode 100644 index 0000000..c3ed331 --- /dev/null +++ b/docs/DATABASE.md @@ -0,0 +1,1056 @@ +# UdminAI 数据库模块文档 + +## 概述 + +UdminAI 项目的数据库模块基于 SeaORM 框架构建,提供了完整的数据库抽象层和 ORM 功能。该模块负责数据库连接管理、事务处理、查询构建、迁移管理等核心功能,为整个系统提供可靠的数据持久化支持。 + +## 技术架构 + +### 核心组件 + +- **SeaORM**: 现代化的 Rust ORM 框架 +- **SQLx**: 异步 SQL 工具包 +- **PostgreSQL**: 主数据库(支持 MySQL、SQLite) +- **Redis**: 缓存和会话存储 +- **Migration**: 数据库版本管理 + +### 设计原则 + +- **类型安全**: 编译时查询验证 +- **异步优先**: 全异步数据库操作 +- **连接池**: 高效的连接管理 +- **事务支持**: ACID 事务保证 +- **迁移管理**: 版本化数据库结构 + +## 模块结构 + +``` +backend/src/ +├── db.rs # 数据库连接和配置 +├── redis.rs # Redis 连接和操作 +└── models/ # 数据模型定义 + ├── mod.rs + ├── user.rs + ├── role.rs + ├── permission.rs + ├── flow.rs + ├── schedule_job.rs + └── ... +``` + +## 数据库连接 (db.rs) + +### 功能特性 + +- 数据库连接池管理 +- 多数据库支持 +- 连接健康检查 +- 自动重连机制 +- 性能监控 + +### 实现代码 + +```rust +use sea_orm::{ + ConnectOptions, Database, DatabaseConnection, DbErr, TransactionTrait, +}; +use std::time::Duration; +use tracing::{error, info, warn}; + +/// 数据库配置 +#[derive(Debug, Clone)] +pub struct DatabaseConfig { + pub url: String, + pub max_connections: u32, + pub min_connections: u32, + pub connect_timeout: Duration, + pub idle_timeout: Duration, + pub max_lifetime: Duration, + pub sqlx_logging: bool, + pub sqlx_logging_level: tracing::Level, +} + +impl Default for DatabaseConfig { + fn default() -> Self { + Self { + url: "postgresql://localhost/udmin_ai".to_string(), + max_connections: 100, + min_connections: 5, + connect_timeout: Duration::from_secs(8), + idle_timeout: Duration::from_secs(600), + max_lifetime: Duration::from_secs(3600), + sqlx_logging: true, + sqlx_logging_level: tracing::Level::INFO, + } + } +} + +/// 数据库连接管理器 +#[derive(Debug, Clone)] +pub struct DatabaseManager { + connection: DatabaseConnection, + config: DatabaseConfig, +} + +impl DatabaseManager { + /// 创建新的数据库管理器 + pub async fn new(config: DatabaseConfig) -> Result { + info!(target = "udmin", url = %config.url, "database.connect.starting"); + + let mut opt = ConnectOptions::new(&config.url); + opt.max_connections(config.max_connections) + .min_connections(config.min_connections) + .connect_timeout(config.connect_timeout) + .idle_timeout(config.idle_timeout) + .max_lifetime(config.max_lifetime) + .sqlx_logging(config.sqlx_logging) + .sqlx_logging_level(config.sqlx_logging_level); + + let connection = Database::connect(opt).await?; + + info!(target = "udmin", "database.connect.success"); + + Ok(Self { connection, config }) + } + + /// 从环境变量创建数据库管理器 + pub async fn from_env() -> Result { + let config = DatabaseConfig { + url: std::env::var("DATABASE_URL") + .unwrap_or_else(|_| "postgresql://localhost/udmin_ai".to_string()), + max_connections: std::env::var("DB_MAX_CONNECTIONS") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(100), + min_connections: std::env::var("DB_MIN_CONNECTIONS") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(5), + connect_timeout: Duration::from_secs( + std::env::var("DB_CONNECT_TIMEOUT") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(8), + ), + idle_timeout: Duration::from_secs( + std::env::var("DB_IDLE_TIMEOUT") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(600), + ), + max_lifetime: Duration::from_secs( + std::env::var("DB_MAX_LIFETIME") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(3600), + ), + sqlx_logging: std::env::var("DB_SQLX_LOGGING") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(true), + sqlx_logging_level: match std::env::var("DB_SQLX_LOGGING_LEVEL") + .unwrap_or_else(|_| "info".to_string()) + .to_lowercase() + .as_str() + { + "trace" => tracing::Level::TRACE, + "debug" => tracing::Level::DEBUG, + "info" => tracing::Level::INFO, + "warn" => tracing::Level::WARN, + "error" => tracing::Level::ERROR, + _ => tracing::Level::INFO, + }, + }; + + Self::new(config).await + } + + /// 获取数据库连接 + pub fn connection(&self) -> &DatabaseConnection { + &self.connection + } + + /// 检查数据库连接健康状态 + pub async fn health_check(&self) -> Result<(), DbErr> { + use sea_orm::Statement; + + let backend = self.connection.get_database_backend(); + let stmt = Statement::from_string(backend, "SELECT 1".to_string()); + + match self.connection.execute(stmt).await { + Ok(_) => { + info!(target = "udmin", "database.health_check.success"); + Ok(()) + } + Err(e) => { + error!(target = "udmin", error = %e, "database.health_check.failed"); + Err(e) + } + } + } + + /// 获取连接池统计信息 + pub async fn pool_stats(&self) -> DatabasePoolStats { + // 注意:SeaORM 目前不直接暴露连接池统计信息 + // 这里提供一个接口,实际实现可能需要通过其他方式获取 + DatabasePoolStats { + active_connections: 0, // 需要通过底层 SQLx 获取 + idle_connections: 0, + total_connections: 0, + max_connections: self.config.max_connections, + } + } + + /// 执行事务 + pub async fn transaction(&self, f: F) -> Result + where + F: for<'c> FnOnce(&'c DatabaseConnection) -> futures::future::BoxFuture<'c, Result> + + Send, + E: From, + R: Send, + { + let txn = self.connection.begin().await.map_err(E::from)?; + + match f(&txn).await { + Ok(result) => { + txn.commit().await.map_err(E::from)?; + info!(target = "udmin", "database.transaction.committed"); + Ok(result) + } + Err(e) => { + if let Err(rollback_err) = txn.rollback().await { + error!(target = "udmin", error = %rollback_err, "database.transaction.rollback_failed"); + } + warn!(target = "udmin", "database.transaction.rolled_back"); + Err(e) + } + } + } + + /// 关闭数据库连接 + pub async fn close(&self) -> Result<(), DbErr> { + info!(target = "udmin", "database.connection.closing"); + self.connection.close().await?; + info!(target = "udmin", "database.connection.closed"); + Ok(()) + } +} + +/// 数据库连接池统计信息 +#[derive(Debug, Clone)] +pub struct DatabasePoolStats { + pub active_connections: u32, + pub idle_connections: u32, + pub total_connections: u32, + pub max_connections: u32, +} + +/// 数据库迁移管理 +pub struct MigrationManager { + connection: DatabaseConnection, +} + +impl MigrationManager { + pub fn new(connection: DatabaseConnection) -> Self { + Self { connection } + } + + /// 运行所有待执行的迁移 + pub async fn migrate_up(&self) -> Result<(), DbErr> { + info!(target = "udmin", "database.migration.starting"); + + // 这里需要根据实际的迁移框架实现 + // 例如使用 sea-orm-migration + + info!(target = "udmin", "database.migration.completed"); + Ok(()) + } + + /// 回滚迁移 + pub async fn migrate_down(&self, steps: Option) -> Result<(), DbErr> { + let steps = steps.unwrap_or(1); + info!(target = "udmin", steps = %steps, "database.migration.rollback.starting"); + + // 实现迁移回滚逻辑 + + info!(target = "udmin", steps = %steps, "database.migration.rollback.completed"); + Ok(()) + } + + /// 获取迁移状态 + pub async fn migration_status(&self) -> Result, DbErr> { + // 返回迁移状态信息 + Ok(vec![]) + } +} + +/// 迁移信息 +#[derive(Debug, Clone)] +pub struct MigrationInfo { + pub version: String, + pub name: String, + pub applied_at: Option>, + pub is_applied: bool, +} + +/// 数据库查询构建器辅助函数 +pub mod query_builder { + use sea_orm::{ + ColumnTrait, Condition, EntityTrait, PaginatorTrait, QueryFilter, QueryOrder, QuerySelect, + Select, + }; + + /// 分页查询构建器 + pub struct PaginationBuilder { + select: Select, + page: u64, + page_size: u64, + } + + impl PaginationBuilder { + pub fn new(select: Select) -> Self { + Self { + select, + page: 1, + page_size: 20, + } + } + + pub fn page(mut self, page: u64) -> Self { + self.page = page.max(1); + self + } + + pub fn page_size(mut self, page_size: u64) -> Self { + self.page_size = page_size.clamp(1, 100); + self + } + + pub fn build(self) -> Select { + let offset = (self.page - 1) * self.page_size; + self.select.limit(self.page_size).offset(offset) + } + } + + /// 条件构建器 + pub struct ConditionBuilder { + condition: Condition, + } + + impl ConditionBuilder { + pub fn new() -> Self { + Self { + condition: Condition::all(), + } + } + + pub fn add(mut self, column_condition: C) -> Self + where + C: Into, + { + self.condition = self.condition.add(column_condition); + self + } + + pub fn add_option(mut self, condition: Option) -> Self + where + C: Into, + { + if let Some(cond) = condition { + self.condition = self.condition.add(cond); + } + self + } + + pub fn build(self) -> Condition { + self.condition + } + } + + impl Default for ConditionBuilder { + fn default() -> Self { + Self::new() + } + } +} + +/// 数据库错误处理 +pub mod error_handler { + use sea_orm::DbErr; + use crate::error::AppError; + + /// 将数据库错误转换为应用错误 + pub fn handle_db_error(err: DbErr) -> AppError { + match err { + DbErr::RecordNotFound(_) => AppError::NotFound("记录不存在".to_string()), + DbErr::Custom(msg) => AppError::DatabaseError(msg), + DbErr::Conn(msg) => AppError::DatabaseError(format!("连接错误: {}", msg)), + DbErr::Exec(msg) => AppError::DatabaseError(format!("执行错误: {}", msg)), + DbErr::Query(msg) => AppError::DatabaseError(format!("查询错误: {}", msg)), + _ => AppError::DatabaseError("未知数据库错误".to_string()), + } + } + + /// 检查是否为唯一约束违反错误 + pub fn is_unique_violation(err: &DbErr) -> bool { + match err { + DbErr::Exec(msg) | DbErr::Query(msg) => { + msg.contains("unique constraint") || msg.contains("duplicate key") + } + _ => false, + } + } + + /// 检查是否为外键约束违反错误 + pub fn is_foreign_key_violation(err: &DbErr) -> bool { + match err { + DbErr::Exec(msg) | DbErr::Query(msg) => { + msg.contains("foreign key constraint") + } + _ => false, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tokio_test; + + #[tokio::test] + async fn test_database_config() { + let config = DatabaseConfig::default(); + assert_eq!(config.max_connections, 100); + assert_eq!(config.min_connections, 5); + } + + #[tokio::test] + async fn test_condition_builder() { + use query_builder::ConditionBuilder; + use sea_orm::{ColumnTrait, Condition}; + + let condition = ConditionBuilder::new() + .add_option(Some(Condition::all())) + .build(); + + // 验证条件构建 + assert!(!format!("{:?}", condition).is_empty()); + } + + #[test] + fn test_error_handling() { + use error_handler::*; + + let db_err = DbErr::RecordNotFound("test".to_string()); + let app_err = handle_db_error(db_err); + + match app_err { + AppError::NotFound(_) => assert!(true), + _ => assert!(false, "Expected NotFound error"), + } + } +} +``` + +## Redis 连接 (redis.rs) + +### 功能特性 + +- Redis 连接池管理 +- 缓存操作封装 +- 会话存储 +- 分布式锁 +- 发布订阅 + +### 实现代码 + +```rust +use redis::{ + aio::ConnectionManager, Client, RedisError, RedisResult, AsyncCommands, +}; +use serde::{Deserialize, Serialize}; +use std::time::Duration; +use tracing::{error, info, warn}; +use tokio::time::timeout; + +/// Redis 配置 +#[derive(Debug, Clone)] +pub struct RedisConfig { + pub url: String, + pub max_connections: u32, + pub connect_timeout: Duration, + pub command_timeout: Duration, + pub retry_attempts: u32, + pub retry_delay: Duration, +} + +impl Default for RedisConfig { + fn default() -> Self { + Self { + url: "redis://localhost:6379".to_string(), + max_connections: 50, + connect_timeout: Duration::from_secs(5), + command_timeout: Duration::from_secs(10), + retry_attempts: 3, + retry_delay: Duration::from_millis(100), + } + } +} + +/// Redis 连接管理器 +#[derive(Debug, Clone)] +pub struct RedisManager { + connection: ConnectionManager, + config: RedisConfig, +} + +impl RedisManager { + /// 创建新的 Redis 管理器 + pub async fn new(config: RedisConfig) -> RedisResult { + info!(target = "udmin", url = %config.url, "redis.connect.starting"); + + let client = Client::open(config.url.clone())?; + let connection = timeout( + config.connect_timeout, + ConnectionManager::new(client), + ) + .await + .map_err(|_| RedisError::from((redis::ErrorKind::IoError, "连接超时")))?? + ; + + info!(target = "udmin", "redis.connect.success"); + + Ok(Self { connection, config }) + } + + /// 从环境变量创建 Redis 管理器 + pub async fn from_env() -> RedisResult { + let config = RedisConfig { + url: std::env::var("REDIS_URL") + .unwrap_or_else(|_| "redis://localhost:6379".to_string()), + max_connections: std::env::var("REDIS_MAX_CONNECTIONS") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(50), + connect_timeout: Duration::from_secs( + std::env::var("REDIS_CONNECT_TIMEOUT") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(5), + ), + command_timeout: Duration::from_secs( + std::env::var("REDIS_COMMAND_TIMEOUT") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(10), + ), + retry_attempts: std::env::var("REDIS_RETRY_ATTEMPTS") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(3), + retry_delay: Duration::from_millis( + std::env::var("REDIS_RETRY_DELAY_MS") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(100), + ), + }; + + Self::new(config).await + } + + /// 执行带超时的 Redis 命令 + async fn execute_with_timeout(&self, f: F) -> RedisResult + where + F: futures::Future>, + { + timeout(self.config.command_timeout, f) + .await + .map_err(|_| RedisError::from((redis::ErrorKind::IoError, "命令执行超时")))?? + } + + /// 设置键值对 + pub async fn set(&self, key: K, value: V) -> RedisResult<()> + where + K: redis::ToRedisArgs + Send + Sync, + V: redis::ToRedisArgs + Send + Sync, + { + let mut conn = self.connection.clone(); + self.execute_with_timeout(async move { + conn.set(key, value).await + }).await + } + + /// 设置键值对(带过期时间) + pub async fn setex(&self, key: K, value: V, seconds: usize) -> RedisResult<()> + where + K: redis::ToRedisArgs + Send + Sync, + V: redis::ToRedisArgs + Send + Sync, + { + let mut conn = self.connection.clone(); + self.execute_with_timeout(async move { + conn.setex(key, seconds, value).await + }).await + } + + /// 获取键值 + pub async fn get(&self, key: K) -> RedisResult> + where + K: redis::ToRedisArgs + Send + Sync, + V: redis::FromRedisValue, + { + let mut conn = self.connection.clone(); + self.execute_with_timeout(async move { + conn.get(key).await + }).await + } + + /// 删除键 + pub async fn del(&self, key: K) -> RedisResult + where + K: redis::ToRedisArgs + Send + Sync, + { + let mut conn = self.connection.clone(); + let result: i32 = self.execute_with_timeout(async move { + conn.del(key).await + }).await?; + Ok(result > 0) + } + + /// 检查键是否存在 + pub async fn exists(&self, key: K) -> RedisResult + where + K: redis::ToRedisArgs + Send + Sync, + { + let mut conn = self.connection.clone(); + self.execute_with_timeout(async move { + conn.exists(key).await + }).await + } + + /// 设置键的过期时间 + pub async fn expire(&self, key: K, seconds: usize) -> RedisResult + where + K: redis::ToRedisArgs + Send + Sync, + { + let mut conn = self.connection.clone(); + self.execute_with_timeout(async move { + conn.expire(key, seconds).await + }).await + } + + /// 获取键的剩余过期时间 + pub async fn ttl(&self, key: K) -> RedisResult + where + K: redis::ToRedisArgs + Send + Sync, + { + let mut conn = self.connection.clone(); + self.execute_with_timeout(async move { + conn.ttl(key).await + }).await + } + + /// 原子递增 + pub async fn incr(&self, key: K, delta: i64) -> RedisResult + where + K: redis::ToRedisArgs + Send + Sync, + { + let mut conn = self.connection.clone(); + self.execute_with_timeout(async move { + conn.incr(key, delta).await + }).await + } + + /// 原子递减 + pub async fn decr(&self, key: K, delta: i64) -> RedisResult + where + K: redis::ToRedisArgs + Send + Sync, + { + let mut conn = self.connection.clone(); + self.execute_with_timeout(async move { + conn.decr(key, delta).await + }).await + } + + /// 列表左推 + pub async fn lpush(&self, key: K, value: V) -> RedisResult + where + K: redis::ToRedisArgs + Send + Sync, + V: redis::ToRedisArgs + Send + Sync, + { + let mut conn = self.connection.clone(); + self.execute_with_timeout(async move { + conn.lpush(key, value).await + }).await + } + + /// 列表右弹 + pub async fn rpop(&self, key: K) -> RedisResult> + where + K: redis::ToRedisArgs + Send + Sync, + V: redis::FromRedisValue, + { + let mut conn = self.connection.clone(); + self.execute_with_timeout(async move { + conn.rpop(key, None).await + }).await + } + + /// 获取列表长度 + pub async fn llen(&self, key: K) -> RedisResult + where + K: redis::ToRedisArgs + Send + Sync, + { + let mut conn = self.connection.clone(); + self.execute_with_timeout(async move { + conn.llen(key).await + }).await + } + + /// 哈希设置字段 + pub async fn hset(&self, key: K, field: F, value: V) -> RedisResult + where + K: redis::ToRedisArgs + Send + Sync, + F: redis::ToRedisArgs + Send + Sync, + V: redis::ToRedisArgs + Send + Sync, + { + let mut conn = self.connection.clone(); + self.execute_with_timeout(async move { + conn.hset(key, field, value).await + }).await + } + + /// 哈希获取字段 + pub async fn hget(&self, key: K, field: F) -> RedisResult> + where + K: redis::ToRedisArgs + Send + Sync, + F: redis::ToRedisArgs + Send + Sync, + V: redis::FromRedisValue, + { + let mut conn = self.connection.clone(); + self.execute_with_timeout(async move { + conn.hget(key, field).await + }).await + } + + /// 哈希删除字段 + pub async fn hdel(&self, key: K, field: F) -> RedisResult + where + K: redis::ToRedisArgs + Send + Sync, + F: redis::ToRedisArgs + Send + Sync, + { + let mut conn = self.connection.clone(); + let result: i32 = self.execute_with_timeout(async move { + conn.hdel(key, field).await + }).await?; + Ok(result > 0) + } + + /// 检查 Redis 连接健康状态 + pub async fn health_check(&self) -> RedisResult<()> { + let mut conn = self.connection.clone(); + match self.execute_with_timeout(async move { + redis::cmd("PING").query_async(&mut conn).await + }).await { + Ok(redis::Value::Status(status)) if status == "PONG" => { + info!(target = "udmin", "redis.health_check.success"); + Ok(()) + } + Ok(_) => { + error!(target = "udmin", "redis.health_check.unexpected_response"); + Err(RedisError::from((redis::ErrorKind::ResponseError, "意外的响应"))) + } + Err(e) => { + error!(target = "udmin", error = %e, "redis.health_check.failed"); + Err(e) + } + } + } +} + +/// Redis 缓存操作封装 +pub struct RedisCache { + manager: RedisManager, + key_prefix: String, + default_ttl: Duration, +} + +impl RedisCache { + pub fn new(manager: RedisManager, key_prefix: String, default_ttl: Duration) -> Self { + Self { + manager, + key_prefix, + default_ttl, + } + } + + /// 构建完整的缓存键 + fn build_key(&self, key: &str) -> String { + format!("{}:{}", self.key_prefix, key) + } + + /// 设置缓存(JSON 序列化) + pub async fn set_json(&self, key: &str, value: &T) -> RedisResult<()> + where + T: Serialize, + { + let json_value = serde_json::to_string(value) + .map_err(|e| RedisError::from((redis::ErrorKind::TypeError, e.to_string())))?; + + let full_key = self.build_key(key); + self.manager.setex(full_key, json_value, self.default_ttl.as_secs() as usize).await + } + + /// 获取缓存(JSON 反序列化) + pub async fn get_json(&self, key: &str) -> RedisResult> + where + T: for<'de> Deserialize<'de>, + { + let full_key = self.build_key(key); + let json_value: Option = self.manager.get(full_key).await?; + + match json_value { + Some(json) => { + let value = serde_json::from_str(&json) + .map_err(|e| RedisError::from((redis::ErrorKind::TypeError, e.to_string())))?; + Ok(Some(value)) + } + None => Ok(None), + } + } + + /// 删除缓存 + pub async fn delete(&self, key: &str) -> RedisResult { + let full_key = self.build_key(key); + self.manager.del(full_key).await + } + + /// 检查缓存是否存在 + pub async fn exists(&self, key: &str) -> RedisResult { + let full_key = self.build_key(key); + self.manager.exists(full_key).await + } + + /// 设置缓存过期时间 + pub async fn expire(&self, key: &str, ttl: Duration) -> RedisResult { + let full_key = self.build_key(key); + self.manager.expire(full_key, ttl.as_secs() as usize).await + } + + /// 获取或设置缓存(缓存穿透保护) + pub async fn get_or_set(&self, key: &str, fetcher: F) -> RedisResult + where + T: Serialize + for<'de> Deserialize<'de> + Clone, + F: FnOnce() -> Fut, + Fut: futures::Future>>, + { + // 先尝试从缓存获取 + if let Some(cached_value) = self.get_json::(key).await? { + return Ok(cached_value); + } + + // 缓存未命中,调用 fetcher 获取数据 + let value = fetcher().await + .map_err(|e| RedisError::from((redis::ErrorKind::TypeError, e.to_string())))?; + + // 将数据存入缓存 + if let Err(e) = self.set_json(key, &value).await { + warn!(target = "udmin", key = %key, error = %e, "redis.cache.set_failed"); + } + + Ok(value) + } +} + +/// 分布式锁 +pub struct DistributedLock { + manager: RedisManager, + key: String, + value: String, + ttl: Duration, +} + +impl DistributedLock { + /// 尝试获取锁 + pub async fn acquire( + manager: RedisManager, + key: String, + ttl: Duration, + ) -> RedisResult> { + let value = uuid::Uuid::new_v4().to_string(); + let lock_key = format!("lock:{}", key); + + let mut conn = manager.connection.clone(); + let result: Option = redis::cmd("SET") + .arg(&lock_key) + .arg(&value) + .arg("EX") + .arg(ttl.as_secs()) + .arg("NX") + .query_async(&mut conn) + .await?; + + if result.is_some() { + info!(target = "udmin", key = %key, "distributed_lock.acquired"); + Ok(Some(Self { + manager, + key: lock_key, + value, + ttl, + })) + } else { + Ok(None) + } + } + + /// 释放锁 + pub async fn release(self) -> RedisResult { + let script = r#" + if redis.call("GET", KEYS[1]) == ARGV[1] then + return redis.call("DEL", KEYS[1]) + else + return 0 + end + "#; + + let mut conn = self.manager.connection.clone(); + let result: i32 = redis::Script::new(script) + .key(&self.key) + .arg(&self.value) + .invoke_async(&mut conn) + .await?; + + let released = result > 0; + if released { + info!(target = "udmin", key = %self.key, "distributed_lock.released"); + } + Ok(released) + } + + /// 续期锁 + pub async fn renew(&self) -> RedisResult { + let script = r#" + if redis.call("GET", KEYS[1]) == ARGV[1] then + return redis.call("EXPIRE", KEYS[1], ARGV[2]) + else + return 0 + end + "#; + + let mut conn = self.manager.connection.clone(); + let result: i32 = redis::Script::new(script) + .key(&self.key) + .arg(&self.value) + .arg(self.ttl.as_secs()) + .invoke_async(&mut conn) + .await?; + + Ok(result > 0) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tokio_test; + + #[tokio::test] + async fn test_redis_config() { + let config = RedisConfig::default(); + assert_eq!(config.max_connections, 50); + assert_eq!(config.connect_timeout, Duration::from_secs(5)); + } + + #[test] + fn test_cache_key_building() { + let manager = RedisManager { + connection: todo!(), // 在实际测试中需要模拟 + config: RedisConfig::default(), + }; + let cache = RedisCache::new(manager, "test".to_string(), Duration::from_secs(300)); + + assert_eq!(cache.build_key("user:123"), "test:user:123"); + } +} +``` + +## 性能优化 + +### 连接池优化 + +- **动态连接池**: 根据负载自动调整连接数 +- **连接复用**: 最大化连接利用率 +- **健康检查**: 定期检查连接状态 +- **超时控制**: 防止连接泄漏 + +### 查询优化 + +- **索引优化**: 合理设计数据库索引 +- **查询缓存**: Redis 缓存热点数据 +- **批量操作**: 减少数据库往返次数 +- **分页优化**: 高效的分页查询 + +### 缓存策略 + +- **多级缓存**: 内存 + Redis 缓存 +- **缓存预热**: 系统启动时预加载热点数据 +- **缓存穿透保护**: 防止恶意查询 +- **缓存雪崩保护**: 错开缓存过期时间 + +## 监控和日志 + +### 性能监控 + +- **连接池监控**: 活跃连接数、等待队列长度 +- **查询性能**: 慢查询日志、执行时间统计 +- **缓存命中率**: Redis 缓存效果监控 +- **错误率监控**: 数据库错误统计 + +### 日志记录 + +- **操作日志**: 记录所有数据库操作 +- **性能日志**: 记录查询执行时间 +- **错误日志**: 详细的错误信息和堆栈 +- **审计日志**: 敏感操作的审计记录 + +## 最佳实践 + +### 数据库设计 + +1. **规范化设计**: 避免数据冗余 +2. **索引策略**: 合理创建和维护索引 +3. **数据类型**: 选择合适的数据类型 +4. **约束设计**: 充分利用数据库约束 + +### 查询优化 + +1. **避免 N+1 查询**: 使用 JOIN 或预加载 +2. **分页查询**: 使用 LIMIT 和 OFFSET +3. **条件过滤**: 在数据库层面进行过滤 +4. **批量操作**: 合并多个操作为批量操作 + +### 事务管理 + +1. **事务边界**: 明确事务的开始和结束 +2. **隔离级别**: 根据需求选择合适的隔离级别 +3. **死锁处理**: 设计避免死锁的策略 +4. **回滚策略**: 合理的错误处理和回滚 + +### 缓存使用 + +1. **缓存键设计**: 有意义且唯一的缓存键 +2. **过期策略**: 合理设置缓存过期时间 +3. **缓存更新**: 及时更新或删除过期缓存 +4. **缓存预热**: 系统启动时预加载重要数据 + +## 总结 + +UdminAI 的数据库模块提供了完整的数据持久化解决方案,具有以下特点: + +- **高性能**: 连接池管理和查询优化 +- **高可用**: 健康检查和自动重连 +- **类型安全**: SeaORM 提供的编译时检查 +- **易维护**: 清晰的模块结构和错误处理 +- **可扩展**: 支持多种数据库和缓存策略 + +通过合理的架构设计和最佳实践,确保了系统的稳定性、性能和可维护性。 \ No newline at end of file diff --git a/docs/ERROR_HANDLING.md b/docs/ERROR_HANDLING.md new file mode 100644 index 0000000..0f835af --- /dev/null +++ b/docs/ERROR_HANDLING.md @@ -0,0 +1,878 @@ +# UdminAI 错误处理模块文档 + +## 概述 + +UdminAI 项目的错误处理模块提供了统一的错误定义、处理和响应机制。该模块基于 Rust 的 `Result` 类型和 `thiserror` 库构建,确保错误信息的一致性、可追踪性和用户友好性。 + +## 设计原则 + +### 核心理念 + +- **统一性**: 所有模块使用统一的错误类型 +- **可追踪性**: 错误包含足够的上下文信息 +- **用户友好**: 面向用户的错误消息清晰易懂 +- **开发友好**: 面向开发者的错误信息详细准确 +- **类型安全**: 编译时错误类型检查 + +### 错误分层 + +1. **应用层错误**: 业务逻辑错误 +2. **服务层错误**: 服务调用错误 +3. **数据层错误**: 数据库和缓存错误 +4. **网络层错误**: HTTP 和网络通信错误 +5. **系统层错误**: 系统资源和配置错误 + +## 错误类型定义 (error.rs) + +### 主要错误类型 + +```rust +use axum::{ + http::StatusCode, + response::{IntoResponse, Response}, + Json, +}; +use serde::{Deserialize, Serialize}; +use std::fmt; +use thiserror::Error; +use tracing::error; + +/// 应用主错误类型 +#[derive(Error, Debug, Clone, Serialize, Deserialize)] +pub enum AppError { + // 认证和授权错误 + #[error("认证失败: {message}")] + AuthenticationFailed { message: String }, + + #[error("授权失败: {message}")] + AuthorizationFailed { message: String }, + + #[error("令牌无效: {message}")] + InvalidToken { message: String }, + + #[error("令牌已过期")] + TokenExpired, + + // 验证错误 + #[error("验证失败: {field} - {message}")] + ValidationFailed { field: String, message: String }, + + #[error("请求参数无效: {message}")] + InvalidRequest { message: String }, + + #[error("必需字段缺失: {field}")] + MissingField { field: String }, + + // 资源错误 + #[error("资源未找到: {resource}")] + NotFound(String), + + #[error("资源已存在: {resource}")] + AlreadyExists(String), + + #[error("资源冲突: {message}")] + Conflict { message: String }, + + // 业务逻辑错误 + #[error("业务规则违反: {message}")] + BusinessRuleViolation { message: String }, + + #[error("操作不被允许: {message}")] + OperationNotAllowed { message: String }, + + #[error("状态无效: 当前状态 {current}, 期望状态 {expected}")] + InvalidState { current: String, expected: String }, + + // 数据库错误 + #[error("数据库错误: {0}")] + DatabaseError(String), + + #[error("数据库连接失败: {message}")] + DatabaseConnectionFailed { message: String }, + + #[error("事务失败: {message}")] + TransactionFailed { message: String }, + + // 缓存错误 + #[error("缓存错误: {0}")] + CacheError(String), + + #[error("缓存连接失败: {message}")] + CacheConnectionFailed { message: String }, + + // 网络和外部服务错误 + #[error("网络错误: {message}")] + NetworkError { message: String }, + + #[error("外部服务错误: {service} - {message}")] + ExternalServiceError { service: String, message: String }, + + #[error("HTTP 请求失败: {status} - {message}")] + HttpRequestFailed { status: u16, message: String }, + + // 文件和 I/O 错误 + #[error("文件操作失败: {message}")] + FileOperationFailed { message: String }, + + #[error("文件未找到: {path}")] + FileNotFound { path: String }, + + #[error("文件权限不足: {path}")] + FilePermissionDenied { path: String }, + + // 配置和环境错误 + #[error("配置错误: {message}")] + ConfigurationError { message: String }, + + #[error("环境变量缺失: {variable}")] + MissingEnvironmentVariable { variable: String }, + + // 序列化和反序列化错误 + #[error("序列化失败: {message}")] + SerializationFailed { message: String }, + + #[error("反序列化失败: {message}")] + DeserializationFailed { message: String }, + + #[error("JSON 格式错误: {message}")] + JsonFormatError { message: String }, + + // 流程引擎错误 + #[error("流程执行失败: {flow_id} - {message}")] + FlowExecutionFailed { flow_id: String, message: String }, + + #[error("流程解析失败: {message}")] + FlowParsingFailed { message: String }, + + #[error("节点执行失败: {node_id} - {message}")] + NodeExecutionFailed { node_id: String, message: String }, + + // 调度任务错误 + #[error("任务调度失败: {job_id} - {message}")] + JobSchedulingFailed { job_id: String, message: String }, + + #[error("Cron 表达式无效: {expression}")] + InvalidCronExpression { expression: String }, + + #[error("任务执行超时: {job_id}")] + JobExecutionTimeout { job_id: String }, + + // 系统错误 + #[error("内部服务器错误: {message}")] + InternalServerError { message: String }, + + #[error("服务不可用: {message}")] + ServiceUnavailable { message: String }, + + #[error("请求超时: {message}")] + RequestTimeout { message: String }, + + #[error("资源耗尽: {resource}")] + ResourceExhausted { resource: String }, + + // 限流和安全错误 + #[error("请求频率过高: {message}")] + RateLimitExceeded { message: String }, + + #[error("请求体过大: 当前大小 {current}, 最大允许 {max}")] + PayloadTooLarge { current: usize, max: usize }, + + #[error("不支持的媒体类型: {media_type}")] + UnsupportedMediaType { media_type: String }, +} + +/// 错误响应结构 +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ErrorResponse { + pub error: ErrorDetail, + pub timestamp: chrono::DateTime, + pub request_id: Option, +} + +/// 错误详情 +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ErrorDetail { + pub code: String, + pub message: String, + pub details: Option, + pub field: Option, +} + +impl AppError { + /// 获取错误代码 + pub fn error_code(&self) -> &'static str { + match self { + // 认证和授权 + AppError::AuthenticationFailed { .. } => "AUTH_FAILED", + AppError::AuthorizationFailed { .. } => "AUTHORIZATION_FAILED", + AppError::InvalidToken { .. } => "INVALID_TOKEN", + AppError::TokenExpired => "TOKEN_EXPIRED", + + // 验证 + AppError::ValidationFailed { .. } => "VALIDATION_FAILED", + AppError::InvalidRequest { .. } => "INVALID_REQUEST", + AppError::MissingField { .. } => "MISSING_FIELD", + + // 资源 + AppError::NotFound(_) => "NOT_FOUND", + AppError::AlreadyExists(_) => "ALREADY_EXISTS", + AppError::Conflict { .. } => "CONFLICT", + + // 业务逻辑 + AppError::BusinessRuleViolation { .. } => "BUSINESS_RULE_VIOLATION", + AppError::OperationNotAllowed { .. } => "OPERATION_NOT_ALLOWED", + AppError::InvalidState { .. } => "INVALID_STATE", + + // 数据库 + AppError::DatabaseError(_) => "DATABASE_ERROR", + AppError::DatabaseConnectionFailed { .. } => "DATABASE_CONNECTION_FAILED", + AppError::TransactionFailed { .. } => "TRANSACTION_FAILED", + + // 缓存 + AppError::CacheError(_) => "CACHE_ERROR", + AppError::CacheConnectionFailed { .. } => "CACHE_CONNECTION_FAILED", + + // 网络 + AppError::NetworkError { .. } => "NETWORK_ERROR", + AppError::ExternalServiceError { .. } => "EXTERNAL_SERVICE_ERROR", + AppError::HttpRequestFailed { .. } => "HTTP_REQUEST_FAILED", + + // 文件 + AppError::FileOperationFailed { .. } => "FILE_OPERATION_FAILED", + AppError::FileNotFound { .. } => "FILE_NOT_FOUND", + AppError::FilePermissionDenied { .. } => "FILE_PERMISSION_DENIED", + + // 配置 + AppError::ConfigurationError { .. } => "CONFIGURATION_ERROR", + AppError::MissingEnvironmentVariable { .. } => "MISSING_ENV_VAR", + + // 序列化 + AppError::SerializationFailed { .. } => "SERIALIZATION_FAILED", + AppError::DeserializationFailed { .. } => "DESERIALIZATION_FAILED", + AppError::JsonFormatError { .. } => "JSON_FORMAT_ERROR", + + // 流程引擎 + AppError::FlowExecutionFailed { .. } => "FLOW_EXECUTION_FAILED", + AppError::FlowParsingFailed { .. } => "FLOW_PARSING_FAILED", + AppError::NodeExecutionFailed { .. } => "NODE_EXECUTION_FAILED", + + // 调度任务 + AppError::JobSchedulingFailed { .. } => "JOB_SCHEDULING_FAILED", + AppError::InvalidCronExpression { .. } => "INVALID_CRON_EXPRESSION", + AppError::JobExecutionTimeout { .. } => "JOB_EXECUTION_TIMEOUT", + + // 系统 + AppError::InternalServerError { .. } => "INTERNAL_SERVER_ERROR", + AppError::ServiceUnavailable { .. } => "SERVICE_UNAVAILABLE", + AppError::RequestTimeout { .. } => "REQUEST_TIMEOUT", + AppError::ResourceExhausted { .. } => "RESOURCE_EXHAUSTED", + + // 限流和安全 + AppError::RateLimitExceeded { .. } => "RATE_LIMIT_EXCEEDED", + AppError::PayloadTooLarge { .. } => "PAYLOAD_TOO_LARGE", + AppError::UnsupportedMediaType { .. } => "UNSUPPORTED_MEDIA_TYPE", + } + } + + /// 获取 HTTP 状态码 + pub fn status_code(&self) -> StatusCode { + match self { + // 4xx 客户端错误 + AppError::AuthenticationFailed { .. } => StatusCode::UNAUTHORIZED, + AppError::AuthorizationFailed { .. } => StatusCode::FORBIDDEN, + AppError::InvalidToken { .. } => StatusCode::UNAUTHORIZED, + AppError::TokenExpired => StatusCode::UNAUTHORIZED, + + AppError::ValidationFailed { .. } => StatusCode::BAD_REQUEST, + AppError::InvalidRequest { .. } => StatusCode::BAD_REQUEST, + AppError::MissingField { .. } => StatusCode::BAD_REQUEST, + + AppError::NotFound(_) => StatusCode::NOT_FOUND, + AppError::AlreadyExists(_) => StatusCode::CONFLICT, + AppError::Conflict { .. } => StatusCode::CONFLICT, + + AppError::BusinessRuleViolation { .. } => StatusCode::BAD_REQUEST, + AppError::OperationNotAllowed { .. } => StatusCode::FORBIDDEN, + AppError::InvalidState { .. } => StatusCode::BAD_REQUEST, + + AppError::FileNotFound { .. } => StatusCode::NOT_FOUND, + AppError::FilePermissionDenied { .. } => StatusCode::FORBIDDEN, + + AppError::JsonFormatError { .. } => StatusCode::BAD_REQUEST, + AppError::InvalidCronExpression { .. } => StatusCode::BAD_REQUEST, + + AppError::RateLimitExceeded { .. } => StatusCode::TOO_MANY_REQUESTS, + AppError::PayloadTooLarge { .. } => StatusCode::PAYLOAD_TOO_LARGE, + AppError::UnsupportedMediaType { .. } => StatusCode::UNSUPPORTED_MEDIA_TYPE, + + // 5xx 服务器错误 + AppError::DatabaseError(_) => StatusCode::INTERNAL_SERVER_ERROR, + AppError::DatabaseConnectionFailed { .. } => StatusCode::SERVICE_UNAVAILABLE, + AppError::TransactionFailed { .. } => StatusCode::INTERNAL_SERVER_ERROR, + + AppError::CacheError(_) => StatusCode::INTERNAL_SERVER_ERROR, + AppError::CacheConnectionFailed { .. } => StatusCode::SERVICE_UNAVAILABLE, + + AppError::NetworkError { .. } => StatusCode::BAD_GATEWAY, + AppError::ExternalServiceError { .. } => StatusCode::BAD_GATEWAY, + AppError::HttpRequestFailed { .. } => StatusCode::BAD_GATEWAY, + + AppError::FileOperationFailed { .. } => StatusCode::INTERNAL_SERVER_ERROR, + + AppError::ConfigurationError { .. } => StatusCode::INTERNAL_SERVER_ERROR, + AppError::MissingEnvironmentVariable { .. } => StatusCode::INTERNAL_SERVER_ERROR, + + AppError::SerializationFailed { .. } => StatusCode::INTERNAL_SERVER_ERROR, + AppError::DeserializationFailed { .. } => StatusCode::INTERNAL_SERVER_ERROR, + + AppError::FlowExecutionFailed { .. } => StatusCode::INTERNAL_SERVER_ERROR, + AppError::FlowParsingFailed { .. } => StatusCode::BAD_REQUEST, + AppError::NodeExecutionFailed { .. } => StatusCode::INTERNAL_SERVER_ERROR, + + AppError::JobSchedulingFailed { .. } => StatusCode::INTERNAL_SERVER_ERROR, + AppError::JobExecutionTimeout { .. } => StatusCode::REQUEST_TIMEOUT, + + AppError::InternalServerError { .. } => StatusCode::INTERNAL_SERVER_ERROR, + AppError::ServiceUnavailable { .. } => StatusCode::SERVICE_UNAVAILABLE, + AppError::RequestTimeout { .. } => StatusCode::REQUEST_TIMEOUT, + AppError::ResourceExhausted { .. } => StatusCode::SERVICE_UNAVAILABLE, + } + } + + /// 获取错误字段(如果适用) + pub fn error_field(&self) -> Option { + match self { + AppError::ValidationFailed { field, .. } => Some(field.clone()), + AppError::MissingField { field } => Some(field.clone()), + _ => None, + } + } + + /// 是否为客户端错误 + pub fn is_client_error(&self) -> bool { + self.status_code().is_client_error() + } + + /// 是否为服务器错误 + pub fn is_server_error(&self) -> bool { + self.status_code().is_server_error() + } + + /// 创建错误响应 + pub fn to_error_response(&self, request_id: Option) -> ErrorResponse { + ErrorResponse { + error: ErrorDetail { + code: self.error_code().to_string(), + message: self.to_string(), + details: None, + field: self.error_field(), + }, + timestamp: chrono::Utc::now(), + request_id, + } + } + + /// 记录错误日志 + pub fn log_error(&self, request_id: Option<&str>) { + let level = if self.is_server_error() { + tracing::Level::ERROR + } else { + tracing::Level::WARN + }; + + match level { + tracing::Level::ERROR => { + error!( + target = "udmin", + error_code = %self.error_code(), + error_message = %self, + request_id = ?request_id, + "application.error.server" + ); + } + _ => { + tracing::warn!( + target = "udmin", + error_code = %self.error_code(), + error_message = %self, + request_id = ?request_id, + "application.error.client" + ); + } + } + } +} + +/// 实现 IntoResponse,使错误可以直接作为 HTTP 响应返回 +impl IntoResponse for AppError { + fn into_response(self) -> Response { + // 从请求上下文获取 request_id(实际实现中可能需要通过中间件传递) + let request_id = None; // 这里应该从上下文获取 + + // 记录错误日志 + self.log_error(request_id.as_deref()); + + // 创建错误响应 + let error_response = self.to_error_response(request_id); + let status_code = self.status_code(); + + (status_code, Json(error_response)).into_response() + } +} + +/// 应用结果类型别名 +pub type AppResult = Result; + +/// 错误转换实现 +impl From for AppError { + fn from(err: sea_orm::DbErr) -> Self { + match err { + sea_orm::DbErr::RecordNotFound(_) => AppError::NotFound("记录不存在".to_string()), + sea_orm::DbErr::Custom(msg) => AppError::DatabaseError(msg), + sea_orm::DbErr::Conn(msg) => AppError::DatabaseConnectionFailed { message: msg }, + sea_orm::DbErr::Exec(msg) => AppError::DatabaseError(msg), + sea_orm::DbErr::Query(msg) => AppError::DatabaseError(msg), + _ => AppError::DatabaseError("未知数据库错误".to_string()), + } + } +} + +impl From for AppError { + fn from(err: redis::RedisError) -> Self { + match err.kind() { + redis::ErrorKind::IoError => AppError::CacheConnectionFailed { + message: err.to_string(), + }, + redis::ErrorKind::AuthenticationFailed => AppError::CacheConnectionFailed { + message: "Redis 认证失败".to_string(), + }, + _ => AppError::CacheError(err.to_string()), + } + } +} + +impl From for AppError { + fn from(err: reqwest::Error) -> Self { + if err.is_timeout() { + AppError::RequestTimeout { + message: "HTTP 请求超时".to_string(), + } + } else if err.is_connect() { + AppError::NetworkError { + message: "网络连接失败".to_string(), + } + } else if let Some(status) = err.status() { + AppError::HttpRequestFailed { + status: status.as_u16(), + message: err.to_string(), + } + } else { + AppError::NetworkError { + message: err.to_string(), + } + } + } +} + +impl From for AppError { + fn from(err: serde_json::Error) -> Self { + if err.is_syntax() { + AppError::JsonFormatError { + message: "JSON 语法错误".to_string(), + } + } else if err.is_data() { + AppError::DeserializationFailed { + message: err.to_string(), + } + } else { + AppError::SerializationFailed { + message: err.to_string(), + } + } + } +} + +impl From for AppError { + fn from(err: std::io::Error) -> Self { + match err.kind() { + std::io::ErrorKind::NotFound => AppError::FileNotFound { + path: "未知路径".to_string(), + }, + std::io::ErrorKind::PermissionDenied => AppError::FilePermissionDenied { + path: "未知路径".to_string(), + }, + _ => AppError::FileOperationFailed { + message: err.to_string(), + }, + } + } +} + +impl From for AppError { + fn from(_: tokio::time::error::Elapsed) -> Self { + AppError::RequestTimeout { + message: "操作超时".to_string(), + } + } +} + +/// 错误构建器 +pub struct ErrorBuilder { + error: AppError, +} + +impl ErrorBuilder { + pub fn new(error: AppError) -> Self { + Self { error } + } + + pub fn with_details(mut self, details: serde_json::Value) -> Self { + // 这里可以扩展错误以包含更多详情 + self + } + + pub fn with_field(mut self, field: String) -> Self { + // 设置错误字段 + self + } + + pub fn build(self) -> AppError { + self.error + } +} + +/// 错误宏 +#[macro_export] +macro_rules! app_error { + ($error_type:ident, $($field:ident = $value:expr),*) => { + AppError::$error_type { + $($field: $value.into()),* + } + }; +} + +/// 结果扩展 trait +pub trait ResultExt { + /// 将错误转换为 AppError + fn map_app_error(self, f: F) -> AppResult + where + F: FnOnce() -> AppError; + + /// 添加错误上下文 + fn with_context(self, f: F) -> AppResult + where + F: FnOnce() -> String; +} + +impl ResultExt for Result +where + E: Into, +{ + fn map_app_error(self, f: F) -> AppResult + where + F: FnOnce() -> AppError, + { + self.map_err(|_| f()) + } + + fn with_context(self, f: F) -> AppResult + where + F: FnOnce() -> String, + { + self.map_err(|e| { + let context = f(); + match e.into() { + AppError::InternalServerError { message } => AppError::InternalServerError { + message: format!("{}: {}", context, message), + }, + other => other, + } + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_error_codes() { + let error = AppError::NotFound("用户".to_string()); + assert_eq!(error.error_code(), "NOT_FOUND"); + assert_eq!(error.status_code(), StatusCode::NOT_FOUND); + } + + #[test] + fn test_error_response() { + let error = AppError::ValidationFailed { + field: "email".to_string(), + message: "格式无效".to_string(), + }; + + let response = error.to_error_response(Some("req-123".to_string())); + assert_eq!(response.error.code, "VALIDATION_FAILED"); + assert_eq!(response.error.field, Some("email".to_string())); + assert_eq!(response.request_id, Some("req-123".to_string())); + } + + #[test] + fn test_error_conversion() { + let db_error = sea_orm::DbErr::RecordNotFound("test".to_string()); + let app_error: AppError = db_error.into(); + + match app_error { + AppError::NotFound(_) => assert!(true), + _ => assert!(false, "Expected NotFound error"), + } + } + + #[test] + fn test_error_macro() { + let error = app_error!(ValidationFailed, + field = "username", + message = "用户名已存在" + ); + + match error { + AppError::ValidationFailed { field, message } => { + assert_eq!(field, "username"); + assert_eq!(message, "用户名已存在"); + } + _ => assert!(false, "Expected ValidationFailed error"), + } + } + + #[test] + fn test_result_ext() { + let result: Result = Err("test error"); + let app_result = result.map_app_error(|| AppError::InternalServerError { + message: "转换错误".to_string(), + }); + + assert!(app_result.is_err()); + } +} +``` + +## 错误处理策略 + +### 错误传播 + +```rust +/// 错误传播示例 +pub async fn create_user(req: CreateUserReq) -> AppResult { + // 验证请求 + validate_create_user_request(&req)?; + + // 检查用户是否已存在 + if user_exists(&req.email).await? { + return Err(AppError::AlreadyExists("用户邮箱已存在".to_string())); + } + + // 创建用户 + let user = User::create(req).await + .with_context(|| "创建用户失败")?; + + Ok(user.into()) +} +``` + +### 错误恢复 + +```rust +/// 错误恢复示例 +pub async fn get_user_with_fallback(id: &str) -> AppResult { + // 首先尝试从缓存获取 + match get_user_from_cache(id).await { + Ok(user) => return Ok(user), + Err(AppError::CacheError(_)) => { + // 缓存错误,尝试从数据库获取 + tracing::warn!("缓存获取用户失败,尝试数据库"); + } + Err(e) => return Err(e), + } + + // 从数据库获取 + let user = get_user_from_db(id).await?; + + // 尝试更新缓存(忽略错误) + if let Err(e) = set_user_cache(id, &user).await { + tracing::warn!(error = %e, "更新用户缓存失败"); + } + + Ok(user) +} +``` + +### 错误聚合 + +```rust +/// 错误聚合示例 +pub struct ValidationErrors { + pub errors: Vec, +} + +impl ValidationErrors { + pub fn new() -> Self { + Self { errors: Vec::new() } + } + + pub fn add(&mut self, error: AppError) { + self.errors.push(error); + } + + pub fn is_empty(&self) -> bool { + self.errors.is_empty() + } + + pub fn into_result(self) -> AppResult<()> { + if self.errors.is_empty() { + Ok(()) + } else { + // 返回第一个错误,或者可以创建一个聚合错误类型 + Err(self.errors.into_iter().next().unwrap()) + } + } +} + +pub fn validate_user_data(data: &CreateUserReq) -> AppResult<()> { + let mut errors = ValidationErrors::new(); + + if data.email.is_empty() { + errors.add(AppError::MissingField { field: "email".to_string() }); + } + + if data.password.len() < 8 { + errors.add(AppError::ValidationFailed { + field: "password".to_string(), + message: "密码长度至少8位".to_string(), + }); + } + + errors.into_result() +} +``` + +## 中间件集成 + +### 错误处理中间件 + +```rust +use axum::{ + extract::Request, + middleware::Next, + response::Response, +}; +use uuid::Uuid; + +/// 错误处理中间件 +pub async fn error_handler_middleware( + mut request: Request, + next: Next, +) -> Response { + // 生成请求 ID + let request_id = Uuid::new_v4().to_string(); + request.extensions_mut().insert(request_id.clone()); + + // 执行请求 + let response = next.run(request).await; + + // 如果响应是错误,添加请求 ID + if response.status().is_client_error() || response.status().is_server_error() { + // 这里可以修改响应头,添加请求 ID + tracing::info!( + target = "udmin", + request_id = %request_id, + status = %response.status(), + "request.completed.error" + ); + } + + response +} +``` + +## 监控和告警 + +### 错误指标收集 + +```rust +use prometheus::{Counter, Histogram, Registry}; +use std::sync::Arc; + +/// 错误指标收集器 +#[derive(Clone)] +pub struct ErrorMetrics { + error_counter: Counter, + error_duration: Histogram, +} + +impl ErrorMetrics { + pub fn new(registry: &Registry) -> Self { + let error_counter = Counter::new( + "app_errors_total", + "Total number of application errors" + ).unwrap(); + + let error_duration = Histogram::new( + "error_handling_duration_seconds", + "Time spent handling errors" + ).unwrap(); + + registry.register(Box::new(error_counter.clone())).unwrap(); + registry.register(Box::new(error_duration.clone())).unwrap(); + + Self { + error_counter, + error_duration, + } + } + + pub fn record_error(&self, error: &AppError) { + self.error_counter.inc(); + + // 可以根据错误类型添加标签 + tracing::info!( + target = "udmin", + error_code = %error.error_code(), + error_type = %std::any::type_name::(), + "metrics.error.recorded" + ); + } +} +``` + +## 最佳实践 + +### 错误设计原则 + +1. **明确性**: 错误消息应该清楚地说明发生了什么 +2. **可操作性**: 错误消息应该告诉用户如何解决问题 +3. **一致性**: 相同类型的错误应该有一致的格式和处理方式 +4. **安全性**: 不要在错误消息中泄露敏感信息 + +### 错误处理模式 + +1. **快速失败**: 尽早检测和报告错误 +2. **优雅降级**: 在可能的情况下提供备选方案 +3. **错误隔离**: 防止错误在系统中传播 +4. **错误恢复**: 在适当的时候尝试从错误中恢复 + +### 日志记录 + +1. **结构化日志**: 使用结构化格式记录错误信息 +2. **上下文信息**: 包含足够的上下文信息用于调试 +3. **敏感信息**: 避免在日志中记录敏感信息 +4. **日志级别**: 根据错误严重程度选择合适的日志级别 + +## 总结 + +UdminAI 的错误处理模块提供了完整的错误管理解决方案,具有以下特点: + +- **类型安全**: 编译时错误类型检查 +- **统一处理**: 所有错误使用统一的类型和格式 +- **用户友好**: 清晰的错误消息和适当的 HTTP 状态码 +- **可观测性**: 完整的错误日志和指标收集 +- **可扩展性**: 易于添加新的错误类型和处理逻辑 + +通过合理的错误处理设计,确保了系统的稳定性、可维护性和用户体验。 \ No newline at end of file diff --git a/docs/FLOW_ENGINE.md b/docs/FLOW_ENGINE.md new file mode 100644 index 0000000..cf06528 --- /dev/null +++ b/docs/FLOW_ENGINE.md @@ -0,0 +1,484 @@ +# 流程引擎文档 + +## 概述 + +流程引擎是 UdminAI 的核心模块,提供可视化流程设计、执行和监控功能。支持多种节点类型、条件分支、循环控制和并发执行。 + +## 架构设计 + +### 核心组件 + +``` +flow/ +├── domain.rs # 领域模型定义 +├── dsl.rs # DSL 解析和构建 +├── engine.rs # 流程执行引擎 +├── context.rs # 执行上下文管理 +├── task.rs # 任务抽象接口 +├── log_handler.rs # 日志处理 +├── mappers.rs # 数据映射器 +├── executors/ # 执行器实现 +└── mappers/ # 具体映射器 +``` + +## 领域模型 (domain.rs) + +### 核心数据结构 + +#### ChainDef - 流程链定义 +```rust +pub struct ChainDef { + pub nodes: Vec, // 节点列表 + pub links: Vec, // 连接列表 +} +``` + +#### NodeDef - 节点定义 +```rust +pub struct NodeDef { + pub id: NodeId, // 节点唯一标识 + pub kind: NodeKind, // 节点类型 + pub data: serde_json::Value, // 节点配置数据 +} +``` + +#### NodeKind - 节点类型 +```rust +pub enum NodeKind { + Start, // 开始节点 + End, // 结束节点 + Condition, // 条件节点 + Http, // HTTP 请求节点 + Database, // 数据库操作节点 + ScriptJs, // JavaScript 脚本节点 + ScriptPython, // Python 脚本节点 + ScriptRhai, // Rhai 脚本节点 + Variable, // 变量操作节点 + Task, // 通用任务节点 +} +``` + +#### LinkDef - 连接定义 +```rust +pub struct LinkDef { + pub from: NodeId, // 源节点 + pub to: NodeId, // 目标节点 + pub condition: Option, // 连接条件 +} +``` + +## DSL 解析 (dsl.rs) + +### DSL 结构 + +#### FlowDSL - 流程 DSL +```rust +pub struct FlowDSL { + pub nodes: Vec, // 节点列表 + pub edges: Vec, // 边列表 +} +``` + +#### DesignSyntax - 设计语法 +```rust +pub struct DesignSyntax { + pub nodes: Vec, // 节点语法 + pub edges: Vec, // 边语法 +} +``` + +### 核心功能 + +#### 1. 设计验证 +```rust +pub fn validate_design(design: &DesignSyntax) -> anyhow::Result<()> +``` + +**验证规则**: +- 节点 ID 唯一性 +- 至少包含一个 start 节点 +- 至少包含一个 end 节点 +- 边的引用合法性 +- 条件节点配置完整性 + +#### 2. 链构建 +```rust +pub fn build_chain_from_design(design: &DesignSyntax) -> anyhow::Result +``` + +**构建过程**: +1. 节点类型推断 +2. 条件节点处理 +3. 边关系建立 +4. 数据完整性检查 + +#### 3. 兼容性处理 +```rust +pub fn chain_from_design_json(input: &str) -> anyhow::Result +``` + +**兼容特性**: +- 字符串/对象输入支持 +- 字段回填 +- 版本兼容 +- 错误恢复 + +## 执行引擎 (engine.rs) + +### FlowEngine - 流程执行引擎 + +#### 核心结构 +```rust +pub struct FlowEngine { + tasks: TaskRegistry, // 任务注册表 +} +``` + +#### 执行选项 +```rust +pub struct DriveOptions { + pub max_steps: Option, // 最大执行步数 + pub timeout_ms: Option, // 超时时间 + pub parallel: bool, // 并发执行 + pub stream_events: bool, // 流事件支持 +} +``` + +### 执行流程 + +#### 1. 起点选择 +- 优先选择 Start 节点 +- 其次选择入度为 0 的节点 +- 最后选择第一个节点 + +#### 2. 执行策略 +- **串行执行**: 按依赖顺序逐个执行 +- **并发执行**: 无依赖节点并行执行 +- **条件分支**: 根据条件选择执行路径 +- **循环控制**: 支持循环节点执行 + +#### 3. 状态管理 +- 节点执行状态跟踪 +- 上下文数据传递 +- 错误状态处理 +- 执行结果收集 + +### 任务注册表 (TaskRegistry) + +#### 注册机制 +```rust +pub struct TaskRegistry { + executors: HashMap>, +} +``` + +#### 执行器接口 +```rust +#[async_trait] +pub trait Executor: Send + Sync { + async fn execute( + &self, + node_id: &NodeId, + node: &NodeDef, + ctx: &mut serde_json::Value, + ) -> anyhow::Result<()>; +} +``` + +## 执行器实现 (executors/) + +### HTTP 执行器 (http.rs) + +**功能**: 执行 HTTP 请求 + +**配置参数**: +- `method`: 请求方法 (GET/POST/PUT/DELETE) +- `url`: 请求 URL +- `headers`: 请求头 +- `query`: 查询参数 +- `body`: 请求体 +- `timeout_ms`: 超时时间 +- `insecure`: 忽略 SSL 验证 + +**执行流程**: +1. 解析 HTTP 配置 +2. 构建请求参数 +3. 发送 HTTP 请求 +4. 处理响应结果 +5. 更新执行上下文 + +### 数据库执行器 (db.rs) + +**功能**: 执行数据库操作 + +**支持操作**: +- `SELECT`: 查询数据 +- `INSERT`: 插入数据 +- `UPDATE`: 更新数据 +- `DELETE`: 删除数据 +- `TRANSACTION`: 事务操作 + +**配置参数**: +- `sql`: SQL 语句 +- `params`: 参数绑定 +- `connection`: 连接配置 +- `transaction`: 事务控制 + +### 条件执行器 (condition.rs) + +**功能**: 条件判断和分支控制 + +**条件类型**: +- 简单比较 (==, !=, >, <, >=, <=) +- 逻辑运算 (AND, OR, NOT) +- 正则匹配 +- 自定义表达式 + +**执行逻辑**: +1. 解析条件表达式 +2. 从上下文获取变量值 +3. 执行条件计算 +4. 返回布尔结果 + +### 脚本执行器 + +#### JavaScript 执行器 (script_js.rs) +**功能**: 执行 JavaScript 代码 +**引擎**: V8 引擎 (通过 rusty_v8) + +#### Python 执行器 (script_python.rs) +**功能**: 执行 Python 代码 +**引擎**: Python 解释器 (通过 pyo3) + +#### Rhai 执行器 (script_rhai.rs) +**功能**: 执行 Rhai 脚本 +**引擎**: Rhai 脚本引擎 + +**通用特性**: +- 沙箱执行环境 +- 上下文变量注入 +- 执行结果获取 +- 错误处理和日志 + +### 变量执行器 (variable.rs) + +**功能**: 变量操作和数据转换 + +**操作类型**: +- `SET`: 设置变量值 +- `GET`: 获取变量值 +- `TRANSFORM`: 数据转换 +- `MERGE`: 数据合并 +- `EXTRACT`: 数据提取 + +## 上下文管理 (context.rs) + +### StreamEvent - 流事件 + +```rust +pub enum StreamEvent { + NodeStart { node_id: String }, // 节点开始 + NodeComplete { node_id: String }, // 节点完成 + NodeError { node_id: String, error: String }, // 节点错误 + FlowComplete, // 流程完成 + FlowError { error: String }, // 流程错误 +} +``` + +### 上下文结构 + +```rust +pub struct ExecutionContext { + pub variables: serde_json::Value, // 变量存储 + pub node_results: HashMap, // 节点结果 + pub execution_log: Vec, // 执行日志 + pub start_time: DateTime, // 开始时间 +} +``` + +### 上下文操作 + +- **变量管理**: 设置、获取、更新变量 +- **结果存储**: 保存节点执行结果 +- **日志记录**: 记录执行过程 +- **状态跟踪**: 跟踪执行状态 + +## 数据映射器 (mappers/) + +### HTTP 映射器 (http.rs) +**功能**: HTTP 请求/响应数据映射 + +### 数据库映射器 (db.rs) +**功能**: 数据库查询结果映射 + +### 脚本映射器 (script.rs) +**功能**: 脚本执行结果映射 + +### 变量映射器 (variable.rs) +**功能**: 变量数据类型映射 + +## 日志处理 (log_handler.rs) + +### 日志类型 + +```rust +pub enum LogLevel { + Debug, + Info, + Warn, + Error, +} + +pub struct LogEntry { + pub level: LogLevel, + pub message: String, + pub timestamp: DateTime, + pub node_id: Option, + pub context: serde_json::Value, +} +``` + +### 日志功能 + +- **执行日志**: 记录节点执行过程 +- **错误日志**: 记录执行错误信息 +- **性能日志**: 记录执行时间和资源使用 +- **调试日志**: 记录调试信息 + +## 流程执行模式 + +### 1. 同步执行 +- 阻塞式执行 +- 顺序执行节点 +- 立即返回结果 +- 适用于简单流程 + +### 2. 异步执行 +- 非阻塞执行 +- 后台执行流程 +- 通过回调获取结果 +- 适用于长时间运行的流程 + +### 3. 流式执行 +- 实时事件推送 +- 执行过程可视化 +- 支持中断和恢复 +- 适用于交互式流程 + +### 4. 批量执行 +- 批量处理多个流程 +- 资源优化 +- 并发控制 +- 适用于批处理场景 + +## 错误处理 + +### 错误类型 + +```rust +pub enum FlowError { + ParseError(String), // 解析错误 + ValidationError(String), // 验证错误 + ExecutionError(String), // 执行错误 + TimeoutError, // 超时错误 + ResourceError(String), // 资源错误 +} +``` + +### 错误处理策略 + +- **重试机制**: 自动重试失败的节点 +- **降级处理**: 执行备用逻辑 +- **错误传播**: 将错误传播到上层 +- **日志记录**: 详细记录错误信息 + +## 性能优化 + +### 执行优化 + +- **并发执行**: 无依赖节点并行执行 +- **资源池**: 复用执行器实例 +- **缓存机制**: 缓存执行结果 +- **懒加载**: 按需加载执行器 + +### 内存优化 + +- **上下文清理**: 及时清理不需要的数据 +- **流式处理**: 大数据流式处理 +- **对象池**: 复用对象实例 +- **垃圾回收**: 主动触发垃圾回收 + +### 网络优化 + +- **连接复用**: HTTP 连接复用 +- **请求合并**: 合并相似请求 +- **超时控制**: 合理设置超时时间 +- **重试策略**: 智能重试机制 + +## 监控和调试 + +### 执行监控 + +- **执行状态**: 实时监控执行状态 +- **性能指标**: 收集执行性能数据 +- **资源使用**: 监控内存和 CPU 使用 +- **错误统计**: 统计错误发生情况 + +### 调试支持 + +- **断点调试**: 支持节点断点 +- **单步执行**: 逐步执行节点 +- **变量查看**: 查看执行上下文 +- **日志输出**: 详细的执行日志 + +## 扩展机制 + +### 自定义执行器 + +```rust +#[derive(Default)] +pub struct CustomExecutor; + +#[async_trait] +impl Executor for CustomExecutor { + async fn execute( + &self, + node_id: &NodeId, + node: &NodeDef, + ctx: &mut serde_json::Value, + ) -> anyhow::Result<()> { + // 自定义执行逻辑 + Ok(()) + } +} +``` + +### 插件系统 + +- **执行器插件**: 扩展新的节点类型 +- **中间件插件**: 扩展执行过程 +- **映射器插件**: 扩展数据映射 +- **日志插件**: 扩展日志处理 + +## 最佳实践 + +### 流程设计 + +- **模块化设计**: 将复杂流程拆分为子流程 +- **错误处理**: 为关键节点添加错误处理 +- **性能考虑**: 避免不必要的数据传递 +- **可维护性**: 添加适当的注释和文档 + +### 节点配置 + +- **参数验证**: 验证节点配置参数 +- **默认值**: 为可选参数提供默认值 +- **类型安全**: 使用强类型配置 +- **版本兼容**: 保持配置向后兼容 + +### 执行优化 + +- **并发控制**: 合理设置并发度 +- **资源限制**: 设置合理的资源限制 +- **超时设置**: 为长时间运行的节点设置超时 +- **监控告警**: 添加关键指标监控 \ No newline at end of file diff --git a/docs/FRONTEND_ARCHITECTURE.md b/docs/FRONTEND_ARCHITECTURE.md new file mode 100644 index 0000000..f2f588d --- /dev/null +++ b/docs/FRONTEND_ARCHITECTURE.md @@ -0,0 +1,439 @@ +# 前端架构文档 + +## 概述 + +前端采用 React 18 + TypeScript 构建,使用现代化的组件化架构,集成了强大的流程可视化编辑器。 + +## 技术栈 + +### 核心框架 +- **React 18**: 前端框架,支持并发特性 +- **TypeScript**: 类型安全的 JavaScript 超集 +- **Vite**: 现代化构建工具 + +### UI 组件库 +- **Semi Design**: 主要 UI 组件库 +- **Ant Design**: 补充 UI 组件 +- **Styled Components**: CSS-in-JS 样式解决方案 + +### 流程编辑器 +- **@flowgram.ai/free-layout-editor**: 自由布局编辑器核心 +- **@flowgram.ai/form-materials**: 表单物料组件 +- **@flowgram.ai/runtime-js**: 流程运行时 +- **@flowgram.ai/minimap-plugin**: 小地图插件 +- **@flowgram.ai/panel-manager-plugin**: 面板管理插件 + +### 状态管理和路由 +- **React Context**: 状态管理 +- **React Router v6**: 客户端路由 +- **Axios**: HTTP 客户端 + +## 项目结构 + +``` +frontend/src/ +├── App.tsx # 应用根组件 +├── main.tsx # 应用入口 +├── vite-env.d.ts # Vite 类型声明 +├── assets/ # 静态资源 +├── components/ # 通用组件 +├── flows/ # 流程编辑器模块 +├── layouts/ # 布局组件 +├── pages/ # 页面组件 +├── styles/ # 全局样式 +└── utils/ # 工具函数 +``` + +## 核心模块 + +### 1. 应用入口 (main.tsx) + +**职责**: 应用初始化和根组件渲染 + +**功能**: +- React 18 严格模式启用 +- 路由配置 +- 全局样式导入 +- 错误边界设置 + +### 2. 应用根组件 (App.tsx) + +**职责**: 应用路由配置和布局管理 + +**功能**: +- 路由定义和保护 +- 认证状态管理 +- 全局错误处理 +- 主题配置 + +### 3. 布局系统 (layouts/) + +#### MainLayout.tsx +**职责**: 主要页面布局 + +**功能**: +- 侧边栏导航 +- 顶部导航栏 +- 面包屑导航 +- 用户信息显示 +- 响应式布局 + +**布局结构**: +```tsx + + 侧边栏 + +
顶部导航
+ 页面内容 +
页脚
+
+
+``` + +### 4. 页面组件 (pages/) + +#### 管理页面 +- `Dashboard.tsx`: 仪表板页面 +- `Users.tsx`: 用户管理页面 +- `Roles.tsx`: 角色管理页面 +- `Menus.tsx`: 菜单管理页面 +- `Departments.tsx`: 部门管理页面 +- `Positions.tsx`: 职位管理页面 +- `Permissions.tsx`: 权限管理页面 + +#### 流程相关页面 +- `FlowList.tsx`: 流程列表页面 +- `FlowRunLogs.tsx`: 流程运行日志页面 +- `ScheduleJobs.tsx`: 定时任务页面 + +#### 系统页面 +- `Login.tsx`: 登录页面 +- `Logs.tsx`: 系统日志页面 + +**页面特性**: +- 统一的 CRUD 操作 +- 表格分页和搜索 +- 表单验证 +- 权限控制 +- 响应式设计 + +### 5. 通用组件 (components/) + +#### PageHeader.tsx +**职责**: 页面头部组件 + +**功能**: +- 页面标题显示 +- 面包屑导航 +- 操作按钮区域 +- 统一样式 + +### 6. 工具函数 (utils/) + +#### axios.ts +**职责**: HTTP 客户端配置 + +**功能**: +- 请求/响应拦截器 +- 自动 Token 添加 +- 错误统一处理 +- 请求重试机制 + +#### token.ts +**职责**: 令牌管理 + +**功能**: +- Token 存储和获取 +- Token 过期检查 +- 自动刷新机制 +- 登出清理 + +#### permission.tsx +**职责**: 权限控制 + +**功能**: +- 权限检查组件 +- 路由权限保护 +- 按钮级权限控制 +- 角色权限验证 + +#### sse.ts +**职责**: 服务端推送事件 + +**功能**: +- SSE 连接管理 +- 事件监听 +- 自动重连 +- 错误处理 + +#### datetime.ts +**职责**: 日期时间处理 + +**功能**: +- 日期格式化 +- 时区转换 +- 相对时间显示 +- 日期计算 + +#### config.ts +**职责**: 应用配置 + +**功能**: +- 环境变量管理 +- API 端点配置 +- 应用常量定义 +- 功能开关 + +## 流程编辑器模块 + +**位置**: `src/flows/` + +### 核心组件 + +#### 1. 编辑器入口 (editor.tsx) +**职责**: 流程编辑器主组件 + +**功能**: +- 编辑器初始化 +- 插件注册 +- 事件处理 +- 数据同步 + +#### 2. 应用容器 (app.tsx) +**职责**: 编辑器应用容器 + +**功能**: +- 编辑器配置 +- 工具栏管理 +- 侧边栏控制 +- 快捷键支持 + +#### 3. 初始数据 (initial-data.ts) +**职责**: 编辑器初始化数据 + +**功能**: +- 默认节点配置 +- 画布初始状态 +- 工具栏配置 +- 插件配置 + +### 节点系统 (nodes/) + +#### 节点类型 +- `start/`: 开始节点 +- `end/`: 结束节点 +- `condition/`: 条件节点 +- `http/`: HTTP 请求节点 +- `db/`: 数据库操作节点 +- `code/`: 代码执行节点 +- `variable/`: 变量操作节点 +- `loop/`: 循环节点 +- `comment/`: 注释节点 +- `group/`: 分组节点 + +#### 节点特性 +- 可视化配置界面 +- 参数验证 +- 实时预览 +- 拖拽支持 +- 连接点管理 + +### 组件系统 (components/) + +#### 核心组件 +- `base-node/`: 基础节点组件 +- `node-panel/`: 节点配置面板 +- `sidebar/`: 侧边栏组件 +- `tools/`: 工具栏组件 +- `testrun/`: 测试运行组件 + +#### 交互组件 +- `add-node/`: 添加节点组件 +- `node-menu/`: 节点菜单 +- `line-add-button/`: 连线添加按钮 +- `selector-box-popover/`: 选择框弹窗 + +### 表单系统 (form-components/) + +#### 表单组件 +- `form-header/`: 表单头部 +- `form-content/`: 表单内容 +- `form-inputs/`: 表单输入组件 +- `form-item/`: 表单项组件 +- `feedback.tsx`: 反馈组件 + +### 插件系统 (plugins/) + +#### 插件列表 +- `context-menu-plugin/`: 右键菜单插件 +- `runtime-plugin/`: 运行时插件 +- `variable-panel-plugin/`: 变量面板插件 + +### 快捷键系统 (shortcuts/) + +#### 快捷键功能 +- `copy/`: 复制功能 +- `paste/`: 粘贴功能 +- `delete/`: 删除功能 +- `select-all/`: 全选功能 +- `zoom-in/`: 放大功能 +- `zoom-out/`: 缩小功能 +- `collapse/`: 折叠功能 +- `expand/`: 展开功能 + +### 上下文管理 (context/) + +#### 上下文类型 +- `node-render-context.ts`: 节点渲染上下文 +- `sidebar-context.ts`: 侧边栏上下文 + +### Hooks 系统 (hooks/) + +#### 自定义 Hooks +- `use-editor-props.tsx`: 编辑器属性 Hook +- `use-is-sidebar.ts`: 侧边栏状态 Hook +- `use-node-render-context.ts`: 节点渲染上下文 Hook +- `use-port-click.ts`: 端口点击 Hook + +### 工具函数 (utils/) + +#### 工具函数 +- `yaml.ts`: YAML 处理工具 +- `on-drag-line-end.ts`: 拖拽连线结束处理 +- `toggle-loop-expanded.ts`: 循环节点展开切换 + +## 状态管理 + +### Context 设计 + +#### 全局状态 +- 用户认证状态 +- 权限信息 +- 主题配置 +- 语言设置 + +#### 页面状态 +- 表格数据 +- 分页信息 +- 搜索条件 +- 选中项 + +#### 编辑器状态 +- 画布数据 +- 选中节点 +- 编辑模式 +- 工具栏状态 + +### 状态更新模式 +- 不可变更新 +- 批量更新 +- 异步状态处理 +- 错误状态管理 + +## 路由设计 + +### 路由结构 +``` +/ +├── /login # 登录页面 +├── /dashboard # 仪表板 +├── /users # 用户管理 +├── /roles # 角色管理 +├── /menus # 菜单管理 +├── /departments # 部门管理 +├── /positions # 职位管理 +├── /permissions # 权限管理 +├── /flows # 流程列表 +├── /flows/:id/edit # 流程编辑 +├── /flows/logs # 流程日志 +├── /schedule-jobs # 定时任务 +└── /logs # 系统日志 +``` + +### 路由保护 +- 认证检查 +- 权限验证 +- 角色控制 +- 重定向处理 + +## 样式系统 + +### CSS 架构 +- 全局样式 (`global.css`) +- 组件样式 (CSS Modules) +- 主题变量 +- 响应式断点 + +### 设计系统 +- 颜色规范 +- 字体规范 +- 间距规范 +- 组件规范 + +## 性能优化 + +### 代码分割 +- 路由级别分割 +- 组件懒加载 +- 动态导入 +- Bundle 分析 + +### 渲染优化 +- React.memo 使用 +- useMemo 缓存 +- useCallback 优化 +- 虚拟滚动 + +### 资源优化 +- 图片懒加载 +- 资源压缩 +- CDN 加速 +- 缓存策略 + +## 测试策略 + +### 测试类型 +- 单元测试 +- 集成测试 +- E2E 测试 +- 视觉回归测试 + +### 测试工具 +- Jest: 单元测试框架 +- React Testing Library: 组件测试 +- Cypress: E2E 测试 +- Storybook: 组件文档 + +## 构建和部署 + +### 构建配置 +- Vite 配置优化 +- 环境变量管理 +- 代码分割策略 +- 资源优化 + +### 部署策略 +- 静态资源部署 +- CDN 配置 +- 缓存策略 +- 版本管理 + +## 开发规范 + +### 代码规范 +- ESLint 配置 +- Prettier 格式化 +- TypeScript 严格模式 +- 提交规范 + +### 组件规范 +- 组件命名 +- Props 定义 +- 事件处理 +- 样式组织 + +### 文件组织 +- 目录结构 +- 文件命名 +- 导入导出 +- 类型定义 \ No newline at end of file diff --git a/docs/MIDDLEWARES.md b/docs/MIDDLEWARES.md new file mode 100644 index 0000000..051d243 --- /dev/null +++ b/docs/MIDDLEWARES.md @@ -0,0 +1,2227 @@ +# 中间件文档 + +## 概述 + +中间件层是 UdminAI 系统的核心组件,基于 Axum 框架实现,提供了认证授权、请求日志、错误处理、CORS、WebSocket 支持、SSE(Server-Sent Events)等功能。中间件采用洋葱模型,按顺序处理请求和响应。 + +## 架构设计 + +### 中间件模块结构 + +``` +middlewares/ +├── mod.rs # 中间件模块导出 +├── auth.rs # 认证中间件 +├── cors.rs # CORS 中间件 +├── http_client.rs # HTTP 客户端中间件 +├── logging.rs # 请求日志中间件 +├── rate_limit.rs # 限流中间件 +├── sse.rs # Server-Sent Events 中间件 +└── ws.rs # WebSocket 中间件 +``` + +### 设计原则 + +- **模块化**: 每个中间件独立实现,职责单一 +- **可组合**: 中间件可以灵活组合使用 +- **高性能**: 最小化性能开销 +- **类型安全**: 利用 Rust 类型系统确保安全 +- **可配置**: 支持灵活的配置选项 +- **可测试**: 易于单元测试和集成测试 + +## 认证中间件 (auth.rs) + +### 功能特性 + +- JWT Token 验证 +- 用户身份识别 +- 权限检查 +- Token 刷新机制 +- 多种认证策略 + +### 实现代码 + +```rust +use axum::{ + extract::{Request, State}, + http::{header, StatusCode}, + middleware::Next, + response::Response, +}; +use jsonwebtoken::{decode, DecodingKey, Validation}; +use serde::{Deserialize, Serialize}; +use std::collections::HashSet; +use tracing::{error, info, warn}; + +use crate::{ + error::AppError, + models::user, + services::user_service, + AppState, +}; + +/// JWT Claims 结构 +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct Claims { + pub sub: String, // 用户ID + pub username: String, // 用户名 + pub roles: Vec, // 角色列表 + pub permissions: Vec, // 权限列表 + pub exp: usize, // 过期时间 + pub iat: usize, // 签发时间 + pub iss: String, // 签发者 +} + +/// 认证上下文 +#[derive(Debug, Clone)] +pub struct AuthContext { + pub user_id: String, + pub username: String, + pub roles: Vec, + pub permissions: HashSet, +} + +/// JWT 认证中间件 +pub async fn jwt_auth_middleware( + State(state): State, + mut request: Request, + next: Next, +) -> Result { + let auth_header = request + .headers() + .get(header::AUTHORIZATION) + .and_then(|header| header.to_str().ok()); + + let token = match auth_header { + Some(header) if header.starts_with("Bearer ") => { + header.trim_start_matches("Bearer ") + } + _ => { + warn!(target = "udmin", "Missing or invalid authorization header"); + return Err(AppError::Unauthorized("缺少认证令牌".to_string())); + } + }; + + // 验证 JWT Token + let claims = match decode::( + token, + &DecodingKey::from_secret(state.config.jwt_secret.as_ref()), + &Validation::default(), + ) { + Ok(token_data) => token_data.claims, + Err(err) => { + error!(target = "udmin", error = %err, "JWT token validation failed"); + return Err(AppError::Unauthorized("无效的认证令牌".to_string())); + } + }; + + // 检查用户是否存在且活跃 + let user = user_service::find_by_id(&state.db, &claims.sub) + .await + .map_err(|_| AppError::Unauthorized("用户不存在".to_string()))?; + + if user.status != user::UserStatus::Active { + warn!(target = "udmin", user_id = %claims.sub, "Inactive user attempted access"); + return Err(AppError::Unauthorized("用户账户已被禁用".to_string())); + } + + // 创建认证上下文 + let auth_context = AuthContext { + user_id: claims.sub.clone(), + username: claims.username.clone(), + roles: claims.roles.clone(), + permissions: claims.permissions.into_iter().collect(), + }; + + // 将认证上下文添加到请求扩展中 + request.extensions_mut().insert(auth_context); + + info!( + target = "udmin", + user_id = %claims.sub, + username = %claims.username, + "User authenticated successfully" + ); + + Ok(next.run(request).await) +} + +/// 权限检查中间件 +pub fn require_permission(required_permission: &'static str) -> impl Fn(Request, Next) -> std::pin::Pin> + Send>> + Clone { + move |request: Request, next: Next| { + Box::pin(async move { + let auth_context = request + .extensions() + .get::() + .ok_or_else(|| AppError::Unauthorized("未认证的请求".to_string()))?; + + if !auth_context.permissions.contains(required_permission) { + warn!( + target = "udmin", + user_id = %auth_context.user_id, + required_permission = %required_permission, + "Permission denied" + ); + return Err(AppError::Forbidden("权限不足".to_string())); + } + + info!( + target = "udmin", + user_id = %auth_context.user_id, + permission = %required_permission, + "Permission granted" + ); + + Ok(next.run(request).await) + }) + } +} + +/// 角色检查中间件 +pub fn require_role(required_role: &'static str) -> impl Fn(Request, Next) -> std::pin::Pin> + Send>> + Clone { + move |request: Request, next: Next| { + Box::pin(async move { + let auth_context = request + .extensions() + .get::() + .ok_or_else(|| AppError::Unauthorized("未认证的请求".to_string()))?; + + if !auth_context.roles.contains(&required_role.to_string()) { + warn!( + target = "udmin", + user_id = %auth_context.user_id, + required_role = %required_role, + "Role check failed" + ); + return Err(AppError::Forbidden("角色权限不足".to_string())); + } + + info!( + target = "udmin", + user_id = %auth_context.user_id, + role = %required_role, + "Role check passed" + ); + + Ok(next.run(request).await) + }) + } +} + +/// 可选认证中间件(不强制要求认证) +pub async fn optional_auth_middleware( + State(state): State, + mut request: Request, + next: Next, +) -> Response { + let auth_header = request + .headers() + .get(header::AUTHORIZATION) + .and_then(|header| header.to_str().ok()); + + if let Some(header) = auth_header { + if let Some(token) = header.strip_prefix("Bearer ") { + if let Ok(token_data) = decode::( + token, + &DecodingKey::from_secret(state.config.jwt_secret.as_ref()), + &Validation::default(), + ) { + let claims = token_data.claims; + + // 创建认证上下文 + let auth_context = AuthContext { + user_id: claims.sub.clone(), + username: claims.username.clone(), + roles: claims.roles.clone(), + permissions: claims.permissions.into_iter().collect(), + }; + + request.extensions_mut().insert(auth_context); + } + } + } + + next.run(request).await +} +``` + +## CORS 中间件 (cors.rs) + +### 功能特性 + +- 跨域资源共享支持 +- 可配置的允许源 +- 预检请求处理 +- 安全头设置 + +### 实现代码 + +```rust +use axum::{ + extract::Request, + http::{header, HeaderValue, Method, StatusCode}, + middleware::Next, + response::Response, +}; +use tower_http::cors::{Any, CorsLayer}; +use tracing::info; + +/// CORS 配置 +#[derive(Debug, Clone)] +pub struct CorsConfig { + pub allowed_origins: Vec, + pub allowed_methods: Vec, + pub allowed_headers: Vec, + pub max_age: u64, + pub allow_credentials: bool, +} + +impl Default for CorsConfig { + fn default() -> Self { + Self { + allowed_origins: vec![ + "http://localhost:3000".to_string(), + "http://localhost:5173".to_string(), + "http://127.0.0.1:3000".to_string(), + "http://127.0.0.1:5173".to_string(), + ], + allowed_methods: vec![ + Method::GET, + Method::POST, + Method::PUT, + Method::DELETE, + Method::PATCH, + Method::OPTIONS, + ], + allowed_headers: vec![ + "content-type".to_string(), + "authorization".to_string(), + "x-requested-with".to_string(), + "x-api-key".to_string(), + ], + max_age: 3600, + allow_credentials: true, + } + } +} + +/// 创建 CORS 层 +pub fn create_cors_layer(config: CorsConfig) -> CorsLayer { + let mut cors = CorsLayer::new() + .allow_methods(config.allowed_methods) + .allow_headers( + config + .allowed_headers + .iter() + .map(|h| h.parse().unwrap()) + .collect::>(), + ) + .max_age(std::time::Duration::from_secs(config.max_age)); + + // 设置允许的源 + if config.allowed_origins.contains(&"*".to_string()) { + cors = cors.allow_origin(Any); + } else { + cors = cors.allow_origin( + config + .allowed_origins + .iter() + .map(|origin| origin.parse().unwrap()) + .collect::>(), + ); + } + + // 设置是否允许凭据 + if config.allow_credentials { + cors = cors.allow_credentials(true); + } + + info!(target = "udmin", "CORS middleware configured"); + cors +} + +/// 自定义 CORS 中间件 +pub async fn cors_middleware( + request: Request, + next: Next, +) -> Response { + let origin = request + .headers() + .get(header::ORIGIN) + .and_then(|v| v.to_str().ok()); + + let method = request.method(); + + // 处理预检请求 + if method == Method::OPTIONS { + let mut response = Response::builder() + .status(StatusCode::OK) + .body(axum::body::Body::empty()) + .unwrap(); + + let headers = response.headers_mut(); + + if let Some(origin) = origin { + headers.insert( + header::ACCESS_CONTROL_ALLOW_ORIGIN, + HeaderValue::from_str(origin).unwrap(), + ); + } + + headers.insert( + header::ACCESS_CONTROL_ALLOW_METHODS, + HeaderValue::from_static("GET, POST, PUT, DELETE, PATCH, OPTIONS"), + ); + + headers.insert( + header::ACCESS_CONTROL_ALLOW_HEADERS, + HeaderValue::from_static("content-type, authorization, x-requested-with, x-api-key"), + ); + + headers.insert( + header::ACCESS_CONTROL_MAX_AGE, + HeaderValue::from_static("3600"), + ); + + headers.insert( + header::ACCESS_CONTROL_ALLOW_CREDENTIALS, + HeaderValue::from_static("true"), + ); + + return response; + } + + let mut response = next.run(request).await; + + // 为实际请求添加 CORS 头 + let headers = response.headers_mut(); + + if let Some(origin) = origin { + headers.insert( + header::ACCESS_CONTROL_ALLOW_ORIGIN, + HeaderValue::from_str(origin).unwrap(), + ); + } + + headers.insert( + header::ACCESS_CONTROL_ALLOW_CREDENTIALS, + HeaderValue::from_static("true"), + ); + + response +} +``` + +## 请求日志中间件 (logging.rs) + +### 功能特性 + +- 请求响应日志记录 +- 性能监控 +- 错误追踪 +- 结构化日志 +- 请求ID追踪 + +### 实现代码 + +```rust +use axum::{ + extract::{MatchedPath, Request}, + http::StatusCode, + middleware::Next, + response::Response, +}; +use std::time::Instant; +use tracing::{error, info, warn}; +use uuid::Uuid; + +/// 请求日志中间件 +pub async fn request_logging_middleware( + request: Request, + next: Next, +) -> Response { + let start_time = Instant::now(); + let request_id = Uuid::new_v4().to_string(); + + // 提取请求信息 + let method = request.method().clone(); + let uri = request.uri().clone(); + let version = request.version(); + let user_agent = request + .headers() + .get("user-agent") + .and_then(|v| v.to_str().ok()) + .unwrap_or("unknown"); + let remote_addr = request + .headers() + .get("x-forwarded-for") + .and_then(|v| v.to_str().ok()) + .or_else(|| { + request + .headers() + .get("x-real-ip") + .and_then(|v| v.to_str().ok()) + }) + .unwrap_or("unknown"); + + // 获取匹配的路径模式 + let matched_path = request + .extensions() + .get::() + .map(|mp| mp.as_str()) + .unwrap_or(uri.path()); + + info!( + target = "udmin", + request_id = %request_id, + method = %method, + uri = %uri, + path = %matched_path, + version = ?version, + user_agent = %user_agent, + remote_addr = %remote_addr, + "Request started" + ); + + // 执行请求 + let response = next.run(request).await; + + // 计算请求耗时 + let duration = start_time.elapsed(); + let status = response.status(); + + // 根据状态码选择日志级别 + match status.as_u16() { + 200..=299 => { + info!( + target = "udmin", + request_id = %request_id, + method = %method, + uri = %uri, + path = %matched_path, + status = %status.as_u16(), + duration_ms = %duration.as_millis(), + "Request completed successfully" + ); + } + 300..=399 => { + info!( + target = "udmin", + request_id = %request_id, + method = %method, + uri = %uri, + path = %matched_path, + status = %status.as_u16(), + duration_ms = %duration.as_millis(), + "Request redirected" + ); + } + 400..=499 => { + warn!( + target = "udmin", + request_id = %request_id, + method = %method, + uri = %uri, + path = %matched_path, + status = %status.as_u16(), + duration_ms = %duration.as_millis(), + "Client error" + ); + } + 500..=599 => { + error!( + target = "udmin", + request_id = %request_id, + method = %method, + uri = %uri, + path = %matched_path, + status = %status.as_u16(), + duration_ms = %duration.as_millis(), + "Server error" + ); + } + _ => { + info!( + target = "udmin", + request_id = %request_id, + method = %method, + uri = %uri, + path = %matched_path, + status = %status.as_u16(), + duration_ms = %duration.as_millis(), + "Request completed" + ); + } + } + + response +} + +/// 性能监控中间件 +pub async fn performance_monitoring_middleware( + request: Request, + next: Next, +) -> Response { + let start_time = Instant::now(); + let method = request.method().clone(); + let uri = request.uri().clone(); + + let response = next.run(request).await; + let duration = start_time.elapsed(); + + // 记录慢请求 + if duration.as_millis() > 1000 { + warn!( + target = "udmin", + method = %method, + uri = %uri, + duration_ms = %duration.as_millis(), + "Slow request detected" + ); + } + + // 记录性能指标 + info!( + target = "udmin.performance", + method = %method, + uri = %uri, + status = %response.status().as_u16(), + duration_ms = %duration.as_millis(), + "Performance metrics" + ); + + response +} +``` + +## HTTP 客户端中间件 (http_client.rs) + +### 功能特性 + +- HTTP 客户端封装 +- 请求重试机制 +- 超时控制 +- 请求日志 +- 错误处理 + +### 实现代码 + +```rust +use reqwest::{Client, ClientBuilder, Response}; +use serde::{Deserialize, Serialize}; +use std::time::Duration; +use tracing::{error, info, warn}; +use url::Url; + +use crate::error::AppError; + +/// HTTP 客户端配置 +#[derive(Debug, Clone)] +pub struct HttpClientConfig { + pub timeout: Duration, + pub connect_timeout: Duration, + pub max_retries: u32, + pub retry_delay: Duration, + pub user_agent: String, +} + +impl Default for HttpClientConfig { + fn default() -> Self { + Self { + timeout: Duration::from_secs(30), + connect_timeout: Duration::from_secs(10), + max_retries: 3, + retry_delay: Duration::from_millis(1000), + user_agent: "UdminAI/1.0".to_string(), + } + } +} + +/// HTTP 客户端包装器 +#[derive(Debug, Clone)] +pub struct HttpClient { + client: Client, + config: HttpClientConfig, +} + +impl HttpClient { + /// 创建新的 HTTP 客户端 + pub fn new(config: HttpClientConfig) -> Result { + let client = ClientBuilder::new() + .timeout(config.timeout) + .connect_timeout(config.connect_timeout) + .user_agent(&config.user_agent) + .build() + .map_err(|e| AppError::InternalServerError(format!("创建HTTP客户端失败: {}", e)))?; + + Ok(Self { client, config }) + } + + /// 发送 GET 请求 + pub async fn get(&self, url: &str) -> Result { + self.request_with_retry("GET", url, None::<()>).await + } + + /// 发送 POST 请求 + pub async fn post(&self, url: &str, body: &T) -> Result { + self.request_with_retry("POST", url, Some(body)).await + } + + /// 发送 PUT 请求 + pub async fn put(&self, url: &str, body: &T) -> Result { + self.request_with_retry("PUT", url, Some(body)).await + } + + /// 发送 DELETE 请求 + pub async fn delete(&self, url: &str) -> Result { + self.request_with_retry("DELETE", url, None::<()>).await + } + + /// 带重试的请求 + async fn request_with_retry( + &self, + method: &str, + url: &str, + body: Option<&T>, + ) -> Result { + let parsed_url = Url::parse(url) + .map_err(|e| AppError::BadRequest(format!("无效的URL: {}", e)))?; + + let mut last_error = None; + + for attempt in 0..=self.config.max_retries { + let start_time = std::time::Instant::now(); + + info!( + target = "udmin", + method = %method, + url = %url, + attempt = %attempt, + "HTTP request started" + ); + + let result = self.make_request(method, &parsed_url, body).await; + let duration = start_time.elapsed(); + + match result { + Ok(response) => { + let status = response.status(); + + info!( + target = "udmin", + method = %method, + url = %url, + status = %status.as_u16(), + duration_ms = %duration.as_millis(), + attempt = %attempt, + "HTTP request completed" + ); + + // 如果是服务器错误且还有重试次数,则重试 + if status.is_server_error() && attempt < self.config.max_retries { + warn!( + target = "udmin", + method = %method, + url = %url, + status = %status.as_u16(), + attempt = %attempt, + "Server error, retrying" + ); + + tokio::time::sleep(self.config.retry_delay).await; + continue; + } + + return Ok(response); + } + Err(e) => { + error!( + target = "udmin", + method = %method, + url = %url, + error = %e, + duration_ms = %duration.as_millis(), + attempt = %attempt, + "HTTP request failed" + ); + + last_error = Some(e); + + // 如果还有重试次数,则等待后重试 + if attempt < self.config.max_retries { + tokio::time::sleep(self.config.retry_delay).await; + } + } + } + } + + Err(last_error.unwrap_or_else(|| { + AppError::InternalServerError("HTTP请求失败".to_string()) + })) + } + + /// 执行实际的 HTTP 请求 + async fn make_request( + &self, + method: &str, + url: &Url, + body: Option<&T>, + ) -> Result { + let mut request_builder = match method { + "GET" => self.client.get(url.clone()), + "POST" => self.client.post(url.clone()), + "PUT" => self.client.put(url.clone()), + "DELETE" => self.client.delete(url.clone()), + _ => { + return Err(AppError::BadRequest(format!( + "不支持的HTTP方法: {}", + method + ))) + } + }; + + // 添加请求体 + if let Some(body) = body { + request_builder = request_builder.json(body); + } + + // 发送请求 + let response = request_builder + .send() + .await + .map_err(|e| AppError::InternalServerError(format!("HTTP请求失败: {}", e)))?; + + Ok(response) + } + + /// 发送 JSON 请求并解析响应 + pub async fn json_request( + &self, + method: &str, + url: &str, + body: Option<&T>, + ) -> Result + where + T: Serialize, + R: for<'de> Deserialize<'de>, + { + let response = match method { + "GET" => self.get(url).await?, + "POST" => { + if let Some(body) = body { + self.post(url, body).await? + } else { + return Err(AppError::BadRequest("POST请求需要请求体".to_string())); + } + } + "PUT" => { + if let Some(body) = body { + self.put(url, body).await? + } else { + return Err(AppError::BadRequest("PUT请求需要请求体".to_string())); + } + } + "DELETE" => self.delete(url).await?, + _ => { + return Err(AppError::BadRequest(format!( + "不支持的HTTP方法: {}", + method + ))) + } + }; + + // 检查响应状态 + if !response.status().is_success() { + let status = response.status(); + let error_text = response + .text() + .await + .unwrap_or_else(|_| "无法读取错误响应".to_string()); + + return Err(AppError::InternalServerError(format!( + "HTTP请求失败: {} - {}", + status, error_text + ))); + } + + // 解析 JSON 响应 + let json_response = response + .json::() + .await + .map_err(|e| AppError::InternalServerError(format!("解析JSON响应失败: {}", e)))?; + + Ok(json_response) + } +} +``` + +## 限流中间件 (rate_limit.rs) + +### 功能特性 + +- 基于令牌桶的限流 +- 支持不同的限流策略 +- IP 级别限流 +- 用户级别限流 +- 动态配置 + +### 实现代码 + +```rust +use axum::{ + extract::{Request, State}, + http::StatusCode, + middleware::Next, + response::Response, +}; +use std::{ + collections::HashMap, + sync::{Arc, Mutex}, + time::{Duration, Instant}, +}; +use tracing::{info, warn}; + +use crate::{error::AppError, middlewares::auth::AuthContext}; + +/// 令牌桶 +#[derive(Debug, Clone)] +struct TokenBucket { + capacity: u32, + tokens: u32, + last_refill: Instant, + refill_rate: u32, // tokens per second +} + +impl TokenBucket { + fn new(capacity: u32, refill_rate: u32) -> Self { + Self { + capacity, + tokens: capacity, + last_refill: Instant::now(), + refill_rate, + } + } + + fn try_consume(&mut self, tokens: u32) -> bool { + self.refill(); + + if self.tokens >= tokens { + self.tokens -= tokens; + true + } else { + false + } + } + + fn refill(&mut self) { + let now = Instant::now(); + let elapsed = now.duration_since(self.last_refill); + let tokens_to_add = (elapsed.as_secs_f64() * self.refill_rate as f64) as u32; + + if tokens_to_add > 0 { + self.tokens = (self.tokens + tokens_to_add).min(self.capacity); + self.last_refill = now; + } + } +} + +/// 限流器 +#[derive(Debug)] +pub struct RateLimiter { + buckets: Arc>>, + default_capacity: u32, + default_refill_rate: u32, +} + +impl RateLimiter { + pub fn new(default_capacity: u32, default_refill_rate: u32) -> Self { + Self { + buckets: Arc::new(Mutex::new(HashMap::new())), + default_capacity, + default_refill_rate, + } + } + + pub fn check_rate_limit(&self, key: &str, tokens: u32) -> bool { + let mut buckets = self.buckets.lock().unwrap(); + + let bucket = buckets + .entry(key.to_string()) + .or_insert_with(|| TokenBucket::new(self.default_capacity, self.default_refill_rate)); + + bucket.try_consume(tokens) + } + + /// 清理过期的桶 + pub fn cleanup_expired_buckets(&self, max_idle_duration: Duration) { + let mut buckets = self.buckets.lock().unwrap(); + let now = Instant::now(); + + buckets.retain(|_, bucket| { + now.duration_since(bucket.last_refill) < max_idle_duration + }); + } +} + +/// 限流配置 +#[derive(Debug, Clone)] +pub struct RateLimitConfig { + pub requests_per_second: u32, + pub burst_capacity: u32, + pub enable_ip_limit: bool, + pub enable_user_limit: bool, +} + +impl Default for RateLimitConfig { + fn default() -> Self { + Self { + requests_per_second: 100, + burst_capacity: 200, + enable_ip_limit: true, + enable_user_limit: true, + } + } +} + +/// IP 限流中间件 +pub async fn ip_rate_limit_middleware( + State(rate_limiter): State>, + request: Request, + next: Next, +) -> Result { + // 获取客户端 IP + let client_ip = request + .headers() + .get("x-forwarded-for") + .and_then(|v| v.to_str().ok()) + .and_then(|v| v.split(',').next()) + .or_else(|| { + request + .headers() + .get("x-real-ip") + .and_then(|v| v.to_str().ok()) + }) + .unwrap_or("unknown") + .trim(); + + let rate_limit_key = format!("ip:{}", client_ip); + + // 检查限流 + if !rate_limiter.check_rate_limit(&rate_limit_key, 1) { + warn!( + target = "udmin", + client_ip = %client_ip, + "IP rate limit exceeded" + ); + + return Err(AppError::TooManyRequests( + "请求过于频繁,请稍后再试".to_string(), + )); + } + + info!( + target = "udmin", + client_ip = %client_ip, + "IP rate limit check passed" + ); + + Ok(next.run(request).await) +} + +/// 用户限流中间件 +pub async fn user_rate_limit_middleware( + State(rate_limiter): State>, + request: Request, + next: Next, +) -> Result { + // 获取用户认证信息 + if let Some(auth_context) = request.extensions().get::() { + let rate_limit_key = format!("user:{}", auth_context.user_id); + + // 检查用户级别限流 + if !rate_limiter.check_rate_limit(&rate_limit_key, 1) { + warn!( + target = "udmin", + user_id = %auth_context.user_id, + "User rate limit exceeded" + ); + + return Err(AppError::TooManyRequests( + "用户请求过于频繁,请稍后再试".to_string(), + )); + } + + info!( + target = "udmin", + user_id = %auth_context.user_id, + "User rate limit check passed" + ); + } + + Ok(next.run(request).await) +} + +/// 创建限流中间件 +pub fn create_rate_limit_middleware( + config: RateLimitConfig, +) -> Arc { + let rate_limiter = Arc::new(RateLimiter::new( + config.burst_capacity, + config.requests_per_second, + )); + + // 启动清理任务 + let cleanup_limiter = rate_limiter.clone(); + tokio::spawn(async move { + let mut interval = tokio::time::interval(Duration::from_secs(300)); // 5分钟清理一次 + + loop { + interval.tick().await; + cleanup_limiter.cleanup_expired_buckets(Duration::from_secs(3600)); // 清理1小时未使用的桶 + } + }); + + info!(target = "udmin", "Rate limiter initialized"); + rate_limiter +} +``` + +## WebSocket 中间件 (ws.rs) + +### 功能特性 + +- WebSocket 连接管理 +- 实时消息推送 +- 连接认证 +- 心跳检测 +- 广播支持 + +### 实现代码 + +```rust +use axum::{ + extract::{ws::WebSocket, Query, State, WebSocketUpgrade}, + response::Response, +}; +use futures_util::{sink::SinkExt, stream::StreamExt}; +use serde::{Deserialize, Serialize}; +use std::{ + collections::HashMap, + sync::{Arc, RwLock}, + time::{Duration, Instant}, +}; +use tokio::sync::mpsc; +use tracing::{error, info, warn}; +use uuid::Uuid; + +use crate::{ + error::AppError, + services::user_service, + AppState, +}; + +/// WebSocket 连接信息 +#[derive(Debug, Clone)] +pub struct ConnectionInfo { + pub id: String, + pub user_id: Option, + pub connected_at: Instant, + pub last_ping: Instant, + pub sender: mpsc::UnboundedSender, +} + +/// WebSocket 消息类型 +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", content = "data")] +pub enum WebSocketMessage { + /// 认证消息 + Auth { token: String }, + /// 认证成功 + AuthSuccess { user_id: String }, + /// 认证失败 + AuthError { message: String }, + /// 心跳包 + Ping, + /// 心跳响应 + Pong, + /// 流程执行状态更新 + FlowExecutionUpdate { + execution_id: String, + status: String, + progress: Option, + message: Option, + }, + /// 任务执行状态更新 + JobExecutionUpdate { + job_id: String, + execution_id: String, + status: String, + message: Option, + }, + /// 系统通知 + SystemNotification { + title: String, + message: String, + level: String, + }, + /// 用户通知 + UserNotification { + id: String, + title: String, + message: String, + created_at: String, + }, + /// 错误消息 + Error { message: String }, +} + +/// WebSocket 连接管理器 +#[derive(Debug)] +pub struct WebSocketManager { + connections: Arc>>, + user_connections: Arc>>>, // user_id -> connection_ids +} + +impl WebSocketManager { + pub fn new() -> Self { + Self { + connections: Arc::new(RwLock::new(HashMap::new())), + user_connections: Arc::new(RwLock::new(HashMap::new())), + } + } + + /// 添加连接 + pub fn add_connection(&self, connection_info: ConnectionInfo) { + let connection_id = connection_info.id.clone(); + let user_id = connection_info.user_id.clone(); + + // 添加到连接列表 + self.connections.write().unwrap().insert(connection_id.clone(), connection_info); + + // 如果有用户ID,添加到用户连接映射 + if let Some(user_id) = user_id { + self.user_connections + .write() + .unwrap() + .entry(user_id) + .or_insert_with(Vec::new) + .push(connection_id.clone()); + } + + info!( + target = "udmin", + connection_id = %connection_id, + "WebSocket connection added" + ); + } + + /// 移除连接 + pub fn remove_connection(&self, connection_id: &str) { + let mut connections = self.connections.write().unwrap(); + + if let Some(connection_info) = connections.remove(connection_id) { + // 从用户连接映射中移除 + if let Some(user_id) = &connection_info.user_id { + let mut user_connections = self.user_connections.write().unwrap(); + if let Some(user_conn_list) = user_connections.get_mut(user_id) { + user_conn_list.retain(|id| id != connection_id); + if user_conn_list.is_empty() { + user_connections.remove(user_id); + } + } + } + + info!( + target = "udmin", + connection_id = %connection_id, + user_id = ?connection_info.user_id, + "WebSocket connection removed" + ); + } + } + + /// 向指定连接发送消息 + pub fn send_to_connection(&self, connection_id: &str, message: WebSocketMessage) -> bool { + let connections = self.connections.read().unwrap(); + + if let Some(connection_info) = connections.get(connection_id) { + match connection_info.sender.send(message) { + Ok(_) => true, + Err(e) => { + warn!( + target = "udmin", + connection_id = %connection_id, + error = %e, + "Failed to send message to connection" + ); + false + } + } + } else { + false + } + } + + /// 向指定用户的所有连接发送消息 + pub fn send_to_user(&self, user_id: &str, message: WebSocketMessage) -> usize { + let user_connections = self.user_connections.read().unwrap(); + let mut sent_count = 0; + + if let Some(connection_ids) = user_connections.get(user_id) { + for connection_id in connection_ids { + if self.send_to_connection(connection_id, message.clone()) { + sent_count += 1; + } + } + } + + info!( + target = "udmin", + user_id = %user_id, + sent_count = %sent_count, + "Message sent to user connections" + ); + + sent_count + } + + /// 广播消息给所有连接 + pub fn broadcast(&self, message: WebSocketMessage) -> usize { + let connections = self.connections.read().unwrap(); + let mut sent_count = 0; + + for (connection_id, connection_info) in connections.iter() { + match connection_info.sender.send(message.clone()) { + Ok(_) => sent_count += 1, + Err(e) => { + warn!( + target = "udmin", + connection_id = %connection_id, + error = %e, + "Failed to broadcast message" + ); + } + } + } + + info!( + target = "udmin", + sent_count = %sent_count, + total_connections = %connections.len(), + "Message broadcasted" + ); + + sent_count + } + + /// 获取连接统计信息 + pub fn get_stats(&self) -> WebSocketStats { + let connections = self.connections.read().unwrap(); + let user_connections = self.user_connections.read().unwrap(); + + WebSocketStats { + total_connections: connections.len(), + authenticated_connections: connections.values().filter(|c| c.user_id.is_some()).count(), + unique_users: user_connections.len(), + } + } + + /// 清理过期连接 + pub fn cleanup_stale_connections(&self, timeout: Duration) { + let now = Instant::now(); + let connections = self.connections.read().unwrap(); + let stale_connections: Vec = connections + .iter() + .filter(|(_, info)| now.duration_since(info.last_ping) > timeout) + .map(|(id, _)| id.clone()) + .collect(); + + drop(connections); + + for connection_id in stale_connections { + self.remove_connection(&connection_id); + warn!( + target = "udmin", + connection_id = %connection_id, + "Removed stale WebSocket connection" + ); + } + } +} + +/// WebSocket 统计信息 +#[derive(Debug, Serialize)] +pub struct WebSocketStats { + pub total_connections: usize, + pub authenticated_connections: usize, + pub unique_users: usize, +} + +/// WebSocket 查询参数 +#[derive(Debug, Deserialize)] +pub struct WebSocketQuery { + pub token: Option, +} + +/// WebSocket 升级处理器 +pub async fn websocket_handler( + ws: WebSocketUpgrade, + Query(params): Query, + State(state): State, +) -> Response { + ws.on_upgrade(move |socket| handle_websocket(socket, params.token, state)) +} + +/// 处理 WebSocket 连接 +async fn handle_websocket( + socket: WebSocket, + token: Option, + state: AppState, +) { + let connection_id = Uuid::new_v4().to_string(); + let (mut sender, mut receiver) = socket.split(); + let (tx, mut rx) = mpsc::unbounded_channel::(); + + info!( + target = "udmin", + connection_id = %connection_id, + "WebSocket connection established" + ); + + // 创建连接信息 + let connection_info = ConnectionInfo { + id: connection_id.clone(), + user_id: None, + connected_at: Instant::now(), + last_ping: Instant::now(), + sender: tx, + }; + + // 添加到连接管理器 + state.ws_manager.add_connection(connection_info); + + // 如果提供了token,尝试认证 + let mut authenticated_user_id = None; + if let Some(token) = token { + match authenticate_websocket_token(&token, &state).await { + Ok(user_id) => { + authenticated_user_id = Some(user_id.clone()); + + // 更新连接信息 + if let Some(mut connection_info) = state.connections.write().unwrap().get_mut(&connection_id) { + connection_info.user_id = Some(user_id.clone()); + } + + // 发送认证成功消息 + let _ = sender.send(axum::extract::ws::Message::Text( + serde_json::to_string(&WebSocketMessage::AuthSuccess { user_id }).unwrap() + )).await; + + info!( + target = "udmin", + connection_id = %connection_id, + user_id = %authenticated_user_id.as_ref().unwrap(), + "WebSocket connection authenticated" + ); + } + Err(e) => { + warn!( + target = "udmin", + connection_id = %connection_id, + error = %e, + "WebSocket authentication failed" + ); + + let _ = sender.send(axum::extract::ws::Message::Text( + serde_json::to_string(&WebSocketMessage::AuthError { + message: "认证失败".to_string() + }).unwrap() + )).await; + } + } + } + + // 启动消息发送任务 + let connection_id_clone = connection_id.clone(); + let ws_manager_clone = state.ws_manager.clone(); + let send_task = tokio::spawn(async move { + while let Some(message) = rx.recv().await { + let text = match serde_json::to_string(&message) { + Ok(text) => text, + Err(e) => { + error!( + target = "udmin", + connection_id = %connection_id_clone, + error = %e, + "Failed to serialize WebSocket message" + ); + continue; + } + }; + + if sender.send(axum::extract::ws::Message::Text(text)).await.is_err() { + break; + } + } + + // 移除连接 + ws_manager_clone.remove_connection(&connection_id_clone); + }); + + // 启动心跳任务 + let connection_id_clone = connection_id.clone(); + let ws_manager_clone = state.ws_manager.clone(); + let heartbeat_task = tokio::spawn(async move { + let mut interval = tokio::time::interval(Duration::from_secs(30)); + + loop { + interval.tick().await; + + if !ws_manager_clone.send_to_connection( + &connection_id_clone, + WebSocketMessage::Ping, + ) { + break; + } + } + }); + + // 处理接收到的消息 + while let Some(msg) = receiver.next().await { + match msg { + Ok(axum::extract::ws::Message::Text(text)) => { + match serde_json::from_str::(&text) { + Ok(WebSocketMessage::Auth { token }) => { + // 处理认证消息 + match authenticate_websocket_token(&token, &state).await { + Ok(user_id) => { + authenticated_user_id = Some(user_id.clone()); + + let _ = state.ws_manager.send_to_connection( + &connection_id, + WebSocketMessage::AuthSuccess { user_id }, + ); + } + Err(_) => { + let _ = state.ws_manager.send_to_connection( + &connection_id, + WebSocketMessage::AuthError { + message: "认证失败".to_string(), + }, + ); + } + } + } + Ok(WebSocketMessage::Pong) => { + // 更新最后ping时间 + // 这里可以更新连接的last_ping时间 + } + Ok(_) => { + // 处理其他消息类型 + } + Err(e) => { + warn!( + target = "udmin", + connection_id = %connection_id, + error = %e, + "Failed to parse WebSocket message" + ); + } + } + } + Ok(axum::extract::ws::Message::Close(_)) => { + info!( + target = "udmin", + connection_id = %connection_id, + "WebSocket connection closed by client" + ); + break; + } + Err(e) => { + error!( + target = "udmin", + connection_id = %connection_id, + error = %e, + "WebSocket error" + ); + break; + } + _ => {} + } + } + + // 清理任务 + send_task.abort(); + heartbeat_task.abort(); + state.ws_manager.remove_connection(&connection_id); + + info!( + target = "udmin", + connection_id = %connection_id, + "WebSocket connection closed" + ); +} + +/// 认证 WebSocket Token +async fn authenticate_websocket_token( + token: &str, + state: &AppState, +) -> Result { + use jsonwebtoken::{decode, DecodingKey, Validation}; + use crate::middlewares::auth::Claims; + + // 验证 JWT Token + let claims = decode::( + token, + &DecodingKey::from_secret(state.config.jwt_secret.as_ref()), + &Validation::default(), + ) + .map_err(|_| AppError::Unauthorized("无效的认证令牌".to_string()))? + .claims; + + // 检查用户是否存在且活跃 + let user = user_service::find_by_id(&state.db, &claims.sub) + .await + .map_err(|_| AppError::Unauthorized("用户不存在".to_string()))?; + + if user.status != crate::models::user::UserStatus::Active { + return Err(AppError::Unauthorized("用户账户已被禁用".to_string())); + } + + Ok(claims.sub) +} + +/// 启动 WebSocket 管理器清理任务 +pub fn start_websocket_cleanup_task(ws_manager: Arc) { + tokio::spawn(async move { + let mut interval = tokio::time::interval(Duration::from_secs(60)); + + loop { + interval.tick().await; + ws_manager.cleanup_stale_connections(Duration::from_secs(300)); // 5分钟超时 + } + }); +} +``` + +## SSE 中间件 (sse.rs) + +### 功能特性 + +- Server-Sent Events 支持 +- 实时事件推送 +- 连接管理 +- 事件过滤 +- 重连支持 + +### 实现代码 + +```rust +use axum::{ + extract::{Query, State}, + http::{header, HeaderValue, StatusCode}, + response::{sse::Event, Response, Sse}, +}; +use futures_util::stream::{self, Stream}; +use serde::{Deserialize, Serialize}; +use std::{ + collections::HashMap, + convert::Infallible, + sync::{Arc, RwLock}, + time::{Duration, SystemTime, UNIX_EPOCH}, +}; +use tokio::sync::mpsc; +use tokio_stream::wrappers::UnboundedReceiverStream; +use tracing::{error, info, warn}; +use uuid::Uuid; + +use crate::{ + error::AppError, + middlewares::auth::AuthContext, + AppState, +}; + +/// SSE 事件类型 +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", content = "data")] +pub enum SseEvent { + /// 流程执行状态更新 + FlowExecutionUpdate { + execution_id: String, + flow_id: String, + status: String, + progress: Option, + message: Option, + timestamp: u64, + }, + /// 任务执行状态更新 + JobExecutionUpdate { + job_id: String, + execution_id: String, + status: String, + message: Option, + timestamp: u64, + }, + /// 系统通知 + SystemNotification { + id: String, + title: String, + message: String, + level: String, + timestamp: u64, + }, + /// 用户通知 + UserNotification { + id: String, + title: String, + message: String, + timestamp: u64, + }, + /// 系统状态更新 + SystemStatus { + cpu_usage: f64, + memory_usage: f64, + disk_usage: f64, + active_connections: usize, + timestamp: u64, + }, + /// 心跳事件 + Heartbeat { + timestamp: u64, + }, +} + +/// SSE 连接信息 +#[derive(Debug)] +struct SseConnection { + id: String, + user_id: Option, + sender: mpsc::UnboundedSender>, + filters: Vec, + created_at: SystemTime, +} + +/// SSE 连接管理器 +#[derive(Debug)] +pub struct SseManager { + connections: Arc>>, + user_connections: Arc>>>, +} + +impl SseManager { + pub fn new() -> Self { + Self { + connections: Arc::new(RwLock::new(HashMap::new())), + user_connections: Arc, + sender: mpsc::UnboundedSender>, + filters: Vec, + ) { + let connection = SseConnection { + id: connection_id.clone(), + user_id: user_id.clone(), + sender, + filters, + created_at: SystemTime::now(), + }; + + // 添加到连接列表 + self.connections.write().unwrap().insert(connection_id.clone(), connection); + + // 如果有用户ID,添加到用户连接映射 + if let Some(user_id) = user_id { + self.user_connections + .write() + .unwrap() + .entry(user_id) + .or_insert_with(Vec::new) + .push(connection_id.clone()); + } + + info!( + target = "udmin", + connection_id = %connection_id, + "SSE connection added" + ); + } + + /// 移除 SSE 连接 + pub fn remove_connection(&self, connection_id: &str) { + let mut connections = self.connections.write().unwrap(); + + if let Some(connection) = connections.remove(connection_id) { + // 从用户连接映射中移除 + if let Some(user_id) = &connection.user_id { + let mut user_connections = self.user_connections.write().unwrap(); + if let Some(user_conn_list) = user_connections.get_mut(user_id) { + user_conn_list.retain(|id| id != connection_id); + if user_conn_list.is_empty() { + user_connections.remove(user_id); + } + } + } + + info!( + target = "udmin", + connection_id = %connection_id, + user_id = ?connection.user_id, + "SSE connection removed" + ); + } + } + + /// 发送事件到指定连接 + pub fn send_to_connection(&self, connection_id: &str, event: SseEvent) -> bool { + let connections = self.connections.read().unwrap(); + + if let Some(connection) = connections.get(connection_id) { + // 检查事件过滤器 + if !connection.filters.is_empty() { + let event_type = match &event { + SseEvent::FlowExecutionUpdate { .. } => "flow_execution", + SseEvent::JobExecutionUpdate { .. } => "job_execution", + SseEvent::SystemNotification { .. } => "system_notification", + SseEvent::UserNotification { .. } => "user_notification", + SseEvent::SystemStatus { .. } => "system_status", + SseEvent::Heartbeat { .. } => "heartbeat", + }; + + if !connection.filters.contains(&event_type.to_string()) { + return true; // 过滤掉,但不算失败 + } + } + + let sse_event = match create_sse_event(&event) { + Ok(event) => event, + Err(e) => { + error!( + target = "udmin", + connection_id = %connection_id, + error = %e, + "Failed to create SSE event" + ); + return false; + } + }; + + match connection.sender.send(Ok(sse_event)) { + Ok(_) => true, + Err(e) => { + warn!( + target = "udmin", + connection_id = %connection_id, + error = %e, + "Failed to send SSE event" + ); + false + } + } + } else { + false + } + } + + /// 发送事件到指定用户的所有连接 + pub fn send_to_user(&self, user_id: &str, event: SseEvent) -> usize { + let user_connections = self.user_connections.read().unwrap(); + let mut sent_count = 0; + + if let Some(connection_ids) = user_connections.get(user_id) { + for connection_id in connection_ids { + if self.send_to_connection(connection_id, event.clone()) { + sent_count += 1; + } + } + } + + info!( + target = "udmin", + user_id = %user_id, + sent_count = %sent_count, + "SSE event sent to user connections" + ); + + sent_count + } + + /// 广播事件到所有连接 + pub fn broadcast(&self, event: SseEvent) -> usize { + let connections = self.connections.read().unwrap(); + let mut sent_count = 0; + + for (connection_id, _) in connections.iter() { + if self.send_to_connection(connection_id, event.clone()) { + sent_count += 1; + } + } + + info!( + target = "udmin", + sent_count = %sent_count, + total_connections = %connections.len(), + "SSE event broadcasted" + ); + + sent_count + } + + /// 获取连接统计信息 + pub fn get_stats(&self) -> SseStats { + let connections = self.connections.read().unwrap(); + let user_connections = self.user_connections.read().unwrap(); + + SseStats { + total_connections: connections.len(), + authenticated_connections: connections.values().filter(|c| c.user_id.is_some()).count(), + unique_users: user_connections.len(), + } + } +} + +/// SSE 统计信息 +#[derive(Debug, Serialize)] +pub struct SseStats { + pub total_connections: usize, + pub authenticated_connections: usize, + pub unique_users: usize, +} + +/// SSE 查询参数 +#[derive(Debug, Deserialize)] +pub struct SseQuery { + pub filters: Option, // 逗号分隔的事件类型过滤器 + pub last_event_id: Option, +} + +/// SSE 处理器 +pub async fn sse_handler( + Query(params): Query, + State(state): State, + auth_context: Option, +) -> Result { + let connection_id = Uuid::new_v4().to_string(); + let (tx, rx) = mpsc::unbounded_channel(); + + // 解析过滤器 + let filters = params + .filters + .map(|f| f.split(',').map(|s| s.trim().to_string()).collect()) + .unwrap_or_default(); + + // 添加连接到管理器 + state.sse_manager.add_connection( + connection_id.clone(), + auth_context.as_ref().map(|ctx| ctx.user_id.clone()), + tx, + filters, + ); + + info!( + target = "udmin", + connection_id = %connection_id, + user_id = ?auth_context.as_ref().map(|ctx| &ctx.user_id), + "SSE connection established" + ); + + // 发送初始心跳 + let initial_heartbeat = SseEvent::Heartbeat { + timestamp: SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(), + }; + + state.sse_manager.send_to_connection(&connection_id, initial_heartbeat); + + // 创建事件流 + let stream = UnboundedReceiverStream::new(rx); + + // 启动心跳任务 + let connection_id_clone = connection_id.clone(); + let sse_manager_clone = state.sse_manager.clone(); + tokio::spawn(async move { + let mut interval = tokio::time::interval(Duration::from_secs(30)); + + loop { + interval.tick().await; + + let heartbeat = SseEvent::Heartbeat { + timestamp: SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(), + }; + + if !sse_manager_clone.send_to_connection(&connection_id_clone, heartbeat) { + break; + } + } + + // 清理连接 + sse_manager_clone.remove_connection(&connection_id_clone); + }); + + let sse = Sse::new(stream).keep_alive( + axum::response::sse::KeepAlive::new() + .interval(Duration::from_secs(15)) + .text("keep-alive-text"), + ); + + Ok(sse.into_response()) +} + +/// 创建 SSE 事件 +fn create_sse_event(event: &SseEvent) -> Result { + let event_type = match event { + SseEvent::FlowExecutionUpdate { .. } => "flow_execution_update", + SseEvent::JobExecutionUpdate { .. } => "job_execution_update", + SseEvent::SystemNotification { .. } => "system_notification", + SseEvent::UserNotification { .. } => "user_notification", + SseEvent::SystemStatus { .. } => "system_status", + SseEvent::Heartbeat { .. } => "heartbeat", + }; + + let data = serde_json::to_string(event)?; + + Ok(Event::default() + .event(event_type) + .data(data)) +} + +/// 启动 SSE 管理器清理任务 +pub fn start_sse_cleanup_task(sse_manager: Arc) { + tokio::spawn(async move { + let mut interval = tokio::time::interval(Duration::from_secs(300)); // 5分钟清理一次 + + loop { + interval.tick().await; + + // 这里可以添加清理逻辑,比如移除长时间未活动的连接 + let stats = sse_manager.get_stats(); + info!( + target = "udmin", + total_connections = %stats.total_connections, + authenticated_connections = %stats.authenticated_connections, + unique_users = %stats.unique_users, + "SSE connection stats" + ); + } + }); +} +``` + +## 中间件组合和配置 + +### 中间件栈配置 + +```rust +use axum::{ + middleware, + Router, +}; +use std::sync::Arc; +use tower::ServiceBuilder; +use tower_http::{ + compression::CompressionLayer, + timeout::TimeoutLayer, + trace::TraceLayer, +}; + +use crate::{ + middlewares::{ + auth::{jwt_auth_middleware, optional_auth_middleware}, + cors::create_cors_layer, + logging::{request_logging_middleware, performance_monitoring_middleware}, + rate_limit::{create_rate_limit_middleware, ip_rate_limit_middleware, user_rate_limit_middleware}, + }, + AppState, +}; + +/// 创建中间件栈 +pub fn create_middleware_stack(state: AppState) -> ServiceBuilder< + impl tower::Layer< + axum::routing::Router, + Service = impl tower::Service< + http::Request, + Response = axum::response::Response, + Error = std::convert::Infallible, + > + Clone + Send + 'static, + > + Clone, +> { + // 创建限流器 + let rate_limiter = create_rate_limit_middleware(Default::default()); + + ServiceBuilder::new() + // 请求追踪 + .layer(TraceLayer::new_for_http()) + // 请求超时 + .layer(TimeoutLayer::new(std::time::Duration::from_secs(30))) + // 响应压缩 + .layer(CompressionLayer::new()) + // CORS + .layer(create_cors_layer(Default::default())) + // 请求日志 + .layer(middleware::from_fn(request_logging_middleware)) + // 性能监控 + .layer(middleware::from_fn(performance_monitoring_middleware)) + // IP 限流 + .layer(middleware::from_fn_with_state( + Arc::clone(&rate_limiter), + ip_rate_limit_middleware, + )) + // 用户限流(需要在认证之后) + .layer(middleware::from_fn_with_state( + Arc::clone(&rate_limiter), + user_rate_limit_middleware, + )) +} + +/// 为需要认证的路由创建中间件栈 +pub fn create_auth_middleware_stack(state: AppState) -> ServiceBuilder< + impl tower::Layer< + axum::routing::Router, + Service = impl tower::Service< + http::Request, + Response = axum::response::Response, + Error = std::convert::Infallible, + > + Clone + Send + 'static, + > + Clone, +> { + ServiceBuilder::new() + .layer(middleware::from_fn_with_state( + state.clone(), + jwt_auth_middleware, + )) +} + +/// 为可选认证的路由创建中间件栈 +pub fn create_optional_auth_middleware_stack(state: AppState) -> ServiceBuilder< + impl tower::Layer< + axum::routing::Router, + Service = impl tower::Service< + http::Request, + Response = axum::response::Response, + Error = std::convert::Infallible, + > + Clone + Send + 'static, + > + Clone, +> { + ServiceBuilder::new() + .layer(middleware::from_fn_with_state( + state.clone(), + optional_auth_middleware, + )) +} +``` + +## 测试支持 + +### 中间件测试 + +```rust +#[cfg(test)] +mod tests { + use super::*; + use axum::{ + body::Body, + http::{Request, StatusCode}, + }; + use tower::ServiceExt; + + #[tokio::test] + async fn test_cors_middleware() { + let app = Router::new() + .route("/test", axum::routing::get(|| async { "OK" })) + .layer(create_cors_layer(Default::default())); + + let request = Request::builder() + .method("OPTIONS") + .uri("/test") + .header("Origin", "http://localhost:3000") + .body(Body::empty()) + .unwrap(); + + let response = app.oneshot(request).await.unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + assert!(response.headers().contains_key("access-control-allow-origin")); + } + + #[tokio::test] + async fn test_rate_limit_middleware() { + let rate_limiter = Arc::new(RateLimiter::new(1, 1)); // 1 token, 1 per second + + let app = Router::new() + .route("/test", axum::routing::get(|| async { "OK" })) + .layer(middleware::from_fn_with_state( + rate_limiter, + ip_rate_limit_middleware, + )); + + // 第一个请求应该成功 + let request1 = Request::builder() + .uri("/test") + .body(Body::empty()) + .unwrap(); + + let response1 = app.clone().oneshot(request1).await.unwrap(); + assert_eq!(response1.status(), StatusCode::OK); + + // 第二个请求应该被限流 + let request2 = Request::builder() + .uri("/test") + .body(Body::empty()) + .unwrap(); + + let response2 = app.oneshot(request2).await.unwrap(); + assert_eq!(response2.status(), StatusCode::TOO_MANY_REQUESTS); + } +} +``` + +## 最佳实践 + +### 中间件设计 + +- **单一职责**: 每个中间件只负责一个特定功能 +- **可组合性**: 中间件可以灵活组合使用 +- **性能优化**: 最小化中间件的性能开销 +- **错误处理**: 优雅地处理中间件中的错误 +- **日志记录**: 记录关键操作和错误信息 + +### 安全考虑 + +- **认证验证**: 严格验证用户身份 +- **权限检查**: 确保用户有足够的权限 +- **输入验证**: 验证所有输入数据 +- **限流保护**: 防止恶意请求和DDoS攻击 +- **CORS配置**: 正确配置跨域资源共享 + +### 性能优化 + +- **缓存策略**: 缓存认证结果和权限信息 +- **连接池**: 合理配置数据库连接池 +- **异步处理**: 使用异步操作提高并发性能 +- **资源清理**: 及时清理过期的连接和资源 + +## 总结 + +中间件层是 UdminAI 系统的重要组成部分,提供了认证授权、请求日志、错误处理、CORS、WebSocket、SSE、限流等核心功能。通过模块化设计和灵活的组合机制,中间件层为系统提供了强大的横切关注点支持,确保了系统的安全性、可靠性和性能。 + +主要特点: + +- **模块化设计**: 每个中间件独立实现,职责单一 +- **灵活组合**: 支持灵活的中间件组合和配置 +- **高性能**: 优化的实现确保最小的性能开销 +- **类型安全**: 利用 Rust 类型系统确保安全性 +- **实时通信**: 支持 WebSocket 和 SSE 实时通信 +- **安全防护**: 提供认证、授权、限流等安全机制 + +通过这套完整的中间件系统,UdminAI 能够提供安全、高效、可靠的 Web 服务,满足企业级应用的各种需求。 +``` \ No newline at end of file diff --git a/docs/MODELS.md b/docs/MODELS.md new file mode 100644 index 0000000..8c2d026 --- /dev/null +++ b/docs/MODELS.md @@ -0,0 +1,1650 @@ +# 数据模型文档 + +## 概述 + +数据模型层是 UdminAI 的数据持久化核心,基于 SeaORM 框架实现,提供类型安全的数据库操作接口。包含用户管理、权限控制、流程管理、定时任务、系统日志等核心业务实体。 + +## 架构设计 + +### 模型模块结构 + +``` +models/ +├── mod.rs # 模型模块导出 +├── user.rs # 用户模型 +├── role.rs # 角色模型 +├── permission.rs # 权限模型 +├── user_role.rs # 用户角色关联模型 +├── role_permission.rs # 角色权限关联模型 +├── flow.rs # 流程模型 +├── flow_version.rs # 流程版本模型 +├── flow_execution.rs # 流程执行记录模型 +├── schedule_job.rs # 定时任务模型 +├── job_execution.rs # 任务执行记录模型 +├── system_config.rs # 系统配置模型 +├── operation_log.rs # 操作日志模型 +├── system_log.rs # 系统日志模型 +├── notification.rs # 通知模型 +└── notification_template.rs # 通知模板模型 +``` + +### 设计原则 + +- **实体完整性**: 每个实体都有完整的字段定义 +- **关系映射**: 正确定义实体间的关联关系 +- **类型安全**: 使用强类型定义所有字段 +- **索引优化**: 为查询字段添加合适的索引 +- **软删除**: 重要数据支持软删除机制 +- **审计字段**: 包含创建时间、更新时间等审计字段 + +## 用户模型 (user.rs) + +### 实体定义 + +```rust +use sea_orm::entity::prelude::*; +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Serialize, Deserialize)] +#[sea_orm(table_name = "users")] +pub struct Model { + #[sea_orm(primary_key, auto_increment = false)] + pub id: String, + + #[sea_orm(unique)] + pub username: String, + + pub password_hash: String, + + #[sea_orm(unique)] + pub email: Option, + + pub display_name: Option, + + pub avatar: Option, + + pub status: UserStatus, + + pub last_login_at: Option, + + pub login_count: i32, + + pub is_deleted: bool, + + pub created_at: DateTimeWithTimeZone, + + pub updated_at: DateTimeWithTimeZone, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm(has_many = "super::user_role::Entity")] + UserRoles, + + #[sea_orm(has_many = "super::flow::Entity")] + Flows, + + #[sea_orm(has_many = "super::schedule_job::Entity")] + ScheduleJobs, + + #[sea_orm(has_many = "super::operation_log::Entity")] + OperationLogs, +} + +impl Related for Entity { + fn to() -> RelationDef { + super::user_role::Relation::Role.def() + } + + fn via() -> Option { + Some(super::user_role::Relation::User.def().rev()) + } +} + +impl ActiveModelBehavior for ActiveModel {} +``` + +### 用户状态枚举 + +```rust +#[derive(Debug, Clone, PartialEq, Eq, EnumIter, DeriveActiveEnum, Serialize, Deserialize)] +#[sea_orm(rs_type = "String", db_type = "Enum", enum_name = "user_status")] +pub enum UserStatus { + #[sea_orm(string_value = "active")] + Active, + + #[sea_orm(string_value = "inactive")] + Inactive, + + #[sea_orm(string_value = "suspended")] + Suspended, + + #[sea_orm(string_value = "deleted")] + Deleted, +} +``` + +### 数据库迁移 + +```rust +use sea_orm_migration::prelude::*; + +#[derive(DeriveMigrationName)] +pub struct Migration; + +#[async_trait::async_trait] +impl MigrationTrait for Migration { + async fn up(&self, manager: &SchemaManager) -> Result<(), DbErr> { + manager + .create_table( + Table::create() + .table(User::Table) + .if_not_exists() + .col( + ColumnDef::new(User::Id) + .string_len(32) + .not_null() + .primary_key(), + ) + .col( + ColumnDef::new(User::Username) + .string_len(50) + .not_null() + .unique_key(), + ) + .col( + ColumnDef::new(User::PasswordHash) + .string_len(255) + .not_null(), + ) + .col( + ColumnDef::new(User::Email) + .string_len(100) + .unique_key(), + ) + .col( + ColumnDef::new(User::DisplayName) + .string_len(100), + ) + .col( + ColumnDef::new(User::Avatar) + .string_len(255), + ) + .col( + ColumnDef::new(User::Status) + .enumeration( + Alias::new("user_status"), + ["active", "inactive", "suspended", "deleted"], + ) + .not_null() + .default("active"), + ) + .col( + ColumnDef::new(User::LastLoginAt) + .timestamp_with_time_zone(), + ) + .col( + ColumnDef::new(User::LoginCount) + .integer() + .not_null() + .default(0), + ) + .col( + ColumnDef::new(User::IsDeleted) + .boolean() + .not_null() + .default(false), + ) + .col( + ColumnDef::new(User::CreatedAt) + .timestamp_with_time_zone() + .not_null() + .default(Expr::current_timestamp()), + ) + .col( + ColumnDef::new(User::UpdatedAt) + .timestamp_with_time_zone() + .not_null() + .default(Expr::current_timestamp()), + ) + .to_owned(), + ) + .await?; + + // 创建索引 + manager + .create_index( + Index::create() + .if_not_exists() + .name("idx_users_email") + .table(User::Table) + .col(User::Email) + .to_owned(), + ) + .await?; + + manager + .create_index( + Index::create() + .if_not_exists() + .name("idx_users_status") + .table(User::Table) + .col(User::Status) + .to_owned(), + ) + .await?; + + Ok(()) + } + + async fn down(&self, manager: &SchemaManager) -> Result<(), DbErr> { + manager + .drop_table(Table::drop().table(User::Table).to_owned()) + .await + } +} +``` + +## 角色模型 (role.rs) + +### 实体定义 + +```rust +#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Serialize, Deserialize)] +#[sea_orm(table_name = "roles")] +pub struct Model { + #[sea_orm(primary_key, auto_increment = false)] + pub id: String, + + #[sea_orm(unique)] + pub name: String, + + pub description: Option, + + pub is_system: bool, + + pub sort_order: i32, + + pub is_deleted: bool, + + pub created_at: DateTimeWithTimeZone, + + pub updated_at: DateTimeWithTimeZone, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm(has_many = "super::user_role::Entity")] + UserRoles, + + #[sea_orm(has_many = "super::role_permission::Entity")] + RolePermissions, +} + +impl Related for Entity { + fn to() -> RelationDef { + super::user_role::Relation::User.def() + } + + fn via() -> Option { + Some(super::user_role::Relation::Role.def().rev()) + } +} + +impl Related for Entity { + fn to() -> RelationDef { + super::role_permission::Relation::Permission.def() + } + + fn via() -> Option { + Some(super::role_permission::Relation::Role.def().rev()) + } +} + +impl ActiveModelBehavior for ActiveModel {} +``` + +## 权限模型 (permission.rs) + +### 实体定义 + +```rust +#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Serialize, Deserialize)] +#[sea_orm(table_name = "permissions")] +pub struct Model { + #[sea_orm(primary_key, auto_increment = false)] + pub id: String, + + pub name: String, + + pub resource: String, + + pub action: String, + + pub description: Option, + + pub parent_id: Option, + + pub sort_order: i32, + + pub is_system: bool, + + pub created_at: DateTimeWithTimeZone, + + pub updated_at: DateTimeWithTimeZone, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm(has_many = "super::role_permission::Entity")] + RolePermissions, + + #[sea_orm( + belongs_to = "Entity", + from = "Column::ParentId", + to = "Column::Id" + )] + Parent, + + #[sea_orm(has_many = "Entity")] + Children, +} + +impl Related for Entity { + fn to() -> RelationDef { + super::role_permission::Relation::Role.def() + } + + fn via() -> Option { + Some(super::role_permission::Relation::Permission.def().rev()) + } +} + +impl ActiveModelBehavior for ActiveModel {} +``` + +## 流程模型 (flow.rs) + +### 实体定义 + +```rust +#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Serialize, Deserialize)] +#[sea_orm(table_name = "flows")] +pub struct Model { + #[sea_orm(primary_key, auto_increment = false)] + pub id: String, + + pub name: String, + + pub description: Option, + + pub category: String, + + pub status: FlowStatus, + + pub version: String, + + pub design: Json, + + pub input_schema: Option, + + pub output_schema: Option, + + pub tags: Option, + + pub created_by: String, + + pub updated_by: Option, + + pub published_at: Option, + + pub is_deleted: bool, + + pub created_at: DateTimeWithTimeZone, + + pub updated_at: DateTimeWithTimeZone, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::user::Entity", + from = "Column::CreatedBy", + to = "super::user::Column::Id" + )] + Creator, + + #[sea_orm( + belongs_to = "super::user::Entity", + from = "Column::UpdatedBy", + to = "super::user::Column::Id" + )] + Updater, + + #[sea_orm(has_many = "super::flow_version::Entity")] + FlowVersions, + + #[sea_orm(has_many = "super::flow_execution::Entity")] + FlowExecutions, + + #[sea_orm(has_many = "super::schedule_job::Entity")] + ScheduleJobs, +} + +impl ActiveModelBehavior for ActiveModel {} +``` + +### 流程状态枚举 + +```rust +#[derive(Debug, Clone, PartialEq, Eq, EnumIter, DeriveActiveEnum, Serialize, Deserialize)] +#[sea_orm(rs_type = "String", db_type = "Enum", enum_name = "flow_status")] +pub enum FlowStatus { + #[sea_orm(string_value = "draft")] + Draft, + + #[sea_orm(string_value = "published")] + Published, + + #[sea_orm(string_value = "archived")] + Archived, + + #[sea_orm(string_value = "deleted")] + Deleted, +} +``` + +## 流程执行记录模型 (flow_execution.rs) + +### 实体定义 + +```rust +#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Serialize, Deserialize)] +#[sea_orm(table_name = "flow_executions")] +pub struct Model { + #[sea_orm(primary_key, auto_increment = false)] + pub id: String, + + pub flow_id: String, + + pub flow_version: String, + + pub status: ExecutionStatus, + + pub input: Option, + + pub output: Option, + + pub error: Option, + + pub context: Option, + + pub execution_log: Option, + + pub started_by: Option, + + pub trigger_type: TriggerType, + + pub trigger_source: Option, + + pub start_time: DateTimeWithTimeZone, + + pub end_time: Option, + + pub duration_ms: Option, + + pub created_at: DateTimeWithTimeZone, + + pub updated_at: DateTimeWithTimeZone, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::flow::Entity", + from = "Column::FlowId", + to = "super::flow::Column::Id" + )] + Flow, + + #[sea_orm( + belongs_to = "super::user::Entity", + from = "Column::StartedBy", + to = "super::user::Column::Id" + )] + User, +} + +impl ActiveModelBehavior for ActiveModel {} +``` + +### 执行状态枚举 + +```rust +#[derive(Debug, Clone, PartialEq, Eq, EnumIter, DeriveActiveEnum, Serialize, Deserialize)] +#[sea_orm(rs_type = "String", db_type = "Enum", enum_name = "execution_status")] +pub enum ExecutionStatus { + #[sea_orm(string_value = "pending")] + Pending, + + #[sea_orm(string_value = "running")] + Running, + + #[sea_orm(string_value = "completed")] + Completed, + + #[sea_orm(string_value = "failed")] + Failed, + + #[sea_orm(string_value = "cancelled")] + Cancelled, + + #[sea_orm(string_value = "timeout")] + Timeout, +} + +#[derive(Debug, Clone, PartialEq, Eq, EnumIter, DeriveActiveEnum, Serialize, Deserialize)] +#[sea_orm(rs_type = "String", db_type = "Enum", enum_name = "trigger_type")] +pub enum TriggerType { + #[sea_orm(string_value = "manual")] + Manual, + + #[sea_orm(string_value = "schedule")] + Schedule, + + #[sea_orm(string_value = "webhook")] + Webhook, + + #[sea_orm(string_value = "api")] + Api, + + #[sea_orm(string_value = "event")] + Event, +} +``` + +## 定时任务模型 (schedule_job.rs) + +### 实体定义 + +```rust +#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Serialize, Deserialize)] +#[sea_orm(table_name = "schedule_jobs")] +pub struct Model { + #[sea_orm(primary_key, auto_increment = false)] + pub id: String, + + pub name: String, + + pub description: Option, + + pub cron_expression: String, + + pub timezone: String, + + pub job_type: JobType, + + pub target_id: String, + + pub target_config: Option, + + pub enabled: bool, + + pub max_retries: i32, + + pub retry_interval: i32, + + pub timeout_seconds: Option, + + pub last_run_at: Option, + + pub next_run_at: Option, + + pub run_count: i64, + + pub success_count: i64, + + pub failure_count: i64, + + pub created_by: String, + + pub updated_by: Option, + + pub is_deleted: bool, + + pub created_at: DateTimeWithTimeZone, + + pub updated_at: DateTimeWithTimeZone, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::user::Entity", + from = "Column::CreatedBy", + to = "super::user::Column::Id" + )] + Creator, + + #[sea_orm( + belongs_to = "super::flow::Entity", + from = "Column::TargetId", + to = "super::flow::Column::Id" + )] + Flow, + + #[sea_orm(has_many = "super::job_execution::Entity")] + JobExecutions, +} + +impl ActiveModelBehavior for ActiveModel {} +``` + +### 任务类型枚举 + +```rust +#[derive(Debug, Clone, PartialEq, Eq, EnumIter, DeriveActiveEnum, Serialize, Deserialize)] +#[sea_orm(rs_type = "String", db_type = "Enum", enum_name = "job_type")] +pub enum JobType { + #[sea_orm(string_value = "flow")] + Flow, + + #[sea_orm(string_value = "script")] + Script, + + #[sea_orm(string_value = "http")] + Http, + + #[sea_orm(string_value = "command")] + Command, +} +``` + +## 任务执行记录模型 (job_execution.rs) + +### 实体定义 + +```rust +#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Serialize, Deserialize)] +#[sea_orm(table_name = "job_executions")] +pub struct Model { + #[sea_orm(primary_key, auto_increment = false)] + pub id: String, + + pub job_id: String, + + pub status: JobExecutionStatus, + + pub trigger_type: JobTriggerType, + + pub start_time: DateTimeWithTimeZone, + + pub end_time: Option, + + pub duration_ms: Option, + + pub output: Option, + + pub error: Option, + + pub retry_count: i32, + + pub created_at: DateTimeWithTimeZone, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::schedule_job::Entity", + from = "Column::JobId", + to = "super::schedule_job::Column::Id" + )] + ScheduleJob, +} + +impl ActiveModelBehavior for ActiveModel {} +``` + +### 执行状态枚举 + +```rust +#[derive(Debug, Clone, PartialEq, Eq, EnumIter, DeriveActiveEnum, Serialize, Deserialize)] +#[sea_orm(rs_type = "String", db_type = "Enum", enum_name = "job_execution_status")] +pub enum JobExecutionStatus { + #[sea_orm(string_value = "running")] + Running, + + #[sea_orm(string_value = "completed")] + Completed, + + #[sea_orm(string_value = "failed")] + Failed, + + #[sea_orm(string_value = "timeout")] + Timeout, + + #[sea_orm(string_value = "cancelled")] + Cancelled, +} + +#[derive(Debug, Clone, PartialEq, Eq, EnumIter, DeriveActiveEnum, Serialize, Deserialize)] +#[sea_orm(rs_type = "String", db_type = "Enum", enum_name = "job_trigger_type")] +pub enum JobTriggerType { + #[sea_orm(string_value = "schedule")] + Schedule, + + #[sea_orm(string_value = "manual")] + Manual, + + #[sea_orm(string_value = "retry")] + Retry, +} +``` + +## 操作日志模型 (operation_log.rs) + +### 实体定义 + +```rust +#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Serialize, Deserialize)] +#[sea_orm(table_name = "operation_logs")] +pub struct Model { + #[sea_orm(primary_key, auto_increment = false)] + pub id: String, + + pub user_id: String, + + pub operation: String, + + pub resource: String, + + pub resource_id: Option, + + pub details: Option, + + pub ip_address: Option, + + pub user_agent: Option, + + pub status: OperationStatus, + + pub error: Option, + + pub duration_ms: Option, + + pub created_at: DateTimeWithTimeZone, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::user::Entity", + from = "Column::UserId", + to = "super::user::Column::Id" + )] + User, +} + +impl ActiveModelBehavior for ActiveModel {} +``` + +### 操作状态枚举 + +```rust +#[derive(Debug, Clone, PartialEq, Eq, EnumIter, DeriveActiveEnum, Serialize, Deserialize)] +#[sea_orm(rs_type = "String", db_type = "Enum", enum_name = "operation_status")] +pub enum OperationStatus { + #[sea_orm(string_value = "success")] + Success, + + #[sea_orm(string_value = "failed")] + Failed, + + #[sea_orm(string_value = "partial")] + Partial, +} +``` + +## 系统日志模型 (system_log.rs) + +### 实体定义 + +```rust +#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Serialize, Deserialize)] +#[sea_orm(table_name = "system_logs")] +pub struct Model { + #[sea_orm(primary_key, auto_increment = false)] + pub id: String, + + pub level: LogLevel, + + pub source: String, + + pub event_type: String, + + pub message: String, + + pub details: Option, + + pub trace_id: Option, + + pub span_id: Option, + + pub created_at: DateTimeWithTimeZone, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation {} + +impl ActiveModelBehavior for ActiveModel {} +``` + +### 日志级别枚举 + +```rust +#[derive(Debug, Clone, PartialEq, Eq, EnumIter, DeriveActiveEnum, Serialize, Deserialize)] +#[sea_orm(rs_type = "String", db_type = "Enum", enum_name = "log_level")] +pub enum LogLevel { + #[sea_orm(string_value = "trace")] + Trace, + + #[sea_orm(string_value = "debug")] + Debug, + + #[sea_orm(string_value = "info")] + Info, + + #[sea_orm(string_value = "warn")] + Warn, + + #[sea_orm(string_value = "error")] + Error, +} +``` + +## 通知模型 (notification.rs) + +### 实体定义 + +```rust +#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Serialize, Deserialize)] +#[sea_orm(table_name = "notifications")] +pub struct Model { + #[sea_orm(primary_key, auto_increment = false)] + pub id: String, + + pub title: String, + + pub content: String, + + pub notification_type: NotificationType, + + pub channel: NotificationChannel, + + pub recipient_id: String, + + pub recipient_type: RecipientType, + + pub status: NotificationStatus, + + pub priority: NotificationPriority, + + pub template_id: Option, + + pub template_data: Option, + + pub scheduled_at: Option, + + pub sent_at: Option, + + pub read_at: Option, + + pub error: Option, + + pub retry_count: i32, + + pub created_at: DateTimeWithTimeZone, + + pub updated_at: DateTimeWithTimeZone, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::notification_template::Entity", + from = "Column::TemplateId", + to = "super::notification_template::Column::Id" + )] + Template, +} + +impl ActiveModelBehavior for ActiveModel {} +``` + +### 通知相关枚举 + +```rust +#[derive(Debug, Clone, PartialEq, Eq, EnumIter, DeriveActiveEnum, Serialize, Deserialize)] +#[sea_orm(rs_type = "String", db_type = "Enum", enum_name = "notification_type")] +pub enum NotificationType { + #[sea_orm(string_value = "system")] + System, + + #[sea_orm(string_value = "flow_execution")] + FlowExecution, + + #[sea_orm(string_value = "job_execution")] + JobExecution, + + #[sea_orm(string_value = "user_action")] + UserAction, + + #[sea_orm(string_value = "alert")] + Alert, +} + +#[derive(Debug, Clone, PartialEq, Eq, EnumIter, DeriveActiveEnum, Serialize, Deserialize)] +#[sea_orm(rs_type = "String", db_type = "Enum", enum_name = "notification_channel")] +pub enum NotificationChannel { + #[sea_orm(string_value = "system")] + System, + + #[sea_orm(string_value = "email")] + Email, + + #[sea_orm(string_value = "sms")] + Sms, + + #[sea_orm(string_value = "webhook")] + Webhook, + + #[sea_orm(string_value = "push")] + Push, +} + +#[derive(Debug, Clone, PartialEq, Eq, EnumIter, DeriveActiveEnum, Serialize, Deserialize)] +#[sea_orm(rs_type = "String", db_type = "Enum", enum_name = "notification_status")] +pub enum NotificationStatus { + #[sea_orm(string_value = "pending")] + Pending, + + #[sea_orm(string_value = "sent")] + Sent, + + #[sea_orm(string_value = "failed")] + Failed, + + #[sea_orm(string_value = "cancelled")] + Cancelled, +} + +#[derive(Debug, Clone, PartialEq, Eq, EnumIter, DeriveActiveEnum, Serialize, Deserialize)] +#[sea_orm(rs_type = "String", db_type = "Enum", enum_name = "notification_priority")] +pub enum NotificationPriority { + #[sea_orm(string_value = "low")] + Low, + + #[sea_orm(string_value = "normal")] + Normal, + + #[sea_orm(string_value = "high")] + High, + + #[sea_orm(string_value = "urgent")] + Urgent, +} + +#[derive(Debug, Clone, PartialEq, Eq, EnumIter, DeriveActiveEnum, Serialize, Deserialize)] +#[sea_orm(rs_type = "String", db_type = "Enum", enum_name = "recipient_type")] +pub enum RecipientType { + #[sea_orm(string_value = "user")] + User, + + #[sea_orm(string_value = "role")] + Role, + + #[sea_orm(string_value = "group")] + Group, + + #[sea_orm(string_value = "external")] + External, +} +``` + +## 系统配置模型 (system_config.rs) + +### 实体定义 + +```rust +#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Serialize, Deserialize)] +#[sea_orm(table_name = "system_configs")] +pub struct Model { + #[sea_orm(primary_key, auto_increment = false)] + pub id: String, + + #[sea_orm(unique)] + pub key: String, + + pub value: Json, + + pub description: Option, + + pub config_type: ConfigType, + + pub is_encrypted: bool, + + pub is_system: bool, + + pub created_at: DateTimeWithTimeZone, + + pub updated_at: DateTimeWithTimeZone, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation {} + +impl ActiveModelBehavior for ActiveModel {} +``` + +### 配置类型枚举 + +```rust +#[derive(Debug, Clone, PartialEq, Eq, EnumIter, DeriveActiveEnum, Serialize, Deserialize)] +#[sea_orm(rs_type = "String", db_type = "Enum", enum_name = "config_type")] +pub enum ConfigType { + #[sea_orm(string_value = "string")] + String, + + #[sea_orm(string_value = "number")] + Number, + + #[sea_orm(string_value = "boolean")] + Boolean, + + #[sea_orm(string_value = "json")] + Json, + + #[sea_orm(string_value = "array")] + Array, +} +``` + +## 关联表模型 + +### 用户角色关联 (user_role.rs) + +```rust +#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Serialize, Deserialize)] +#[sea_orm(table_name = "user_roles")] +pub struct Model { + #[sea_orm(primary_key, auto_increment = false)] + pub id: String, + + pub user_id: String, + + pub role_id: String, + + pub assigned_by: String, + + pub assigned_at: DateTimeWithTimeZone, + + pub expires_at: Option, + + pub created_at: DateTimeWithTimeZone, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::user::Entity", + from = "Column::UserId", + to = "super::user::Column::Id" + )] + User, + + #[sea_orm( + belongs_to = "super::role::Entity", + from = "Column::RoleId", + to = "super::role::Column::Id" + )] + Role, +} + +impl ActiveModelBehavior for ActiveModel {} +``` + +### 角色权限关联 (role_permission.rs) + +```rust +#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Serialize, Deserialize)] +#[sea_orm(table_name = "role_permissions")] +pub struct Model { + #[sea_orm(primary_key, auto_increment = false)] + pub id: String, + + pub role_id: String, + + pub permission_id: String, + + pub granted_by: String, + + pub granted_at: DateTimeWithTimeZone, + + pub created_at: DateTimeWithTimeZone, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::role::Entity", + from = "Column::RoleId", + to = "super::role::Column::Id" + )] + Role, + + #[sea_orm( + belongs_to = "super::permission::Entity", + from = "Column::PermissionId", + to = "super::permission::Column::Id" + )] + Permission, +} + +impl ActiveModelBehavior for ActiveModel {} +``` + +## 数据库连接和配置 + +### 数据库连接 + +```rust +use sea_orm::*; + +pub async fn establish_connection(database_url: &str) -> Result { + let mut opt = ConnectOptions::new(database_url.to_owned()); + opt.max_connections(100) + .min_connections(5) + .connect_timeout(Duration::from_secs(8)) + .acquire_timeout(Duration::from_secs(8)) + .idle_timeout(Duration::from_secs(8)) + .max_lifetime(Duration::from_secs(8)) + .sqlx_logging(true) + .sqlx_logging_level(log::LevelFilter::Info); + + Database::connect(opt).await +} +``` + +### 迁移管理 + +```rust +use sea_orm_migration::prelude::*; + +pub struct Migrator; + +#[async_trait::async_trait] +impl MigratorTrait for Migrator { + fn migrations() -> Vec> { + vec![ + Box::new(m20220101_000001_create_users_table::Migration), + Box::new(m20220101_000002_create_roles_table::Migration), + Box::new(m20220101_000003_create_permissions_table::Migration), + Box::new(m20220101_000004_create_user_roles_table::Migration), + Box::new(m20220101_000005_create_role_permissions_table::Migration), + Box::new(m20220101_000006_create_flows_table::Migration), + Box::new(m20220101_000007_create_flow_versions_table::Migration), + Box::new(m20220101_000008_create_flow_executions_table::Migration), + Box::new(m20220101_000009_create_schedule_jobs_table::Migration), + Box::new(m20220101_000010_create_job_executions_table::Migration), + Box::new(m20220101_000011_create_operation_logs_table::Migration), + Box::new(m20220101_000012_create_system_logs_table::Migration), + Box::new(m20220101_000013_create_notifications_table::Migration), + Box::new(m20220101_000014_create_notification_templates_table::Migration), + Box::new(m20220101_000015_create_system_configs_table::Migration), + ] + } +} +``` + +## 查询构建器 + +### 用户查询示例 + +```rust +use sea_orm::*; + +/// 查询活跃用户列表 +pub async fn find_active_users( + db: &DatabaseConnection, + page: u64, + page_size: u64, +) -> Result<(Vec, u64), DbErr> { + let paginator = user::Entity::find() + .filter(user::Column::Status.eq(UserStatus::Active)) + .filter(user::Column::IsDeleted.eq(false)) + .order_by_desc(user::Column::CreatedAt) + .paginate(db, page_size); + + let total = paginator.num_items().await?; + let users = paginator.fetch_page(page - 1).await?; + + Ok((users, total)) +} + +/// 根据用户名查询用户 +pub async fn find_user_by_username( + db: &DatabaseConnection, + username: &str, +) -> Result, DbErr> { + user::Entity::find() + .filter(user::Column::Username.eq(username)) + .filter(user::Column::IsDeleted.eq(false)) + .one(db) + .await +} + +/// 查询用户及其角色 +pub async fn find_user_with_roles( + db: &DatabaseConnection, + user_id: &str, +) -> Result)>, DbErr> { + user::Entity::find_by_id(user_id) + .find_with_related(role::Entity) + .all(db) + .await + .map(|results| { + results.into_iter().next().map(|(user, roles)| (user, roles)) + }) +} +``` + +### 流程查询示例 + +```rust +/// 查询已发布的流程 +pub async fn find_published_flows( + db: &DatabaseConnection, + category: Option<&str>, +) -> Result, DbErr> { + let mut query = flow::Entity::find() + .filter(flow::Column::Status.eq(FlowStatus::Published)) + .filter(flow::Column::IsDeleted.eq(false)); + + if let Some(cat) = category { + query = query.filter(flow::Column::Category.eq(cat)); + } + + query.order_by_desc(flow::Column::UpdatedAt).all(db).await +} + +/// 查询流程执行历史 +pub async fn find_flow_executions( + db: &DatabaseConnection, + flow_id: &str, + status: Option, + page: u64, + page_size: u64, +) -> Result<(Vec, u64), DbErr> { + let mut query = flow_execution::Entity::find() + .filter(flow_execution::Column::FlowId.eq(flow_id)); + + if let Some(s) = status { + query = query.filter(flow_execution::Column::Status.eq(s)); + } + + let paginator = query + .order_by_desc(flow_execution::Column::StartTime) + .paginate(db, page_size); + + let total = paginator.num_items().await?; + let executions = paginator.fetch_page(page - 1).await?; + + Ok((executions, total)) +} +``` + +## 事务处理 + +### 事务示例 + +```rust +use sea_orm::*; + +/// 创建用户并分配角色(事务) +pub async fn create_user_with_roles_tx( + db: &DatabaseConnection, + user_data: user::ActiveModel, + role_ids: Vec, +) -> Result { + let txn = db.begin().await?; + + // 创建用户 + let user = user_data.insert(&txn).await?; + + // 分配角色 + for role_id in role_ids { + let user_role = user_role::ActiveModel { + id: Set(crate::utils::generate_id()), + user_id: Set(user.id.clone()), + role_id: Set(role_id), + assigned_by: Set("system".to_string()), + assigned_at: Set(crate::utils::now_fixed_offset()), + expires_at: NotSet, + created_at: Set(crate::utils::now_fixed_offset()), + }; + user_role.insert(&txn).await?; + } + + txn.commit().await?; + Ok(user) +} + +/// 更新流程状态(事务) +pub async fn publish_flow_tx( + db: &DatabaseConnection, + flow_id: &str, + user_id: &str, +) -> Result { + let txn = db.begin().await?; + + // 更新流程状态 + let flow = flow::Entity::find_by_id(flow_id) + .one(&txn) + .await? + .ok_or(DbErr::RecordNotFound("流程不存在".to_string()))?; + + let mut flow: flow::ActiveModel = flow.into(); + flow.status = Set(FlowStatus::Published); + flow.published_at = Set(Some(crate::utils::now_fixed_offset())); + flow.updated_by = Set(Some(user_id.to_string())); + flow.updated_at = Set(crate::utils::now_fixed_offset()); + + let updated_flow = flow.update(&txn).await?; + + // 记录操作日志 + let log = operation_log::ActiveModel { + id: Set(crate::utils::generate_id()), + user_id: Set(user_id.to_string()), + operation: Set("publish_flow".to_string()), + resource: Set("flow".to_string()), + resource_id: Set(Some(flow_id.to_string())), + details: Set(Some(serde_json::json!({ + "flow_name": updated_flow.name, + "version": updated_flow.version + }))), + ip_address: NotSet, + user_agent: NotSet, + status: Set(OperationStatus::Success), + error: NotSet, + duration_ms: NotSet, + created_at: Set(crate::utils::now_fixed_offset()), + }; + log.insert(&txn).await?; + + txn.commit().await?; + Ok(updated_flow) +} +``` + +## 性能优化 + +### 索引策略 + +```sql +-- 用户表索引 +CREATE INDEX idx_users_username ON users(username); +CREATE INDEX idx_users_email ON users(email); +CREATE INDEX idx_users_status ON users(status); +CREATE INDEX idx_users_created_at ON users(created_at); + +-- 流程表索引 +CREATE INDEX idx_flows_status ON flows(status); +CREATE INDEX idx_flows_category ON flows(category); +CREATE INDEX idx_flows_created_by ON flows(created_by); +CREATE INDEX idx_flows_updated_at ON flows(updated_at); + +-- 流程执行表索引 +CREATE INDEX idx_flow_executions_flow_id ON flow_executions(flow_id); +CREATE INDEX idx_flow_executions_status ON flow_executions(status); +CREATE INDEX idx_flow_executions_start_time ON flow_executions(start_time); + +-- 定时任务表索引 +CREATE INDEX idx_schedule_jobs_enabled ON schedule_jobs(enabled); +CREATE INDEX idx_schedule_jobs_next_run_at ON schedule_jobs(next_run_at); +CREATE INDEX idx_schedule_jobs_created_by ON schedule_jobs(created_by); + +-- 操作日志表索引 +CREATE INDEX idx_operation_logs_user_id ON operation_logs(user_id); +CREATE INDEX idx_operation_logs_resource ON operation_logs(resource); +CREATE INDEX idx_operation_logs_created_at ON operation_logs(created_at); + +-- 系统日志表索引 +CREATE INDEX idx_system_logs_level ON system_logs(level); +CREATE INDEX idx_system_logs_source ON system_logs(source); +CREATE INDEX idx_system_logs_created_at ON system_logs(created_at); +``` + +### 查询优化 + +```rust +/// 使用预加载优化查询 +pub async fn find_users_with_roles_optimized( + db: &DatabaseConnection, + page: u64, + page_size: u64, +) -> Result)>, DbErr> { + user::Entity::find() + .filter(user::Column::Status.eq(UserStatus::Active)) + .filter(user::Column::IsDeleted.eq(false)) + .find_with_related(role::Entity) + .paginate(db, page_size) + .fetch_page(page - 1) + .await +} + +/// 使用原生 SQL 优化复杂查询 +pub async fn get_user_statistics( + db: &DatabaseConnection, +) -> Result { + let result = db + .query_one(Statement::from_sql_and_values( + DbBackend::Postgres, + r#" + SELECT + COUNT(*) as total_users, + COUNT(CASE WHEN status = 'active' THEN 1 END) as active_users, + COUNT(CASE WHEN status = 'inactive' THEN 1 END) as inactive_users, + COUNT(CASE WHEN last_login_at > NOW() - INTERVAL '30 days' THEN 1 END) as recent_active_users + FROM users + WHERE is_deleted = false + "#, + vec![], + )) + .await? + .ok_or(DbErr::RecordNotFound("统计数据不存在".to_string()))?; + + Ok(UserStatistics { + total_users: result.try_get("", "total_users")?, + active_users: result.try_get("", "active_users")?, + inactive_users: result.try_get("", "inactive_users")?, + recent_active_users: result.try_get("", "recent_active_users")?, + }) +} +``` + +## 数据验证 + +### 模型验证 + +```rust +use validator::{Validate, ValidationError}; + +#[derive(Debug, Validate)] +pub struct CreateUserRequest { + #[validate(length(min = 3, max = 50, message = "用户名长度必须在3-50个字符之间"))] + pub username: String, + + #[validate(length(min = 6, message = "密码长度至少6个字符"))] + pub password: String, + + #[validate(email(message = "邮箱格式不正确"))] + pub email: Option, + + #[validate(length(max = 100, message = "显示名称不能超过100个字符"))] + pub display_name: Option, +} + +/// 自定义验证函数 +fn validate_cron_expression(cron: &str) -> Result<(), ValidationError> { + cron::Schedule::from_str(cron) + .map_err(|_| ValidationError::new("invalid_cron_expression"))?; + Ok(()) +} + +#[derive(Debug, Validate)] +pub struct CreateScheduleJobRequest { + #[validate(length(min = 1, max = 100))] + pub name: String, + + #[validate(custom = "validate_cron_expression")] + pub cron_expression: String, + + #[validate(range(min = 0, max = 10))] + pub max_retries: i32, +} +``` + +## 测试支持 + +### 测试数据库设置 + +```rust +#[cfg(test)] +mod tests { + use super::*; + use sea_orm::*; + + async fn setup_test_db() -> DatabaseConnection { + let db = Database::connect("sqlite::memory:").await.unwrap(); + + // 运行迁移 + Migrator::up(&db, None).await.unwrap(); + + db + } + + async fn create_test_user(db: &DatabaseConnection) -> user::Model { + let user = user::ActiveModel { + id: Set(crate::utils::generate_id()), + username: Set("test_user".to_string()), + password_hash: Set("hashed_password".to_string()), + email: Set(Some("test@example.com".to_string())), + display_name: Set(Some("Test User".to_string())), + avatar: NotSet, + status: Set(UserStatus::Active), + last_login_at: NotSet, + login_count: Set(0), + is_deleted: Set(false), + created_at: Set(crate::utils::now_fixed_offset()), + updated_at: Set(crate::utils::now_fixed_offset()), + }; + + user.insert(db).await.unwrap() + } + + #[tokio::test] + async fn test_create_user() { + let db = setup_test_db().await; + let user = create_test_user(&db).await; + + assert_eq!(user.username, "test_user"); + assert_eq!(user.status, UserStatus::Active); + assert!(!user.is_deleted); + } + + #[tokio::test] + async fn test_user_role_relationship() { + let db = setup_test_db().await; + + // 创建用户和角色 + let user = create_test_user(&db).await; + let role = create_test_role(&db).await; + + // 创建用户角色关联 + let user_role = user_role::ActiveModel { + id: Set(crate::utils::generate_id()), + user_id: Set(user.id.clone()), + role_id: Set(role.id.clone()), + assigned_by: Set("system".to_string()), + assigned_at: Set(crate::utils::now_fixed_offset()), + expires_at: NotSet, + created_at: Set(crate::utils::now_fixed_offset()), + }; + user_role.insert(&db).await.unwrap(); + + // 验证关联关系 + let user_with_roles = user::Entity::find_by_id(&user.id) + .find_with_related(role::Entity) + .all(&db) + .await + .unwrap(); + + assert_eq!(user_with_roles.len(), 1); + assert_eq!(user_with_roles[0].1.len(), 1); + assert_eq!(user_with_roles[0].1[0].id, role.id); + } +} +``` + +## 最佳实践 + +### 模型设计 + +- **主键设计**: 使用字符串类型的 UUID 作为主键 +- **外键约束**: 正确定义外键关系和级联操作 +- **索引优化**: 为查询字段添加合适的索引 +- **软删除**: 重要数据使用软删除而非物理删除 +- **审计字段**: 包含创建时间、更新时间等审计信息 + +### 查询优化 + +- **预加载**: 使用 `find_with_related` 避免 N+1 查询问题 +- **分页查询**: 使用 `paginate` 进行分页查询 +- **索引使用**: 确保查询条件使用了合适的索引 +- **原生 SQL**: 复杂查询使用原生 SQL 优化性能 + +### 事务管理 + +- **原子性**: 确保相关操作在同一事务中执行 +- **一致性**: 维护数据的完整性约束 +- **隔离性**: 避免并发事务的相互干扰 +- **持久性**: 确保提交的事务持久保存 + +### 错误处理 + +- **类型安全**: 使用 `Result` 处理数据库错误 +- **错误分类**: 区分不同类型的数据库错误 +- **错误恢复**: 提供合适的错误恢复机制 +- **日志记录**: 记录详细的错误信息用于调试 + +### 缓存策略 + +- **查询缓存**: 缓存频繁查询的结果 +- **实体缓存**: 缓存常用的实体对象 +- **失效策略**: 合理的缓存失效机制 +- **一致性**: 保证缓存与数据库的一致性 + +## 总结 + +数据模型层是 UdminAI 系统的数据基础,通过 SeaORM 提供了类型安全、高性能的数据访问接口。模型设计遵循了关系型数据库的最佳实践,包括合理的索引设计、事务管理、数据验证和性能优化。 + +主要特点: + +- **类型安全**: 使用 Rust 的类型系统确保数据安全 +- **关系映射**: 正确定义实体间的关联关系 +- **性能优化**: 通过索引和查询优化提升性能 +- **事务支持**: 提供完整的事务管理机制 +- **测试友好**: 支持单元测试和集成测试 +- **扩展性**: 易于扩展和维护的模型设计 + +通过这套完整的数据模型设计,UdminAI 能够高效、安全地管理用户数据、流程数据、任务数据和系统数据,为上层业务逻辑提供可靠的数据支撑。 \ No newline at end of file diff --git a/docs/PROJECT_OVERVIEW.md b/docs/PROJECT_OVERVIEW.md new file mode 100644 index 0000000..52aefda --- /dev/null +++ b/docs/PROJECT_OVERVIEW.md @@ -0,0 +1,137 @@ +# UdminAI 项目总览 + +## 项目简介 + +UdminAI 是一个基于 Rust + React 的现代化流程管理和自动化平台,提供可视化流程编辑、定时任务调度、用户权限管理等功能。 + +## 技术架构 + +### 后端技术栈 +- **框架**: Axum (异步 Web 框架) +- **数据库**: SeaORM (支持 MySQL/PostgreSQL/SQLite) +- **缓存**: Redis +- **认证**: JWT + Argon2 密码哈希 +- **定时任务**: tokio-cron-scheduler +- **流程引擎**: 自研流程执行引擎 +- **实时通信**: WebSocket + SSE + +### 前端技术栈 +- **框架**: React 18 + TypeScript +- **UI 库**: Semi Design + Ant Design +- **流程编辑器**: @flowgram.ai 系列组件 +- **状态管理**: React Context +- **路由**: React Router v6 +- **HTTP 客户端**: Axios + +## 项目结构 + +``` +udmin_ai/ +├── backend/ # Rust 后端服务 +│ ├── src/ +│ │ ├── flow/ # 流程引擎核心 +│ │ ├── models/ # 数据模型 +│ │ ├── services/ # 业务逻辑层 +│ │ ├── routes/ # API 路由 +│ │ ├── middlewares/ # 中间件 +│ │ └── utils/ # 工具函数 +│ └── migration/ # 数据库迁移 +├── frontend/ # React 前端应用 +│ └── src/ +│ ├── flows/ # 流程编辑器 +│ ├── pages/ # 页面组件 +│ ├── components/ # 通用组件 +│ └── utils/ # 工具函数 +├── docs/ # 项目文档 +├── scripts/ # 部署脚本 +└── README.md +``` + +## 核心功能模块 + +### 1. 用户权限管理 +- 用户管理 (Users) +- 角色管理 (Roles) +- 菜单权限 (Menus) +- 部门管理 (Departments) +- 职位管理 (Positions) + +### 2. 流程管理 +- 可视化流程编辑器 +- 流程执行引擎 +- 流程运行日志 +- 多种节点类型支持 (HTTP、数据库、脚本、条件等) + +### 3. 定时任务 +- Cron 表达式支持 +- 任务调度管理 +- 执行状态监控 + +### 4. 系统监控 +- 请求日志记录 +- 系统运行状态 +- 实时通信支持 + +## 部署架构 + +### 服务端口分配 +- **HTTP API**: 9898 (可配置) +- **WebSocket**: 8877 (可配置) +- **SSE**: 8866 (可配置) +- **前端开发服务器**: 8888 + +### 环境配置 +- 开发环境: `.env` +- 生产环境: `.env.prod` +- 测试环境: `.env.staging` + +## 开发指南 + +### 后端开发 +```bash +cd backend +cargo run # 开发模式 +cargo build --release # 生产构建 +``` + +### 前端开发 +```bash +cd frontend +npm install +npm run dev # 开发服务器 +npm run build # 生产构建 +``` + +### 数据库迁移 +```bash +cd backend/migration +cargo run +``` + +## API 文档 + +项目集成了 Swagger UI,启动后端服务后可访问: +- Swagger UI: `http://localhost:9898/swagger-ui/` +- OpenAPI JSON: `http://localhost:9898/api-docs/openapi.json` + +## 安全特性 + +- JWT 令牌认证 +- Argon2 密码哈希 +- CORS 跨域保护 +- 请求日志记录 +- 权限中间件验证 + +## 扩展性设计 + +- 模块化架构,易于扩展新功能 +- 插件化流程节点,支持自定义执行器 +- 微服务友好的设计 +- 支持水平扩展 + +## 监控和日志 + +- 结构化日志 (tracing) +- 请求链路追踪 +- 性能监控 +- 错误处理和报告 \ No newline at end of file diff --git a/docs/RESPONSE.md b/docs/RESPONSE.md new file mode 100644 index 0000000..546dbb0 --- /dev/null +++ b/docs/RESPONSE.md @@ -0,0 +1,1161 @@ +# UdminAI 响应格式模块文档 + +## 概述 + +UdminAI 项目的响应格式模块定义了统一的 API 响应结构,确保所有接口返回一致的数据格式。该模块基于 Axum 框架和 Serde 序列化库构建,提供了类型安全的响应处理机制。 + +## 设计原则 + +### 核心理念 + +- **一致性**: 所有 API 接口使用统一的响应格式 +- **类型安全**: 编译时类型检查,避免运行时错误 +- **可扩展性**: 支持不同类型的响应数据 +- **用户友好**: 清晰的响应结构和错误信息 +- **标准化**: 遵循 RESTful API 设计规范 + +### 响应分类 + +1. **成功响应**: 操作成功时的数据返回 +2. **分页响应**: 列表数据的分页返回 +3. **错误响应**: 操作失败时的错误信息 +4. **空响应**: 无数据返回的成功操作 + +## 响应结构定义 (response.rs) + +### 基础响应类型 + +```rust +use axum::{ + http::StatusCode, + response::{IntoResponse, Response}, + Json, +}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +/// API 响应的基础结构 +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ApiResponse { + /// 响应状态码 + pub code: u16, + /// 响应消息 + pub message: String, + /// 响应数据 + pub data: Option, + /// 响应时间戳 + pub timestamp: chrono::DateTime, + /// 请求追踪 ID + pub request_id: Option, +} + +/// 分页响应结构 +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PageResponse { + /// 数据列表 + pub items: Vec, + /// 分页信息 + pub pagination: PaginationInfo, +} + +/// 分页信息 +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PaginationInfo { + /// 当前页码(从1开始) + pub current_page: u64, + /// 每页大小 + pub page_size: u64, + /// 总记录数 + pub total_count: u64, + /// 总页数 + pub total_pages: u64, + /// 是否有下一页 + pub has_next: bool, + /// 是否有上一页 + pub has_prev: bool, +} + +/// 批量操作响应 +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct BatchResponse { + /// 成功处理的项目 + pub success: Vec, + /// 失败的项目及错误信息 + pub failed: Vec, + /// 成功数量 + pub success_count: usize, + /// 失败数量 + pub failed_count: usize, +} + +/// 批量操作错误 +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct BatchError { + /// 项目标识 + pub id: String, + /// 错误代码 + pub error_code: String, + /// 错误消息 + pub error_message: String, +} + +/// 统计响应 +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct StatsResponse { + /// 统计数据 + pub stats: HashMap, + /// 统计时间范围 + pub time_range: Option, + /// 更新时间 + pub updated_at: chrono::DateTime, +} + +/// 时间范围 +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TimeRange { + /// 开始时间 + pub start: chrono::DateTime, + /// 结束时间 + pub end: chrono::DateTime, +} + +/// 健康检查响应 +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct HealthResponse { + /// 服务状态 + pub status: HealthStatus, + /// 版本信息 + pub version: String, + /// 启动时间 + pub uptime: String, + /// 依赖服务状态 + pub dependencies: HashMap, +} + +/// 健康状态 +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum HealthStatus { + Healthy, + Degraded, + Unhealthy, +} + +/// 依赖服务状态 +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DependencyStatus { + /// 状态 + pub status: HealthStatus, + /// 响应时间(毫秒) + pub response_time_ms: Option, + /// 错误信息 + pub error: Option, +} + +impl ApiResponse { + /// 创建成功响应 + pub fn success(data: T) -> Self { + Self { + code: 200, + message: "操作成功".to_string(), + data: Some(data), + timestamp: chrono::Utc::now(), + request_id: None, + } + } + + /// 创建成功响应(带自定义消息) + pub fn success_with_message(data: T, message: impl Into) -> Self { + Self { + code: 200, + message: message.into(), + data: Some(data), + timestamp: chrono::Utc::now(), + request_id: None, + } + } + + /// 创建创建成功响应 + pub fn created(data: T) -> Self { + Self { + code: 201, + message: "创建成功".to_string(), + data: Some(data), + timestamp: chrono::Utc::now(), + request_id: None, + } + } + + /// 创建无内容响应 + pub fn no_content() -> ApiResponse<()> { + ApiResponse { + code: 204, + message: "操作成功".to_string(), + data: None, + timestamp: chrono::Utc::now(), + request_id: None, + } + } + + /// 设置请求 ID + pub fn with_request_id(mut self, request_id: impl Into) -> Self { + self.request_id = Some(request_id.into()); + self + } + + /// 设置状态码 + pub fn with_code(mut self, code: u16) -> Self { + self.code = code; + self + } + + /// 设置消息 + pub fn with_message(mut self, message: impl Into) -> Self { + self.message = message.into(); + self + } +} + +impl PageResponse { + /// 创建分页响应 + pub fn new( + items: Vec, + current_page: u64, + page_size: u64, + total_count: u64, + ) -> Self { + let total_pages = if total_count == 0 { + 1 + } else { + (total_count + page_size - 1) / page_size + }; + + let has_next = current_page < total_pages; + let has_prev = current_page > 1; + + Self { + items, + pagination: PaginationInfo { + current_page, + page_size, + total_count, + total_pages, + has_next, + has_prev, + }, + } + } + + /// 创建空分页响应 + pub fn empty(page_size: u64) -> Self { + Self::new(Vec::new(), 1, page_size, 0) + } +} + +impl BatchResponse { + /// 创建批量响应 + pub fn new(success: Vec, failed: Vec) -> Self { + let success_count = success.len(); + let failed_count = failed.len(); + + Self { + success, + failed, + success_count, + failed_count, + } + } + + /// 创建全部成功的批量响应 + pub fn all_success(success: Vec) -> Self { + Self::new(success, Vec::new()) + } + + /// 创建全部失败的批量响应 + pub fn all_failed(failed: Vec) -> Self { + Self::new(Vec::new(), failed) + } +} + +impl StatsResponse { + /// 创建统计响应 + pub fn new(stats: HashMap) -> Self { + Self { + stats, + time_range: None, + updated_at: chrono::Utc::now(), + } + } + + /// 设置时间范围 + pub fn with_time_range( + mut self, + start: chrono::DateTime, + end: chrono::DateTime, + ) -> Self { + self.time_range = Some(TimeRange { start, end }); + self + } +} + +impl HealthResponse { + /// 创建健康响应 + pub fn new( + status: HealthStatus, + version: impl Into, + uptime: impl Into, + ) -> Self { + Self { + status, + version: version.into(), + uptime: uptime.into(), + dependencies: HashMap::new(), + } + } + + /// 添加依赖状态 + pub fn with_dependency( + mut self, + name: impl Into, + status: DependencyStatus, + ) -> Self { + self.dependencies.insert(name.into(), status); + self + } +} + +impl DependencyStatus { + /// 创建健康的依赖状态 + pub fn healthy(response_time_ms: u64) -> Self { + Self { + status: HealthStatus::Healthy, + response_time_ms: Some(response_time_ms), + error: None, + } + } + + /// 创建不健康的依赖状态 + pub fn unhealthy(error: impl Into) -> Self { + Self { + status: HealthStatus::Unhealthy, + response_time_ms: None, + error: Some(error.into()), + } + } + + /// 创建降级的依赖状态 + pub fn degraded(response_time_ms: u64, error: impl Into) -> Self { + Self { + status: HealthStatus::Degraded, + response_time_ms: Some(response_time_ms), + error: Some(error.into()), + } + } +} + +/// 实现 IntoResponse,使响应可以直接返回 +impl IntoResponse for ApiResponse +where + T: Serialize, +{ + fn into_response(self) -> Response { + let status_code = StatusCode::from_u16(self.code) + .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); + + (status_code, Json(self)).into_response() + } +} + +/// 响应构建器 +pub struct ResponseBuilder { + response: ApiResponse, +} + +impl ResponseBuilder { + /// 创建新的响应构建器 + pub fn new(data: T) -> Self { + Self { + response: ApiResponse::success(data), + } + } + + /// 设置状态码 + pub fn code(mut self, code: u16) -> Self { + self.response.code = code; + self + } + + /// 设置消息 + pub fn message(mut self, message: impl Into) -> Self { + self.response.message = message.into(); + self + } + + /// 设置请求 ID + pub fn request_id(mut self, request_id: impl Into) -> Self { + self.response.request_id = Some(request_id.into()); + self + } + + /// 构建响应 + pub fn build(self) -> ApiResponse { + self.response + } +} + +/// 响应宏 +#[macro_export] +macro_rules! ok_response { + ($data:expr) => { + ApiResponse::success($data) + }; + ($data:expr, $message:expr) => { + ApiResponse::success_with_message($data, $message) + }; +} + +#[macro_export] +macro_rules! created_response { + ($data:expr) => { + ApiResponse::created($data) + }; +} + +#[macro_export] +macro_rules! no_content_response { + () => { + ApiResponse::no_content() + }; +} + +#[macro_export] +macro_rules! page_response { + ($items:expr, $page:expr, $size:expr, $total:expr) => { + ApiResponse::success(PageResponse::new($items, $page, $size, $total)) + }; +} + +/// 响应扩展 trait +pub trait ResponseExt { + /// 转换为 API 响应 + fn into_api_response(self) -> ApiResponse; + + /// 转换为分页响应 + fn into_page_response( + self, + current_page: u64, + page_size: u64, + total_count: u64, + ) -> ApiResponse> + where + Self: IntoIterator, + Self::IntoIter: ExactSizeIterator; +} + +impl ResponseExt for T { + fn into_api_response(self) -> ApiResponse { + ApiResponse::success(self) + } + + fn into_page_response( + self, + current_page: u64, + page_size: u64, + total_count: u64, + ) -> ApiResponse> + where + Self: IntoIterator, + Self::IntoIter: ExactSizeIterator, + { + let items: Vec = self.into_iter().collect(); + let page_response = PageResponse::new(items, current_page, page_size, total_count); + ApiResponse::success(page_response) + } +} + +impl ResponseExt for Vec { + fn into_api_response(self) -> ApiResponse> { + ApiResponse::success(self) + } + + fn into_page_response( + self, + current_page: u64, + page_size: u64, + total_count: u64, + ) -> ApiResponse> + where + Self: IntoIterator, + Self::IntoIter: ExactSizeIterator, + { + let page_response = PageResponse::new(self, current_page, page_size, total_count); + ApiResponse::success(page_response) + } +} + +/// 常用响应类型别名 +pub type JsonResponse = ApiResponse; +pub type ListResponse = ApiResponse>; +pub type PageResp = ApiResponse>; +pub type BatchResp = ApiResponse>; +pub type StatsResp = ApiResponse; +pub type HealthResp = ApiResponse; + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] + struct TestData { + id: u32, + name: String, + } + + #[test] + fn test_success_response() { + let data = TestData { + id: 1, + name: "test".to_string(), + }; + + let response = ApiResponse::success(data.clone()); + assert_eq!(response.code, 200); + assert_eq!(response.message, "操作成功"); + assert_eq!(response.data, Some(data)); + } + + #[test] + fn test_created_response() { + let data = TestData { + id: 1, + name: "test".to_string(), + }; + + let response = ApiResponse::created(data.clone()); + assert_eq!(response.code, 201); + assert_eq!(response.message, "创建成功"); + assert_eq!(response.data, Some(data)); + } + + #[test] + fn test_no_content_response() { + let response = ApiResponse::no_content(); + assert_eq!(response.code, 204); + assert_eq!(response.message, "操作成功"); + assert_eq!(response.data, None); + } + + #[test] + fn test_page_response() { + let items = vec![ + TestData { id: 1, name: "test1".to_string() }, + TestData { id: 2, name: "test2".to_string() }, + ]; + + let page_response = PageResponse::new(items.clone(), 1, 10, 25); + + assert_eq!(page_response.items, items); + assert_eq!(page_response.pagination.current_page, 1); + assert_eq!(page_response.pagination.page_size, 10); + assert_eq!(page_response.pagination.total_count, 25); + assert_eq!(page_response.pagination.total_pages, 3); + assert_eq!(page_response.pagination.has_next, true); + assert_eq!(page_response.pagination.has_prev, false); + } + + #[test] + fn test_batch_response() { + let success_items = vec![ + TestData { id: 1, name: "test1".to_string() }, + TestData { id: 2, name: "test2".to_string() }, + ]; + + let failed_items = vec![ + BatchError { + id: "3".to_string(), + error_code: "VALIDATION_FAILED".to_string(), + error_message: "名称不能为空".to_string(), + }, + ]; + + let batch_response = BatchResponse::new(success_items.clone(), failed_items.clone()); + + assert_eq!(batch_response.success, success_items); + assert_eq!(batch_response.failed, failed_items); + assert_eq!(batch_response.success_count, 2); + assert_eq!(batch_response.failed_count, 1); + } + + #[test] + fn test_stats_response() { + let mut stats = HashMap::new(); + stats.insert("total_users".to_string(), json!(100)); + stats.insert("active_users".to_string(), json!(85)); + + let stats_response = StatsResponse::new(stats.clone()); + + assert_eq!(stats_response.stats, stats); + assert!(stats_response.time_range.is_none()); + } + + #[test] + fn test_health_response() { + let health_response = HealthResponse::new( + HealthStatus::Healthy, + "1.0.0", + "2 days", + ) + .with_dependency( + "database", + DependencyStatus::healthy(50), + ) + .with_dependency( + "redis", + DependencyStatus::degraded(200, "连接缓慢"), + ); + + assert!(matches!(health_response.status, HealthStatus::Healthy)); + assert_eq!(health_response.version, "1.0.0"); + assert_eq!(health_response.uptime, "2 days"); + assert_eq!(health_response.dependencies.len(), 2); + } + + #[test] + fn test_response_builder() { + let data = TestData { + id: 1, + name: "test".to_string(), + }; + + let response = ResponseBuilder::new(data.clone()) + .code(201) + .message("自定义消息") + .request_id("req-123") + .build(); + + assert_eq!(response.code, 201); + assert_eq!(response.message, "自定义消息"); + assert_eq!(response.request_id, Some("req-123".to_string())); + assert_eq!(response.data, Some(data)); + } + + #[test] + fn test_response_macros() { + let data = TestData { + id: 1, + name: "test".to_string(), + }; + + let ok_resp = ok_response!(data.clone()); + assert_eq!(ok_resp.code, 200); + + let ok_resp_with_msg = ok_response!(data.clone(), "自定义消息"); + assert_eq!(ok_resp_with_msg.message, "自定义消息"); + + let created_resp = created_response!(data.clone()); + assert_eq!(created_resp.code, 201); + + let no_content_resp = no_content_response!(); + assert_eq!(no_content_resp.code, 204); + + let items = vec![data.clone()]; + let page_resp = page_response!(items, 1, 10, 1); + assert_eq!(page_resp.code, 200); + } + + #[test] + fn test_response_ext() { + let data = TestData { + id: 1, + name: "test".to_string(), + }; + + let response = data.clone().into_api_response(); + assert_eq!(response.code, 200); + assert_eq!(response.data, Some(data)); + + let items = vec![ + TestData { id: 1, name: "test1".to_string() }, + TestData { id: 2, name: "test2".to_string() }, + ]; + + let page_response = items.clone().into_page_response(1, 10, 25); + assert_eq!(page_response.code, 200); + + if let Some(page_data) = page_response.data { + assert_eq!(page_data.items, items); + assert_eq!(page_data.pagination.total_count, 25); + } + } +} +``` + +## 使用示例 + +### 基础响应 + +```rust +use axum::{extract::Path, Json}; +use crate::response::{ApiResponse, ok_response, created_response}; + +/// 获取用户信息 +pub async fn get_user(Path(id): Path) -> impl IntoResponse { + match user_service::get_by_id(&id).await { + Ok(user) => ok_response!(user, "获取用户信息成功"), + Err(e) => e.into_response(), + } +} + +/// 创建用户 +pub async fn create_user(Json(req): Json) -> impl IntoResponse { + match user_service::create(req).await { + Ok(user) => created_response!(user), + Err(e) => e.into_response(), + } +} + +/// 删除用户 +pub async fn delete_user(Path(id): Path) -> impl IntoResponse { + match user_service::delete(&id).await { + Ok(_) => no_content_response!(), + Err(e) => e.into_response(), + } +} +``` + +### 分页响应 + +```rust +use axum::extract::Query; +use crate::response::{PageResponse, page_response}; + +/// 获取用户列表 +pub async fn list_users(Query(params): Query) -> impl IntoResponse { + match user_service::list(params).await { + Ok((users, total)) => { + page_response!(users, params.page, params.page_size, total) + } + Err(e) => e.into_response(), + } +} +``` + +### 批量操作响应 + +```rust +use crate::response::{BatchResponse, BatchError}; + +/// 批量创建用户 +pub async fn batch_create_users( + Json(req): Json, +) -> impl IntoResponse { + let mut success = Vec::new(); + let mut failed = Vec::new(); + + for (index, user_req) in req.users.into_iter().enumerate() { + match user_service::create(user_req).await { + Ok(user) => success.push(user), + Err(e) => failed.push(BatchError { + id: index.to_string(), + error_code: e.error_code().to_string(), + error_message: e.to_string(), + }), + } + } + + let batch_response = BatchResponse::new(success, failed); + ApiResponse::success(batch_response) +} +``` + +### 统计响应 + +```rust +use std::collections::HashMap; +use serde_json::json; +use crate::response::StatsResponse; + +/// 获取用户统计 +pub async fn get_user_stats() -> impl IntoResponse { + let mut stats = HashMap::new(); + + match user_service::get_stats().await { + Ok(user_stats) => { + stats.insert("total_users".to_string(), json!(user_stats.total)); + stats.insert("active_users".to_string(), json!(user_stats.active)); + stats.insert("new_users_today".to_string(), json!(user_stats.new_today)); + + let stats_response = StatsResponse::new(stats) + .with_time_range( + chrono::Utc::now() - chrono::Duration::days(30), + chrono::Utc::now(), + ); + + ApiResponse::success(stats_response) + } + Err(e) => e.into_response(), + } +} +``` + +### 健康检查响应 + +```rust +use crate::response::{HealthResponse, HealthStatus, DependencyStatus}; + +/// 健康检查 +pub async fn health_check() -> impl IntoResponse { + let mut health_response = HealthResponse::new( + HealthStatus::Healthy, + env!("CARGO_PKG_VERSION"), + format_uptime(), + ); + + // 检查数据库连接 + match check_database().await { + Ok(response_time) => { + health_response = health_response.with_dependency( + "database", + DependencyStatus::healthy(response_time), + ); + } + Err(e) => { + health_response = health_response.with_dependency( + "database", + DependencyStatus::unhealthy(e.to_string()), + ); + } + } + + // 检查 Redis 连接 + match check_redis().await { + Ok(response_time) => { + health_response = health_response.with_dependency( + "redis", + DependencyStatus::healthy(response_time), + ); + } + Err(e) => { + health_response = health_response.with_dependency( + "redis", + DependencyStatus::unhealthy(e.to_string()), + ); + } + } + + ApiResponse::success(health_response) +} +``` + +## 中间件集成 + +### 响应处理中间件 + +```rust +use axum::{ + extract::Request, + middleware::Next, + response::Response, +}; +use uuid::Uuid; + +/// 响应处理中间件 +pub async fn response_middleware( + mut request: Request, + next: Next, +) -> Response { + // 生成请求 ID + let request_id = Uuid::new_v4().to_string(); + request.extensions_mut().insert(request_id.clone()); + + // 执行请求 + let mut response = next.run(request).await; + + // 添加响应头 + response.headers_mut().insert( + "X-Request-ID", + request_id.parse().unwrap(), + ); + + response.headers_mut().insert( + "X-Response-Time", + chrono::Utc::now().timestamp_millis().to_string().parse().unwrap(), + ); + + response +} +``` + +### 响应压缩中间件 + +```rust +use axum::{ + body::Body, + http::{HeaderValue, header}, + response::Response, +}; +use tower_http::compression::CompressionLayer; + +/// 响应压缩配置 +pub fn compression_layer() -> CompressionLayer { + CompressionLayer::new() + .gzip(true) + .deflate(true) + .br(true) +} +``` + +## 性能优化 + +### 响应缓存 + +```rust +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::RwLock; +use chrono::{DateTime, Utc, Duration}; + +/// 响应缓存 +#[derive(Clone)] +pub struct ResponseCache { + cache: Arc>>, +} + +#[derive(Clone)] +struct CachedResponse { + data: Vec, + expires_at: DateTime, + content_type: String, +} + +impl ResponseCache { + pub fn new() -> Self { + Self { + cache: Arc::new(RwLock::new(HashMap::new())), + } + } + + /// 获取缓存的响应 + pub async fn get(&self, key: &str) -> Option> { + let cache = self.cache.read().await; + + if let Some(cached) = cache.get(key) { + if cached.expires_at > Utc::now() { + let mut response = Response::new(Body::from(cached.data.clone())); + response.headers_mut().insert( + header::CONTENT_TYPE, + HeaderValue::from_str(&cached.content_type).unwrap(), + ); + response.headers_mut().insert( + "X-Cache", + HeaderValue::from_static("HIT"), + ); + return Some(response); + } + } + + None + } + + /// 设置响应缓存 + pub async fn set( + &self, + key: String, + data: Vec, + content_type: String, + ttl: Duration, + ) { + let mut cache = self.cache.write().await; + + cache.insert(key, CachedResponse { + data, + expires_at: Utc::now() + ttl, + content_type, + }); + } + + /// 清理过期缓存 + pub async fn cleanup_expired(&self) { + let mut cache = self.cache.write().await; + let now = Utc::now(); + + cache.retain(|_, cached| cached.expires_at > now); + } +} +``` + +### 响应流式处理 + +```rust +use axum::{ + body::Body, + response::{IntoResponse, Response}, +}; +use futures::stream::{self, StreamExt}; +use tokio_util::codec::{FramedWrite, LinesCodec}; + +/// 流式响应 +pub struct StreamResponse { + items: Vec, +} + +impl StreamResponse +where + T: Serialize + Send + 'static, +{ + pub fn new(items: Vec) -> Self { + Self { items } + } +} + +impl IntoResponse for StreamResponse +where + T: Serialize + Send + 'static, +{ + fn into_response(self) -> Response { + let stream = stream::iter(self.items) + .map(|item| { + serde_json::to_string(&item) + .map(|json| format!("{}\n", json)) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)) + }) + .map(|result| result.map(axum::body::Bytes::from)); + + let body = Body::from_stream(stream); + + Response::builder() + .header("Content-Type", "application/x-ndjson") + .header("Transfer-Encoding", "chunked") + .body(body) + .unwrap() + } +} +``` + +## 测试支持 + +### 响应测试工具 + +```rust +use axum::{ + body::Body, + http::{Request, StatusCode}, +}; +use tower::ServiceExt; +use serde::de::DeserializeOwned; + +/// 响应测试助手 +pub struct ResponseTester; + +impl ResponseTester { + /// 测试成功响应 + pub async fn assert_success_response( + response: axum::response::Response, + expected_code: u16, + ) -> T + where + T: DeserializeOwned, + { + assert_eq!(response.status(), StatusCode::from_u16(expected_code).unwrap()); + + let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); + let api_response: ApiResponse = serde_json::from_slice(&body).unwrap(); + + assert_eq!(api_response.code, expected_code); + api_response.data.unwrap() + } + + /// 测试分页响应 + pub async fn assert_page_response( + response: axum::response::Response, + expected_total: u64, + ) -> PageResponse + where + T: DeserializeOwned, + { + let page_data: PageResponse = Self::assert_success_response(response, 200).await; + assert_eq!(page_data.pagination.total_count, expected_total); + page_data + } + + /// 测试错误响应 + pub async fn assert_error_response( + response: axum::response::Response, + expected_status: StatusCode, + expected_error_code: &str, + ) { + assert_eq!(response.status(), expected_status); + + let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); + let error_response: crate::error::ErrorResponse = + serde_json::from_slice(&body).unwrap(); + + assert_eq!(error_response.error.code, expected_error_code); + } +} + +#[cfg(test)] +mod response_tests { + use super::*; + use axum::{ + routing::get, + Router, + }; + + async fn test_handler() -> impl IntoResponse { + ok_response!("test data") + } + + #[tokio::test] + async fn test_success_response_integration() { + let app = Router::new().route("/test", get(test_handler)); + + let request = Request::builder() + .uri("/test") + .body(Body::empty()) + .unwrap(); + + let response = app.oneshot(request).await.unwrap(); + + let data: String = ResponseTester::assert_success_response(response, 200).await; + assert_eq!(data, "test data"); + } +} +``` + +## 最佳实践 + +### 响应设计原则 + +1. **一致性**: 所有接口使用统一的响应格式 +2. **可预测性**: 响应结构应该是可预测的 +3. **信息完整性**: 包含足够的信息用于客户端处理 +4. **向后兼容性**: 新增字段不应破坏现有客户端 + +### 性能考虑 + +1. **序列化优化**: 使用高效的序列化库 +2. **响应压缩**: 对大响应启用压缩 +3. **缓存策略**: 对适当的响应启用缓存 +4. **流式处理**: 对大数据集使用流式响应 + +### 安全考虑 + +1. **敏感信息**: 避免在响应中暴露敏感信息 +2. **错误信息**: 错误响应不应泄露系统内部信息 +3. **响应头**: 设置适当的安全响应头 +4. **数据验证**: 确保响应数据的完整性 + +## 总结 + +UdminAI 的响应格式模块提供了完整的 API 响应解决方案,具有以下特点: + +- **统一格式**: 所有接口使用一致的响应结构 +- **类型安全**: 编译时类型检查,避免运行时错误 +- **功能丰富**: 支持多种响应类型(成功、分页、批量、统计等) +- **易于使用**: 提供宏和扩展 trait 简化使用 +- **高性能**: 支持缓存、压缩和流式处理 +- **可测试**: 提供完整的测试支持工具 + +通过统一的响应格式设计,确保了 API 的一致性、可维护性和用户体验。 \ No newline at end of file diff --git a/docs/ROUTES.md b/docs/ROUTES.md new file mode 100644 index 0000000..5c5c0e7 --- /dev/null +++ b/docs/ROUTES.md @@ -0,0 +1,1201 @@ +# 路由层文档 + +## 概述 + +路由层是 UdminAI 的 HTTP API 接口层,负责处理客户端请求、参数验证、权限检查、调用服务层业务逻辑,并返回统一格式的响应。基于 Axum 框架构建,提供 RESTful API 接口。 + +## 架构设计 + +### 路由模块结构 + +``` +routes/ +├── mod.rs # 路由模块导出和总路由配置 +├── auth.rs # 认证相关路由 +├── user.rs # 用户管理路由 +├── role.rs # 角色管理路由 +├── permission.rs # 权限管理路由 +├── flow.rs # 流程管理路由 +├── schedule_job.rs # 定时任务路由 +├── system.rs # 系统管理路由 +├── log.rs # 日志查询路由 +├── notification.rs # 通知管理路由 +└── websocket.rs # WebSocket 路由 +``` + +### 设计原则 + +- **RESTful 设计**: 遵循 REST API 设计规范 +- **统一响应格式**: 所有接口返回统一的响应格式 +- **参数验证**: 在路由层进行请求参数验证 +- **权限控制**: 集成权限中间件进行访问控制 +- **错误处理**: 统一的错误处理和响应 +- **文档支持**: 支持 OpenAPI/Swagger 文档生成 + +## 总路由配置 (mod.rs) + +### 路由树结构 + +```rust +pub fn create_routes() -> Router { + Router::new() + // 认证路由 (无需认证) + .nest("/auth", auth::routes()) + + // API 路由 (需要认证) + .nest("/api", + Router::new() + .nest("/users", user::routes()) + .nest("/roles", role::routes()) + .nest("/permissions", permission::routes()) + .nest("/flows", flow::routes()) + .nest("/jobs", schedule_job::routes()) + .nest("/system", system::routes()) + .nest("/logs", log::routes()) + .nest("/notifications", notification::routes()) + .layer(AuthMiddleware::new()) + ) + + // WebSocket 路由 + .nest("/ws", websocket::routes()) + + // 静态文件路由 + .nest("/static", static_files::routes()) + + // 健康检查 + .route("/health", get(health_check)) + + // 中间件层 + .layer(CorsLayer::permissive()) + .layer(TraceLayer::new_for_http()) + .layer(CompressionLayer::new()) +} +``` + +### 中间件配置 + +```rust +/// 应用中间件 +pub fn apply_middlewares(app: Router) -> Router { + app + // 请求追踪 + .layer(TraceLayer::new_for_http()) + + // CORS 支持 + .layer(CorsLayer::new() + .allow_origin(Any) + .allow_methods([Method::GET, Method::POST, Method::PUT, Method::DELETE]) + .allow_headers([AUTHORIZATION, CONTENT_TYPE]) + ) + + // 请求压缩 + .layer(CompressionLayer::new()) + + // 请求限流 + .layer(RateLimitLayer::new(100, Duration::from_secs(60))) + + // 请求日志 + .layer(RequestLogLayer::new()) +} +``` + +## 认证路由 (auth.rs) + +### 路由定义 + +```rust +pub fn routes() -> Router { + Router::new() + .route("/login", post(login)) + .route("/logout", post(logout)) + .route("/refresh", post(refresh_token)) + .route("/register", post(register)) + .route("/forgot-password", post(forgot_password)) + .route("/reset-password", post(reset_password)) + .route("/verify-email", post(verify_email)) +} +``` + +### 接口实现 + +#### 用户登录 +```rust +/// POST /auth/login +/// 用户登录接口 +pub async fn login( + State(app_state): State, + Json(req): Json, +) -> Result>, AppError> { + // 参数验证 + req.validate().map_err(|e| AppError::BadRequest(e.to_string()))?; + + // 用户认证 + let user = user_service::authenticate( + &app_state.db, + &req.username, + &req.password, + ).await?; + + // 生成 JWT Token + let token = jwt::generate_token(&user.id, &app_state.jwt_secret)?; + let refresh_token = jwt::generate_refresh_token(&user.id, &app_state.jwt_secret)?; + + // 记录登录日志 + log_service::log_operation( + &app_state.db, + OperationLog { + user_id: user.id.clone(), + operation: "login".to_string(), + resource: "auth".to_string(), + resource_id: None, + details: json!({ "username": req.username }), + ip_address: extract_client_ip(&req), + user_agent: extract_user_agent(&req), + timestamp: now_fixed_offset(), + }, + ).await?; + + Ok(Json(ApiResponse::success(LoginResp { + user, + token, + refresh_token, + expires_in: 3600, + }))) +} +``` + +#### 请求/响应结构 + +```rust +#[derive(Debug, Deserialize, Validate)] +pub struct LoginReq { + #[validate(length(min = 3, max = 50))] + pub username: String, + #[validate(length(min = 6))] + pub password: String, + pub remember_me: Option, +} + +#[derive(Debug, Serialize)] +pub struct LoginResp { + pub user: UserDoc, + pub token: String, + pub refresh_token: String, + pub expires_in: u64, +} +``` + +## 用户管理路由 (user.rs) + +### 路由定义 + +```rust +pub fn routes() -> Router { + Router::new() + .route("/", get(list_users).post(create_user)) + .route("/:id", get(get_user).put(update_user).delete(delete_user)) + .route("/:id/roles", get(get_user_roles).put(update_user_roles)) + .route("/:id/permissions", get(get_user_permissions)) + .route("/:id/password", put(change_password)) + .route("/:id/status", put(update_user_status)) + .route("/profile", get(get_current_user).put(update_profile)) + .route("/avatar", post(upload_avatar)) +} +``` + +### 接口实现 + +#### 用户列表查询 +```rust +/// GET /api/users +/// 分页查询用户列表 +pub async fn list_users( + State(app_state): State, + Query(params): Query, + Extension(current_user): Extension, +) -> Result>>, AppError> { + // 权限检查 + permission_service::check_user_permission( + &app_state.db, + ¤t_user.id, + "user", + "read", + ).await?; + + // 构建过滤条件 + let filters = UserFilters { + username: params.username, + email: params.email, + status: params.status, + role_ids: params.role_ids, + created_after: params.created_after, + created_before: params.created_before, + }; + + // 查询用户列表 + let result = user_service::list_users( + &app_state.db, + params.page.unwrap_or(1), + params.page_size.unwrap_or(20), + Some(filters), + ).await?; + + Ok(Json(ApiResponse::success(result))) +} +``` + +#### 创建用户 +```rust +/// POST /api/users +/// 创建新用户 +pub async fn create_user( + State(app_state): State, + Extension(current_user): Extension, + Json(req): Json, +) -> Result>, AppError> { + // 参数验证 + req.validate().map_err(|e| AppError::BadRequest(e.to_string()))?; + + // 权限检查 + permission_service::check_user_permission( + &app_state.db, + ¤t_user.id, + "user", + "create", + ).await?; + + // 创建用户 + let user = user_service::create_user(&app_state.db, req).await?; + + // 记录操作日志 + log_service::log_operation( + &app_state.db, + OperationLog { + user_id: current_user.id, + operation: "create_user".to_string(), + resource: "user".to_string(), + resource_id: Some(user.id.clone()), + details: json!({ "username": user.username }), + ip_address: None, + user_agent: None, + timestamp: now_fixed_offset(), + }, + ).await?; + + Ok(Json(ApiResponse::success(user))) +} +``` + +### 查询参数结构 + +```rust +#[derive(Debug, Deserialize)] +pub struct ListUsersQuery { + pub page: Option, + pub page_size: Option, + pub username: Option, + pub email: Option, + pub status: Option, + pub role_ids: Option>, + pub created_after: Option>, + pub created_before: Option>, + pub sort_by: Option, + pub sort_order: Option, +} +``` + +## 流程管理路由 (flow.rs) + +### 路由定义 + +```rust +pub fn routes() -> Router { + Router::new() + .route("/", get(list_flows).post(create_flow)) + .route("/:id", get(get_flow).put(update_flow).delete(delete_flow)) + .route("/:id/versions", get(list_flow_versions).post(create_flow_version)) + .route("/:id/publish", post(publish_flow)) + .route("/:id/execute", post(execute_flow)) + .route("/:id/executions", get(list_flow_executions)) + .route("/executions/:execution_id", get(get_execution_detail)) + .route("/executions/:execution_id/stop", post(stop_execution)) + .route("/categories", get(list_flow_categories)) + .route("/templates", get(list_flow_templates)) +} +``` + +### 接口实现 + +#### 执行流程 +```rust +/// POST /api/flows/:id/execute +/// 执行指定流程 +pub async fn execute_flow( + State(app_state): State, + Path(flow_id): Path, + Extension(current_user): Extension, + Json(req): Json, +) -> Result>, AppError> { + // 权限检查 + permission_service::check_user_permission( + &app_state.db, + ¤t_user.id, + "flow", + "execute", + ).await?; + + // 获取流程信息 + let flow = flow_service::get_flow_by_id(&app_state.db, &flow_id).await? + .ok_or_else(|| AppError::NotFound("流程不存在".to_string()))?; + + // 检查流程状态 + if flow.status != FlowStatus::Published { + return Err(AppError::BadRequest("只能执行已发布的流程".to_string())); + } + + // 执行选项 + let options = ExecuteOptions { + max_steps: req.max_steps, + timeout_ms: req.timeout_ms, + parallel: req.parallel.unwrap_or(false), + stream_events: req.stream_events.unwrap_or(false), + }; + + // 执行流程 + let result = flow_service::execute_flow( + &app_state.db, + &flow_id, + req.input.unwrap_or(json!({})), + options, + ).await?; + + // 记录执行日志 + log_service::log_operation( + &app_state.db, + OperationLog { + user_id: current_user.id, + operation: "execute_flow".to_string(), + resource: "flow".to_string(), + resource_id: Some(flow_id), + details: json!({ + "execution_id": result.execution_id, + "status": result.status + }), + ip_address: None, + user_agent: None, + timestamp: now_fixed_offset(), + }, + ).await?; + + Ok(Json(ApiResponse::success(result))) +} +``` + +#### 请求/响应结构 + +```rust +#[derive(Debug, Deserialize, Validate)] +pub struct ExecuteFlowReq { + pub input: Option, + pub max_steps: Option, + pub timeout_ms: Option, + pub parallel: Option, + pub stream_events: Option, +} + +#[derive(Debug, Serialize)] +pub struct ExecutionResult { + pub execution_id: String, + pub status: ExecutionStatus, + pub result: Option, + pub error: Option, + pub start_time: DateTime, + pub end_time: Option>, + pub duration_ms: Option, +} +``` + +## 定时任务路由 (schedule_job.rs) + +### 路由定义 + +```rust +pub fn routes() -> Router { + Router::new() + .route("/", get(list_jobs).post(create_job)) + .route("/:id", get(get_job).put(update_job).delete(delete_job)) + .route("/:id/enable", post(enable_job)) + .route("/:id/disable", post(disable_job)) + .route("/:id/trigger", post(trigger_job)) + .route("/:id/executions", get(list_job_executions)) + .route("/executions/:execution_id", get(get_execution_detail)) + .route("/cron/validate", post(validate_cron)) + .route("/cron/next-runs", post(get_next_runs)) +} +``` + +### 接口实现 + +#### 创建定时任务 +```rust +/// POST /api/jobs +/// 创建定时任务 +pub async fn create_job( + State(app_state): State, + Extension(current_user): Extension, + Json(req): Json, +) -> Result>, AppError> { + // 参数验证 + req.validate().map_err(|e| AppError::BadRequest(e.to_string()))?; + + // 权限检查 + permission_service::check_user_permission( + &app_state.db, + ¤t_user.id, + "job", + "create", + ).await?; + + // 验证 Cron 表达式 + cron::Schedule::from_str(&req.cron_expression) + .map_err(|_| AppError::BadRequest("无效的 Cron 表达式".to_string()))?; + + // 创建任务 + let job = schedule_job_service::create_schedule_job(&app_state.db, req).await?; + + // 注册到调度器 + if job.enabled { + schedule_job_service::register_job_to_scheduler( + &app_state.scheduler, + &job, + ).await?; + } + + Ok(Json(ApiResponse::success(job))) +} +``` + +#### Cron 表达式验证 +```rust +/// POST /api/jobs/cron/validate +/// 验证 Cron 表达式 +pub async fn validate_cron( + Json(req): Json, +) -> Result>, AppError> { + let schedule = cron::Schedule::from_str(&req.cron_expression) + .map_err(|e| AppError::BadRequest(format!("无效的 Cron 表达式: {}", e)))?; + + // 计算接下来的执行时间 + let now = Utc::now(); + let next_runs: Vec> = schedule + .upcoming(Utc) + .take(5) + .collect(); + + Ok(Json(ApiResponse::success(ValidateCronResp { + valid: true, + next_runs, + description: describe_cron(&req.cron_expression), + }))) +} +``` + +## WebSocket 路由 (websocket.rs) + +### 路由定义 + +```rust +pub fn routes() -> Router { + Router::new() + .route("/flow-execution/:execution_id", get(flow_execution_ws)) + .route("/system-monitor", get(system_monitor_ws)) + .route("/notifications", get(notifications_ws)) + .route("/logs", get(logs_ws)) +} +``` + +### WebSocket 处理 + +#### 流程执行监控 +```rust +/// GET /ws/flow-execution/:execution_id +/// 流程执行实时监控 +pub async fn flow_execution_ws( + ws: WebSocketUpgrade, + Path(execution_id): Path, + Extension(current_user): Extension, + State(app_state): State, +) -> Result { + // 权限检查 + permission_service::check_user_permission( + &app_state.db, + ¤t_user.id, + "flow", + "read", + ).await?; + + Ok(ws.on_upgrade(move |socket| { + handle_flow_execution_ws(socket, execution_id, current_user, app_state) + })) +} + +async fn handle_flow_execution_ws( + socket: WebSocket, + execution_id: String, + current_user: CurrentUser, + app_state: AppState, +) { + let (mut sender, mut receiver) = socket.split(); + + // 订阅执行事件 + let mut event_stream = app_state.event_bus + .subscribe(&format!("flow_execution:{}", execution_id)) + .await; + + // 发送事件到客户端 + tokio::spawn(async move { + while let Ok(event) = event_stream.recv().await { + let message = serde_json::to_string(&event).unwrap(); + if sender.send(Message::Text(message)).await.is_err() { + break; + } + } + }); + + // 处理客户端消息 + while let Some(msg) = receiver.next().await { + match msg { + Ok(Message::Text(text)) => { + // 处理客户端命令 + if let Ok(command) = serde_json::from_str::(&text) { + handle_ws_command(command, &app_state).await; + } + } + Ok(Message::Close(_)) => break, + _ => {} + } + } +} +``` + +## 系统管理路由 (system.rs) + +### 路由定义 + +```rust +pub fn routes() -> Router { + Router::new() + .route("/info", get(get_system_info)) + .route("/status", get(get_system_status)) + .route("/health", get(health_check)) + .route("/metrics", get(get_metrics)) + .route("/config", get(get_system_config).put(update_system_config)) + .route("/cache/clear", post(clear_cache)) + .route("/backup", post(create_backup)) + .route("/restore", post(restore_backup)) +} +``` + +### 接口实现 + +#### 系统状态监控 +```rust +/// GET /api/system/status +/// 获取系统运行状态 +pub async fn get_system_status( + State(app_state): State, + Extension(current_user): Extension, +) -> Result>, AppError> { + // 权限检查 + permission_service::check_user_permission( + &app_state.db, + ¤t_user.id, + "system", + "read", + ).await?; + + // 获取系统状态 + let status = system_service::get_system_status( + &app_state.db, + &app_state.redis, + ).await?; + + Ok(Json(ApiResponse::success(status))) +} +``` + +#### 健康检查 +```rust +/// GET /health +/// 系统健康检查 +pub async fn health_check( + State(app_state): State, +) -> Result, AppError> { + let mut checks = HashMap::new(); + + // 数据库健康检查 + let db_health = system_service::check_database_health(&app_state.db).await + .unwrap_or(HealthStatus::Unhealthy); + checks.insert("database".to_string(), db_health); + + // Redis 健康检查 + let redis_health = system_service::check_redis_health(&app_state.redis).await + .unwrap_or(HealthStatus::Unhealthy); + checks.insert("redis".to_string(), redis_health); + + // 整体健康状态 + let overall_status = if checks.values().all(|&status| status == HealthStatus::Healthy) { + HealthStatus::Healthy + } else { + HealthStatus::Unhealthy + }; + + Ok(Json(HealthCheckResp { + status: overall_status, + timestamp: Utc::now(), + checks, + })) +} +``` + +## 响应格式 + +### 统一响应结构 + +```rust +#[derive(Debug, Serialize)] +pub struct ApiResponse { + pub success: bool, + pub data: Option, + pub message: String, + pub code: u32, + pub timestamp: DateTime, +} + +impl ApiResponse { + pub fn success(data: T) -> Self { + Self { + success: true, + data: Some(data), + message: "操作成功".to_string(), + code: 200, + timestamp: Utc::now(), + } + } + + pub fn error(message: String, code: u32) -> ApiResponse<()> { + ApiResponse { + success: false, + data: None, + message, + code, + timestamp: Utc::now(), + } + } +} +``` + +### 分页响应结构 + +```rust +#[derive(Debug, Serialize)] +pub struct PageResp { + pub items: Vec, + pub total: u64, + pub page: u64, + pub page_size: u64, + pub total_pages: u64, + pub has_next: bool, + pub has_prev: bool, +} +``` + +## 错误处理 + +### 错误类型定义 + +```rust +#[derive(Debug, thiserror::Error)] +pub enum AppError { + #[error("请求参数错误: {0}")] + BadRequest(String), + + #[error("未授权访问")] + Unauthorized, + + #[error("权限不足: {0}")] + Forbidden(String), + + #[error("资源不存在: {0}")] + NotFound(String), + + #[error("请求冲突: {0}")] + Conflict(String), + + #[error("请求过于频繁")] + TooManyRequests, + + #[error("内部服务器错误: {0}")] + InternalServerError(String), +} +``` + +### 错误响应处理 + +```rust +impl IntoResponse for AppError { + fn into_response(self) -> Response { + let (status, code, message) = match self { + AppError::BadRequest(msg) => (StatusCode::BAD_REQUEST, 400, msg), + AppError::Unauthorized => (StatusCode::UNAUTHORIZED, 401, "未授权访问".to_string()), + AppError::Forbidden(msg) => (StatusCode::FORBIDDEN, 403, msg), + AppError::NotFound(msg) => (StatusCode::NOT_FOUND, 404, msg), + AppError::Conflict(msg) => (StatusCode::CONFLICT, 409, msg), + AppError::TooManyRequests => (StatusCode::TOO_MANY_REQUESTS, 429, "请求过于频繁".to_string()), + AppError::InternalServerError(msg) => (StatusCode::INTERNAL_SERVER_ERROR, 500, msg), + }; + + let response = ApiResponse::<()>::error(message, code); + (status, Json(response)).into_response() + } +} +``` + +## 中间件 + +### 认证中间件 + +```rust +#[derive(Clone)] +pub struct AuthMiddleware { + jwt_secret: String, +} + +impl AuthMiddleware { + pub fn new(jwt_secret: String) -> Self { + Self { jwt_secret } + } +} + +#[async_trait] +impl FromRequestParts for CurrentUser +where + S: Send + Sync, +{ + type Rejection = AppError; + + async fn from_request_parts( + parts: &mut Parts, + state: &S, + ) -> Result { + // 从请求头获取 Token + let auth_header = parts.headers + .get(AUTHORIZATION) + .and_then(|header| header.to_str().ok()) + .ok_or(AppError::Unauthorized)?; + + // 验证 Bearer Token 格式 + let token = auth_header + .strip_prefix("Bearer ") + .ok_or(AppError::Unauthorized)?; + + // 验证 JWT Token + let claims = jwt::verify_token(token, &jwt_secret) + .map_err(|_| AppError::Unauthorized)?; + + // 获取用户信息 + let app_state = parts.extensions.get::() + .ok_or(AppError::InternalServerError("应用状态未找到".to_string()))?; + + let user = user_service::get_user_by_id(&app_state.db, &claims.user_id).await + .map_err(|e| AppError::InternalServerError(e.to_string()))? + .ok_or(AppError::Unauthorized)?; + + Ok(CurrentUser { + id: user.id, + username: user.username, + roles: user.roles, + }) + } +} +``` + +### 权限中间件 + +```rust +pub struct PermissionMiddleware { + resource: String, + action: String, +} + +impl PermissionMiddleware { + pub fn new(resource: &str, action: &str) -> Self { + Self { + resource: resource.to_string(), + action: action.to_string(), + } + } +} + +#[async_trait] +impl FromRequestParts for PermissionGuard +where + S: Send + Sync, +{ + type Rejection = AppError; + + async fn from_request_parts( + parts: &mut Parts, + state: &S, + ) -> Result { + // 获取当前用户 + let current_user = CurrentUser::from_request_parts(parts, state).await?; + + // 获取权限要求 + let permission_req = parts.extensions.get::() + .ok_or(AppError::InternalServerError("权限要求未设置".to_string()))?; + + // 检查权限 + let app_state = parts.extensions.get::() + .ok_or(AppError::InternalServerError("应用状态未找到".to_string()))?; + + let has_permission = permission_service::check_user_permission( + &app_state.db, + ¤t_user.id, + &permission_req.resource, + &permission_req.action, + ).await + .map_err(|e| AppError::InternalServerError(e.to_string()))?; + + if !has_permission { + return Err(AppError::Forbidden("权限不足".to_string())); + } + + Ok(PermissionGuard { current_user }) + } +} +``` + +### 请求日志中间件 + +```rust +pub struct RequestLogMiddleware; + +impl Layer for RequestLogMiddleware { + type Service = RequestLogService; + + fn layer(&self, inner: S) -> Self::Service { + RequestLogService { inner } + } +} + +pub struct RequestLogService { + inner: S, +} + +impl Service> for RequestLogService +where + S: Service, Response = Response> + Clone + Send + 'static, + S::Future: Send + 'static, +{ + type Response = S::Response; + type Error = S::Error; + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, request: Request) -> Self::Future { + let start_time = Instant::now(); + let method = request.method().clone(); + let uri = request.uri().clone(); + let user_agent = request.headers() + .get(USER_AGENT) + .and_then(|h| h.to_str().ok()) + .unwrap_or("") + .to_string(); + + let future = self.inner.call(request); + + Box::pin(async move { + let response = future.await?; + let duration = start_time.elapsed(); + + info!( + target = "udmin", + method = %method, + uri = %uri, + status = %response.status(), + duration_ms = %duration.as_millis(), + user_agent = %user_agent, + "http.request.completed" + ); + + Ok(response) + }) + } +} +``` + +## API 文档 + +### OpenAPI 集成 + +```rust +use utoipa::{OpenApi, ToSchema}; +use utoipa_swagger_ui::SwaggerUi; + +#[derive(OpenApi)] +#[openapi( + paths( + auth::login, + auth::logout, + user::list_users, + user::create_user, + flow::execute_flow, + ), + components( + schemas( + LoginReq, + LoginResp, + UserDoc, + CreateUserReq, + ExecuteFlowReq, + ExecutionResult, + ) + ), + tags( + (name = "auth", description = "认证相关接口"), + (name = "user", description = "用户管理接口"), + (name = "flow", description = "流程管理接口"), + ) +)] +struct ApiDoc; + +pub fn create_swagger_routes() -> Router { + Router::new() + .merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", ApiDoc::openapi())) +} +``` + +### 接口文档注解 + +```rust +#[utoipa::path( + post, + path = "/auth/login", + tag = "auth", + summary = "用户登录", + description = "用户名密码登录,返回 JWT Token", + request_body = LoginReq, + responses( + (status = 200, description = "登录成功", body = ApiResponse), + (status = 400, description = "请求参数错误", body = ApiResponse<()>), + (status = 401, description = "用户名或密码错误", body = ApiResponse<()>), + ) +)] +pub async fn login( + State(app_state): State, + Json(req): Json, +) -> Result>, AppError> { + // 实现代码... +} +``` + +## 性能优化 + +### 请求缓存 + +```rust +pub struct CacheMiddleware { + redis: RedisConnection, + ttl: Duration, +} + +impl CacheMiddleware { + pub fn new(redis: RedisConnection, ttl: Duration) -> Self { + Self { redis, ttl } + } + + async fn get_cache_key(&self, request: &Request) -> String { + let method = request.method(); + let uri = request.uri(); + let query = uri.query().unwrap_or(""); + format!("api_cache:{}:{}:{}", method, uri.path(), query) + } +} +``` + +### 响应压缩 + +```rust +pub fn create_compression_layer() -> CompressionLayer { + CompressionLayer::new() + .gzip(true) + .br(true) + .deflate(true) + .quality(CompressionLevel::Default) +} +``` + +### 请求限流 + +```rust +pub struct RateLimitMiddleware { + redis: RedisConnection, + max_requests: u32, + window: Duration, +} + +impl RateLimitMiddleware { + pub async fn check_rate_limit( + &self, + client_id: &str, + ) -> Result { + let key = format!("rate_limit:{}", client_id); + let current_count: u32 = self.redis.get(&key).await.unwrap_or(0); + + if current_count >= self.max_requests { + return Ok(false); + } + + let _: () = self.redis.incr(&key, 1).await?; + let _: () = self.redis.expire(&key, self.window.as_secs() as usize).await?; + + Ok(true) + } +} +``` + +## 测试策略 + +### 单元测试 + +```rust +#[cfg(test)] +mod tests { + use super::*; + use axum_test::TestServer; + + #[tokio::test] + async fn test_login_success() { + let app = create_test_app().await; + let server = TestServer::new(app).unwrap(); + + let response = server + .post("/auth/login") + .json(&json!({ + "username": "admin", + "password": "password123" + })) + .await; + + response.assert_status_ok(); + + let body: ApiResponse = response.json(); + assert!(body.success); + assert!(body.data.is_some()); + } + + #[tokio::test] + async fn test_login_invalid_credentials() { + let app = create_test_app().await; + let server = TestServer::new(app).unwrap(); + + let response = server + .post("/auth/login") + .json(&json!({ + "username": "admin", + "password": "wrong_password" + })) + .await; + + response.assert_status(StatusCode::UNAUTHORIZED); + } +} +``` + +### 集成测试 + +```rust +#[tokio::test] +async fn test_user_crud_flow() { + let app = create_test_app().await; + let server = TestServer::new(app).unwrap(); + + // 登录获取 Token + let login_response = server + .post("/auth/login") + .json(&json!({ + "username": "admin", + "password": "password123" + })) + .await; + + let login_body: ApiResponse = login_response.json(); + let token = login_body.data.unwrap().token; + + // 创建用户 + let create_response = server + .post("/api/users") + .add_header(AUTHORIZATION, format!("Bearer {}", token)) + .json(&json!({ + "username": "test_user", + "password": "password123", + "email": "test@example.com" + })) + .await; + + create_response.assert_status_ok(); + + let create_body: ApiResponse = create_response.json(); + let user_id = create_body.data.unwrap().id; + + // 获取用户 + let get_response = server + .get(&format!("/api/users/{}", user_id)) + .add_header(AUTHORIZATION, format!("Bearer {}", token)) + .await; + + get_response.assert_status_ok(); + + // 删除用户 + let delete_response = server + .delete(&format!("/api/users/{}", user_id)) + .add_header(AUTHORIZATION, format!("Bearer {}", token)) + .await; + + delete_response.assert_status_ok(); +} +``` + +## 最佳实践 + +### 路由设计 + +- **RESTful 风格**: 遵循 REST API 设计原则 +- **资源命名**: 使用复数名词表示资源 +- **HTTP 方法**: 正确使用 GET、POST、PUT、DELETE +- **状态码**: 返回合适的 HTTP 状态码 + +### 参数验证 + +- **输入验证**: 使用 validator 进行参数验证 +- **类型安全**: 使用强类型结构体 +- **错误信息**: 提供清晰的验证错误信息 +- **安全过滤**: 过滤恶意输入 + +### 错误处理 + +- **统一格式**: 使用统一的错误响应格式 +- **错误分类**: 合理分类不同类型的错误 +- **日志记录**: 记录详细的错误日志 +- **用户友好**: 提供用户友好的错误信息 + +### 安全考虑 + +- **认证授权**: 实现完善的认证授权机制 +- **输入验证**: 严格验证所有输入参数 +- **HTTPS**: 在生产环境使用 HTTPS +- **CORS**: 正确配置 CORS 策略 \ No newline at end of file diff --git a/docs/SERVICES.md b/docs/SERVICES.md new file mode 100644 index 0000000..4a31824 --- /dev/null +++ b/docs/SERVICES.md @@ -0,0 +1,853 @@ +# 服务层文档 + +## 概述 + +服务层是 UdminAI 的业务逻辑核心,负责处理各种业务操作,包括用户管理、权限控制、流程管理、定时任务、系统监控等功能。 + +## 架构设计 + +### 服务模块结构 + +``` +services/ +├── mod.rs # 服务模块导出 +├── user_service.rs # 用户服务 +├── role_service.rs # 角色服务 +├── permission_service.rs # 权限服务 +├── flow_service.rs # 流程服务 +├── schedule_job_service.rs # 定时任务服务 +├── system_service.rs # 系统服务 +├── log_service.rs # 日志服务 +└── notification_service.rs # 通知服务 +``` + +### 设计原则 + +- **单一职责**: 每个服务专注于特定业务领域 +- **依赖注入**: 通过参数传递数据库连接等依赖 +- **错误处理**: 统一的错误处理和返回格式 +- **异步支持**: 所有服务方法都是异步的 +- **事务支持**: 支持数据库事务操作 + +## 用户服务 (user_service.rs) + +### 核心功能 + +#### 1. 用户认证 +```rust +/// 用户登录验证 +pub async fn authenticate( + db: &DatabaseConnection, + username: &str, + password: &str, +) -> Result +``` + +**功能**: +- 用户名/密码验证 +- 密码哈希比较 +- 登录状态更新 +- 登录日志记录 + +#### 2. 用户管理 +```rust +/// 创建用户 +pub async fn create_user( + db: &DatabaseConnection, + req: CreateUserReq, +) -> Result + +/// 更新用户信息 +pub async fn update_user( + db: &DatabaseConnection, + id: &str, + req: UpdateUserReq, +) -> Result + +/// 删除用户 +pub async fn delete_user( + db: &DatabaseConnection, + id: &str, +) -> Result<(), AppError> +``` + +#### 3. 用户查询 +```rust +/// 分页查询用户列表 +pub async fn list_users( + db: &DatabaseConnection, + page: u64, + page_size: u64, + filters: Option, +) -> Result, AppError> + +/// 根据ID获取用户 +pub async fn get_user_by_id( + db: &DatabaseConnection, + id: &str, +) -> Result, AppError> +``` + +### 数据传输对象 + +#### UserDoc - 用户文档 +```rust +#[derive(Debug, Serialize, Deserialize)] +pub struct UserDoc { + pub id: String, + pub username: String, + pub email: Option, + pub display_name: Option, + pub avatar: Option, + pub status: UserStatus, + pub roles: Vec, + pub created_at: DateTime, + pub updated_at: DateTime, +} +``` + +#### CreateUserReq - 创建用户请求 +```rust +#[derive(Debug, Deserialize, Validate)] +pub struct CreateUserReq { + #[validate(length(min = 3, max = 50))] + pub username: String, + #[validate(length(min = 6))] + pub password: String, + #[validate(email)] + pub email: Option, + pub display_name: Option, + pub roles: Vec, +} +``` + +### 安全特性 + +- **密码加密**: 使用 bcrypt 进行密码哈希 +- **输入验证**: 使用 validator 进行参数验证 +- **权限检查**: 集成权限服务进行访问控制 +- **审计日志**: 记录用户操作日志 + +## 角色服务 (role_service.rs) + +### 核心功能 + +#### 1. 角色管理 +```rust +/// 创建角色 +pub async fn create_role( + db: &DatabaseConnection, + req: CreateRoleReq, +) -> Result + +/// 更新角色 +pub async fn update_role( + db: &DatabaseConnection, + id: &str, + req: UpdateRoleReq, +) -> Result + +/// 删除角色 +pub async fn delete_role( + db: &DatabaseConnection, + id: &str, +) -> Result<(), AppError> +``` + +#### 2. 权限分配 +```rust +/// 为角色分配权限 +pub async fn assign_permissions( + db: &DatabaseConnection, + role_id: &str, + permission_ids: Vec, +) -> Result<(), AppError> + +/// 移除角色权限 +pub async fn remove_permissions( + db: &DatabaseConnection, + role_id: &str, + permission_ids: Vec, +) -> Result<(), AppError> +``` + +#### 3. 用户角色管理 +```rust +/// 为用户分配角色 +pub async fn assign_user_roles( + db: &DatabaseConnection, + user_id: &str, + role_ids: Vec, +) -> Result<(), AppError> + +/// 获取用户角色 +pub async fn get_user_roles( + db: &DatabaseConnection, + user_id: &str, +) -> Result, AppError> +``` + +### 数据传输对象 + +#### RoleDoc - 角色文档 +```rust +#[derive(Debug, Serialize, Deserialize)] +pub struct RoleDoc { + pub id: String, + pub name: String, + pub description: Option, + pub permissions: Vec, + pub is_system: bool, + pub created_at: DateTime, + pub updated_at: DateTime, +} +``` + +## 权限服务 (permission_service.rs) + +### 核心功能 + +#### 1. 权限检查 +```rust +/// 检查用户权限 +pub async fn check_user_permission( + db: &DatabaseConnection, + user_id: &str, + resource: &str, + action: &str, +) -> Result + +/// 检查角色权限 +pub async fn check_role_permission( + db: &DatabaseConnection, + role_id: &str, + resource: &str, + action: &str, +) -> Result +``` + +#### 2. 权限管理 +```rust +/// 创建权限 +pub async fn create_permission( + db: &DatabaseConnection, + req: CreatePermissionReq, +) -> Result + +/// 获取权限树 +pub async fn get_permission_tree( + db: &DatabaseConnection, +) -> Result, AppError> +``` + +### 权限模型 + +#### 资源-动作模型 +- **资源 (Resource)**: 系统中的实体 (user, role, flow, job) +- **动作 (Action)**: 对资源的操作 (create, read, update, delete) +- **权限 (Permission)**: 资源和动作的组合 + +#### 权限继承 +- 角色权限继承 +- 用户权限继承 +- 权限组合计算 + +## 流程服务 (flow_service.rs) + +### 核心功能 + +#### 1. 流程管理 +```rust +/// 创建流程 +pub async fn create_flow( + db: &DatabaseConnection, + req: CreateFlowReq, +) -> Result + +/// 更新流程 +pub async fn update_flow( + db: &DatabaseConnection, + id: &str, + req: UpdateFlowReq, +) -> Result + +/// 发布流程 +pub async fn publish_flow( + db: &DatabaseConnection, + id: &str, +) -> Result +``` + +#### 2. 流程执行 +```rust +/// 执行流程 +pub async fn execute_flow( + db: &DatabaseConnection, + flow_id: &str, + input: serde_json::Value, + options: ExecuteOptions, +) -> Result + +/// 获取执行状态 +pub async fn get_execution_status( + db: &DatabaseConnection, + execution_id: &str, +) -> Result +``` + +#### 3. 流程版本管理 +```rust +/// 创建流程版本 +pub async fn create_flow_version( + db: &DatabaseConnection, + flow_id: &str, + version_data: FlowVersionData, +) -> Result + +/// 获取流程版本列表 +pub async fn list_flow_versions( + db: &DatabaseConnection, + flow_id: &str, +) -> Result, AppError> +``` + +### 数据传输对象 + +#### FlowDoc - 流程文档 +```rust +#[derive(Debug, Serialize, Deserialize)] +pub struct FlowDoc { + pub id: String, + pub name: String, + pub description: Option, + pub category: String, + pub status: FlowStatus, + pub version: String, + pub design: serde_json::Value, + pub created_by: String, + pub created_at: DateTime, + pub updated_at: DateTime, +} +``` + +#### ExecutionResult - 执行结果 +```rust +#[derive(Debug, Serialize, Deserialize)] +pub struct ExecutionResult { + pub execution_id: String, + pub status: ExecutionStatus, + pub result: Option, + pub error: Option, + pub start_time: DateTime, + pub end_time: Option>, + pub duration_ms: Option, +} +``` + +## 定时任务服务 (schedule_job_service.rs) + +### 核心功能 + +#### 1. 任务管理 +```rust +/// 创建定时任务 +pub async fn create_schedule_job( + db: &DatabaseConnection, + req: CreateScheduleJobReq, +) -> Result + +/// 更新任务 +pub async fn update_schedule_job( + db: &DatabaseConnection, + id: &str, + req: UpdateScheduleJobReq, +) -> Result + +/// 启用/禁用任务 +pub async fn toggle_job_status( + db: &DatabaseConnection, + id: &str, + enabled: bool, +) -> Result +``` + +#### 2. 任务调度 +```rust +/// 注册任务到调度器 +pub async fn register_job_to_scheduler( + scheduler: &JobScheduler, + job: &ScheduleJobDoc, +) -> Result<(), AppError> + +/// 从调度器移除任务 +pub async fn unregister_job_from_scheduler( + scheduler: &JobScheduler, + job_id: &str, +) -> Result<(), AppError> +``` + +#### 3. 执行历史 +```rust +/// 记录任务执行 +pub async fn record_job_execution( + db: &DatabaseConnection, + execution: JobExecutionRecord, +) -> Result<(), AppError> + +/// 获取执行历史 +pub async fn get_job_execution_history( + db: &DatabaseConnection, + job_id: &str, + page: u64, + page_size: u64, +) -> Result, AppError> +``` + +### 调度特性 + +- **Cron 表达式**: 支持标准 Cron 表达式 +- **时区支持**: 支持不同时区的任务调度 +- **并发控制**: 防止任务重复执行 +- **失败重试**: 支持任务失败重试 +- **执行超时**: 支持任务执行超时控制 + +## 系统服务 (system_service.rs) + +### 核心功能 + +#### 1. 系统信息 +```rust +/// 获取系统信息 +pub async fn get_system_info() -> Result + +/// 获取系统状态 +pub async fn get_system_status( + db: &DatabaseConnection, + redis: &RedisConnection, +) -> Result +``` + +#### 2. 健康检查 +```rust +/// 数据库健康检查 +pub async fn check_database_health( + db: &DatabaseConnection, +) -> Result + +/// Redis 健康检查 +pub async fn check_redis_health( + redis: &RedisConnection, +) -> Result +``` + +#### 3. 系统配置 +```rust +/// 获取系统配置 +pub async fn get_system_config( + db: &DatabaseConnection, +) -> Result + +/// 更新系统配置 +pub async fn update_system_config( + db: &DatabaseConnection, + config: SystemConfig, +) -> Result<(), AppError> +``` + +### 监控指标 + +#### SystemStatus - 系统状态 +```rust +#[derive(Debug, Serialize, Deserialize)] +pub struct SystemStatus { + pub uptime: u64, + pub memory_usage: MemoryUsage, + pub cpu_usage: f64, + pub disk_usage: DiskUsage, + pub database_status: HealthStatus, + pub redis_status: HealthStatus, + pub active_connections: u32, + pub request_count: u64, + pub error_count: u64, +} +``` + +## 日志服务 (log_service.rs) + +### 核心功能 + +#### 1. 日志记录 +```rust +/// 记录操作日志 +pub async fn log_operation( + db: &DatabaseConnection, + log: OperationLog, +) -> Result<(), AppError> + +/// 记录系统日志 +pub async fn log_system_event( + db: &DatabaseConnection, + event: SystemEvent, +) -> Result<(), AppError> +``` + +#### 2. 日志查询 +```rust +/// 查询操作日志 +pub async fn query_operation_logs( + db: &DatabaseConnection, + filters: LogFilters, + page: u64, + page_size: u64, +) -> Result, AppError> + +/// 查询系统日志 +pub async fn query_system_logs( + db: &DatabaseConnection, + filters: LogFilters, + page: u64, + page_size: u64, +) -> Result, AppError> +``` + +#### 3. 日志分析 +```rust +/// 获取日志统计 +pub async fn get_log_statistics( + db: &DatabaseConnection, + time_range: TimeRange, +) -> Result + +/// 获取错误日志趋势 +pub async fn get_error_log_trend( + db: &DatabaseConnection, + time_range: TimeRange, +) -> Result, AppError> +``` + +### 日志类型 + +#### OperationLog - 操作日志 +```rust +#[derive(Debug, Serialize, Deserialize)] +pub struct OperationLog { + pub user_id: String, + pub operation: String, + pub resource: String, + pub resource_id: Option, + pub details: serde_json::Value, + pub ip_address: Option, + pub user_agent: Option, + pub timestamp: DateTime, +} +``` + +#### SystemEvent - 系统事件 +```rust +#[derive(Debug, Serialize, Deserialize)] +pub struct SystemEvent { + pub event_type: String, + pub level: LogLevel, + pub message: String, + pub details: serde_json::Value, + pub source: String, + pub timestamp: DateTime, +} +``` + +## 通知服务 (notification_service.rs) + +### 核心功能 + +#### 1. 通知发送 +```rust +/// 发送邮件通知 +pub async fn send_email_notification( + config: &EmailConfig, + notification: EmailNotification, +) -> Result<(), AppError> + +/// 发送短信通知 +pub async fn send_sms_notification( + config: &SmsConfig, + notification: SmsNotification, +) -> Result<(), AppError> + +/// 发送系统通知 +pub async fn send_system_notification( + db: &DatabaseConnection, + notification: SystemNotification, +) -> Result<(), AppError> +``` + +#### 2. 通知模板 +```rust +/// 创建通知模板 +pub async fn create_notification_template( + db: &DatabaseConnection, + template: NotificationTemplate, +) -> Result + +/// 渲染通知模板 +pub async fn render_notification_template( + template: &NotificationTemplate, + variables: &serde_json::Value, +) -> Result +``` + +#### 3. 通知历史 +```rust +/// 记录通知历史 +pub async fn record_notification_history( + db: &DatabaseConnection, + history: NotificationHistory, +) -> Result<(), AppError> + +/// 查询通知历史 +pub async fn query_notification_history( + db: &DatabaseConnection, + filters: NotificationFilters, + page: u64, + page_size: u64, +) -> Result, AppError> +``` + +### 通知渠道 + +- **邮件通知**: SMTP 邮件发送 +- **短信通知**: SMS 短信发送 +- **系统通知**: 站内消息通知 +- **Webhook**: HTTP 回调通知 +- **推送通知**: 移动端推送 + +## 服务集成 + +### 依赖注入 + +```rust +// 服务依赖注入示例 +pub struct ServiceContainer { + pub db: DatabaseConnection, + pub redis: RedisConnection, + pub scheduler: JobScheduler, + pub email_config: EmailConfig, + pub sms_config: SmsConfig, +} + +impl ServiceContainer { + pub fn user_service(&self) -> UserService { + UserService::new(&self.db) + } + + pub fn flow_service(&self) -> FlowService { + FlowService::new(&self.db, &self.redis) + } +} +``` + +### 事务管理 + +```rust +/// 事务执行示例 +pub async fn create_user_with_roles( + db: &DatabaseConnection, + user_req: CreateUserReq, + role_ids: Vec, +) -> Result { + let txn = db.begin().await?; + + // 创建用户 + let user = user_service::create_user(&txn, user_req).await?; + + // 分配角色 + role_service::assign_user_roles(&txn, &user.id, role_ids).await?; + + txn.commit().await?; + Ok(user) +} +``` + +### 缓存策略 + +```rust +/// 缓存使用示例 +pub async fn get_user_with_cache( + db: &DatabaseConnection, + redis: &RedisConnection, + user_id: &str, +) -> Result, AppError> { + // 先从缓存获取 + if let Some(cached_user) = redis.get(&format!("user:{}", user_id)).await? { + return Ok(Some(serde_json::from_str(&cached_user)?)); + } + + // 从数据库获取 + if let Some(user) = user_service::get_user_by_id(db, user_id).await? { + // 写入缓存 + redis.setex( + &format!("user:{}", user_id), + 3600, // 1小时过期 + &serde_json::to_string(&user)?, + ).await?; + Ok(Some(user)) + } else { + Ok(None) + } +} +``` + +## 错误处理 + +### 统一错误类型 + +```rust +#[derive(Debug, thiserror::Error)] +pub enum ServiceError { + #[error("数据库错误: {0}")] + DatabaseError(#[from] sea_orm::DbErr), + + #[error("验证错误: {0}")] + ValidationError(String), + + #[error("权限不足")] + PermissionDenied, + + #[error("资源不存在: {0}")] + ResourceNotFound(String), + + #[error("业务逻辑错误: {0}")] + BusinessLogicError(String), +} +``` + +### 错误处理模式 + +```rust +/// 统一错误处理 +pub async fn handle_service_result( + result: Result, +) -> Result { + match result { + Ok(value) => Ok(value), + Err(ServiceError::ValidationError(msg)) => { + Err(AppError::BadRequest(msg)) + }, + Err(ServiceError::PermissionDenied) => { + Err(AppError::Forbidden("权限不足".to_string())) + }, + Err(ServiceError::ResourceNotFound(resource)) => { + Err(AppError::NotFound(format!("{}不存在", resource))) + }, + Err(e) => Err(AppError::InternalServerError(e.to_string())), + } +} +``` + +## 性能优化 + +### 数据库优化 + +- **连接池**: 使用数据库连接池 +- **查询优化**: 优化 SQL 查询语句 +- **索引使用**: 合理使用数据库索引 +- **批量操作**: 使用批量插入/更新 + +### 缓存优化 + +- **热点数据缓存**: 缓存频繁访问的数据 +- **查询结果缓存**: 缓存复杂查询结果 +- **缓存预热**: 系统启动时预加载缓存 +- **缓存更新**: 及时更新过期缓存 + +### 并发优化 + +- **异步处理**: 使用异步编程模型 +- **并发控制**: 合理控制并发数量 +- **锁优化**: 减少锁的使用和持有时间 +- **无锁设计**: 使用无锁数据结构 + +## 测试策略 + +### 单元测试 + +```rust +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_create_user() { + let db = setup_test_db().await; + let req = CreateUserReq { + username: "test_user".to_string(), + password: "password123".to_string(), + email: Some("test@example.com".to_string()), + display_name: None, + roles: vec![], + }; + + let result = create_user(&db, req).await; + assert!(result.is_ok()); + + let user = result.unwrap(); + assert_eq!(user.username, "test_user"); + } +} +``` + +### 集成测试 + +```rust +#[tokio::test] +async fn test_user_role_integration() { + let container = setup_test_container().await; + + // 创建角色 + let role = create_test_role(&container.db).await; + + // 创建用户 + let user = create_test_user(&container.db).await; + + // 分配角色 + let result = role_service::assign_user_roles( + &container.db, + &user.id, + vec![role.id.clone()], + ).await; + + assert!(result.is_ok()); + + // 验证权限 + let has_permission = permission_service::check_user_permission( + &container.db, + &user.id, + "user", + "read", + ).await.unwrap(); + + assert!(has_permission); +} +``` + +## 最佳实践 + +### 服务设计 + +- **接口设计**: 设计清晰的服务接口 +- **参数验证**: 严格验证输入参数 +- **返回值**: 统一返回值格式 +- **文档注释**: 为公开方法添加文档注释 + +### 数据处理 + +- **数据验证**: 在服务层进行数据验证 +- **数据转换**: 合理进行数据类型转换 +- **数据清理**: 及时清理无用数据 +- **数据备份**: 重要操作前备份数据 + +### 安全考虑 + +- **权限检查**: 在服务层进行权限检查 +- **输入过滤**: 过滤恶意输入 +- **敏感数据**: 保护敏感数据不泄露 +- **审计日志**: 记录重要操作的审计日志 \ No newline at end of file diff --git a/docs/UTILS.md b/docs/UTILS.md new file mode 100644 index 0000000..cfe981e --- /dev/null +++ b/docs/UTILS.md @@ -0,0 +1,1271 @@ +# UdminAI 工具函数模块文档 + +## 概述 + +UdminAI 项目的工具函数模块(Utils)提供了一系列通用的工具函数和辅助功能,支持系统的各个组件。这些工具函数涵盖了 ID 生成、时间处理、加密解密、验证、格式化、文件操作等多个方面,为整个系统提供了基础的功能支持。 + +## 模块结构 + +``` +backend/src/utils/ +├── mod.rs # 模块导出 +├── id.rs # ID 生成工具 +├── time.rs # 时间处理工具 +├── crypto.rs # 加密解密工具 +├── validation.rs # 验证工具 +├── format.rs # 格式化工具 +├── file.rs # 文件操作工具 +├── hash.rs # 哈希工具 +├── jwt.rs # JWT 工具 +├── password.rs # 密码处理工具 +├── email.rs # 邮件工具 +├── config.rs # 配置工具 +└── error.rs # 错误处理工具 +``` + +## 核心工具模块 + +### ID 生成工具 (id.rs) + +#### 功能特性 + +- 唯一 ID 生成 +- 多种 ID 格式支持 +- 高性能生成 +- 分布式友好 + +#### 实现代码 + +```rust +use chrono::{DateTime, Utc}; +use rand::{thread_rng, Rng}; +use std::sync::atomic::{AtomicU64, Ordering}; +use uuid::Uuid; + +/// 全局序列号计数器 +static SEQUENCE_COUNTER: AtomicU64 = AtomicU64::new(0); + +/// ID 生成器类型 +#[derive(Debug, Clone)] +pub enum IdType { + /// UUID v4 + Uuid, + /// 雪花算法 ID + Snowflake, + /// 时间戳 + 随机数 + Timestamp, + /// 自定义前缀 + UUID + Prefixed(String), +} + +/// 生成唯一 ID +pub fn generate_id() -> String { + generate_id_with_type(IdType::Uuid) +} + +/// 根据类型生成 ID +pub fn generate_id_with_type(id_type: IdType) -> String { + match id_type { + IdType::Uuid => Uuid::new_v4().to_string(), + IdType::Snowflake => generate_snowflake_id(), + IdType::Timestamp => generate_timestamp_id(), + IdType::Prefixed(prefix) => format!("{}-{}", prefix, Uuid::new_v4()), + } +} + +/// 生成雪花算法 ID +fn generate_snowflake_id() -> String { + let timestamp = Utc::now().timestamp_millis() as u64; + let sequence = SEQUENCE_COUNTER.fetch_add(1, Ordering::SeqCst) & 0xFFF; // 12位序列号 + let machine_id = 1u64; // 机器ID,实际使用时应该从配置获取 + + let id = (timestamp << 22) | (machine_id << 12) | sequence; + id.to_string() +} + +/// 生成时间戳 ID +fn generate_timestamp_id() -> String { + let timestamp = Utc::now().timestamp_millis(); + let random: u32 = thread_rng().gen(); + format!("{}{:08x}", timestamp, random) +} + +/// 生成短 ID(8位) +pub fn generate_short_id() -> String { + let chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"; + let mut rng = thread_rng(); + (0..8) + .map(|_| { + let idx = rng.gen_range(0..chars.len()); + chars.chars().nth(idx).unwrap() + }) + .collect() +} + +/// 生成数字 ID +pub fn generate_numeric_id(length: usize) -> String { + let mut rng = thread_rng(); + (0..length) + .map(|_| rng.gen_range(0..10).to_string()) + .collect() +} + +/// 验证 ID 格式 +pub fn validate_id(id: &str, id_type: IdType) -> bool { + match id_type { + IdType::Uuid => Uuid::parse_str(id).is_ok(), + IdType::Snowflake => id.parse::().is_ok(), + IdType::Timestamp => id.len() >= 13 && id[..13].parse::().is_ok(), + IdType::Prefixed(prefix) => { + id.starts_with(&format!("{}-", prefix)) && + id.len() > prefix.len() + 1 && + Uuid::parse_str(&id[prefix.len() + 1..]).is_ok() + } + } +} + +/// 从雪花 ID 提取时间戳 +pub fn extract_timestamp_from_snowflake(id: &str) -> Option> { + if let Ok(snowflake_id) = id.parse::() { + let timestamp = (snowflake_id >> 22) as i64; + DateTime::from_timestamp_millis(timestamp) + } else { + None + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_generate_uuid() { + let id = generate_id(); + assert!(validate_id(&id, IdType::Uuid)); + } + + #[test] + fn test_generate_snowflake() { + let id = generate_id_with_type(IdType::Snowflake); + assert!(validate_id(&id, IdType::Snowflake)); + } + + #[test] + fn test_generate_prefixed_id() { + let id = generate_id_with_type(IdType::Prefixed("user".to_string())); + assert!(id.starts_with("user-")); + assert!(validate_id(&id, IdType::Prefixed("user".to_string()))); + } + + #[test] + fn test_short_id() { + let id = generate_short_id(); + assert_eq!(id.len(), 8); + } +} +``` + +### 时间处理工具 (time.rs) + +#### 功能特性 + +- 时间格式化 +- 时区转换 +- 时间计算 +- 相对时间 + +#### 实现代码 + +```rust +use chrono::{ + DateTime, Duration, FixedOffset, Local, NaiveDateTime, TimeZone, Utc, +}; +use serde::{Deserialize, Serialize}; +use std::fmt; + +/// 时间格式常量 +pub const DEFAULT_DATETIME_FORMAT: &str = "%Y-%m-%d %H:%M:%S"; +pub const ISO_DATETIME_FORMAT: &str = "%Y-%m-%dT%H:%M:%S%.3fZ"; +pub const DATE_FORMAT: &str = "%Y-%m-%d"; +pub const TIME_FORMAT: &str = "%H:%M:%S"; + +/// 时区偏移量(东八区) +pub const CHINA_OFFSET: i32 = 8 * 3600; + +/// 获取当前时间(UTC) +pub fn now_utc() -> DateTime { + Utc::now() +} + +/// 获取当前时间(带固定偏移量) +pub fn now_fixed_offset() -> DateTime { + Utc::now().with_timezone(&FixedOffset::east_opt(0).unwrap()) +} + +/// 获取当前时间(中国时区) +pub fn now_china() -> DateTime { + Utc::now().with_timezone(&FixedOffset::east_opt(CHINA_OFFSET).unwrap()) +} + +/// 获取当前本地时间 +pub fn now_local() -> DateTime { + Local::now() +} + +/// 格式化时间 +pub fn format_datetime(dt: &DateTime, format: &str) -> String { + dt.format(format).to_string() +} + +/// 格式化时间(默认格式) +pub fn format_datetime_default(dt: &DateTime) -> String { + format_datetime(dt, DEFAULT_DATETIME_FORMAT) +} + +/// 格式化时间(ISO 格式) +pub fn format_datetime_iso(dt: &DateTime) -> String { + format_datetime(dt, ISO_DATETIME_FORMAT) +} + +/// 解析时间字符串 +pub fn parse_datetime(s: &str, format: &str) -> Result, chrono::ParseError> { + let naive = NaiveDateTime::parse_from_str(s, format)?; + Ok(Utc.from_utc_datetime(&naive)) +} + +/// 解析 ISO 时间字符串 +pub fn parse_datetime_iso(s: &str) -> Result, chrono::ParseError> { + DateTime::parse_from_rfc3339(s).map(|dt| dt.with_timezone(&Utc)) +} + +/// 时间戳转换为 DateTime +pub fn timestamp_to_datetime(timestamp: i64) -> Option> { + DateTime::from_timestamp(timestamp, 0) +} + +/// 毫秒时间戳转换为 DateTime +pub fn timestamp_millis_to_datetime(timestamp: i64) -> Option> { + DateTime::from_timestamp_millis(timestamp) +} + +/// DateTime 转换为时间戳 +pub fn datetime_to_timestamp(dt: &DateTime) -> i64 { + dt.timestamp() +} + +/// DateTime 转换为毫秒时间戳 +pub fn datetime_to_timestamp_millis(dt: &DateTime) -> i64 { + dt.timestamp_millis() +} + +/// 计算时间差 +pub fn time_diff(start: &DateTime, end: &DateTime) -> Duration { + *end - *start +} + +/// 计算时间差(秒) +pub fn time_diff_seconds(start: &DateTime, end: &DateTime) -> i64 { + time_diff(start, end).num_seconds() +} + +/// 计算时间差(毫秒) +pub fn time_diff_millis(start: &DateTime, end: &DateTime) -> i64 { + time_diff(start, end).num_milliseconds() +} + +/// 时间加法 +pub fn add_duration(dt: &DateTime, duration: Duration) -> DateTime { + *dt + duration +} + +/// 时间减法 +pub fn sub_duration(dt: &DateTime, duration: Duration) -> DateTime { + *dt - duration +} + +/// 添加秒数 +pub fn add_seconds(dt: &DateTime, seconds: i64) -> DateTime { + add_duration(dt, Duration::seconds(seconds)) +} + +/// 添加分钟 +pub fn add_minutes(dt: &DateTime, minutes: i64) -> DateTime { + add_duration(dt, Duration::minutes(minutes)) +} + +/// 添加小时 +pub fn add_hours(dt: &DateTime, hours: i64) -> DateTime { + add_duration(dt, Duration::hours(hours)) +} + +/// 添加天数 +pub fn add_days(dt: &DateTime, days: i64) -> DateTime { + add_duration(dt, Duration::days(days)) +} + +/// 获取今天开始时间 +pub fn today_start() -> DateTime { + let now = now_utc(); + now.date_naive().and_hms_opt(0, 0, 0).unwrap().and_utc() +} + +/// 获取今天结束时间 +pub fn today_end() -> DateTime { + let now = now_utc(); + now.date_naive().and_hms_opt(23, 59, 59).unwrap().and_utc() +} + +/// 获取本周开始时间(周一) +pub fn week_start() -> DateTime { + let now = now_utc(); + let weekday = now.weekday().num_days_from_monday(); + let start_date = now.date_naive() - Duration::days(weekday as i64); + start_date.and_hms_opt(0, 0, 0).unwrap().and_utc() +} + +/// 获取本月开始时间 +pub fn month_start() -> DateTime { + let now = now_utc(); + let start_date = now.date_naive().with_day(1).unwrap(); + start_date.and_hms_opt(0, 0, 0).unwrap().and_utc() +} + +/// 相对时间描述 +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RelativeTime { + pub value: i64, + pub unit: TimeUnit, + pub description: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum TimeUnit { + Second, + Minute, + Hour, + Day, + Week, + Month, + Year, +} + +impl fmt::Display for TimeUnit { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let s = match self { + TimeUnit::Second => "秒", + TimeUnit::Minute => "分钟", + TimeUnit::Hour => "小时", + TimeUnit::Day => "天", + TimeUnit::Week => "周", + TimeUnit::Month => "月", + TimeUnit::Year => "年", + }; + write!(f, "{}", s) + } +} + +/// 计算相对时间 +pub fn relative_time(dt: &DateTime) -> RelativeTime { + let now = now_utc(); + let diff = time_diff(dt, &now); + + let abs_seconds = diff.num_seconds().abs(); + let is_past = diff.num_seconds() < 0; + + let (value, unit) = if abs_seconds < 60 { + (abs_seconds, TimeUnit::Second) + } else if abs_seconds < 3600 { + (abs_seconds / 60, TimeUnit::Minute) + } else if abs_seconds < 86400 { + (abs_seconds / 3600, TimeUnit::Hour) + } else if abs_seconds < 604800 { + (abs_seconds / 86400, TimeUnit::Day) + } else if abs_seconds < 2592000 { + (abs_seconds / 604800, TimeUnit::Week) + } else if abs_seconds < 31536000 { + (abs_seconds / 2592000, TimeUnit::Month) + } else { + (abs_seconds / 31536000, TimeUnit::Year) + }; + + let description = if is_past { + format!("{}{} 前", value, unit) + } else { + format!("{}{} 后", value, unit) + }; + + RelativeTime { + value, + unit, + description, + } +} + +/// 判断是否为同一天 +pub fn is_same_day(dt1: &DateTime, dt2: &DateTime) -> bool { + dt1.date_naive() == dt2.date_naive() +} + +/// 判断是否为今天 +pub fn is_today(dt: &DateTime) -> bool { + is_same_day(dt, &now_utc()) +} + +/// 判断是否为昨天 +pub fn is_yesterday(dt: &DateTime) -> bool { + let yesterday = sub_duration(&now_utc(), Duration::days(1)); + is_same_day(dt, &yesterday) +} + +/// 判断是否为本周 +pub fn is_this_week(dt: &DateTime) -> bool { + let week_start = week_start(); + let week_end = add_days(&week_start, 7); + dt >= &week_start && dt < &week_end +} + +/// 判断是否为本月 +pub fn is_this_month(dt: &DateTime) -> bool { + let now = now_utc(); + dt.year() == now.year() && dt.month() == now.month() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_time_formatting() { + let dt = Utc::now(); + let formatted = format_datetime_default(&dt); + assert!(!formatted.is_empty()); + } + + #[test] + fn test_time_parsing() { + let time_str = "2023-12-25 10:30:00"; + let parsed = parse_datetime(time_str, DEFAULT_DATETIME_FORMAT); + assert!(parsed.is_ok()); + } + + #[test] + fn test_timestamp_conversion() { + let dt = Utc::now(); + let timestamp = datetime_to_timestamp(&dt); + let converted = timestamp_to_datetime(timestamp).unwrap(); + assert_eq!(dt.timestamp(), converted.timestamp()); + } + + #[test] + fn test_relative_time() { + let past = sub_duration(&now_utc(), Duration::hours(2)); + let rel_time = relative_time(&past); + assert!(rel_time.description.contains("前")); + } + + #[test] + fn test_date_checks() { + let now = now_utc(); + assert!(is_today(&now)); + + let yesterday = sub_duration(&now, Duration::days(1)); + assert!(is_yesterday(&yesterday)); + } +} +``` + +### 加密解密工具 (crypto.rs) + +#### 功能特性 + +- AES 加密解密 +- RSA 加密解密 +- 数字签名 +- 密钥生成 + +#### 实现代码 + +```rust +use aes_gcm::{ + aead::{Aead, KeyInit, OsRng}, + Aes256Gcm, Key, Nonce, +}; +use base64::{engine::general_purpose, Engine as _}; +use rand::{thread_rng, RngCore}; +use rsa::{ + pkcs1::{DecodeRsaPrivateKey, DecodeRsaPublicKey, EncodeRsaPrivateKey, EncodeRsaPublicKey}, + pkcs8::{DecodePrivateKey, DecodePublicKey, EncodePrivateKey, EncodePublicKey}, + Pkcs1v15Encrypt, RsaPrivateKey, RsaPublicKey, +}; +use sha2::{Digest, Sha256}; +use std::error::Error; +use thiserror::Error; + +/// 加密错误类型 +#[derive(Error, Debug)] +pub enum CryptoError { + #[error("加密失败: {0}")] + EncryptionFailed(String), + #[error("解密失败: {0}")] + DecryptionFailed(String), + #[error("密钥生成失败: {0}")] + KeyGenerationFailed(String), + #[error("签名失败: {0}")] + SigningFailed(String), + #[error("验证失败: {0}")] + VerificationFailed(String), + #[error("编码失败: {0}")] + EncodingFailed(String), +} + +/// AES 加密工具 +pub struct AesEncryption { + cipher: Aes256Gcm, +} + +impl AesEncryption { + /// 创建新的 AES 加密器 + pub fn new(key: &[u8; 32]) -> Self { + let key = Key::::from_slice(key); + let cipher = Aes256Gcm::new(key); + Self { cipher } + } + + /// 从密码创建 AES 加密器 + pub fn from_password(password: &str) -> Self { + let key = Self::derive_key_from_password(password); + Self::new(&key) + } + + /// 从密码派生密钥 + fn derive_key_from_password(password: &str) -> [u8; 32] { + let mut hasher = Sha256::new(); + hasher.update(password.as_bytes()); + hasher.finalize().into() + } + + /// 加密数据 + pub fn encrypt(&self, plaintext: &[u8]) -> Result, CryptoError> { + let mut nonce_bytes = [0u8; 12]; + thread_rng().fill_bytes(&mut nonce_bytes); + let nonce = Nonce::from_slice(&nonce_bytes); + + let ciphertext = self + .cipher + .encrypt(nonce, plaintext) + .map_err(|e| CryptoError::EncryptionFailed(e.to_string()))?; + + // 将 nonce 和密文组合 + let mut result = Vec::with_capacity(12 + ciphertext.len()); + result.extend_from_slice(&nonce_bytes); + result.extend_from_slice(&ciphertext); + + Ok(result) + } + + /// 解密数据 + pub fn decrypt(&self, ciphertext: &[u8]) -> Result, CryptoError> { + if ciphertext.len() < 12 { + return Err(CryptoError::DecryptionFailed( + "密文长度不足".to_string(), + )); + } + + let (nonce_bytes, encrypted_data) = ciphertext.split_at(12); + let nonce = Nonce::from_slice(nonce_bytes); + + self.cipher + .decrypt(nonce, encrypted_data) + .map_err(|e| CryptoError::DecryptionFailed(e.to_string())) + } + + /// 加密字符串并返回 Base64 编码 + pub fn encrypt_string(&self, plaintext: &str) -> Result { + let encrypted = self.encrypt(plaintext.as_bytes())?; + Ok(general_purpose::STANDARD.encode(encrypted)) + } + + /// 解密 Base64 编码的字符串 + pub fn decrypt_string(&self, ciphertext: &str) -> Result { + let decoded = general_purpose::STANDARD + .decode(ciphertext) + .map_err(|e| CryptoError::DecryptionFailed(e.to_string()))?; + + let decrypted = self.decrypt(&decoded)?; + String::from_utf8(decrypted) + .map_err(|e| CryptoError::DecryptionFailed(e.to_string())) + } +} + +/// RSA 密钥对 +#[derive(Debug, Clone)] +pub struct RsaKeyPair { + pub private_key: RsaPrivateKey, + pub public_key: RsaPublicKey, +} + +impl RsaKeyPair { + /// 生成新的 RSA 密钥对 + pub fn generate(bits: usize) -> Result { + let mut rng = OsRng; + let private_key = RsaPrivateKey::new(&mut rng, bits) + .map_err(|e| CryptoError::KeyGenerationFailed(e.to_string()))?; + let public_key = RsaPublicKey::from(&private_key); + + Ok(Self { + private_key, + public_key, + }) + } + + /// 从 PEM 格式加载私钥 + pub fn from_private_key_pem(pem: &str) -> Result { + let private_key = RsaPrivateKey::from_pkcs8_pem(pem) + .map_err(|e| CryptoError::KeyGenerationFailed(e.to_string()))?; + let public_key = RsaPublicKey::from(&private_key); + + Ok(Self { + private_key, + public_key, + }) + } + + /// 导出私钥为 PEM 格式 + pub fn private_key_to_pem(&self) -> Result { + self.private_key + .to_pkcs8_pem(rsa::pkcs8::LineEnding::LF) + .map_err(|e| CryptoError::EncodingFailed(e.to_string())) + .map(|s| s.to_string()) + } + + /// 导出公钥为 PEM 格式 + pub fn public_key_to_pem(&self) -> Result { + self.public_key + .to_public_key_pem(rsa::pkcs8::LineEnding::LF) + .map_err(|e| CryptoError::EncodingFailed(e.to_string())) + } + + /// 使用公钥加密 + pub fn encrypt(&self, plaintext: &[u8]) -> Result, CryptoError> { + let mut rng = OsRng; + self.public_key + .encrypt(&mut rng, Pkcs1v15Encrypt, plaintext) + .map_err(|e| CryptoError::EncryptionFailed(e.to_string())) + } + + /// 使用私钥解密 + pub fn decrypt(&self, ciphertext: &[u8]) -> Result, CryptoError> { + self.private_key + .decrypt(Pkcs1v15Encrypt, ciphertext) + .map_err(|e| CryptoError::DecryptionFailed(e.to_string())) + } + + /// 加密字符串并返回 Base64 编码 + pub fn encrypt_string(&self, plaintext: &str) -> Result { + let encrypted = self.encrypt(plaintext.as_bytes())?; + Ok(general_purpose::STANDARD.encode(encrypted)) + } + + /// 解密 Base64 编码的字符串 + pub fn decrypt_string(&self, ciphertext: &str) -> Result { + let decoded = general_purpose::STANDARD + .decode(ciphertext) + .map_err(|e| CryptoError::DecryptionFailed(e.to_string()))?; + + let decrypted = self.decrypt(&decoded)?; + String::from_utf8(decrypted) + .map_err(|e| CryptoError::DecryptionFailed(e.to_string())) + } +} + +/// 生成随机密钥 +pub fn generate_random_key(length: usize) -> Vec { + let mut key = vec![0u8; length]; + thread_rng().fill_bytes(&mut key); + key +} + +/// 生成 AES-256 密钥 +pub fn generate_aes_key() -> [u8; 32] { + let mut key = [0u8; 32]; + thread_rng().fill_bytes(&mut key); + key +} + +/// 计算 SHA-256 哈希 +pub fn sha256_hash(data: &[u8]) -> Vec { + let mut hasher = Sha256::new(); + hasher.update(data); + hasher.finalize().to_vec() +} + +/// 计算字符串的 SHA-256 哈希并返回十六进制字符串 +pub fn sha256_hex(data: &str) -> String { + let hash = sha256_hash(data.as_bytes()); + hex::encode(hash) +} + +/// 验证哈希 +pub fn verify_hash(data: &[u8], expected_hash: &[u8]) -> bool { + let actual_hash = sha256_hash(data); + actual_hash == expected_hash +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_aes_encryption() { + let key = generate_aes_key(); + let aes = AesEncryption::new(&key); + + let plaintext = "Hello, World!"; + let encrypted = aes.encrypt_string(plaintext).unwrap(); + let decrypted = aes.decrypt_string(&encrypted).unwrap(); + + assert_eq!(plaintext, decrypted); + } + + #[test] + fn test_rsa_encryption() { + let keypair = RsaKeyPair::generate(2048).unwrap(); + + let plaintext = "Hello, RSA!"; + let encrypted = keypair.encrypt_string(plaintext).unwrap(); + let decrypted = keypair.decrypt_string(&encrypted).unwrap(); + + assert_eq!(plaintext, decrypted); + } + + #[test] + fn test_sha256_hash() { + let data = "test data"; + let hash1 = sha256_hex(data); + let hash2 = sha256_hex(data); + + assert_eq!(hash1, hash2); + assert_eq!(hash1.len(), 64); // SHA-256 产生 64 个十六进制字符 + } + + #[test] + fn test_key_generation() { + let key1 = generate_aes_key(); + let key2 = generate_aes_key(); + + assert_ne!(key1, key2); // 密钥应该是随机的 + assert_eq!(key1.len(), 32); // AES-256 密钥长度 + } +} +``` + +### 验证工具 (validation.rs) + +#### 功能特性 + +- 邮箱验证 +- 手机号验证 +- URL 验证 +- 密码强度验证 +- 自定义验证规则 + +#### 实现代码 + +```rust +use regex::Regex; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use thiserror::Error; + +/// 验证错误类型 +#[derive(Error, Debug, Clone, Serialize, Deserialize)] +pub enum ValidationError { + #[error("字段 '{field}' 是必需的")] + Required { field: String }, + #[error("字段 '{field}' 长度必须在 {min} 到 {max} 之间")] + Length { field: String, min: usize, max: usize }, + #[error("字段 '{field}' 格式无效: {message}")] + Format { field: String, message: String }, + #[error("字段 '{field}' 值无效: {message}")] + Invalid { field: String, message: String }, + #[error("自定义验证失败: {message}")] + Custom { message: String }, +} + +/// 验证结果 +pub type ValidationResult = Result; + +/// 验证器特征 +pub trait Validator { + fn validate(&self, value: &T) -> ValidationResult<()>; +} + +/// 字符串验证器 +#[derive(Debug, Clone)] +pub struct StringValidator { + pub field_name: String, + pub required: bool, + pub min_length: Option, + pub max_length: Option, + pub pattern: Option, + pub custom_validators: Vec ValidationResult<()> + Send + Sync>>, +} + +impl StringValidator { + pub fn new(field_name: &str) -> Self { + Self { + field_name: field_name.to_string(), + required: false, + min_length: None, + max_length: None, + pattern: None, + custom_validators: Vec::new(), + } + } + + pub fn required(mut self) -> Self { + self.required = true; + self + } + + pub fn min_length(mut self, min: usize) -> Self { + self.min_length = Some(min); + self + } + + pub fn max_length(mut self, max: usize) -> Self { + self.max_length = Some(max); + self + } + + pub fn length_range(mut self, min: usize, max: usize) -> Self { + self.min_length = Some(min); + self.max_length = Some(max); + self + } + + pub fn pattern(mut self, pattern: &str) -> Result { + self.pattern = Some(Regex::new(pattern)?); + Ok(self) + } + + pub fn email(self) -> Result { + self.pattern(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$") + } + + pub fn phone(self) -> Result { + self.pattern(r"^1[3-9]\d{9}$") + } + + pub fn url(self) -> Result { + self.pattern(r"^https?://[^\s/$.?#].[^\s]*$") + } +} + +impl Validator> for StringValidator { + fn validate(&self, value: &Option) -> ValidationResult<()> { + match value { + None => { + if self.required { + Err(ValidationError::Required { + field: self.field_name.clone(), + }) + } else { + Ok(()) + } + } + Some(s) => self.validate(s), + } + } +} + +impl Validator for StringValidator { + fn validate(&self, value: &String) -> ValidationResult<()> { + // 检查是否为空且必需 + if value.is_empty() && self.required { + return Err(ValidationError::Required { + field: self.field_name.clone(), + }); + } + + // 如果为空且不是必需的,跳过其他验证 + if value.is_empty() && !self.required { + return Ok(()); + } + + // 检查长度 + if let Some(min) = self.min_length { + if value.len() < min { + return Err(ValidationError::Length { + field: self.field_name.clone(), + min, + max: self.max_length.unwrap_or(usize::MAX), + }); + } + } + + if let Some(max) = self.max_length { + if value.len() > max { + return Err(ValidationError::Length { + field: self.field_name.clone(), + min: self.min_length.unwrap_or(0), + max, + }); + } + } + + // 检查正则表达式 + if let Some(pattern) = &self.pattern { + if !pattern.is_match(value) { + return Err(ValidationError::Format { + field: self.field_name.clone(), + message: "格式不匹配".to_string(), + }); + } + } + + // 执行自定义验证 + for validator in &self.custom_validators { + validator(value)?; + } + + Ok(()) + } +} + +/// 数字验证器 +#[derive(Debug, Clone)] +pub struct NumberValidator { + pub field_name: String, + pub required: bool, + pub min_value: Option, + pub max_value: Option, +} + +impl NumberValidator +where + T: PartialOrd + Copy, +{ + pub fn new(field_name: &str) -> Self { + Self { + field_name: field_name.to_string(), + required: false, + min_value: None, + max_value: None, + } + } + + pub fn required(mut self) -> Self { + self.required = true; + self + } + + pub fn min_value(mut self, min: T) -> Self { + self.min_value = Some(min); + self + } + + pub fn max_value(mut self, max: T) -> Self { + self.max_value = Some(max); + self + } + + pub fn range(mut self, min: T, max: T) -> Self { + self.min_value = Some(min); + self.max_value = Some(max); + self + } +} + +impl Validator> for NumberValidator +where + T: PartialOrd + Copy + std::fmt::Display, +{ + fn validate(&self, value: &Option) -> ValidationResult<()> { + match value { + None => { + if self.required { + Err(ValidationError::Required { + field: self.field_name.clone(), + }) + } else { + Ok(()) + } + } + Some(v) => self.validate(v), + } + } +} + +impl Validator for NumberValidator +where + T: PartialOrd + Copy + std::fmt::Display, +{ + fn validate(&self, value: &T) -> ValidationResult<()> { + if let Some(min) = self.min_value { + if *value < min { + return Err(ValidationError::Invalid { + field: self.field_name.clone(), + message: format!("值必须大于等于 {}", min), + }); + } + } + + if let Some(max) = self.max_value { + if *value > max { + return Err(ValidationError::Invalid { + field: self.field_name.clone(), + message: format!("值必须小于等于 {}", max), + }); + } + } + + Ok(()) + } +} + +/// 密码强度等级 +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub enum PasswordStrength { + Weak, + Medium, + Strong, + VeryStrong, +} + +/// 密码验证结果 +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PasswordValidation { + pub is_valid: bool, + pub strength: PasswordStrength, + pub score: u8, + pub feedback: Vec, +} + +/// 验证邮箱地址 +pub fn validate_email(email: &str) -> bool { + let email_regex = Regex::new(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$").unwrap(); + email_regex.is_match(email) +} + +/// 验证中国手机号 +pub fn validate_phone(phone: &str) -> bool { + let phone_regex = Regex::new(r"^1[3-9]\d{9}$").unwrap(); + phone_regex.is_match(phone) +} + +/// 验证 URL +pub fn validate_url(url: &str) -> bool { + let url_regex = Regex::new(r"^https?://[^\s/$.?#].[^\s]*$").unwrap(); + url_regex.is_match(url) +} + +/// 验证 IP 地址 +pub fn validate_ip(ip: &str) -> bool { + ip.parse::().is_ok() +} + +/// 验证密码强度 +pub fn validate_password(password: &str) -> PasswordValidation { + let mut score = 0u8; + let mut feedback = Vec::new(); + + // 长度检查 + if password.len() >= 8 { + score += 1; + } else { + feedback.push("密码长度至少需要8位".to_string()); + } + + if password.len() >= 12 { + score += 1; + } + + // 包含小写字母 + if password.chars().any(|c| c.is_ascii_lowercase()) { + score += 1; + } else { + feedback.push("密码应包含小写字母".to_string()); + } + + // 包含大写字母 + if password.chars().any(|c| c.is_ascii_uppercase()) { + score += 1; + } else { + feedback.push("密码应包含大写字母".to_string()); + } + + // 包含数字 + if password.chars().any(|c| c.is_ascii_digit()) { + score += 1; + } else { + feedback.push("密码应包含数字".to_string()); + } + + // 包含特殊字符 + if password.chars().any(|c| "!@#$%^&*()_+-=[]{}|;:,.<>?".contains(c)) { + score += 1; + } else { + feedback.push("密码应包含特殊字符".to_string()); + } + + // 不包含常见弱密码 + let weak_passwords = [ + "password", "123456", "123456789", "qwerty", "abc123", + "password123", "admin", "root", "user", "guest", + ]; + + if weak_passwords.iter().any(|&weak| password.to_lowercase().contains(weak)) { + score = score.saturating_sub(2); + feedback.push("密码不能包含常见弱密码".to_string()); + } + + let strength = match score { + 0..=2 => PasswordStrength::Weak, + 3..=4 => PasswordStrength::Medium, + 5..=6 => PasswordStrength::Strong, + _ => PasswordStrength::VeryStrong, + }; + + let is_valid = score >= 3 && password.len() >= 8; + + if is_valid && feedback.is_empty() { + feedback.push("密码强度良好".to_string()); + } + + PasswordValidation { + is_valid, + strength, + score, + feedback, + } +} + +/// 验证 JSON 格式 +pub fn validate_json(json_str: &str) -> bool { + serde_json::from_str::(json_str).is_ok() +} + +/// 验证 UUID 格式 +pub fn validate_uuid(uuid_str: &str) -> bool { + uuid::Uuid::parse_str(uuid_str).is_ok() +} + +/// 批量验证器 +#[derive(Debug)] +pub struct BatchValidator { + errors: Vec, +} + +impl BatchValidator { + pub fn new() -> Self { + Self { + errors: Vec::new(), + } + } + + pub fn validate(&mut self, validator: &V, value: &T) -> &mut Self + where + V: Validator, + { + if let Err(error) = validator.validate(value) { + self.errors.push(error); + } + self + } + + pub fn is_valid(&self) -> bool { + self.errors.is_empty() + } + + pub fn errors(&self) -> &[ValidationError] { + &self.errors + } + + pub fn into_result(self) -> ValidationResult<()> { + if self.errors.is_empty() { + Ok(()) + } else { + // 返回第一个错误,实际使用中可能需要返回所有错误 + Err(self.errors.into_iter().next().unwrap()) + } + } +} + +impl Default for BatchValidator { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_email_validation() { + assert!(validate_email("test@example.com")); + assert!(validate_email("user.name+tag@domain.co.uk")); + assert!(!validate_email("invalid.email")); + assert!(!validate_email("@domain.com")); + } + + #[test] + fn test_phone_validation() { + assert!(validate_phone("13812345678")); + assert!(validate_phone("15987654321")); + assert!(!validate_phone("12345678901")); + assert!(!validate_phone("1381234567")); + } + + #[test] + fn test_password_validation() { + let weak = validate_password("123"); + assert_eq!(weak.strength, PasswordStrength::Weak); + assert!(!weak.is_valid); + + let strong = validate_password("MyStr0ng!Pass"); + assert!(strong.is_valid); + assert!(matches!(strong.strength, PasswordStrength::Strong | PasswordStrength::VeryStrong)); + } + + #[test] + fn test_string_validator() { + let validator = StringValidator::new("username") + .required() + .length_range(3, 20) + .pattern(r"^[a-zA-Z0-9_]+$") + .unwrap(); + + assert!(validator.validate(&"valid_user123".to_string()).is_ok()); + assert!(validator.validate(&"ab".to_string()).is_err()); // 太短 + assert!(validator.validate(&"invalid-user".to_string()).is_err()); // 包含非法字符 + } + + #[test] + fn test_number_validator() { + let validator = NumberValidator::new("age") + .required() + .range(0, 150); + + assert!(validator.validate(&25).is_ok()); + assert!(validator.validate(&-1).is_err()); // 小于最小值 + assert!(validator.validate(&200).is_err()); // 大于最大值 + } + + #[test] + fn test_batch_validator() { + let mut batch = BatchValidator::new(); + + let email_validator = StringValidator::new("email").required().email().unwrap(); + let age_validator = NumberValidator::new("age").required().range(0, 150); + + batch + .validate(&email_validator, &"invalid.email".to_string()) + .validate(&age_validator, &200); + + assert!(!batch.is_valid()); + assert_eq!(batch.errors().len(), 2); + } +} +``` \ No newline at end of file