Files
udmin/docs/UTILS.md
ayou a3f2f99a68 docs: 添加项目文档包括总览、架构、流程引擎和服务层
新增以下文档文件:
- PROJECT_OVERVIEW.md 项目总览文档
- BACKEND_ARCHITECTURE.md 后端架构文档
- FRONTEND_ARCHITECTURE.md 前端架构文档
- FLOW_ENGINE.md 流程引擎文档
- SERVICES.md 服务层文档
- ERROR_HANDLING.md 错误处理模块文档

文档内容涵盖项目整体介绍、技术架构、核心模块设计和实现细节
2025-09-24 20:21:45 +08:00

1271 lines
34 KiB
Markdown
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# UdminAI 工具函数模块文档
## 概述
UdminAI 项目的工具函数模块Utils提供了一系列通用的工具函数和辅助功能支持系统的各个组件。这些工具函数涵盖了 ID 生成、时间处理、加密解密、验证、格式化、文件操作等多个方面,为整个系统提供了基础的功能支持。
## 模块结构
```
backend/src/utils/
├── mod.rs # 模块导出
├── id.rs # ID 生成工具
├── time.rs # 时间处理工具
├── crypto.rs # 加密解密工具
├── validation.rs # 验证工具
├── format.rs # 格式化工具
├── file.rs # 文件操作工具
├── hash.rs # 哈希工具
├── jwt.rs # JWT 工具
├── password.rs # 密码处理工具
├── email.rs # 邮件工具
├── config.rs # 配置工具
└── error.rs # 错误处理工具
```
## 核心工具模块
### ID 生成工具 (id.rs)
#### 功能特性
- 唯一 ID 生成
- 多种 ID 格式支持
- 高性能生成
- 分布式友好
#### 实现代码
```rust
use chrono::{DateTime, Utc};
use rand::{thread_rng, Rng};
use std::sync::atomic::{AtomicU64, Ordering};
use uuid::Uuid;
/// 全局序列号计数器
static SEQUENCE_COUNTER: AtomicU64 = AtomicU64::new(0);
/// ID 生成器类型
#[derive(Debug, Clone)]
pub enum IdType {
/// UUID v4
Uuid,
/// 雪花算法 ID
Snowflake,
/// 时间戳 + 随机数
Timestamp,
/// 自定义前缀 + UUID
Prefixed(String),
}
/// 生成唯一 ID
pub fn generate_id() -> String {
generate_id_with_type(IdType::Uuid)
}
/// 根据类型生成 ID
pub fn generate_id_with_type(id_type: IdType) -> String {
match id_type {
IdType::Uuid => Uuid::new_v4().to_string(),
IdType::Snowflake => generate_snowflake_id(),
IdType::Timestamp => generate_timestamp_id(),
IdType::Prefixed(prefix) => format!("{}-{}", prefix, Uuid::new_v4()),
}
}
/// 生成雪花算法 ID
fn generate_snowflake_id() -> String {
let timestamp = Utc::now().timestamp_millis() as u64;
let sequence = SEQUENCE_COUNTER.fetch_add(1, Ordering::SeqCst) & 0xFFF; // 12位序列号
let machine_id = 1u64; // 机器ID实际使用时应该从配置获取
let id = (timestamp << 22) | (machine_id << 12) | sequence;
id.to_string()
}
/// 生成时间戳 ID
fn generate_timestamp_id() -> String {
let timestamp = Utc::now().timestamp_millis();
let random: u32 = thread_rng().gen();
format!("{}{:08x}", timestamp, random)
}
/// 生成短 ID8位
pub fn generate_short_id() -> String {
let chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";
let mut rng = thread_rng();
(0..8)
.map(|_| {
let idx = rng.gen_range(0..chars.len());
chars.chars().nth(idx).unwrap()
})
.collect()
}
/// 生成数字 ID
pub fn generate_numeric_id(length: usize) -> String {
let mut rng = thread_rng();
(0..length)
.map(|_| rng.gen_range(0..10).to_string())
.collect()
}
/// 验证 ID 格式
pub fn validate_id(id: &str, id_type: IdType) -> bool {
match id_type {
IdType::Uuid => Uuid::parse_str(id).is_ok(),
IdType::Snowflake => id.parse::<u64>().is_ok(),
IdType::Timestamp => id.len() >= 13 && id[..13].parse::<i64>().is_ok(),
IdType::Prefixed(prefix) => {
id.starts_with(&format!("{}-", prefix)) &&
id.len() > prefix.len() + 1 &&
Uuid::parse_str(&id[prefix.len() + 1..]).is_ok()
}
}
}
/// 从雪花 ID 提取时间戳
pub fn extract_timestamp_from_snowflake(id: &str) -> Option<DateTime<Utc>> {
if let Ok(snowflake_id) = id.parse::<u64>() {
let timestamp = (snowflake_id >> 22) as i64;
DateTime::from_timestamp_millis(timestamp)
} else {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_generate_uuid() {
let id = generate_id();
assert!(validate_id(&id, IdType::Uuid));
}
#[test]
fn test_generate_snowflake() {
let id = generate_id_with_type(IdType::Snowflake);
assert!(validate_id(&id, IdType::Snowflake));
}
#[test]
fn test_generate_prefixed_id() {
let id = generate_id_with_type(IdType::Prefixed("user".to_string()));
assert!(id.starts_with("user-"));
assert!(validate_id(&id, IdType::Prefixed("user".to_string())));
}
#[test]
fn test_short_id() {
let id = generate_short_id();
assert_eq!(id.len(), 8);
}
}
```
### 时间处理工具 (time.rs)
#### 功能特性
- 时间格式化
- 时区转换
- 时间计算
- 相对时间
#### 实现代码
```rust
use chrono::{
DateTime, Duration, FixedOffset, Local, NaiveDateTime, TimeZone, Utc,
};
use serde::{Deserialize, Serialize};
use std::fmt;
/// 时间格式常量
pub const DEFAULT_DATETIME_FORMAT: &str = "%Y-%m-%d %H:%M:%S";
pub const ISO_DATETIME_FORMAT: &str = "%Y-%m-%dT%H:%M:%S%.3fZ";
pub const DATE_FORMAT: &str = "%Y-%m-%d";
pub const TIME_FORMAT: &str = "%H:%M:%S";
/// 时区偏移量(东八区)
pub const CHINA_OFFSET: i32 = 8 * 3600;
/// 获取当前时间UTC
pub fn now_utc() -> DateTime<Utc> {
Utc::now()
}
/// 获取当前时间(带固定偏移量)
pub fn now_fixed_offset() -> DateTime<FixedOffset> {
Utc::now().with_timezone(&FixedOffset::east_opt(0).unwrap())
}
/// 获取当前时间(中国时区)
pub fn now_china() -> DateTime<FixedOffset> {
Utc::now().with_timezone(&FixedOffset::east_opt(CHINA_OFFSET).unwrap())
}
/// 获取当前本地时间
pub fn now_local() -> DateTime<Local> {
Local::now()
}
/// 格式化时间
pub fn format_datetime(dt: &DateTime<Utc>, format: &str) -> String {
dt.format(format).to_string()
}
/// 格式化时间(默认格式)
pub fn format_datetime_default(dt: &DateTime<Utc>) -> String {
format_datetime(dt, DEFAULT_DATETIME_FORMAT)
}
/// 格式化时间ISO 格式)
pub fn format_datetime_iso(dt: &DateTime<Utc>) -> String {
format_datetime(dt, ISO_DATETIME_FORMAT)
}
/// 解析时间字符串
pub fn parse_datetime(s: &str, format: &str) -> Result<DateTime<Utc>, chrono::ParseError> {
let naive = NaiveDateTime::parse_from_str(s, format)?;
Ok(Utc.from_utc_datetime(&naive))
}
/// 解析 ISO 时间字符串
pub fn parse_datetime_iso(s: &str) -> Result<DateTime<Utc>, chrono::ParseError> {
DateTime::parse_from_rfc3339(s).map(|dt| dt.with_timezone(&Utc))
}
/// 时间戳转换为 DateTime
pub fn timestamp_to_datetime(timestamp: i64) -> Option<DateTime<Utc>> {
DateTime::from_timestamp(timestamp, 0)
}
/// 毫秒时间戳转换为 DateTime
pub fn timestamp_millis_to_datetime(timestamp: i64) -> Option<DateTime<Utc>> {
DateTime::from_timestamp_millis(timestamp)
}
/// DateTime 转换为时间戳
pub fn datetime_to_timestamp(dt: &DateTime<Utc>) -> i64 {
dt.timestamp()
}
/// DateTime 转换为毫秒时间戳
pub fn datetime_to_timestamp_millis(dt: &DateTime<Utc>) -> i64 {
dt.timestamp_millis()
}
/// 计算时间差
pub fn time_diff(start: &DateTime<Utc>, end: &DateTime<Utc>) -> Duration {
*end - *start
}
/// 计算时间差(秒)
pub fn time_diff_seconds(start: &DateTime<Utc>, end: &DateTime<Utc>) -> i64 {
time_diff(start, end).num_seconds()
}
/// 计算时间差(毫秒)
pub fn time_diff_millis(start: &DateTime<Utc>, end: &DateTime<Utc>) -> i64 {
time_diff(start, end).num_milliseconds()
}
/// 时间加法
pub fn add_duration(dt: &DateTime<Utc>, duration: Duration) -> DateTime<Utc> {
*dt + duration
}
/// 时间减法
pub fn sub_duration(dt: &DateTime<Utc>, duration: Duration) -> DateTime<Utc> {
*dt - duration
}
/// 添加秒数
pub fn add_seconds(dt: &DateTime<Utc>, seconds: i64) -> DateTime<Utc> {
add_duration(dt, Duration::seconds(seconds))
}
/// 添加分钟
pub fn add_minutes(dt: &DateTime<Utc>, minutes: i64) -> DateTime<Utc> {
add_duration(dt, Duration::minutes(minutes))
}
/// 添加小时
pub fn add_hours(dt: &DateTime<Utc>, hours: i64) -> DateTime<Utc> {
add_duration(dt, Duration::hours(hours))
}
/// 添加天数
pub fn add_days(dt: &DateTime<Utc>, days: i64) -> DateTime<Utc> {
add_duration(dt, Duration::days(days))
}
/// 获取今天开始时间
pub fn today_start() -> DateTime<Utc> {
let now = now_utc();
now.date_naive().and_hms_opt(0, 0, 0).unwrap().and_utc()
}
/// 获取今天结束时间
pub fn today_end() -> DateTime<Utc> {
let now = now_utc();
now.date_naive().and_hms_opt(23, 59, 59).unwrap().and_utc()
}
/// 获取本周开始时间(周一)
pub fn week_start() -> DateTime<Utc> {
let now = now_utc();
let weekday = now.weekday().num_days_from_monday();
let start_date = now.date_naive() - Duration::days(weekday as i64);
start_date.and_hms_opt(0, 0, 0).unwrap().and_utc()
}
/// 获取本月开始时间
pub fn month_start() -> DateTime<Utc> {
let now = now_utc();
let start_date = now.date_naive().with_day(1).unwrap();
start_date.and_hms_opt(0, 0, 0).unwrap().and_utc()
}
/// 相对时间描述
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RelativeTime {
pub value: i64,
pub unit: TimeUnit,
pub description: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum TimeUnit {
Second,
Minute,
Hour,
Day,
Week,
Month,
Year,
}
impl fmt::Display for TimeUnit {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let s = match self {
TimeUnit::Second => "秒",
TimeUnit::Minute => "分钟",
TimeUnit::Hour => "小时",
TimeUnit::Day => "天",
TimeUnit::Week => "周",
TimeUnit::Month => "月",
TimeUnit::Year => "年",
};
write!(f, "{}", s)
}
}
/// 计算相对时间
pub fn relative_time(dt: &DateTime<Utc>) -> RelativeTime {
let now = now_utc();
let diff = time_diff(dt, &now);
let abs_seconds = diff.num_seconds().abs();
let is_past = diff.num_seconds() < 0;
let (value, unit) = if abs_seconds < 60 {
(abs_seconds, TimeUnit::Second)
} else if abs_seconds < 3600 {
(abs_seconds / 60, TimeUnit::Minute)
} else if abs_seconds < 86400 {
(abs_seconds / 3600, TimeUnit::Hour)
} else if abs_seconds < 604800 {
(abs_seconds / 86400, TimeUnit::Day)
} else if abs_seconds < 2592000 {
(abs_seconds / 604800, TimeUnit::Week)
} else if abs_seconds < 31536000 {
(abs_seconds / 2592000, TimeUnit::Month)
} else {
(abs_seconds / 31536000, TimeUnit::Year)
};
let description = if is_past {
format!("{}{} 前", value, unit)
} else {
format!("{}{} 后", value, unit)
};
RelativeTime {
value,
unit,
description,
}
}
/// 判断是否为同一天
pub fn is_same_day(dt1: &DateTime<Utc>, dt2: &DateTime<Utc>) -> bool {
dt1.date_naive() == dt2.date_naive()
}
/// 判断是否为今天
pub fn is_today(dt: &DateTime<Utc>) -> bool {
is_same_day(dt, &now_utc())
}
/// 判断是否为昨天
pub fn is_yesterday(dt: &DateTime<Utc>) -> bool {
let yesterday = sub_duration(&now_utc(), Duration::days(1));
is_same_day(dt, &yesterday)
}
/// 判断是否为本周
pub fn is_this_week(dt: &DateTime<Utc>) -> bool {
let week_start = week_start();
let week_end = add_days(&week_start, 7);
dt >= &week_start && dt < &week_end
}
/// 判断是否为本月
pub fn is_this_month(dt: &DateTime<Utc>) -> bool {
let now = now_utc();
dt.year() == now.year() && dt.month() == now.month()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_time_formatting() {
let dt = Utc::now();
let formatted = format_datetime_default(&dt);
assert!(!formatted.is_empty());
}
#[test]
fn test_time_parsing() {
let time_str = "2023-12-25 10:30:00";
let parsed = parse_datetime(time_str, DEFAULT_DATETIME_FORMAT);
assert!(parsed.is_ok());
}
#[test]
fn test_timestamp_conversion() {
let dt = Utc::now();
let timestamp = datetime_to_timestamp(&dt);
let converted = timestamp_to_datetime(timestamp).unwrap();
assert_eq!(dt.timestamp(), converted.timestamp());
}
#[test]
fn test_relative_time() {
let past = sub_duration(&now_utc(), Duration::hours(2));
let rel_time = relative_time(&past);
assert!(rel_time.description.contains("前"));
}
#[test]
fn test_date_checks() {
let now = now_utc();
assert!(is_today(&now));
let yesterday = sub_duration(&now, Duration::days(1));
assert!(is_yesterday(&yesterday));
}
}
```
### 加密解密工具 (crypto.rs)
#### 功能特性
- AES 加密解密
- RSA 加密解密
- 数字签名
- 密钥生成
#### 实现代码
```rust
use aes_gcm::{
aead::{Aead, KeyInit, OsRng},
Aes256Gcm, Key, Nonce,
};
use base64::{engine::general_purpose, Engine as _};
use rand::{thread_rng, RngCore};
use rsa::{
pkcs1::{DecodeRsaPrivateKey, DecodeRsaPublicKey, EncodeRsaPrivateKey, EncodeRsaPublicKey},
pkcs8::{DecodePrivateKey, DecodePublicKey, EncodePrivateKey, EncodePublicKey},
Pkcs1v15Encrypt, RsaPrivateKey, RsaPublicKey,
};
use sha2::{Digest, Sha256};
use std::error::Error;
use thiserror::Error;
/// 加密错误类型
#[derive(Error, Debug)]
pub enum CryptoError {
#[error("加密失败: {0}")]
EncryptionFailed(String),
#[error("解密失败: {0}")]
DecryptionFailed(String),
#[error("密钥生成失败: {0}")]
KeyGenerationFailed(String),
#[error("签名失败: {0}")]
SigningFailed(String),
#[error("验证失败: {0}")]
VerificationFailed(String),
#[error("编码失败: {0}")]
EncodingFailed(String),
}
/// AES 加密工具
pub struct AesEncryption {
cipher: Aes256Gcm,
}
impl AesEncryption {
/// 创建新的 AES 加密器
pub fn new(key: &[u8; 32]) -> Self {
let key = Key::<Aes256Gcm>::from_slice(key);
let cipher = Aes256Gcm::new(key);
Self { cipher }
}
/// 从密码创建 AES 加密器
pub fn from_password(password: &str) -> Self {
let key = Self::derive_key_from_password(password);
Self::new(&key)
}
/// 从密码派生密钥
fn derive_key_from_password(password: &str) -> [u8; 32] {
let mut hasher = Sha256::new();
hasher.update(password.as_bytes());
hasher.finalize().into()
}
/// 加密数据
pub fn encrypt(&self, plaintext: &[u8]) -> Result<Vec<u8>, CryptoError> {
let mut nonce_bytes = [0u8; 12];
thread_rng().fill_bytes(&mut nonce_bytes);
let nonce = Nonce::from_slice(&nonce_bytes);
let ciphertext = self
.cipher
.encrypt(nonce, plaintext)
.map_err(|e| CryptoError::EncryptionFailed(e.to_string()))?;
// 将 nonce 和密文组合
let mut result = Vec::with_capacity(12 + ciphertext.len());
result.extend_from_slice(&nonce_bytes);
result.extend_from_slice(&ciphertext);
Ok(result)
}
/// 解密数据
pub fn decrypt(&self, ciphertext: &[u8]) -> Result<Vec<u8>, CryptoError> {
if ciphertext.len() < 12 {
return Err(CryptoError::DecryptionFailed(
"密文长度不足".to_string(),
));
}
let (nonce_bytes, encrypted_data) = ciphertext.split_at(12);
let nonce = Nonce::from_slice(nonce_bytes);
self.cipher
.decrypt(nonce, encrypted_data)
.map_err(|e| CryptoError::DecryptionFailed(e.to_string()))
}
/// 加密字符串并返回 Base64 编码
pub fn encrypt_string(&self, plaintext: &str) -> Result<String, CryptoError> {
let encrypted = self.encrypt(plaintext.as_bytes())?;
Ok(general_purpose::STANDARD.encode(encrypted))
}
/// 解密 Base64 编码的字符串
pub fn decrypt_string(&self, ciphertext: &str) -> Result<String, CryptoError> {
let decoded = general_purpose::STANDARD
.decode(ciphertext)
.map_err(|e| CryptoError::DecryptionFailed(e.to_string()))?;
let decrypted = self.decrypt(&decoded)?;
String::from_utf8(decrypted)
.map_err(|e| CryptoError::DecryptionFailed(e.to_string()))
}
}
/// RSA 密钥对
#[derive(Debug, Clone)]
pub struct RsaKeyPair {
pub private_key: RsaPrivateKey,
pub public_key: RsaPublicKey,
}
impl RsaKeyPair {
/// 生成新的 RSA 密钥对
pub fn generate(bits: usize) -> Result<Self, CryptoError> {
let mut rng = OsRng;
let private_key = RsaPrivateKey::new(&mut rng, bits)
.map_err(|e| CryptoError::KeyGenerationFailed(e.to_string()))?;
let public_key = RsaPublicKey::from(&private_key);
Ok(Self {
private_key,
public_key,
})
}
/// 从 PEM 格式加载私钥
pub fn from_private_key_pem(pem: &str) -> Result<Self, CryptoError> {
let private_key = RsaPrivateKey::from_pkcs8_pem(pem)
.map_err(|e| CryptoError::KeyGenerationFailed(e.to_string()))?;
let public_key = RsaPublicKey::from(&private_key);
Ok(Self {
private_key,
public_key,
})
}
/// 导出私钥为 PEM 格式
pub fn private_key_to_pem(&self) -> Result<String, CryptoError> {
self.private_key
.to_pkcs8_pem(rsa::pkcs8::LineEnding::LF)
.map_err(|e| CryptoError::EncodingFailed(e.to_string()))
.map(|s| s.to_string())
}
/// 导出公钥为 PEM 格式
pub fn public_key_to_pem(&self) -> Result<String, CryptoError> {
self.public_key
.to_public_key_pem(rsa::pkcs8::LineEnding::LF)
.map_err(|e| CryptoError::EncodingFailed(e.to_string()))
}
/// 使用公钥加密
pub fn encrypt(&self, plaintext: &[u8]) -> Result<Vec<u8>, CryptoError> {
let mut rng = OsRng;
self.public_key
.encrypt(&mut rng, Pkcs1v15Encrypt, plaintext)
.map_err(|e| CryptoError::EncryptionFailed(e.to_string()))
}
/// 使用私钥解密
pub fn decrypt(&self, ciphertext: &[u8]) -> Result<Vec<u8>, CryptoError> {
self.private_key
.decrypt(Pkcs1v15Encrypt, ciphertext)
.map_err(|e| CryptoError::DecryptionFailed(e.to_string()))
}
/// 加密字符串并返回 Base64 编码
pub fn encrypt_string(&self, plaintext: &str) -> Result<String, CryptoError> {
let encrypted = self.encrypt(plaintext.as_bytes())?;
Ok(general_purpose::STANDARD.encode(encrypted))
}
/// 解密 Base64 编码的字符串
pub fn decrypt_string(&self, ciphertext: &str) -> Result<String, CryptoError> {
let decoded = general_purpose::STANDARD
.decode(ciphertext)
.map_err(|e| CryptoError::DecryptionFailed(e.to_string()))?;
let decrypted = self.decrypt(&decoded)?;
String::from_utf8(decrypted)
.map_err(|e| CryptoError::DecryptionFailed(e.to_string()))
}
}
/// 生成随机密钥
pub fn generate_random_key(length: usize) -> Vec<u8> {
let mut key = vec![0u8; length];
thread_rng().fill_bytes(&mut key);
key
}
/// 生成 AES-256 密钥
pub fn generate_aes_key() -> [u8; 32] {
let mut key = [0u8; 32];
thread_rng().fill_bytes(&mut key);
key
}
/// 计算 SHA-256 哈希
pub fn sha256_hash(data: &[u8]) -> Vec<u8> {
let mut hasher = Sha256::new();
hasher.update(data);
hasher.finalize().to_vec()
}
/// 计算字符串的 SHA-256 哈希并返回十六进制字符串
pub fn sha256_hex(data: &str) -> String {
let hash = sha256_hash(data.as_bytes());
hex::encode(hash)
}
/// 验证哈希
pub fn verify_hash(data: &[u8], expected_hash: &[u8]) -> bool {
let actual_hash = sha256_hash(data);
actual_hash == expected_hash
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_aes_encryption() {
let key = generate_aes_key();
let aes = AesEncryption::new(&key);
let plaintext = "Hello, World!";
let encrypted = aes.encrypt_string(plaintext).unwrap();
let decrypted = aes.decrypt_string(&encrypted).unwrap();
assert_eq!(plaintext, decrypted);
}
#[test]
fn test_rsa_encryption() {
let keypair = RsaKeyPair::generate(2048).unwrap();
let plaintext = "Hello, RSA!";
let encrypted = keypair.encrypt_string(plaintext).unwrap();
let decrypted = keypair.decrypt_string(&encrypted).unwrap();
assert_eq!(plaintext, decrypted);
}
#[test]
fn test_sha256_hash() {
let data = "test data";
let hash1 = sha256_hex(data);
let hash2 = sha256_hex(data);
assert_eq!(hash1, hash2);
assert_eq!(hash1.len(), 64); // SHA-256 产生 64 个十六进制字符
}
#[test]
fn test_key_generation() {
let key1 = generate_aes_key();
let key2 = generate_aes_key();
assert_ne!(key1, key2); // 密钥应该是随机的
assert_eq!(key1.len(), 32); // AES-256 密钥长度
}
}
```
### 验证工具 (validation.rs)
#### 功能特性
- 邮箱验证
- 手机号验证
- URL 验证
- 密码强度验证
- 自定义验证规则
#### 实现代码
```rust
use regex::Regex;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use thiserror::Error;
/// 验证错误类型
#[derive(Error, Debug, Clone, Serialize, Deserialize)]
pub enum ValidationError {
#[error("字段 '{field}' 是必需的")]
Required { field: String },
#[error("字段 '{field}' 长度必须在 {min} 到 {max} 之间")]
Length { field: String, min: usize, max: usize },
#[error("字段 '{field}' 格式无效: {message}")]
Format { field: String, message: String },
#[error("字段 '{field}' 值无效: {message}")]
Invalid { field: String, message: String },
#[error("自定义验证失败: {message}")]
Custom { message: String },
}
/// 验证结果
pub type ValidationResult<T> = Result<T, ValidationError>;
/// 验证器特征
pub trait Validator<T> {
fn validate(&self, value: &T) -> ValidationResult<()>;
}
/// 字符串验证器
#[derive(Debug, Clone)]
pub struct StringValidator {
pub field_name: String,
pub required: bool,
pub min_length: Option<usize>,
pub max_length: Option<usize>,
pub pattern: Option<Regex>,
pub custom_validators: Vec<Box<dyn Fn(&str) -> ValidationResult<()> + Send + Sync>>,
}
impl StringValidator {
pub fn new(field_name: &str) -> Self {
Self {
field_name: field_name.to_string(),
required: false,
min_length: None,
max_length: None,
pattern: None,
custom_validators: Vec::new(),
}
}
pub fn required(mut self) -> Self {
self.required = true;
self
}
pub fn min_length(mut self, min: usize) -> Self {
self.min_length = Some(min);
self
}
pub fn max_length(mut self, max: usize) -> Self {
self.max_length = Some(max);
self
}
pub fn length_range(mut self, min: usize, max: usize) -> Self {
self.min_length = Some(min);
self.max_length = Some(max);
self
}
pub fn pattern(mut self, pattern: &str) -> Result<Self, regex::Error> {
self.pattern = Some(Regex::new(pattern)?);
Ok(self)
}
pub fn email(self) -> Result<Self, regex::Error> {
self.pattern(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$")
}
pub fn phone(self) -> Result<Self, regex::Error> {
self.pattern(r"^1[3-9]\d{9}$")
}
pub fn url(self) -> Result<Self, regex::Error> {
self.pattern(r"^https?://[^\s/$.?#].[^\s]*$")
}
}
impl Validator<Option<String>> for StringValidator {
fn validate(&self, value: &Option<String>) -> ValidationResult<()> {
match value {
None => {
if self.required {
Err(ValidationError::Required {
field: self.field_name.clone(),
})
} else {
Ok(())
}
}
Some(s) => self.validate(s),
}
}
}
impl Validator<String> for StringValidator {
fn validate(&self, value: &String) -> ValidationResult<()> {
// 检查是否为空且必需
if value.is_empty() && self.required {
return Err(ValidationError::Required {
field: self.field_name.clone(),
});
}
// 如果为空且不是必需的,跳过其他验证
if value.is_empty() && !self.required {
return Ok(());
}
// 检查长度
if let Some(min) = self.min_length {
if value.len() < min {
return Err(ValidationError::Length {
field: self.field_name.clone(),
min,
max: self.max_length.unwrap_or(usize::MAX),
});
}
}
if let Some(max) = self.max_length {
if value.len() > max {
return Err(ValidationError::Length {
field: self.field_name.clone(),
min: self.min_length.unwrap_or(0),
max,
});
}
}
// 检查正则表达式
if let Some(pattern) = &self.pattern {
if !pattern.is_match(value) {
return Err(ValidationError::Format {
field: self.field_name.clone(),
message: "格式不匹配".to_string(),
});
}
}
// 执行自定义验证
for validator in &self.custom_validators {
validator(value)?;
}
Ok(())
}
}
/// 数字验证器
#[derive(Debug, Clone)]
pub struct NumberValidator<T> {
pub field_name: String,
pub required: bool,
pub min_value: Option<T>,
pub max_value: Option<T>,
}
impl<T> NumberValidator<T>
where
T: PartialOrd + Copy,
{
pub fn new(field_name: &str) -> Self {
Self {
field_name: field_name.to_string(),
required: false,
min_value: None,
max_value: None,
}
}
pub fn required(mut self) -> Self {
self.required = true;
self
}
pub fn min_value(mut self, min: T) -> Self {
self.min_value = Some(min);
self
}
pub fn max_value(mut self, max: T) -> Self {
self.max_value = Some(max);
self
}
pub fn range(mut self, min: T, max: T) -> Self {
self.min_value = Some(min);
self.max_value = Some(max);
self
}
}
impl<T> Validator<Option<T>> for NumberValidator<T>
where
T: PartialOrd + Copy + std::fmt::Display,
{
fn validate(&self, value: &Option<T>) -> ValidationResult<()> {
match value {
None => {
if self.required {
Err(ValidationError::Required {
field: self.field_name.clone(),
})
} else {
Ok(())
}
}
Some(v) => self.validate(v),
}
}
}
impl<T> Validator<T> for NumberValidator<T>
where
T: PartialOrd + Copy + std::fmt::Display,
{
fn validate(&self, value: &T) -> ValidationResult<()> {
if let Some(min) = self.min_value {
if *value < min {
return Err(ValidationError::Invalid {
field: self.field_name.clone(),
message: format!("值必须大于等于 {}", min),
});
}
}
if let Some(max) = self.max_value {
if *value > max {
return Err(ValidationError::Invalid {
field: self.field_name.clone(),
message: format!("值必须小于等于 {}", max),
});
}
}
Ok(())
}
}
/// 密码强度等级
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum PasswordStrength {
Weak,
Medium,
Strong,
VeryStrong,
}
/// 密码验证结果
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PasswordValidation {
pub is_valid: bool,
pub strength: PasswordStrength,
pub score: u8,
pub feedback: Vec<String>,
}
/// 验证邮箱地址
pub fn validate_email(email: &str) -> bool {
let email_regex = Regex::new(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$").unwrap();
email_regex.is_match(email)
}
/// 验证中国手机号
pub fn validate_phone(phone: &str) -> bool {
let phone_regex = Regex::new(r"^1[3-9]\d{9}$").unwrap();
phone_regex.is_match(phone)
}
/// 验证 URL
pub fn validate_url(url: &str) -> bool {
let url_regex = Regex::new(r"^https?://[^\s/$.?#].[^\s]*$").unwrap();
url_regex.is_match(url)
}
/// 验证 IP 地址
pub fn validate_ip(ip: &str) -> bool {
ip.parse::<std::net::IpAddr>().is_ok()
}
/// 验证密码强度
pub fn validate_password(password: &str) -> PasswordValidation {
let mut score = 0u8;
let mut feedback = Vec::new();
// 长度检查
if password.len() >= 8 {
score += 1;
} else {
feedback.push("密码长度至少需要8位".to_string());
}
if password.len() >= 12 {
score += 1;
}
// 包含小写字母
if password.chars().any(|c| c.is_ascii_lowercase()) {
score += 1;
} else {
feedback.push("密码应包含小写字母".to_string());
}
// 包含大写字母
if password.chars().any(|c| c.is_ascii_uppercase()) {
score += 1;
} else {
feedback.push("密码应包含大写字母".to_string());
}
// 包含数字
if password.chars().any(|c| c.is_ascii_digit()) {
score += 1;
} else {
feedback.push("密码应包含数字".to_string());
}
// 包含特殊字符
if password.chars().any(|c| "!@#$%^&*()_+-=[]{}|;:,.<>?".contains(c)) {
score += 1;
} else {
feedback.push("密码应包含特殊字符".to_string());
}
// 不包含常见弱密码
let weak_passwords = [
"password", "123456", "123456789", "qwerty", "abc123",
"password123", "admin", "root", "user", "guest",
];
if weak_passwords.iter().any(|&weak| password.to_lowercase().contains(weak)) {
score = score.saturating_sub(2);
feedback.push("密码不能包含常见弱密码".to_string());
}
let strength = match score {
0..=2 => PasswordStrength::Weak,
3..=4 => PasswordStrength::Medium,
5..=6 => PasswordStrength::Strong,
_ => PasswordStrength::VeryStrong,
};
let is_valid = score >= 3 && password.len() >= 8;
if is_valid && feedback.is_empty() {
feedback.push("密码强度良好".to_string());
}
PasswordValidation {
is_valid,
strength,
score,
feedback,
}
}
/// 验证 JSON 格式
pub fn validate_json(json_str: &str) -> bool {
serde_json::from_str::<serde_json::Value>(json_str).is_ok()
}
/// 验证 UUID 格式
pub fn validate_uuid(uuid_str: &str) -> bool {
uuid::Uuid::parse_str(uuid_str).is_ok()
}
/// 批量验证器
#[derive(Debug)]
pub struct BatchValidator {
errors: Vec<ValidationError>,
}
impl BatchValidator {
pub fn new() -> Self {
Self {
errors: Vec::new(),
}
}
pub fn validate<T, V>(&mut self, validator: &V, value: &T) -> &mut Self
where
V: Validator<T>,
{
if let Err(error) = validator.validate(value) {
self.errors.push(error);
}
self
}
pub fn is_valid(&self) -> bool {
self.errors.is_empty()
}
pub fn errors(&self) -> &[ValidationError] {
&self.errors
}
pub fn into_result(self) -> ValidationResult<()> {
if self.errors.is_empty() {
Ok(())
} else {
// 返回第一个错误,实际使用中可能需要返回所有错误
Err(self.errors.into_iter().next().unwrap())
}
}
}
impl Default for BatchValidator {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_email_validation() {
assert!(validate_email("test@example.com"));
assert!(validate_email("user.name+tag@domain.co.uk"));
assert!(!validate_email("invalid.email"));
assert!(!validate_email("@domain.com"));
}
#[test]
fn test_phone_validation() {
assert!(validate_phone("13812345678"));
assert!(validate_phone("15987654321"));
assert!(!validate_phone("12345678901"));
assert!(!validate_phone("1381234567"));
}
#[test]
fn test_password_validation() {
let weak = validate_password("123");
assert_eq!(weak.strength, PasswordStrength::Weak);
assert!(!weak.is_valid);
let strong = validate_password("MyStr0ng!Pass");
assert!(strong.is_valid);
assert!(matches!(strong.strength, PasswordStrength::Strong | PasswordStrength::VeryStrong));
}
#[test]
fn test_string_validator() {
let validator = StringValidator::new("username")
.required()
.length_range(3, 20)
.pattern(r"^[a-zA-Z0-9_]+$")
.unwrap();
assert!(validator.validate(&"valid_user123".to_string()).is_ok());
assert!(validator.validate(&"ab".to_string()).is_err()); // 太短
assert!(validator.validate(&"invalid-user".to_string()).is_err()); // 包含非法字符
}
#[test]
fn test_number_validator() {
let validator = NumberValidator::new("age")
.required()
.range(0, 150);
assert!(validator.validate(&25).is_ok());
assert!(validator.validate(&-1).is_err()); // 小于最小值
assert!(validator.validate(&200).is_err()); // 大于最大值
}
#[test]
fn test_batch_validator() {
let mut batch = BatchValidator::new();
let email_validator = StringValidator::new("email").required().email().unwrap();
let age_validator = NumberValidator::new("age").required().range(0, 150);
batch
.validate(&email_validator, &"invalid.email".to_string())
.validate(&age_validator, &200);
assert!(!batch.is_valid());
assert_eq!(batch.errors().len(), 2);
}
}
```