Files
udmin/docs/UTILS.md
ayou a3f2f99a68 docs: 添加项目文档包括总览、架构、流程引擎和服务层
新增以下文档文件:
- PROJECT_OVERVIEW.md 项目总览文档
- BACKEND_ARCHITECTURE.md 后端架构文档
- FRONTEND_ARCHITECTURE.md 前端架构文档
- FLOW_ENGINE.md 流程引擎文档
- SERVICES.md 服务层文档
- ERROR_HANDLING.md 错误处理模块文档

文档内容涵盖项目整体介绍、技术架构、核心模块设计和实现细节
2025-09-24 20:21:45 +08:00

34 KiB
Raw Permalink Blame History

UdminAI 工具函数模块文档

概述

UdminAI 项目的工具函数模块Utils提供了一系列通用的工具函数和辅助功能支持系统的各个组件。这些工具函数涵盖了 ID 生成、时间处理、加密解密、验证、格式化、文件操作等多个方面,为整个系统提供了基础的功能支持。

模块结构

backend/src/utils/
├── mod.rs              # 模块导出
├── id.rs               # ID 生成工具
├── time.rs             # 时间处理工具
├── crypto.rs           # 加密解密工具
├── validation.rs       # 验证工具
├── format.rs           # 格式化工具
├── file.rs             # 文件操作工具
├── hash.rs             # 哈希工具
├── jwt.rs              # JWT 工具
├── password.rs         # 密码处理工具
├── email.rs            # 邮件工具
├── config.rs           # 配置工具
└── error.rs            # 错误处理工具

核心工具模块

ID 生成工具 (id.rs)

功能特性

  • 唯一 ID 生成
  • 多种 ID 格式支持
  • 高性能生成
  • 分布式友好

实现代码

use chrono::{DateTime, Utc};
use rand::{thread_rng, Rng};
use std::sync::atomic::{AtomicU64, Ordering};
use uuid::Uuid;

/// 全局序列号计数器
static SEQUENCE_COUNTER: AtomicU64 = AtomicU64::new(0);

/// ID 生成器类型
#[derive(Debug, Clone)]
pub enum IdType {
    /// UUID v4
    Uuid,
    /// 雪花算法 ID
    Snowflake,
    /// 时间戳 + 随机数
    Timestamp,
    /// 自定义前缀 + UUID
    Prefixed(String),
}

/// 生成唯一 ID
pub fn generate_id() -> String {
    generate_id_with_type(IdType::Uuid)
}

/// 根据类型生成 ID
pub fn generate_id_with_type(id_type: IdType) -> String {
    match id_type {
        IdType::Uuid => Uuid::new_v4().to_string(),
        IdType::Snowflake => generate_snowflake_id(),
        IdType::Timestamp => generate_timestamp_id(),
        IdType::Prefixed(prefix) => format!("{}-{}", prefix, Uuid::new_v4()),
    }
}

/// 生成雪花算法 ID
fn generate_snowflake_id() -> String {
    let timestamp = Utc::now().timestamp_millis() as u64;
    let sequence = SEQUENCE_COUNTER.fetch_add(1, Ordering::SeqCst) & 0xFFF; // 12位序列号
    let machine_id = 1u64; // 机器ID实际使用时应该从配置获取
    
    let id = (timestamp << 22) | (machine_id << 12) | sequence;
    id.to_string()
}

/// 生成时间戳 ID
fn generate_timestamp_id() -> String {
    let timestamp = Utc::now().timestamp_millis();
    let random: u32 = thread_rng().gen();
    format!("{}{:08x}", timestamp, random)
}

/// 生成短 ID8位
pub fn generate_short_id() -> String {
    let chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";
    let mut rng = thread_rng();
    (0..8)
        .map(|_| {
            let idx = rng.gen_range(0..chars.len());
            chars.chars().nth(idx).unwrap()
        })
        .collect()
}

/// 生成数字 ID
pub fn generate_numeric_id(length: usize) -> String {
    let mut rng = thread_rng();
    (0..length)
        .map(|_| rng.gen_range(0..10).to_string())
        .collect()
}

/// 验证 ID 格式
pub fn validate_id(id: &str, id_type: IdType) -> bool {
    match id_type {
        IdType::Uuid => Uuid::parse_str(id).is_ok(),
        IdType::Snowflake => id.parse::<u64>().is_ok(),
        IdType::Timestamp => id.len() >= 13 && id[..13].parse::<i64>().is_ok(),
        IdType::Prefixed(prefix) => {
            id.starts_with(&format!("{}-", prefix)) && 
            id.len() > prefix.len() + 1 &&
            Uuid::parse_str(&id[prefix.len() + 1..]).is_ok()
        }
    }
}

