新增以下文档文件: - PROJECT_OVERVIEW.md 项目总览文档 - BACKEND_ARCHITECTURE.md 后端架构文档 - FRONTEND_ARCHITECTURE.md 前端架构文档 - FLOW_ENGINE.md 流程引擎文档 - SERVICES.md 服务层文档 - ERROR_HANDLING.md 错误处理模块文档 文档内容涵盖项目整体介绍、技术架构、核心模块设计和实现细节
34 KiB
34 KiB
UdminAI 工具函数模块文档
概述
UdminAI 项目的工具函数模块(Utils)提供了一系列通用的工具函数和辅助功能,支持系统的各个组件。这些工具函数涵盖了 ID 生成、时间处理、加密解密、验证、格式化、文件操作等多个方面,为整个系统提供了基础的功能支持。
模块结构
backend/src/utils/
├── mod.rs # 模块导出
├── id.rs # ID 生成工具
├── time.rs # 时间处理工具
├── crypto.rs # 加密解密工具
├── validation.rs # 验证工具
├── format.rs # 格式化工具
├── file.rs # 文件操作工具
├── hash.rs # 哈希工具
├── jwt.rs # JWT 工具
├── password.rs # 密码处理工具
├── email.rs # 邮件工具
├── config.rs # 配置工具
└── error.rs # 错误处理工具
核心工具模块
ID 生成工具 (id.rs)
功能特性
- 唯一 ID 生成
- 多种 ID 格式支持
- 高性能生成
- 分布式友好
实现代码
use chrono::{DateTime, Utc};
use rand::{thread_rng, Rng};
use std::sync::atomic::{AtomicU64, Ordering};
use uuid::Uuid;
/// 全局序列号计数器
static SEQUENCE_COUNTER: AtomicU64 = AtomicU64::new(0);
/// ID 生成器类型
#[derive(Debug, Clone)]
pub enum IdType {
/// UUID v4
Uuid,
/// 雪花算法 ID
Snowflake,
/// 时间戳 + 随机数
Timestamp,
/// 自定义前缀 + UUID
Prefixed(String),
}
/// 生成唯一 ID
pub fn generate_id() -> String {
generate_id_with_type(IdType::Uuid)
}
/// 根据类型生成 ID
pub fn generate_id_with_type(id_type: IdType) -> String {
match id_type {
IdType::Uuid => Uuid::new_v4().to_string(),
IdType::Snowflake => generate_snowflake_id(),
IdType::Timestamp => generate_timestamp_id(),
IdType::Prefixed(prefix) => format!("{}-{}", prefix, Uuid::new_v4()),
}
}
/// 生成雪花算法 ID
fn generate_snowflake_id() -> String {
let timestamp = Utc::now().timestamp_millis() as u64;
let sequence = SEQUENCE_COUNTER.fetch_add(1, Ordering::SeqCst) & 0xFFF; // 12位序列号
let machine_id = 1u64; // 机器ID,实际使用时应该从配置获取
let id = (timestamp << 22) | (machine_id << 12) | sequence;
id.to_string()
}
/// 生成时间戳 ID
fn generate_timestamp_id() -> String {
let timestamp = Utc::now().timestamp_millis();
let random: u32 = thread_rng().gen();
format!("{}{:08x}", timestamp, random)
}
/// 生成短 ID(8位)
pub fn generate_short_id() -> String {
let chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";
let mut rng = thread_rng();
(0..8)
.map(|_| {
let idx = rng.gen_range(0..chars.len());
chars.chars().nth(idx).unwrap()
})
.collect()
}
/// 生成数字 ID
pub fn generate_numeric_id(length: usize) -> String {
let mut rng = thread_rng();
(0..length)
.map(|_| rng.gen_range(0..10).to_string())
.collect()
}
/// 验证 ID 格式
pub fn validate_id(id: &str, id_type: IdType) -> bool {
match id_type {
IdType::Uuid => Uuid::parse_str(id).is_ok(),
IdType::Snowflake => id.parse::<u64>().is_ok(),
IdType::Timestamp => id.len() >= 13 && id[..13].parse::<i64>().is_ok(),
IdType::Prefixed(prefix) => {
id.starts_with(&format!("{}-", prefix)) &&
id.len() > prefix.len() + 1 &&
Uuid::parse_str(&id[prefix.len() + 1..]).is_ok()
}
}
}
/// 从雪花 ID 提取时间戳
pub fn extract_timestamp_from_snowflake(id: &str) -> Option<DateTime<Utc>> {
if let Ok(snowflake_id) = id.parse::<u64>() {
let timestamp = (snowflake_id >> 22) as i64;
DateTime::from_timestamp_millis(timestamp)
} else {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_generate_uuid() {
let id = generate_id();
assert!(validate_id(&id, IdType::Uuid));
}
#[test]
fn test_generate_snowflake() {
let id = generate_id_with_type(IdType::Snowflake);
assert!(validate_id(&id, IdType::Snowflake));
}
#[test]
fn test_generate_prefixed_id() {
let id = generate_id_with_type(IdType::Prefixed("user".to_string()));
assert!(id.starts_with("user-"));
assert!(validate_id(&id, IdType::Prefixed("user".to_string())));
}
#[test]
fn test_short_id() {
let id = generate_short_id();
assert_eq!(id.len(), 8);
}
}
时间处理工具 (time.rs)
功能特性
- 时间格式化
- 时区转换
- 时间计算
- 相对时间
实现代码
use chrono::{
DateTime, Duration, FixedOffset, Local, NaiveDateTime, TimeZone, Utc,
};
use serde::{Deserialize, Serialize};
use std::fmt;
/// 时间格式常量
pub const DEFAULT_DATETIME_FORMAT: &str = "%Y-%m-%d %H:%M:%S";
pub const ISO_DATETIME_FORMAT: &str = "%Y-%m-%dT%H:%M:%S%.3fZ";
pub const DATE_FORMAT: &str = "%Y-%m-%d";
pub const TIME_FORMAT: &str = "%H:%M:%S";
/// 时区偏移量(东八区)
pub const CHINA_OFFSET: i32 = 8 * 3600;
/// 获取当前时间(UTC)
pub fn now_utc() -> DateTime<Utc> {
Utc::now()
}
/// 获取当前时间(带固定偏移量)
pub fn now_fixed_offset() -> DateTime<FixedOffset> {
Utc::now().with_timezone(&FixedOffset::east_opt(0).unwrap())
}
/// 获取当前时间(中国时区)
pub fn now_china() -> DateTime<FixedOffset> {
Utc::now().with_timezone(&FixedOffset::east_opt(CHINA_OFFSET).unwrap())
}
/// 获取当前本地时间
pub fn now_local() -> DateTime<Local> {
Local::now()
}
/// 格式化时间
pub fn format_datetime(dt: &DateTime<Utc>, format: &str) -> String {
dt.format(format).to_string()
}
/// 格式化时间(默认格式)
pub fn format_datetime_default(dt: &DateTime<Utc>) -> String {
format_datetime(dt, DEFAULT_DATETIME_FORMAT)
}
/// 格式化时间(ISO 格式)
pub fn format_datetime_iso(dt: &DateTime<Utc>) -> String {
format_datetime(dt, ISO_DATETIME_FORMAT)
}
/// 解析时间字符串
pub fn parse_datetime(s: &str, format: &str) -> Result<DateTime<Utc>, chrono::ParseError> {
let naive = NaiveDateTime::parse_from_str(s, format)?;
Ok(Utc.from_utc_datetime(&naive))
}
/// 解析 ISO 时间字符串
pub fn parse_datetime_iso(s: &str) -> Result<DateTime<Utc>, chrono::ParseError> {
DateTime::parse_from_rfc3339(s).map(|dt| dt.with_timezone(&Utc))
}
/// 时间戳转换为 DateTime
pub fn timestamp_to_datetime(timestamp: i64) -> Option<DateTime<Utc>> {
DateTime::from_timestamp(timestamp, 0)
}
/// 毫秒时间戳转换为 DateTime
pub fn timestamp_millis_to_datetime(timestamp: i64) -> Option<DateTime<Utc>> {
DateTime::from_timestamp_millis(timestamp)
}
/// DateTime 转换为时间戳
pub fn datetime_to_timestamp(dt: &DateTime<Utc>) -> i64 {
dt.timestamp()
}
/// DateTime 转换为毫秒时间戳
pub fn datetime_to_timestamp_millis(dt: &DateTime<Utc>) -> i64 {
dt.timestamp_millis()
}
/// 计算时间差
pub fn time_diff(start: &DateTime<Utc>, end: &DateTime<Utc>) -> Duration {
*end - *start
}
/// 计算时间差(秒)
pub fn time_diff_seconds(start: &DateTime<Utc>, end: &DateTime<Utc>) -> i64 {
time_diff(start, end).num_seconds()
}
/// 计算时间差(毫秒)
pub fn time_diff_millis(start: &DateTime<Utc>, end: &DateTime<Utc>) -> i64 {
time_diff(start, end).num_milliseconds()
}
/// 时间加法
pub fn add_duration(dt: &DateTime<Utc>, duration: Duration) -> DateTime<Utc> {
*dt + duration
}
/// 时间减法
pub fn sub_duration(dt: &DateTime<Utc>, duration: Duration) -> DateTime<Utc> {
*dt - duration
}
/// 添加秒数
pub fn add_seconds(dt: &DateTime<Utc>, seconds: i64) -> DateTime<Utc> {
add_duration(dt, Duration::seconds(seconds))
}
/// 添加分钟
pub fn add_minutes(dt: &DateTime<Utc>, minutes: i64) -> DateTime<Utc> {
add_duration(dt, Duration::minutes(minutes))
}
/// 添加小时
pub fn add_hours(dt: &DateTime<Utc>, hours: i64) -> DateTime<Utc> {
add_duration(dt, Duration::hours(hours))
}
/// 添加天数
pub fn add_days(dt: &DateTime<Utc>, days: i64) -> DateTime<Utc> {
add_duration(dt, Duration::days(days))
}
/// 获取今天开始时间
pub fn today_start() -> DateTime<Utc> {
let now = now_utc();
now.date_naive().and_hms_opt(0, 0, 0).unwrap().and_utc()
}
/// 获取今天结束时间
pub fn today_end() -> DateTime<Utc> {
let now = now_utc();
now.date_naive().and_hms_opt(23, 59, 59).unwrap().and_utc()
}
/// 获取本周开始时间(周一)
pub fn week_start() -> DateTime<Utc> {
let now = now_utc();
let weekday = now.weekday().num_days_from_monday();
let start_date = now.date_naive() - Duration::days(weekday as i64);
start_date.and_hms_opt(0, 0, 0).unwrap().and_utc()
}
/// 获取本月开始时间
pub fn month_start() -> DateTime<Utc> {
let now = now_utc();
let start_date = now.date_naive().with_day(1).unwrap();
start_date.and_hms_opt(0, 0, 0).unwrap().and_utc()
}
/// 相对时间描述
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RelativeTime {
pub value: i64,
pub unit: TimeUnit,
pub description: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum TimeUnit {
Second,
Minute,
Hour,
Day,
Week,
Month,
Year,
}
impl fmt::Display for TimeUnit {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let s = match self {
TimeUnit::Second => "秒",
TimeUnit::Minute => "分钟",
TimeUnit::Hour => "小时",
TimeUnit::Day => "天",
TimeUnit::Week => "周",
TimeUnit::Month => "月",
TimeUnit::Year => "年",
};
write!(f, "{}", s)
}
}
/// 计算相对时间
pub fn relative_time(dt: &DateTime<Utc>) -> RelativeTime {
let now = now_utc();
let diff = time_diff(dt, &now);
let abs_seconds = diff.num_seconds().abs();
let is_past = diff.num_seconds() < 0;
let (value, unit) = if abs_seconds < 60 {
(abs_seconds, TimeUnit::Second)
} else if abs_seconds < 3600 {
(abs_seconds / 60, TimeUnit::Minute)
} else if abs_seconds < 86400 {
(abs_seconds / 3600, TimeUnit::Hour)
} else if abs_seconds < 604800 {
(abs_seconds / 86400, TimeUnit::Day)
} else if abs_seconds < 2592000 {
(abs_seconds / 604800, TimeUnit::Week)
} else if abs_seconds < 31536000 {
(abs_seconds / 2592000, TimeUnit::Month)
} else {
(abs_seconds / 31536000, TimeUnit::Year)
};
let description = if is_past {
format!("{}{} 前", value, unit)
} else {
format!("{}{} 后", value, unit)
};
RelativeTime {
value,
unit,
description,
}
}
/// 判断是否为同一天
pub fn is_same_day(dt1: &DateTime<Utc>, dt2: &DateTime<Utc>) -> bool {
dt1.date_naive() == dt2.date_naive()
}
/// 判断是否为今天
pub fn is_today(dt: &DateTime<Utc>) -> bool {
is_same_day(dt, &now_utc())
}
/// 判断是否为昨天
pub fn is_yesterday(dt: &DateTime<Utc>) -> bool {
let yesterday = sub_duration(&now_utc(), Duration::days(1));
is_same_day(dt, &yesterday)
}
/// 判断是否为本周
pub fn is_this_week(dt: &DateTime<Utc>) -> bool {
let week_start = week_start();
let week_end = add_days(&week_start, 7);
dt >= &week_start && dt < &week_end
}
/// 判断是否为本月
pub fn is_this_month(dt: &DateTime<Utc>) -> bool {
let now = now_utc();
dt.year() == now.year() && dt.month() == now.month()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_time_formatting() {
let dt = Utc::now();
let formatted = format_datetime_default(&dt);
assert!(!formatted.is_empty());
}
#[test]
fn test_time_parsing() {
let time_str = "2023-12-25 10:30:00";
let parsed = parse_datetime(time_str, DEFAULT_DATETIME_FORMAT);
assert!(parsed.is_ok());
}
#[test]
fn test_timestamp_conversion() {
let dt = Utc::now();
let timestamp = datetime_to_timestamp(&dt);
let converted = timestamp_to_datetime(timestamp).unwrap();
assert_eq!(dt.timestamp(), converted.timestamp());
}
#[test]
fn test_relative_time() {
let past = sub_duration(&now_utc(), Duration::hours(2));
let rel_time = relative_time(&past);
assert!(rel_time.description.contains("前"));
}
#[test]
fn test_date_checks() {
let now = now_utc();
assert!(is_today(&now));
let yesterday = sub_duration(&now, Duration::days(1));
assert!(is_yesterday(&yesterday));
}
}
加密解密工具 (crypto.rs)
功能特性
- AES 加密解密
- RSA 加密解密
- 数字签名
- 密钥生成
实现代码
use aes_gcm::{
aead::{Aead, KeyInit, OsRng},
Aes256Gcm, Key, Nonce,
};
use base64::{engine::general_purpose, Engine as _};
use rand::{thread_rng, RngCore};
use rsa::{
pkcs1::{DecodeRsaPrivateKey, DecodeRsaPublicKey, EncodeRsaPrivateKey, EncodeRsaPublicKey},
pkcs8::{DecodePrivateKey, DecodePublicKey, EncodePrivateKey, EncodePublicKey},
Pkcs1v15Encrypt, RsaPrivateKey, RsaPublicKey,
};
use sha2::{Digest, Sha256};
use std::error::Error;
use thiserror::Error;
/// 加密错误类型
#[derive(Error, Debug)]
pub enum CryptoError {
#[error("加密失败: {0}")]
EncryptionFailed(String),
#[error("解密失败: {0}")]
DecryptionFailed(String),
#[error("密钥生成失败: {0}")]
KeyGenerationFailed(String),
#[error("签名失败: {0}")]
SigningFailed(String),
#[error("验证失败: {0}")]
VerificationFailed(String),
#[error("编码失败: {0}")]
EncodingFailed(String),
}
/// AES 加密工具
pub struct AesEncryption {
cipher: Aes256Gcm,
}
impl AesEncryption {
/// 创建新的 AES 加密器
pub fn new(key: &[u8; 32]) -> Self {
let key = Key::<Aes256Gcm>::from_slice(key);
let cipher = Aes256Gcm::new(key);
Self { cipher }
}
/// 从密码创建 AES 加密器
pub fn from_password(password: &str) -> Self {
let key = Self::derive_key_from_password(password);
Self::new(&key)
}
/// 从密码派生密钥
fn derive_key_from_password(password: &str) -> [u8; 32] {
let mut hasher = Sha256::new();
hasher.update(password.as_bytes());
hasher.finalize().into()
}
/// 加密数据
pub fn encrypt(&self, plaintext: &[u8]) -> Result<Vec<u8>, CryptoError> {
let mut nonce_bytes = [0u8; 12];
thread_rng().fill_bytes(&mut nonce_bytes);
let nonce = Nonce::from_slice(&nonce_bytes);
let ciphertext = self
.cipher
.encrypt(nonce, plaintext)
.map_err(|e| CryptoError::EncryptionFailed(e.to_string()))?;
// 将 nonce 和密文组合
let mut result = Vec::with_capacity(12 + ciphertext.len());
result.extend_from_slice(&nonce_bytes);
result.extend_from_slice(&ciphertext);
Ok(result)
}
/// 解密数据
pub fn decrypt(&self, ciphertext: &[u8]) -> Result<Vec<u8>, CryptoError> {
if ciphertext.len() < 12 {
return Err(CryptoError::DecryptionFailed(
"密文长度不足".to_string(),
));
}
let (nonce_bytes, encrypted_data) = ciphertext.split_at(12);
let nonce = Nonce::from_slice(nonce_bytes);
self.cipher
.decrypt(nonce, encrypted_data)
.map_err(|e| CryptoError::DecryptionFailed(e.to_string()))
}
/// 加密字符串并返回 Base64 编码
pub fn encrypt_string(&self, plaintext: &str) -> Result<String, CryptoError> {
let encrypted = self.encrypt(plaintext.as_bytes())?;
Ok(general_purpose::STANDARD.encode(encrypted))
}
/// 解密 Base64 编码的字符串
pub fn decrypt_string(&self, ciphertext: &str) -> Result<String, CryptoError> {
let decoded = general_purpose::STANDARD
.decode(ciphertext)
.map_err(|e| CryptoError::DecryptionFailed(e.to_string()))?;
let decrypted = self.decrypt(&decoded)?;
String::from_utf8(decrypted)
.map_err(|e| CryptoError::DecryptionFailed(e.to_string()))
}
}
/// RSA 密钥对
#[derive(Debug, Clone)]
pub struct RsaKeyPair {
pub private_key: RsaPrivateKey,
pub public_key: RsaPublicKey,
}
impl RsaKeyPair {
/// 生成新的 RSA 密钥对
pub fn generate(bits: usize) -> Result<Self, CryptoError> {
let mut rng = OsRng;
let private_key = RsaPrivateKey::new(&mut rng, bits)
.map_err(|e| CryptoError::KeyGenerationFailed(e.to_string()))?;
let public_key = RsaPublicKey::from(&private_key);
Ok(Self {
private_key,
public_key,
})
}
/// 从 PEM 格式加载私钥
pub fn from_private_key_pem(pem: &str) -> Result<Self, CryptoError> {
let private_key = RsaPrivateKey::from_pkcs8_pem(pem)
.map_err(|e| CryptoError::KeyGenerationFailed(e.to_string()))?;
let public_key = RsaPublicKey::from(&private_key);
Ok(Self {
private_key,
public_key,
})
}
/// 导出私钥为 PEM 格式
pub fn private_key_to_pem(&self) -> Result<String, CryptoError> {
self.private_key
.to_pkcs8_pem(rsa::pkcs8::LineEnding::LF)
.map_err(|e| CryptoError::EncodingFailed(e.to_string()))
.map(|s| s.to_string())
}
/// 导出公钥为 PEM 格式
pub fn public_key_to_pem(&self) -> Result<String, CryptoError> {
self.public_key
.to_public_key_pem(rsa::pkcs8::LineEnding::LF)
.map_err(|e| CryptoError::EncodingFailed(e.to_string()))
}
/// 使用公钥加密
pub fn encrypt(&self, plaintext: &[u8]) -> Result<Vec<u8>, CryptoError> {
let mut rng = OsRng;
self.public_key
.encrypt(&mut rng, Pkcs1v15Encrypt, plaintext)
.map_err(|e| CryptoError::EncryptionFailed(e.to_string()))
}
/// 使用私钥解密
pub fn decrypt(&self, ciphertext: &[u8]) -> Result<Vec<u8>, CryptoError> {
self.private_key
.decrypt(Pkcs1v15Encrypt, ciphertext)
.map_err(|e| CryptoError::DecryptionFailed(e.to_string()))
}
/// 加密字符串并返回 Base64 编码
pub fn encrypt_string(&self, plaintext: &str) -> Result<String, CryptoError> {
let encrypted = self.encrypt(plaintext.as_bytes())?;
Ok(general_purpose::STANDARD.encode(encrypted))
}
/// 解密 Base64 编码的字符串
pub fn decrypt_string(&self, ciphertext: &str) -> Result<String, CryptoError> {
let decoded = general_purpose::STANDARD
.decode(ciphertext)
.map_err(|e| CryptoError::DecryptionFailed(e.to_string()))?;
let decrypted = self.decrypt(&decoded)?;
String::from_utf8(decrypted)
.map_err(|e| CryptoError::DecryptionFailed(e.to_string()))
}
}
/// 生成随机密钥
pub fn generate_random_key(length: usize) -> Vec<u8> {
let mut key = vec![0u8; length];
thread_rng().fill_bytes(&mut key);
key
}
/// 生成 AES-256 密钥
pub fn generate_aes_key() -> [u8; 32] {
let mut key = [0u8; 32];
thread_rng().fill_bytes(&mut key);
key
}
/// 计算 SHA-256 哈希
pub fn sha256_hash(data: &[u8]) -> Vec<u8> {
let mut hasher = Sha256::new();
hasher.update(data);
hasher.finalize().to_vec()
}
/// 计算字符串的 SHA-256 哈希并返回十六进制字符串
pub fn sha256_hex(data: &str) -> String {
let hash = sha256_hash(data.as_bytes());
hex::encode(hash)
}
/// 验证哈希
pub fn verify_hash(data: &[u8], expected_hash: &[u8]) -> bool {
let actual_hash = sha256_hash(data);
actual_hash == expected_hash
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_aes_encryption() {
let key = generate_aes_key();
let aes = AesEncryption::new(&key);
let plaintext = "Hello, World!";
let encrypted = aes.encrypt_string(plaintext).unwrap();
let decrypted = aes.decrypt_string(&encrypted).unwrap();
assert_eq!(plaintext, decrypted);
}
#[test]
fn test_rsa_encryption() {
let keypair = RsaKeyPair::generate(2048).unwrap();
let plaintext = "Hello, RSA!";
let encrypted = keypair.encrypt_string(plaintext).unwrap();
let decrypted = keypair.decrypt_string(&encrypted).unwrap();
assert_eq!(plaintext, decrypted);
}
#[test]
fn test_sha256_hash() {
let data = "test data";
let hash1 = sha256_hex(data);
let hash2 = sha256_hex(data);
assert_eq!(hash1, hash2);
assert_eq!(hash1.len(), 64); // SHA-256 产生 64 个十六进制字符
}
#[test]
fn test_key_generation() {
let key1 = generate_aes_key();
let key2 = generate_aes_key();
assert_ne!(key1, key2); // 密钥应该是随机的
assert_eq!(key1.len(), 32); // AES-256 密钥长度
}
}
验证工具 (validation.rs)
功能特性
- 邮箱验证
- 手机号验证
- URL 验证
- 密码强度验证
- 自定义验证规则
实现代码
use regex::Regex;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use thiserror::Error;
/// 验证错误类型
#[derive(Error, Debug, Clone, Serialize, Deserialize)]
pub enum ValidationError {
#[error("字段 '{field}' 是必需的")]
Required { field: String },
#[error("字段 '{field}' 长度必须在 {min} 到 {max} 之间")]
Length { field: String, min: usize, max: usize },
#[error("字段 '{field}' 格式无效: {message}")]
Format { field: String, message: String },
#[error("字段 '{field}' 值无效: {message}")]
Invalid { field: String, message: String },
#[error("自定义验证失败: {message}")]
Custom { message: String },
}
/// 验证结果
pub type ValidationResult<T> = Result<T, ValidationError>;
/// 验证器特征
pub trait Validator<T> {
fn validate(&self, value: &T) -> ValidationResult<()>;
}
/// 字符串验证器
#[derive(Debug, Clone)]
pub struct StringValidator {
pub field_name: String,
pub required: bool,
pub min_length: Option<usize>,
pub max_length: Option<usize>,
pub pattern: Option<Regex>,
pub custom_validators: Vec<Box<dyn Fn(&str) -> ValidationResult<()> + Send + Sync>>,
}
impl StringValidator {
pub fn new(field_name: &str) -> Self {
Self {
field_name: field_name.to_string(),
required: false,
min_length: None,
max_length: None,
pattern: None,
custom_validators: Vec::new(),
}
}
pub fn required(mut self) -> Self {
self.required = true;
self
}
pub fn min_length(mut self, min: usize) -> Self {
self.min_length = Some(min);
self
}
pub fn max_length(mut self, max: usize) -> Self {
self.max_length = Some(max);
self
}
pub fn length_range(mut self, min: usize, max: usize) -> Self {
self.min_length = Some(min);
self.max_length = Some(max);
self
}
pub fn pattern(mut self, pattern: &str) -> Result<Self, regex::Error> {
self.pattern = Some(Regex::new(pattern)?);
Ok(self)
}
pub fn email(self) -> Result<Self, regex::Error> {
self.pattern(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$")
}
pub fn phone(self) -> Result<Self, regex::Error> {
self.pattern(r"^1[3-9]\d{9}$")
}
pub fn url(self) -> Result<Self, regex::Error> {
self.pattern(r"^https?://[^\s/$.?#].[^\s]*$")
}
}
impl Validator<Option<String>> for StringValidator {
fn validate(&self, value: &Option<String>) -> ValidationResult<()> {
match value {
None => {
if self.required {
Err(ValidationError::Required {
field: self.field_name.clone(),
})
} else {
Ok(())
}
}
Some(s) => self.validate(s),
}
}
}
impl Validator<String> for StringValidator {
fn validate(&self, value: &String) -> ValidationResult<()> {
// 检查是否为空且必需
if value.is_empty() && self.required {
return Err(ValidationError::Required {
field: self.field_name.clone(),
});
}
// 如果为空且不是必需的,跳过其他验证
if value.is_empty() && !self.required {
return Ok(());
}
// 检查长度
if let Some(min) = self.min_length {
if value.len() < min {
return Err(ValidationError::Length {
field: self.field_name.clone(),
min,
max: self.max_length.unwrap_or(usize::MAX),
});
}
}
if let Some(max) = self.max_length {
if value.len() > max {
return Err(ValidationError::Length {
field: self.field_name.clone(),
min: self.min_length.unwrap_or(0),
max,
});
}
}
// 检查正则表达式
if let Some(pattern) = &self.pattern {
if !pattern.is_match(value) {
return Err(ValidationError::Format {
field: self.field_name.clone(),
message: "格式不匹配".to_string(),
});
}
}
// 执行自定义验证
for validator in &self.custom_validators {
validator(value)?;
}
Ok(())
}
}
/// 数字验证器
#[derive(Debug, Clone)]
pub struct NumberValidator<T> {
pub field_name: String,
pub required: bool,
pub min_value: Option<T>,
pub max_value: Option<T>,
}
impl<T> NumberValidator<T>
where
T: PartialOrd + Copy,
{
pub fn new(field_name: &str) -> Self {
Self {
field_name: field_name.to_string(),
required: false,
min_value: None,
max_value: None,
}
}
pub fn required(mut self) -> Self {
self.required = true;
self
}
pub fn min_value(mut self, min: T) -> Self {
self.min_value = Some(min);
self
}
pub fn max_value(mut self, max: T) -> Self {
self.max_value = Some(max);
self
}
pub fn range(mut self, min: T, max: T) -> Self {
self.min_value = Some(min);
self.max_value = Some(max);
self
}
}
impl<T> Validator<Option<T>> for NumberValidator<T>
where
T: PartialOrd + Copy + std::fmt::Display,
{
fn validate(&self, value: &Option<T>) -> ValidationResult<()> {
match value {
None => {
if self.required {
Err(ValidationError::Required {
field: self.field_name.clone(),
})
} else {
Ok(())
}
}
Some(v) => self.validate(v),
}
}
}
impl<T> Validator<T> for NumberValidator<T>
where
T: PartialOrd + Copy + std::fmt::Display,
{
fn validate(&self, value: &T) -> ValidationResult<()> {
if let Some(min) = self.min_value {
if *value < min {
return Err(ValidationError::Invalid {
field: self.field_name.clone(),
message: format!("值必须大于等于 {}", min),
});
}
}
if let Some(max) = self.max_value {
if *value > max {
return Err(ValidationError::Invalid {
field: self.field_name.clone(),
message: format!("值必须小于等于 {}", max),
});
}
}
Ok(())
}
}
/// 密码强度等级
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum PasswordStrength {
Weak,
Medium,
Strong,
VeryStrong,
}
/// 密码验证结果
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PasswordValidation {
pub is_valid: bool,
pub strength: PasswordStrength,
pub score: u8,
pub feedback: Vec<String>,
}
/// 验证邮箱地址
pub fn validate_email(email: &str) -> bool {
let email_regex = Regex::new(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$").unwrap();
email_regex.is_match(email)
}
/// 验证中国手机号
pub fn validate_phone(phone: &str) -> bool {
let phone_regex = Regex::new(r"^1[3-9]\d{9}$").unwrap();
phone_regex.is_match(phone)
}
/// 验证 URL
pub fn validate_url(url: &str) -> bool {
let url_regex = Regex::new(r"^https?://[^\s/$.?#].[^\s]*$").unwrap();
url_regex.is_match(url)
}
/// 验证 IP 地址
pub fn validate_ip(ip: &str) -> bool {
ip.parse::<std::net::IpAddr>().is_ok()
}
/// 验证密码强度
pub fn validate_password(password: &str) -> PasswordValidation {
let mut score = 0u8;
let mut feedback = Vec::new();
// 长度检查
if password.len() >= 8 {
score += 1;
} else {
feedback.push("密码长度至少需要8位".to_string());
}
if password.len() >= 12 {
score += 1;
}
// 包含小写字母
if password.chars().any(|c| c.is_ascii_lowercase()) {
score += 1;
} else {
feedback.push("密码应包含小写字母".to_string());
}
// 包含大写字母
if password.chars().any(|c| c.is_ascii_uppercase()) {
score += 1;
} else {
feedback.push("密码应包含大写字母".to_string());
}
// 包含数字
if password.chars().any(|c| c.is_ascii_digit()) {
score += 1;
} else {
feedback.push("密码应包含数字".to_string());
}
// 包含特殊字符
if password.chars().any(|c| "!@#$%^&*()_+-=[]{}|;:,.<>?".contains(c)) {
score += 1;
} else {
feedback.push("密码应包含特殊字符".to_string());
}
// 不包含常见弱密码
let weak_passwords = [
"password", "123456", "123456789", "qwerty", "abc123",
"password123", "admin", "root", "user", "guest",
];
if weak_passwords.iter().any(|&weak| password.to_lowercase().contains(weak)) {
score = score.saturating_sub(2);
feedback.push("密码不能包含常见弱密码".to_string());
}
let strength = match score {
0..=2 => PasswordStrength::Weak,
3..=4 => PasswordStrength::Medium,
5..=6 => PasswordStrength::Strong,
_ => PasswordStrength::VeryStrong,
};
let is_valid = score >= 3 && password.len() >= 8;
if is_valid && feedback.is_empty() {
feedback.push("密码强度良好".to_string());
}
PasswordValidation {
is_valid,
strength,
score,
feedback,
}
}
/// 验证 JSON 格式
pub fn validate_json(json_str: &str) -> bool {
serde_json::from_str::<serde_json::Value>(json_str).is_ok()
}
/// 验证 UUID 格式
pub fn validate_uuid(uuid_str: &str) -> bool {
uuid::Uuid::parse_str(uuid_str).is_ok()
}
/// 批量验证器
#[derive(Debug)]
pub struct BatchValidator {
errors: Vec<ValidationError>,
}
impl BatchValidator {
pub fn new() -> Self {
Self {
errors: Vec::new(),
}
}
pub fn validate<T, V>(&mut self, validator: &V, value: &T) -> &mut Self
where
V: Validator<T>,
{
if let Err(error) = validator.validate(value) {
self.errors.push(error);
}
self
}
pub fn is_valid(&self) -> bool {
self.errors.is_empty()
}
pub fn errors(&self) -> &[ValidationError] {
&self.errors
}
pub fn into_result(self) -> ValidationResult<()> {
if self.errors.is_empty() {
Ok(())
} else {
// 返回第一个错误,实际使用中可能需要返回所有错误
Err(self.errors.into_iter().next().unwrap())
}
}
}
impl Default for BatchValidator {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_email_validation() {
assert!(validate_email("test@example.com"));
assert!(validate_email("user.name+tag@domain.co.uk"));
assert!(!validate_email("invalid.email"));
assert!(!validate_email("@domain.com"));
}
#[test]
fn test_phone_validation() {
assert!(validate_phone("13812345678"));
assert!(validate_phone("15987654321"));
assert!(!validate_phone("12345678901"));
assert!(!validate_phone("1381234567"));
}
#[test]
fn test_password_validation() {
let weak = validate_password("123");
assert_eq!(weak.strength, PasswordStrength::Weak);
assert!(!weak.is_valid);
let strong = validate_password("MyStr0ng!Pass");
assert!(strong.is_valid);
assert!(matches!(strong.strength, PasswordStrength::Strong | PasswordStrength::VeryStrong));
}
#[test]
fn test_string_validator() {
let validator = StringValidator::new("username")
.required()
.length_range(3, 20)
.pattern(r"^[a-zA-Z0-9_]+$")
.unwrap();
assert!(validator.validate(&"valid_user123".to_string()).is_ok());
assert!(validator.validate(&"ab".to_string()).is_err()); // 太短
assert!(validator.validate(&"invalid-user".to_string()).is_err()); // 包含非法字符
}
#[test]
fn test_number_validator() {
let validator = NumberValidator::new("age")
.required()
.range(0, 150);
assert!(validator.validate(&25).is_ok());
assert!(validator.validate(&-1).is_err()); // 小于最小值
assert!(validator.validate(&200).is_err()); // 大于最大值
}
#[test]
fn test_batch_validator() {
let mut batch = BatchValidator::new();
let email_validator = StringValidator::new("email").required().email().unwrap();
let age_validator = NumberValidator::new("age").required().range(0, 150);
batch
.validate(&email_validator, &"invalid.email".to_string())
.validate(&age_validator, &200);
assert!(!batch.is_valid());
assert_eq!(batch.errors().len(), 2);
}
}