/// 从雪花 ID 提取时间戳
pub fn extract_timestamp_from_snowflake(id: &str) -> Option<DateTime<Utc>> {
    if let Ok(snowflake_id) = id.parse::<u64>() {
        let timestamp = (snowflake_id >> 22) as i64;
        DateTime::from_timestamp_millis(timestamp)
    } else {
        None
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_generate_uuid() {
        let id = generate_id();
        assert!(validate_id(&id, IdType::Uuid));
    }

    #[test]
    fn test_generate_snowflake() {
        let id = generate_id_with_type(IdType::Snowflake);
        assert!(validate_id(&id, IdType::Snowflake));
    }

    #[test]
    fn test_generate_prefixed_id() {
        let id = generate_id_with_type(IdType::Prefixed("user".to_string()));
        assert!(id.starts_with("user-"));
        assert!(validate_id(&id, IdType::Prefixed("user".to_string())));
    }

    #[test]
    fn test_short_id() {
        let id = generate_short_id();
        assert_eq!(id.len(), 8);
    }
}

时间处理工具 (time.rs)

功能特性

  • 时间格式化
  • 时区转换
  • 时间计算
  • 相对时间

实现代码

use chrono::{
    DateTime, Duration, FixedOffset, Local, NaiveDateTime, TimeZone, Utc,
};
use serde::{Deserialize, Serialize};
use std::fmt;

/// 时间格式常量
pub const DEFAULT_DATETIME_FORMAT: &str = "%Y-%m-%d %H:%M:%S";
pub const ISO_DATETIME_FORMAT: &str = "%Y-%m-%dT%H:%M:%S%.3fZ";
pub const DATE_FORMAT: &str = "%Y-%m-%d";
pub const TIME_FORMAT: &str = "%H:%M:%S";

/// 时区偏移量(东八区)
pub const CHINA_OFFSET: i32 = 8 * 3600;

/// 获取当前时间UTC
pub fn now_utc() -> DateTime<Utc> {
    Utc::now()
}

/// 获取当前时间(带固定偏移量)
pub fn now_fixed_offset() -> DateTime<FixedOffset> {
    Utc::now().with_timezone(&FixedOffset::east_opt(0).unwrap())
}

/// 获取当前时间(中国时区)
pub fn now_china() -> DateTime<FixedOffset> {
    Utc::now().with_timezone(&FixedOffset::east_opt(CHINA_OFFSET).unwrap())
}

/// 获取当前本地时间
pub fn now_local() -> DateTime<Local> {
    Local::now()
}

/// 格式化时间
pub fn format_datetime(dt: &DateTime<Utc>, format: &str) -> String {
    dt.format(format).to_string()
}

/// 格式化时间(默认格式)
pub fn format_datetime_default(dt: &DateTime<Utc>) -> String {
    format_datetime(dt, DEFAULT_DATETIME_FORMAT)
}

/// 格式化时间ISO 格式)
pub fn format_datetime_iso(dt: &DateTime<Utc>) -> String {
    format_datetime(dt, ISO_DATETIME_FORMAT)
}

/// 解析时间字符串
pub fn parse_datetime(s: &str, format: &str) -> Result<DateTime<Utc>, chrono::ParseError> {
    let naive = NaiveDateTime::parse_from_str(s, format)?;
    Ok(Utc.from_utc_datetime(&naive))
}

/// 解析 ISO 时间字符串
pub fn parse_datetime_iso(s: &str) -> Result<DateTime<Utc>, chrono::ParseError> {
    DateTime::parse_from_rfc3339(s).map(|dt| dt.with_timezone(&Utc))
}

/// 时间戳转换为 DateTime
pub fn timestamp_to_datetime(timestamp: i64) -> Option<DateTime<Utc>> {
    DateTime::from_timestamp(timestamp, 0)
}

/// 毫秒时间戳转换为 DateTime
pub fn timestamp_millis_to_datetime(timestamp: i64) -> Option<DateTime<Utc>> {
    DateTime::from_timestamp_millis(timestamp)
}

/// DateTime 转换为时间戳
pub fn datetime_to_timestamp(dt: &DateTime<Utc>) -> i64 {
    dt.timestamp()
}

/// DateTime 转换为毫秒时间戳
pub fn datetime_to_timestamp_millis(dt: &DateTime<Utc>) -> i64 {
    dt.timestamp_millis()
}

/// 计算时间差
pub fn time_diff(start: &DateTime<Utc>, end: &DateTime<Utc>) -> Duration {
    *end - *start
}

/// 计算时间差(秒)
pub fn time_diff_seconds(start: &DateTime<Utc>, end: &DateTime<Utc>) -> i64 {
    time_diff(start, end).num_seconds()
}

/// 计算时间差(毫秒)
pub fn time_diff_millis(start: &DateTime<Utc>, end: &DateTime<Utc>) -> i64 {
    time_diff(start, end).num_milliseconds()
}

/// 时间加法
pub fn add_duration(dt: &DateTime<Utc>, duration: Duration) -> DateTime<Utc> {
    *dt + duration
}

/// 时间减法
pub fn sub_duration(dt: &DateTime<Utc>, duration: Duration) -> DateTime<Utc> {
    *dt - duration
}

/// 添加秒数
pub fn add_seconds(dt: &DateTime<Utc>, seconds: i64) -> DateTime<Utc> {
    add_duration(dt, Duration::seconds(seconds))
}

/// 添加分钟
pub fn add_minutes(dt: &DateTime<Utc>, minutes: i64) -> DateTime<Utc> {
    add_duration(dt, Duration::minutes(minutes))
}

/// 添加小时
pub fn add_hours(dt: &DateTime<Utc>, hours: i64) -> DateTime<Utc> {
    add_duration(dt, Duration::hours(hours))
}

/// 添加天数
pub fn add_days(dt: &DateTime<Utc>, days: i64) -> DateTime<Utc> {
    add_duration(dt, Duration::days(days))
}

/// 获取今天开始时间
pub fn today_start() -> DateTime<Utc> {
    let now = now_utc();
    now.date_naive().and_hms_opt(0, 0, 0).unwrap().and_utc()
}

/// 获取今天结束时间
pub fn today_end() -> DateTime<Utc> {
    let now = now_utc();
    now.date_naive().and_hms_opt(23, 59, 59).unwrap().and_utc()
}

/// 获取本周开始时间(周一)
pub fn week_start() -> DateTime<Utc> {
    let now = now_utc();
    let weekday = now.weekday().num_days_from_monday();
    let start_date = now.date_naive() - Duration::days(weekday as i64);
    start_date.and_hms_opt(0, 0, 0).unwrap().and_utc()
}

/// 获取本月开始时间
pub fn month_start() -> DateTime<Utc> {
    let now = now_utc();
    let start_date = now.date_naive().with_day(1).unwrap();
    start_date.and_hms_opt(0, 0, 0).unwrap().and_utc()
}

/// 相对时间描述
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RelativeTime {
    pub value: i64,
    pub unit: TimeUnit,
    pub description: String,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum TimeUnit {
    Second,
    Minute,
    Hour,
    Day,
    Week,
    Month,
    Year,
}

impl fmt::Display for TimeUnit {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        let s = match self {
            TimeUnit::Second => "秒",
            TimeUnit::Minute => "分钟",
            TimeUnit::Hour => "小时",
            TimeUnit::Day => "天",
            TimeUnit::Week => "周",
            TimeUnit::Month => "月",
            TimeUnit::Year => "年",
        };
        write!(f, "{}", s)
    }
}

/// 计算相对时间
pub fn relative_time(dt: &DateTime<Utc>) -> RelativeTime {
    let now = now_utc();
    let diff = time_diff(dt, &now);
    
    let abs_seconds = diff.num_seconds().abs();
    let is_past = diff.num_seconds() < 0;
    
    let (value, unit) = if abs_seconds < 60 {
        (abs_seconds, TimeUnit::Second)
    } else if abs_seconds < 3600 {
        (abs_seconds / 60, TimeUnit::Minute)
    } else if abs_seconds < 86400 {
        (abs_seconds / 3600, TimeUnit::Hour)
    } else if abs_seconds < 604800 {
        (abs_seconds / 86400, TimeUnit::Day)
    } else if abs_seconds < 2592000 {
        (abs_seconds / 604800, TimeUnit::Week)
    } else if abs_seconds < 31536000 {
        (abs_seconds / 2592000, TimeUnit::Month)
    } else {
        (abs_seconds / 31536000, TimeUnit::Year)
    };
    
    let description = if is_past {
        format!("{}{} 前", value, unit)
    } else {
        format!("{}{} 后", value, unit)
    };
    
    RelativeTime {
        value,
        unit,
        description,
    }
}

/// 判断是否为同一天
pub fn is_same_day(dt1: &DateTime<Utc>, dt2: &DateTime<Utc>) -> bool {
    dt1.date_naive() == dt2.date_naive()
}

/// 判断是否为今天
pub fn is_today(dt: &DateTime<Utc>) -> bool {
    is_same_day(dt, &now_utc())
}

/// 判断是否为昨天
pub fn is_yesterday(dt: &DateTime<Utc>) -> bool {
    let yesterday = sub_duration(&now_utc(), Duration::days(1));
    is_same_day(dt, &yesterday)
}

/// 判断是否为本周
pub fn is_this_week(dt: &DateTime<Utc>) -> bool {
    let week_start = week_start();
    let week_end = add_days(&week_start, 7);
    dt >= &week_start && dt < &week_end
}

/// 判断是否为本月
pub fn is_this_month(dt: &DateTime<Utc>) -> bool {
    let now = now_utc();
    dt.year() == now.year() && dt.month() == now.month()
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_time_formatting() {
        let dt = Utc::now();
        let formatted = format_datetime_default(&dt);
        assert!(!formatted.is_empty());
    }

    #[test]
    fn test_time_parsing() {
        let time_str = "2023-12-25 10:30:00";
        let parsed = parse_datetime(time_str, DEFAULT_DATETIME_FORMAT);
        assert!(parsed.is_ok());
    }

    #[test]
    fn test_timestamp_conversion() {
        let dt = Utc::now();
        let timestamp = datetime_to_timestamp(&dt);
        let converted = timestamp_to_datetime(timestamp).unwrap();
        assert_eq!(dt.timestamp(), converted.timestamp());
    }

    #[test]
    fn test_relative_time() {
        let past = sub_duration(&now_utc(), Duration::hours(2));
        let rel_time = relative_time(&past);
        assert!(rel_time.description.contains("前"));
    }

    #[test]
    fn test_date_checks() {
        let now = now_utc();
        assert!(is_today(&now));
        
        let yesterday = sub_duration(&now, Duration::days(1));
        assert!(is_yesterday(&yesterday));
    }
}

加密解密工具 (crypto.rs)

功能特性

  • AES 加密解密
  • RSA 加密解密
  • 数字签名
  • 密钥生成

实现代码

use aes_gcm::{
    aead::{Aead, KeyInit, OsRng},
    Aes256Gcm, Key, Nonce,
};
use base64::{engine::general_purpose, Engine as _};
use rand::{thread_rng, RngCore};
use rsa::{
    pkcs1::{DecodeRsaPrivateKey, DecodeRsaPublicKey, EncodeRsaPrivateKey, EncodeRsaPublicKey},
    pkcs8::{DecodePrivateKey, DecodePublicKey, EncodePrivateKey, EncodePublicKey},
    Pkcs1v15Encrypt, RsaPrivateKey, RsaPublicKey,
};
use sha2::{Digest, Sha256};
use std::error::Error;
use thiserror::Error;

/// 加密错误类型
#[derive(Error, Debug)]
pub enum CryptoError {
    #[error("加密失败: {0}")]
    EncryptionFailed(String),
    #[error("解密失败: {0}")]
    DecryptionFailed(String),
    #[error("密钥生成失败: {0}")]
    KeyGenerationFailed(String),
    #[error("签名失败: {0}")]
    SigningFailed(String),
    #[error("验证失败: {0}")]
    VerificationFailed(String),
    #[error("编码失败: {0}")]
    EncodingFailed(String),
}

/// AES 加密工具
pub struct AesEncryption {
    cipher: Aes256Gcm,
}

impl AesEncryption {
    /// 创建新的 AES 加密器
    pub fn new(key: &[u8; 32]) -> Self {
        let key = Key::<Aes256Gcm>::from_slice(key);
        let cipher = Aes256Gcm::new(key);
        Self { cipher }
    }

    /// 从密码创建 AES 加密器
    pub fn from_password(password: &str) -> Self {
        let key = Self::derive_key_from_password(password);
        Self::new(&key)
    }

    /// 从密码派生密钥
    fn derive_key_from_password(password: &str) -> [u8; 32] {
        let mut hasher = Sha256::new();
        hasher.update(password.as_bytes());
        hasher.finalize().into()
    }

    /// 加密数据
    pub fn encrypt(&self, plaintext: &[u8]) -> Result<Vec<u8>, CryptoError> {
        let mut nonce_bytes = [0u8; 12];
        thread_rng().fill_bytes(&mut nonce_bytes);
        let nonce = Nonce::from_slice(&nonce_bytes);

        let ciphertext = self
            .cipher
            .encrypt(nonce, plaintext)
            .map_err(|e| CryptoError::EncryptionFailed(e.to_string()))?;

        // 将 nonce 和密文组合
        let mut result = Vec::with_capacity(12 + ciphertext.len());
        result.extend_from_slice(&nonce_bytes);
        result.extend_from_slice(&ciphertext);

        Ok(result)
    }

    /// 解密数据
    pub fn decrypt(&self, ciphertext: &[u8]) -> Result<Vec<u8>, CryptoError> {
        if ciphertext.len() < 12 {
            return Err(CryptoError::DecryptionFailed(
                "密文长度不足".to_string(),
            ));
        }

        let (nonce_bytes, encrypted_data) = ciphertext.split_at(12);
        let nonce = Nonce::from_slice(nonce_bytes);

        self.cipher
            .decrypt(nonce, encrypted_data)
            .map_err(|e| CryptoError::DecryptionFailed(e.to_string()))
    }

    /// 加密字符串并返回 Base64 编码
    pub fn encrypt_string(&self, plaintext: &str) -> Result<String, CryptoError> {
        let encrypted = self.encrypt(plaintext.as_bytes())?;
        Ok(general_purpose::STANDARD.encode(encrypted))
    }

    /// 解密 Base64 编码的字符串
    pub fn decrypt_string(&self, ciphertext: &str) -> Result<String, CryptoError> {
        let decoded = general_purpose::STANDARD
            .decode(ciphertext)
            .map_err(|e| CryptoError::DecryptionFailed(e.to_string()))?;

        let decrypted = self.decrypt(&decoded)?;
        String::from_utf8(decrypted)
            .map_err(|e| CryptoError::DecryptionFailed(e.to_string()))
    }
}

/// RSA 密钥对
#[derive(Debug, Clone)]
pub struct RsaKeyPair {
    pub private_key: RsaPrivateKey,
    pub public_key: RsaPublicKey,
}

impl RsaKeyPair {
    /// 生成新的 RSA 密钥对
    pub fn generate(bits: usize) -> Result<Self, CryptoError> {
        let mut rng = OsRng;
        let private_key = RsaPrivateKey::new(&mut rng, bits)
            .map_err(|e| CryptoError::KeyGenerationFailed(e.to_string()))?;
        let public_key = RsaPublicKey::from(&private_key);

        Ok(Self {
            private_key,
            public_key,
        })
    }

    /// 从 PEM 格式加载私钥
    pub fn from_private_key_pem(pem: &str) -> Result<Self, CryptoError> {
        let private_key = RsaPrivateKey::from_pkcs8_pem(pem)
            .map_err(|e| CryptoError::KeyGenerationFailed(e.to_string()))?;
        let public_key = RsaPublicKey::from(&private_key);

        Ok(Self {
            private_key,
            public_key,
        })
    }

    /// 导出私钥为 PEM 格式
    pub fn private_key_to_pem(&self) -> Result<String, CryptoError> {
        self.private_key
            .to_pkcs8_pem(rsa::pkcs8::LineEnding::LF)
            .map_err(|e| CryptoError::EncodingFailed(e.to_string()))
            .map(|s| s.to_string())
    }

    /// 导出公钥为 PEM 格式
    pub fn public_key_to_pem(&self) -> Result<String, CryptoError> {
        self.public_key
            .to_public_key_pem(rsa::pkcs8::LineEnding::LF)
            .map_err(|e| CryptoError::EncodingFailed(e.to_string()))
    }

    /// 使用公钥加密
    pub fn encrypt(&self, plaintext: &[u8]) -> Result<Vec<u8>, CryptoError> {
        let mut rng = OsRng;
        self.public_key
            .encrypt(&mut rng, Pkcs1v15Encrypt, plaintext)
            .map_err(|e| CryptoError::EncryptionFailed(e.to_string()))
    }

    /// 使用私钥解密
    pub fn decrypt(&self, ciphertext: &[u8]) -> Result<Vec<u8>, CryptoError> {
        self.private_key
            .decrypt(Pkcs1v15Encrypt, ciphertext)
            .map_err(|e| CryptoError::DecryptionFailed(e.to_string()))
    }

    /// 加密字符串并返回 Base64 编码
    pub fn encrypt_string(&self, plaintext: &str) -> Result<String, CryptoError> {
        let encrypted = self.encrypt(plaintext.as_bytes())?;
        Ok(general_purpose::STANDARD.encode(encrypted))
    }

    /// 解密 Base64 编码的字符串
    pub fn decrypt_string(&self, ciphertext: &str) -> Result<String, CryptoError> {
        let decoded = general_purpose::STANDARD
            .decode(ciphertext)
            .map_err(|e| CryptoError::DecryptionFailed(e.to_string()))?;

        let decrypted = self.decrypt(&decoded)?;
        String::from_utf8(decrypted)
            .map_err(|e| CryptoError::DecryptionFailed(e.to_string()))
    }
}

/// 生成随机密钥
pub fn generate_random_key(length: usize) -> Vec<u8> {
    let mut key = vec![0u8; length];
    thread_rng().fill_bytes(&mut key);
    key
}

/// 生成 AES-256 密钥
pub fn generate_aes_key() -> [u8; 32] {
    let mut key = [0u8; 32];
    thread_rng().fill_bytes(&mut key);
    key
}

/// 计算 SHA-256 哈希
pub fn sha256_hash(data: &[u8]) -> Vec<u8> {
    let mut hasher = Sha256::new();
    hasher.update(data);
    hasher.finalize().to_vec()
}

/// 计算字符串的 SHA-256 哈希并返回十六进制字符串
pub fn sha256_hex(data: &str) -> String {
    let hash = sha256_hash(data.as_bytes());
    hex::encode(hash)
}

/// 验证哈希
pub fn verify_hash(data: &[u8], expected_hash: &[u8]) -> bool {
    let actual_hash = sha256_hash(data);
    actual_hash == expected_hash
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_aes_encryption() {
        let key = generate_aes_key();
        let aes = AesEncryption::new(&key);
        
        let plaintext = "Hello, World!";
        let encrypted = aes.encrypt_string(plaintext).unwrap();
        let decrypted = aes.decrypt_string(&encrypted).unwrap();
        
        assert_eq!(plaintext, decrypted);
    }

    #[test]
    fn test_rsa_encryption() {
        let keypair = RsaKeyPair::generate(2048).unwrap();
        
        let plaintext = "Hello, RSA!";
        let encrypted = keypair.encrypt_string(plaintext).unwrap();
        let decrypted = keypair.decrypt_string(&encrypted).unwrap();
        
        assert_eq!(plaintext, decrypted);
    }

    #[test]
    fn test_sha256_hash() {
        let data = "test data";
        let hash1 = sha256_hex(data);
        let hash2 = sha256_hex(data);
        
        assert_eq!(hash1, hash2);
        assert_eq!(hash1.len(), 64); // SHA-256 产生 64 个十六进制字符
    }

    #[test]
    fn test_key_generation() {
        let key1 = generate_aes_key();
        let key2 = generate_aes_key();
        
        assert_ne!(key1, key2); // 密钥应该是随机的
        assert_eq!(key1.len(), 32); // AES-256 密钥长度
    }
}

验证工具 (validation.rs)

功能特性

  • 邮箱验证
  • 手机号验证
  • URL 验证
  • 密码强度验证
  • 自定义验证规则

实现代码

use regex::Regex;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use thiserror::Error;

/// 验证错误类型
#[derive(Error, Debug, Clone, Serialize, Deserialize)]
pub enum ValidationError {
    #[error("字段 '{field}' 是必需的")]
    Required { field: String },
    #[error("字段 '{field}' 长度必须在 {min} 到 {max} 之间")]
    Length { field: String, min: usize, max: usize },
    #[error("字段 '{field}' 格式无效: {message}")]
    Format { field: String, message: String },
    #[error("字段 '{field}' 值无效: {message}")]
    Invalid { field: String, message: String },
    #[error("自定义验证失败: {message}")]
    Custom { message: String },
}

/// 验证结果
pub type ValidationResult<T> = Result<T, ValidationError>;

/// 验证器特征
pub trait Validator<T> {
    fn validate(&self, value: &T) -> ValidationResult<()>;
}

/// 字符串验证器
#[derive(Debug, Clone)]
pub struct StringValidator {
    pub field_name: String,
    pub required: bool,
    pub min_length: Option<usize>,
    pub max_length: Option<usize>,
    pub pattern: Option<Regex>,
    pub custom_validators: Vec<Box<dyn Fn(&str) -> ValidationResult<()> + Send + Sync>>,
}

impl StringValidator {
    pub fn new(field_name: &str) -> Self {
        Self {
            field_name: field_name.to_string(),
            required: false,
            min_length: None,
            max_length: None,
            pattern: None,
            custom_validators: Vec::new(),
        }
    }

    pub fn required(mut self) -> Self {
        self.required = true;
        self
    }

    pub fn min_length(mut self, min: usize) -> Self {
        self.min_length = Some(min);
        self
    }

    pub fn max_length(mut self, max: usize) -> Self {
        self.max_length = Some(max);
        self
    }

    pub fn length_range(mut self, min: usize, max: usize) -> Self {
        self.min_length = Some(min);
        self.max_length = Some(max);
        self
    }

    pub fn pattern(mut self, pattern: &str) -> Result<Self, regex::Error> {
        self.pattern = Some(Regex::new(pattern)?);
        Ok(self)
    }

    pub fn email(self) -> Result<Self, regex::Error> {
        self.pattern(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$")
    }

    pub fn phone(self) -> Result<Self, regex::Error> {
        self.pattern(r"^1[3-9]\d{9}$")
    }

    pub fn url(self) -> Result<Self, regex::Error> {
        self.pattern(r"^https?://[^\s/$.?#].[^\s]*$")
    }
}

impl Validator<Option<String>> for StringValidator {
    fn validate(&self, value: &Option<String>) -> ValidationResult<()> {
        match value {
            None => {
                if self.required {
                    Err(ValidationError::Required {
                        field: self.field_name.clone(),
                    })
                } else {
                    Ok(())
                }
            }
            Some(s) => self.validate(s),
        }
    }
}

impl Validator<String> for StringValidator {
    fn validate(&self, value: &String) -> ValidationResult<()> {
        // 检查是否为空且必需
        if value.is_empty() && self.required {
            return Err(ValidationError::Required {
                field: self.field_name.clone(),
            });
        }

        // 如果为空且不是必需的,跳过其他验证
        if value.is_empty() && !self.required {
            return Ok(());
        }

        // 检查长度
        if let Some(min) = self.min_length {
            if value.len() < min {
                return Err(ValidationError::Length {
                    field: self.field_name.clone(),
                    min,
                    max: self.max_length.unwrap_or(usize::MAX),
                });
            }
        }

        if let Some(max) = self.max_length {
            if value.len() > max {
                return Err(ValidationError::Length {
                    field: self.field_name.clone(),
                    min: self.min_length.unwrap_or(0),
                    max,
                });
            }
        }

        // 检查正则表达式
        if let Some(pattern) = &self.pattern {
            if !pattern.is_match(value) {
                return Err(ValidationError::Format {
                    field: self.field_name.clone(),
                    message: "格式不匹配".to_string(),
                });
            }
        }

        // 执行自定义验证
        for validator in &self.custom_validators {
            validator(value)?;
        }

        Ok(())
    }
}

/// 数字验证器
#[derive(Debug, Clone)]
pub struct NumberValidator<T> {
    pub field_name: String,
    pub required: bool,
    pub min_value: Option<T>,
    pub max_value: Option<T>,
}

impl<T> NumberValidator<T>
where
    T: PartialOrd + Copy,
{
    pub fn new(field_name: &str) -> Self {
        Self {
            field_name: field_name.to_string(),
            required: false,
            min_value: None,
            max_value: None,
        }
    }

    pub fn required(mut self) -> Self {
        self.required = true;
        self
    }

    pub fn min_value(mut self, min: T) -> Self {
        self.min_value = Some(min);
        self
    }

    pub fn max_value(mut self, max: T) -> Self {
        self.max_value = Some(max);
        self
    }

    pub fn range(mut self, min: T, max: T) -> Self {
        self.min_value = Some(min);
        self.max_value = Some(max);
        self
    }
}

impl<T> Validator<Option<T>> for NumberValidator<T>
where
    T: PartialOrd + Copy + std::fmt::Display,
{
    fn validate(&self, value: &Option<T>) -> ValidationResult<()> {
        match value {
            None => {
                if self.required {
                    Err(ValidationError::Required {
                        field: self.field_name.clone(),
                    })
                } else {
                    Ok(())
                }
            }
            Some(v) => self.validate(v),
        }
    }
}

impl<T> Validator<T> for NumberValidator<T>
where
    T: PartialOrd + Copy + std::fmt::Display,
{
    fn validate(&self, value: &T) -> ValidationResult<()> {
        if let Some(min) = self.min_value {
            if *value < min {
                return Err(ValidationError::Invalid {
                    field: self.field_name.clone(),
                    message: format!("值必须大于等于 {}", min),
                });
            }
        }

        if let Some(max) = self.max_value {
            if *value > max {
                return Err(ValidationError::Invalid {
                    field: self.field_name.clone(),
                    message: format!("值必须小于等于 {}", max),
                });
            }
        }

        Ok(())
    }
}

/// 密码强度等级
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum PasswordStrength {
    Weak,
    Medium,
    Strong,
    VeryStrong,
}

/// 密码验证结果
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PasswordValidation {
    pub is_valid: bool,
    pub strength: PasswordStrength,
    pub score: u8,
    pub feedback: Vec<String>,
}

/// 验证邮箱地址
pub fn validate_email(email: &str) -> bool {
    let email_regex = Regex::new(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$").unwrap();
    email_regex.is_match(email)
}

/// 验证中国手机号
pub fn validate_phone(phone: &str) -> bool {
    let phone_regex = Regex::new(r"^1[3-9]\d{9}$").unwrap();
    phone_regex.is_match(phone)
}

/// 验证 URL
pub fn validate_url(url: &str) -> bool {
    let url_regex = Regex::new(r"^https?://[^\s/$.?#].[^\s]*$").unwrap();
    url_regex.is_match(url)
}

/// 验证 IP 地址
pub fn validate_ip(ip: &str) -> bool {
    ip.parse::<std::net::IpAddr>().is_ok()
}

/// 验证密码强度
pub fn validate_password(password: &str) -> PasswordValidation {
    let mut score = 0u8;
    let mut feedback = Vec::new();

    // 长度检查
    if password.len() >= 8 {
        score += 1;
    } else {
        feedback.push("密码长度至少需要8位".to_string());
    }

    if password.len() >= 12 {
        score += 1;
    }

    // 包含小写字母
    if password.chars().any(|c| c.is_ascii_lowercase()) {
        score += 1;
    } else {
        feedback.push("密码应包含小写字母".to_string());
    }

    // 包含大写字母
    if password.chars().any(|c| c.is_ascii_uppercase()) {
        score += 1;
    } else {
        feedback.push("密码应包含大写字母".to_string());
    }

    // 包含数字
    if password.chars().any(|c| c.is_ascii_digit()) {
        score += 1;
    } else {
        feedback.push("密码应包含数字".to_string());
    }

    // 包含特殊字符
    if password.chars().any(|c| "!@#$%^&*()_+-=[]{}|;:,.<>?".contains(c)) {
        score += 1;
    } else {
        feedback.push("密码应包含特殊字符".to_string());
    }

    // 不包含常见弱密码
    let weak_passwords = [
        "password", "123456", "123456789", "qwerty", "abc123",
        "password123", "admin", "root", "user", "guest",
    ];
    
    if weak_passwords.iter().any(|&weak| password.to_lowercase().contains(weak)) {
        score = score.saturating_sub(2);
        feedback.push("密码不能包含常见弱密码".to_string());
    }

    let strength = match score {
        0..=2 => PasswordStrength::Weak,
        3..=4 => PasswordStrength::Medium,
        5..=6 => PasswordStrength::Strong,
        _ => PasswordStrength::VeryStrong,
    };

    let is_valid = score >= 3 && password.len() >= 8;

    if is_valid && feedback.is_empty() {
        feedback.push("密码强度良好".to_string());
    }

    PasswordValidation {
        is_valid,
        strength,
        score,
        feedback,
    }
}

/// 验证 JSON 格式
pub fn validate_json(json_str: &str) -> bool {
    serde_json::from_str::<serde_json::Value>(json_str).is_ok()
}

/// 验证 UUID 格式
pub fn validate_uuid(uuid_str: &str) -> bool {
    uuid::Uuid::parse_str(uuid_str).is_ok()
}

/// 批量验证器
#[derive(Debug)]
pub struct BatchValidator {
    errors: Vec<ValidationError>,
}

impl BatchValidator {
    pub fn new() -> Self {
        Self {
            errors: Vec::new(),
        }
    }

    pub fn validate<T, V>(&mut self, validator: &V, value: &T) -> &mut Self
    where
        V: Validator<T>,
    {
        if let Err(error) = validator.validate(value) {
            self.errors.push(error);
        }
        self
    }

    pub fn is_valid(&self) -> bool {
        self.errors.is_empty()
    }

    pub fn errors(&self) -> &[ValidationError] {
        &self.errors
    }

    pub fn into_result(self) -> ValidationResult<()> {
        if self.errors.is_empty() {
            Ok(())
        } else {
            // 返回第一个错误,实际使用中可能需要返回所有错误
            Err(self.errors.into_iter().next().unwrap())
        }
    }
}

impl Default for BatchValidator {
    fn default() -> Self {
        Self::new()
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_email_validation() {
        assert!(validate_email("test@example.com"));
        assert!(validate_email("user.name+tag@domain.co.uk"));
        assert!(!validate_email("invalid.email"));
        assert!(!validate_email("@domain.com"));
    }

    #[test]
    fn test_phone_validation() {
        assert!(validate_phone("13812345678"));
        assert!(validate_phone("15987654321"));
        assert!(!validate_phone("12345678901"));
        assert!(!validate_phone("1381234567"));
    }

    #[test]
    fn test_password_validation() {
        let weak = validate_password("123");
        assert_eq!(weak.strength, PasswordStrength::Weak);
        assert!(!weak.is_valid);

        let strong = validate_password("MyStr0ng!Pass");
        assert!(strong.is_valid);
        assert!(matches!(strong.strength, PasswordStrength::Strong | PasswordStrength::VeryStrong));
    }

    #[test]
    fn test_string_validator() {
        let validator = StringValidator::new("username")
            .required()
            .length_range(3, 20)
            .pattern(r"^[a-zA-Z0-9_]+$")
            .unwrap();

        assert!(validator.validate(&"valid_user123".to_string()).is_ok());
        assert!(validator.validate(&"ab".to_string()).is_err()); // 太短
        assert!(validator.validate(&"invalid-user".to_string()).is_err()); // 包含非法字符
    }

    #[test]
    fn test_number_validator() {
        let validator = NumberValidator::new("age")
            .required()
            .range(0, 150);

        assert!(validator.validate(&25).is_ok());
        assert!(validator.validate(&-1).is_err()); // 小于最小值
        assert!(validator.validate(&200).is_err()); // 大于最大值
    }

    #[test]
    fn test_batch_validator() {
        let mut batch = BatchValidator::new();
        
        let email_validator = StringValidator::new("email").required().email().unwrap();
        let age_validator = NumberValidator::new("age").required().range(0, 150);
        
        batch
            .validate(&email_validator, &"invalid.email".to_string())
            .validate(&age_validator, &200);
        
        assert!(!batch.is_valid());
        assert_eq!(batch.errors().len(), 2);
    }
}