新增以下文档文件: - PROJECT_OVERVIEW.md 项目总览文档 - BACKEND_ARCHITECTURE.md 后端架构文档 - FRONTEND_ARCHITECTURE.md 前端架构文档 - FLOW_ENGINE.md 流程引擎文档 - SERVICES.md 服务层文档 - ERROR_HANDLING.md 错误处理模块文档 文档内容涵盖项目整体介绍、技术架构、核心模块设计和实现细节
1271 lines
34 KiB
Markdown
1271 lines
34 KiB
Markdown
# UdminAI 工具函数模块文档
|
||
|
||
## 概述
|
||
|
||
UdminAI 项目的工具函数模块(Utils)提供了一系列通用的工具函数和辅助功能,支持系统的各个组件。这些工具函数涵盖了 ID 生成、时间处理、加密解密、验证、格式化、文件操作等多个方面,为整个系统提供了基础的功能支持。
|
||
|
||
## 模块结构
|
||
|
||
```
|
||
backend/src/utils/
|
||
├── mod.rs # 模块导出
|
||
├── id.rs # ID 生成工具
|
||
├── time.rs # 时间处理工具
|
||
├── crypto.rs # 加密解密工具
|
||
├── validation.rs # 验证工具
|
||
├── format.rs # 格式化工具
|
||
├── file.rs # 文件操作工具
|
||
├── hash.rs # 哈希工具
|
||
├── jwt.rs # JWT 工具
|
||
├── password.rs # 密码处理工具
|
||
├── email.rs # 邮件工具
|
||
├── config.rs # 配置工具
|
||
└── error.rs # 错误处理工具
|
||
```
|
||
|
||
## 核心工具模块
|
||
|
||
### ID 生成工具 (id.rs)
|
||
|
||
#### 功能特性
|
||
|
||
- 唯一 ID 生成
|
||
- 多种 ID 格式支持
|
||
- 高性能生成
|
||
- 分布式友好
|
||
|
||
#### 实现代码
|
||
|
||
```rust
|
||
use chrono::{DateTime, Utc};
|
||
use rand::{thread_rng, Rng};
|
||
use std::sync::atomic::{AtomicU64, Ordering};
|
||
use uuid::Uuid;
|
||
|
||
/// 全局序列号计数器
|
||
static SEQUENCE_COUNTER: AtomicU64 = AtomicU64::new(0);
|
||
|
||
/// ID 生成器类型
|
||
#[derive(Debug, Clone)]
|
||
pub enum IdType {
|
||
/// UUID v4
|
||
Uuid,
|
||
/// 雪花算法 ID
|
||
Snowflake,
|
||
/// 时间戳 + 随机数
|
||
Timestamp,
|
||
/// 自定义前缀 + UUID
|
||
Prefixed(String),
|
||
}
|
||
|
||
/// 生成唯一 ID
|
||
pub fn generate_id() -> String {
|
||
generate_id_with_type(IdType::Uuid)
|
||
}
|
||
|
||
/// 根据类型生成 ID
|
||
pub fn generate_id_with_type(id_type: IdType) -> String {
|
||
match id_type {
|
||
IdType::Uuid => Uuid::new_v4().to_string(),
|
||
IdType::Snowflake => generate_snowflake_id(),
|
||
IdType::Timestamp => generate_timestamp_id(),
|
||
IdType::Prefixed(prefix) => format!("{}-{}", prefix, Uuid::new_v4()),
|
||
}
|
||
}
|
||
|
||
/// 生成雪花算法 ID
|
||
fn generate_snowflake_id() -> String {
|
||
let timestamp = Utc::now().timestamp_millis() as u64;
|
||
let sequence = SEQUENCE_COUNTER.fetch_add(1, Ordering::SeqCst) & 0xFFF; // 12位序列号
|
||
let machine_id = 1u64; // 机器ID,实际使用时应该从配置获取
|
||
|
||
let id = (timestamp << 22) | (machine_id << 12) | sequence;
|
||
id.to_string()
|
||
}
|
||
|
||
/// 生成时间戳 ID
|
||
fn generate_timestamp_id() -> String {
|
||
let timestamp = Utc::now().timestamp_millis();
|
||
let random: u32 = thread_rng().gen();
|
||
format!("{}{:08x}", timestamp, random)
|
||
}
|
||
|
||
/// 生成短 ID(8位)
|
||
pub fn generate_short_id() -> String {
|
||
let chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";
|
||
let mut rng = thread_rng();
|
||
(0..8)
|
||
.map(|_| {
|
||
let idx = rng.gen_range(0..chars.len());
|
||
chars.chars().nth(idx).unwrap()
|
||
})
|
||
.collect()
|
||
}
|
||
|
||
/// 生成数字 ID
|
||
pub fn generate_numeric_id(length: usize) -> String {
|
||
let mut rng = thread_rng();
|
||
(0..length)
|
||
.map(|_| rng.gen_range(0..10).to_string())
|
||
.collect()
|
||
}
|
||
|
||
/// 验证 ID 格式
|
||
pub fn validate_id(id: &str, id_type: IdType) -> bool {
|
||
match id_type {
|
||
IdType::Uuid => Uuid::parse_str(id).is_ok(),
|
||
IdType::Snowflake => id.parse::<u64>().is_ok(),
|
||
IdType::Timestamp => id.len() >= 13 && id[..13].parse::<i64>().is_ok(),
|
||
IdType::Prefixed(prefix) => {
|
||
id.starts_with(&format!("{}-", prefix)) &&
|
||
id.len() > prefix.len() + 1 &&
|
||
Uuid::parse_str(&id[prefix.len() + 1..]).is_ok()
|
||
}
|
||
}
|
||
}
|
||
|
||
/// 从雪花 ID 提取时间戳
|
||
pub fn extract_timestamp_from_snowflake(id: &str) -> Option<DateTime<Utc>> {
|
||
if let Ok(snowflake_id) = id.parse::<u64>() {
|
||
let timestamp = (snowflake_id >> 22) as i64;
|
||
DateTime::from_timestamp_millis(timestamp)
|
||
} else {
|
||
None
|
||
}
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
|
||
#[test]
|
||
fn test_generate_uuid() {
|
||
let id = generate_id();
|
||
assert!(validate_id(&id, IdType::Uuid));
|
||
}
|
||
|
||
#[test]
|
||
fn test_generate_snowflake() {
|
||
let id = generate_id_with_type(IdType::Snowflake);
|
||
assert!(validate_id(&id, IdType::Snowflake));
|
||
}
|
||
|
||
#[test]
|
||
fn test_generate_prefixed_id() {
|
||
let id = generate_id_with_type(IdType::Prefixed("user".to_string()));
|
||
assert!(id.starts_with("user-"));
|
||
assert!(validate_id(&id, IdType::Prefixed("user".to_string())));
|
||
}
|
||
|
||
#[test]
|
||
fn test_short_id() {
|
||
let id = generate_short_id();
|
||
assert_eq!(id.len(), 8);
|
||
}
|
||
}
|
||
```
|
||
|
||
### 时间处理工具 (time.rs)
|
||
|
||
#### 功能特性
|
||
|
||
- 时间格式化
|
||
- 时区转换
|
||
- 时间计算
|
||
- 相对时间
|
||
|
||
#### 实现代码
|
||
|
||
```rust
|
||
use chrono::{
|
||
DateTime, Duration, FixedOffset, Local, NaiveDateTime, TimeZone, Utc,
|
||
};
|
||
use serde::{Deserialize, Serialize};
|
||
use std::fmt;
|
||
|
||
/// 时间格式常量
|
||
pub const DEFAULT_DATETIME_FORMAT: &str = "%Y-%m-%d %H:%M:%S";
|
||
pub const ISO_DATETIME_FORMAT: &str = "%Y-%m-%dT%H:%M:%S%.3fZ";
|
||
pub const DATE_FORMAT: &str = "%Y-%m-%d";
|
||
pub const TIME_FORMAT: &str = "%H:%M:%S";
|
||
|
||
/// 时区偏移量(东八区)
|
||
pub const CHINA_OFFSET: i32 = 8 * 3600;
|
||
|
||
/// 获取当前时间(UTC)
|
||
pub fn now_utc() -> DateTime<Utc> {
|
||
Utc::now()
|
||
}
|
||
|
||
/// 获取当前时间(带固定偏移量)
|
||
pub fn now_fixed_offset() -> DateTime<FixedOffset> {
|
||
Utc::now().with_timezone(&FixedOffset::east_opt(0).unwrap())
|
||
}
|
||
|
||
/// 获取当前时间(中国时区)
|
||
pub fn now_china() -> DateTime<FixedOffset> {
|
||
Utc::now().with_timezone(&FixedOffset::east_opt(CHINA_OFFSET).unwrap())
|
||
}
|
||
|
||
/// 获取当前本地时间
|
||
pub fn now_local() -> DateTime<Local> {
|
||
Local::now()
|
||
}
|
||
|
||
/// 格式化时间
|
||
pub fn format_datetime(dt: &DateTime<Utc>, format: &str) -> String {
|
||
dt.format(format).to_string()
|
||
}
|
||
|
||
/// 格式化时间(默认格式)
|
||
pub fn format_datetime_default(dt: &DateTime<Utc>) -> String {
|
||
format_datetime(dt, DEFAULT_DATETIME_FORMAT)
|
||
}
|
||
|
||
/// 格式化时间(ISO 格式)
|
||
pub fn format_datetime_iso(dt: &DateTime<Utc>) -> String {
|
||
format_datetime(dt, ISO_DATETIME_FORMAT)
|
||
}
|
||
|
||
/// 解析时间字符串
|
||
pub fn parse_datetime(s: &str, format: &str) -> Result<DateTime<Utc>, chrono::ParseError> {
|
||
let naive = NaiveDateTime::parse_from_str(s, format)?;
|
||
Ok(Utc.from_utc_datetime(&naive))
|
||
}
|
||
|
||
/// 解析 ISO 时间字符串
|
||
pub fn parse_datetime_iso(s: &str) -> Result<DateTime<Utc>, chrono::ParseError> {
|
||
DateTime::parse_from_rfc3339(s).map(|dt| dt.with_timezone(&Utc))
|
||
}
|
||
|
||
/// 时间戳转换为 DateTime
|
||
pub fn timestamp_to_datetime(timestamp: i64) -> Option<DateTime<Utc>> {
|
||
DateTime::from_timestamp(timestamp, 0)
|
||
}
|
||
|
||
/// 毫秒时间戳转换为 DateTime
|
||
pub fn timestamp_millis_to_datetime(timestamp: i64) -> Option<DateTime<Utc>> {
|
||
DateTime::from_timestamp_millis(timestamp)
|
||
}
|
||
|
||
/// DateTime 转换为时间戳
|
||
pub fn datetime_to_timestamp(dt: &DateTime<Utc>) -> i64 {
|
||
dt.timestamp()
|
||
}
|
||
|
||
/// DateTime 转换为毫秒时间戳
|
||
pub fn datetime_to_timestamp_millis(dt: &DateTime<Utc>) -> i64 {
|
||
dt.timestamp_millis()
|
||
}
|
||
|
||
/// 计算时间差
|
||
pub fn time_diff(start: &DateTime<Utc>, end: &DateTime<Utc>) -> Duration {
|
||
*end - *start
|
||
}
|
||
|
||
/// 计算时间差(秒)
|
||
pub fn time_diff_seconds(start: &DateTime<Utc>, end: &DateTime<Utc>) -> i64 {
|
||
time_diff(start, end).num_seconds()
|
||
}
|
||
|
||
/// 计算时间差(毫秒)
|
||
pub fn time_diff_millis(start: &DateTime<Utc>, end: &DateTime<Utc>) -> i64 {
|
||
time_diff(start, end).num_milliseconds()
|
||
}
|
||
|
||
/// 时间加法
|
||
pub fn add_duration(dt: &DateTime<Utc>, duration: Duration) -> DateTime<Utc> {
|
||
*dt + duration
|
||
}
|
||
|
||
/// 时间减法
|
||
pub fn sub_duration(dt: &DateTime<Utc>, duration: Duration) -> DateTime<Utc> {
|
||
*dt - duration
|
||
}
|
||
|
||
/// 添加秒数
|
||
pub fn add_seconds(dt: &DateTime<Utc>, seconds: i64) -> DateTime<Utc> {
|
||
add_duration(dt, Duration::seconds(seconds))
|
||
}
|
||
|
||
/// 添加分钟
|
||
pub fn add_minutes(dt: &DateTime<Utc>, minutes: i64) -> DateTime<Utc> {
|
||
add_duration(dt, Duration::minutes(minutes))
|
||
}
|
||
|
||
/// 添加小时
|
||
pub fn add_hours(dt: &DateTime<Utc>, hours: i64) -> DateTime<Utc> {
|
||
add_duration(dt, Duration::hours(hours))
|
||
}
|
||
|
||
/// 添加天数
|
||
pub fn add_days(dt: &DateTime<Utc>, days: i64) -> DateTime<Utc> {
|
||
add_duration(dt, Duration::days(days))
|
||
}
|
||
|
||
/// 获取今天开始时间
|
||
pub fn today_start() -> DateTime<Utc> {
|
||
let now = now_utc();
|
||
now.date_naive().and_hms_opt(0, 0, 0).unwrap().and_utc()
|
||
}
|
||
|
||
/// 获取今天结束时间
|
||
pub fn today_end() -> DateTime<Utc> {
|
||
let now = now_utc();
|
||
now.date_naive().and_hms_opt(23, 59, 59).unwrap().and_utc()
|
||
}
|
||
|
||
/// 获取本周开始时间(周一)
|
||
pub fn week_start() -> DateTime<Utc> {
|
||
let now = now_utc();
|
||
let weekday = now.weekday().num_days_from_monday();
|
||
let start_date = now.date_naive() - Duration::days(weekday as i64);
|
||
start_date.and_hms_opt(0, 0, 0).unwrap().and_utc()
|
||
}
|
||
|
||
/// 获取本月开始时间
|
||
pub fn month_start() -> DateTime<Utc> {
|
||
let now = now_utc();
|
||
let start_date = now.date_naive().with_day(1).unwrap();
|
||
start_date.and_hms_opt(0, 0, 0).unwrap().and_utc()
|
||
}
|
||
|
||
/// 相对时间描述
|
||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
pub struct RelativeTime {
|
||
pub value: i64,
|
||
pub unit: TimeUnit,
|
||
pub description: String,
|
||
}
|
||
|
||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
pub enum TimeUnit {
|
||
Second,
|
||
Minute,
|
||
Hour,
|
||
Day,
|
||
Week,
|
||
Month,
|
||
Year,
|
||
}
|
||
|
||
impl fmt::Display for TimeUnit {
|
||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||
let s = match self {
|
||
TimeUnit::Second => "秒",
|
||
TimeUnit::Minute => "分钟",
|
||
TimeUnit::Hour => "小时",
|
||
TimeUnit::Day => "天",
|
||
TimeUnit::Week => "周",
|
||
TimeUnit::Month => "月",
|
||
TimeUnit::Year => "年",
|
||
};
|
||
write!(f, "{}", s)
|
||
}
|
||
}
|
||
|
||
/// 计算相对时间
|
||
pub fn relative_time(dt: &DateTime<Utc>) -> RelativeTime {
|
||
let now = now_utc();
|
||
let diff = time_diff(dt, &now);
|
||
|
||
let abs_seconds = diff.num_seconds().abs();
|
||
let is_past = diff.num_seconds() < 0;
|
||
|
||
let (value, unit) = if abs_seconds < 60 {
|
||
(abs_seconds, TimeUnit::Second)
|
||
} else if abs_seconds < 3600 {
|
||
(abs_seconds / 60, TimeUnit::Minute)
|
||
} else if abs_seconds < 86400 {
|
||
(abs_seconds / 3600, TimeUnit::Hour)
|
||
} else if abs_seconds < 604800 {
|
||
(abs_seconds / 86400, TimeUnit::Day)
|
||
} else if abs_seconds < 2592000 {
|
||
(abs_seconds / 604800, TimeUnit::Week)
|
||
} else if abs_seconds < 31536000 {
|
||
(abs_seconds / 2592000, TimeUnit::Month)
|
||
} else {
|
||
(abs_seconds / 31536000, TimeUnit::Year)
|
||
};
|
||
|
||
let description = if is_past {
|
||
format!("{}{} 前", value, unit)
|
||
} else {
|
||
format!("{}{} 后", value, unit)
|
||
};
|
||
|
||
RelativeTime {
|
||
value,
|
||
unit,
|
||
description,
|
||
}
|
||
}
|
||
|
||
/// 判断是否为同一天
|
||
pub fn is_same_day(dt1: &DateTime<Utc>, dt2: &DateTime<Utc>) -> bool {
|
||
dt1.date_naive() == dt2.date_naive()
|
||
}
|
||
|
||
/// 判断是否为今天
|
||
pub fn is_today(dt: &DateTime<Utc>) -> bool {
|
||
is_same_day(dt, &now_utc())
|
||
}
|
||
|
||
/// 判断是否为昨天
|
||
pub fn is_yesterday(dt: &DateTime<Utc>) -> bool {
|
||
let yesterday = sub_duration(&now_utc(), Duration::days(1));
|
||
is_same_day(dt, &yesterday)
|
||
}
|
||
|
||
/// 判断是否为本周
|
||
pub fn is_this_week(dt: &DateTime<Utc>) -> bool {
|
||
let week_start = week_start();
|
||
let week_end = add_days(&week_start, 7);
|
||
dt >= &week_start && dt < &week_end
|
||
}
|
||
|
||
/// 判断是否为本月
|
||
pub fn is_this_month(dt: &DateTime<Utc>) -> bool {
|
||
let now = now_utc();
|
||
dt.year() == now.year() && dt.month() == now.month()
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
|
||
#[test]
|
||
fn test_time_formatting() {
|
||
let dt = Utc::now();
|
||
let formatted = format_datetime_default(&dt);
|
||
assert!(!formatted.is_empty());
|
||
}
|
||
|
||
#[test]
|
||
fn test_time_parsing() {
|
||
let time_str = "2023-12-25 10:30:00";
|
||
let parsed = parse_datetime(time_str, DEFAULT_DATETIME_FORMAT);
|
||
assert!(parsed.is_ok());
|
||
}
|
||
|
||
#[test]
|
||
fn test_timestamp_conversion() {
|
||
let dt = Utc::now();
|
||
let timestamp = datetime_to_timestamp(&dt);
|
||
let converted = timestamp_to_datetime(timestamp).unwrap();
|
||
assert_eq!(dt.timestamp(), converted.timestamp());
|
||
}
|
||
|
||
#[test]
|
||
fn test_relative_time() {
|
||
let past = sub_duration(&now_utc(), Duration::hours(2));
|
||
let rel_time = relative_time(&past);
|
||
assert!(rel_time.description.contains("前"));
|
||
}
|
||
|
||
#[test]
|
||
fn test_date_checks() {
|
||
let now = now_utc();
|
||
assert!(is_today(&now));
|
||
|
||
let yesterday = sub_duration(&now, Duration::days(1));
|
||
assert!(is_yesterday(&yesterday));
|
||
}
|
||
}
|
||
```
|
||
|
||
### 加密解密工具 (crypto.rs)
|
||
|
||
#### 功能特性
|
||
|
||
- AES 加密解密
|
||
- RSA 加密解密
|
||
- 数字签名
|
||
- 密钥生成
|
||
|
||
#### 实现代码
|
||
|
||
```rust
|
||
use aes_gcm::{
|
||
aead::{Aead, KeyInit, OsRng},
|
||
Aes256Gcm, Key, Nonce,
|
||
};
|
||
use base64::{engine::general_purpose, Engine as _};
|
||
use rand::{thread_rng, RngCore};
|
||
use rsa::{
|
||
pkcs1::{DecodeRsaPrivateKey, DecodeRsaPublicKey, EncodeRsaPrivateKey, EncodeRsaPublicKey},
|
||
pkcs8::{DecodePrivateKey, DecodePublicKey, EncodePrivateKey, EncodePublicKey},
|
||
Pkcs1v15Encrypt, RsaPrivateKey, RsaPublicKey,
|
||
};
|
||
use sha2::{Digest, Sha256};
|
||
use std::error::Error;
|
||
use thiserror::Error;
|
||
|
||
/// 加密错误类型
|
||
#[derive(Error, Debug)]
|
||
pub enum CryptoError {
|
||
#[error("加密失败: {0}")]
|
||
EncryptionFailed(String),
|
||
#[error("解密失败: {0}")]
|
||
DecryptionFailed(String),
|
||
#[error("密钥生成失败: {0}")]
|
||
KeyGenerationFailed(String),
|
||
#[error("签名失败: {0}")]
|
||
SigningFailed(String),
|
||
#[error("验证失败: {0}")]
|
||
VerificationFailed(String),
|
||
#[error("编码失败: {0}")]
|
||
EncodingFailed(String),
|
||
}
|
||
|
||
/// AES 加密工具
|
||
pub struct AesEncryption {
|
||
cipher: Aes256Gcm,
|
||
}
|
||
|
||
impl AesEncryption {
|
||
/// 创建新的 AES 加密器
|
||
pub fn new(key: &[u8; 32]) -> Self {
|
||
let key = Key::<Aes256Gcm>::from_slice(key);
|
||
let cipher = Aes256Gcm::new(key);
|
||
Self { cipher }
|
||
}
|
||
|
||
/// 从密码创建 AES 加密器
|
||
pub fn from_password(password: &str) -> Self {
|
||
let key = Self::derive_key_from_password(password);
|
||
Self::new(&key)
|
||
}
|
||
|
||
/// 从密码派生密钥
|
||
fn derive_key_from_password(password: &str) -> [u8; 32] {
|
||
let mut hasher = Sha256::new();
|
||
hasher.update(password.as_bytes());
|
||
hasher.finalize().into()
|
||
}
|
||
|
||
/// 加密数据
|
||
pub fn encrypt(&self, plaintext: &[u8]) -> Result<Vec<u8>, CryptoError> {
|
||
let mut nonce_bytes = [0u8; 12];
|
||
thread_rng().fill_bytes(&mut nonce_bytes);
|
||
let nonce = Nonce::from_slice(&nonce_bytes);
|
||
|
||
let ciphertext = self
|
||
.cipher
|
||
.encrypt(nonce, plaintext)
|
||
.map_err(|e| CryptoError::EncryptionFailed(e.to_string()))?;
|
||
|
||
// 将 nonce 和密文组合
|
||
let mut result = Vec::with_capacity(12 + ciphertext.len());
|
||
result.extend_from_slice(&nonce_bytes);
|
||
result.extend_from_slice(&ciphertext);
|
||
|
||
Ok(result)
|
||
}
|
||
|
||
/// 解密数据
|
||
pub fn decrypt(&self, ciphertext: &[u8]) -> Result<Vec<u8>, CryptoError> {
|
||
if ciphertext.len() < 12 {
|
||
return Err(CryptoError::DecryptionFailed(
|
||
"密文长度不足".to_string(),
|
||
));
|
||
}
|
||
|
||
let (nonce_bytes, encrypted_data) = ciphertext.split_at(12);
|
||
let nonce = Nonce::from_slice(nonce_bytes);
|
||
|
||
self.cipher
|
||
.decrypt(nonce, encrypted_data)
|
||
.map_err(|e| CryptoError::DecryptionFailed(e.to_string()))
|
||
}
|
||
|
||
/// 加密字符串并返回 Base64 编码
|
||
pub fn encrypt_string(&self, plaintext: &str) -> Result<String, CryptoError> {
|
||
let encrypted = self.encrypt(plaintext.as_bytes())?;
|
||
Ok(general_purpose::STANDARD.encode(encrypted))
|
||
}
|
||
|
||
/// 解密 Base64 编码的字符串
|
||
pub fn decrypt_string(&self, ciphertext: &str) -> Result<String, CryptoError> {
|
||
let decoded = general_purpose::STANDARD
|
||
.decode(ciphertext)
|
||
.map_err(|e| CryptoError::DecryptionFailed(e.to_string()))?;
|
||
|
||
let decrypted = self.decrypt(&decoded)?;
|
||
String::from_utf8(decrypted)
|
||
.map_err(|e| CryptoError::DecryptionFailed(e.to_string()))
|
||
}
|
||
}
|
||
|
||
/// RSA 密钥对
|
||
#[derive(Debug, Clone)]
|
||
pub struct RsaKeyPair {
|
||
pub private_key: RsaPrivateKey,
|
||
pub public_key: RsaPublicKey,
|
||
}
|
||
|
||
impl RsaKeyPair {
|
||
/// 生成新的 RSA 密钥对
|
||
pub fn generate(bits: usize) -> Result<Self, CryptoError> {
|
||
let mut rng = OsRng;
|
||
let private_key = RsaPrivateKey::new(&mut rng, bits)
|
||
.map_err(|e| CryptoError::KeyGenerationFailed(e.to_string()))?;
|
||
let public_key = RsaPublicKey::from(&private_key);
|
||
|
||
Ok(Self {
|
||
private_key,
|
||
public_key,
|
||
})
|
||
}
|
||
|
||
/// 从 PEM 格式加载私钥
|
||
pub fn from_private_key_pem(pem: &str) -> Result<Self, CryptoError> {
|
||
let private_key = RsaPrivateKey::from_pkcs8_pem(pem)
|
||
.map_err(|e| CryptoError::KeyGenerationFailed(e.to_string()))?;
|
||
let public_key = RsaPublicKey::from(&private_key);
|
||
|
||
Ok(Self {
|
||
private_key,
|
||
public_key,
|
||
})
|
||
}
|
||
|
||
/// 导出私钥为 PEM 格式
|
||
pub fn private_key_to_pem(&self) -> Result<String, CryptoError> {
|
||
self.private_key
|
||
.to_pkcs8_pem(rsa::pkcs8::LineEnding::LF)
|
||
.map_err(|e| CryptoError::EncodingFailed(e.to_string()))
|
||
.map(|s| s.to_string())
|
||
}
|
||
|
||
/// 导出公钥为 PEM 格式
|
||
pub fn public_key_to_pem(&self) -> Result<String, CryptoError> {
|
||
self.public_key
|
||
.to_public_key_pem(rsa::pkcs8::LineEnding::LF)
|
||
.map_err(|e| CryptoError::EncodingFailed(e.to_string()))
|
||
}
|
||
|
||
/// 使用公钥加密
|
||
pub fn encrypt(&self, plaintext: &[u8]) -> Result<Vec<u8>, CryptoError> {
|
||
let mut rng = OsRng;
|
||
self.public_key
|
||
.encrypt(&mut rng, Pkcs1v15Encrypt, plaintext)
|
||
.map_err(|e| CryptoError::EncryptionFailed(e.to_string()))
|
||
}
|
||
|
||
/// 使用私钥解密
|
||
pub fn decrypt(&self, ciphertext: &[u8]) -> Result<Vec<u8>, CryptoError> {
|
||
self.private_key
|
||
.decrypt(Pkcs1v15Encrypt, ciphertext)
|
||
.map_err(|e| CryptoError::DecryptionFailed(e.to_string()))
|
||
}
|
||
|
||
/// 加密字符串并返回 Base64 编码
|
||
pub fn encrypt_string(&self, plaintext: &str) -> Result<String, CryptoError> {
|
||
let encrypted = self.encrypt(plaintext.as_bytes())?;
|
||
Ok(general_purpose::STANDARD.encode(encrypted))
|
||
}
|
||
|
||
/// 解密 Base64 编码的字符串
|
||
pub fn decrypt_string(&self, ciphertext: &str) -> Result<String, CryptoError> {
|
||
let decoded = general_purpose::STANDARD
|
||
.decode(ciphertext)
|
||
.map_err(|e| CryptoError::DecryptionFailed(e.to_string()))?;
|
||
|
||
let decrypted = self.decrypt(&decoded)?;
|
||
String::from_utf8(decrypted)
|
||
.map_err(|e| CryptoError::DecryptionFailed(e.to_string()))
|
||
}
|
||
}
|
||
|
||
/// 生成随机密钥
|
||
pub fn generate_random_key(length: usize) -> Vec<u8> {
|
||
let mut key = vec![0u8; length];
|
||
thread_rng().fill_bytes(&mut key);
|
||
key
|
||
}
|
||
|
||
/// 生成 AES-256 密钥
|
||
pub fn generate_aes_key() -> [u8; 32] {
|
||
let mut key = [0u8; 32];
|
||
thread_rng().fill_bytes(&mut key);
|
||
key
|
||
}
|
||
|
||
/// 计算 SHA-256 哈希
|
||
pub fn sha256_hash(data: &[u8]) -> Vec<u8> {
|
||
let mut hasher = Sha256::new();
|
||
hasher.update(data);
|
||
hasher.finalize().to_vec()
|
||
}
|
||
|
||
/// 计算字符串的 SHA-256 哈希并返回十六进制字符串
|
||
pub fn sha256_hex(data: &str) -> String {
|
||
let hash = sha256_hash(data.as_bytes());
|
||
hex::encode(hash)
|
||
}
|
||
|
||
/// 验证哈希
|
||
pub fn verify_hash(data: &[u8], expected_hash: &[u8]) -> bool {
|
||
let actual_hash = sha256_hash(data);
|
||
actual_hash == expected_hash
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
|
||
#[test]
|
||
fn test_aes_encryption() {
|
||
let key = generate_aes_key();
|
||
let aes = AesEncryption::new(&key);
|
||
|
||
let plaintext = "Hello, World!";
|
||
let encrypted = aes.encrypt_string(plaintext).unwrap();
|
||
let decrypted = aes.decrypt_string(&encrypted).unwrap();
|
||
|
||
assert_eq!(plaintext, decrypted);
|
||
}
|
||
|
||
#[test]
|
||
fn test_rsa_encryption() {
|
||
let keypair = RsaKeyPair::generate(2048).unwrap();
|
||
|
||
let plaintext = "Hello, RSA!";
|
||
let encrypted = keypair.encrypt_string(plaintext).unwrap();
|
||
let decrypted = keypair.decrypt_string(&encrypted).unwrap();
|
||
|
||
assert_eq!(plaintext, decrypted);
|
||
}
|
||
|
||
#[test]
|
||
fn test_sha256_hash() {
|
||
let data = "test data";
|
||
let hash1 = sha256_hex(data);
|
||
let hash2 = sha256_hex(data);
|
||
|
||
assert_eq!(hash1, hash2);
|
||
assert_eq!(hash1.len(), 64); // SHA-256 产生 64 个十六进制字符
|
||
}
|
||
|
||
#[test]
|
||
fn test_key_generation() {
|
||
let key1 = generate_aes_key();
|
||
let key2 = generate_aes_key();
|
||
|
||
assert_ne!(key1, key2); // 密钥应该是随机的
|
||
assert_eq!(key1.len(), 32); // AES-256 密钥长度
|
||
}
|
||
}
|
||
```
|
||
|
||
### 验证工具 (validation.rs)
|
||
|
||
#### 功能特性
|
||
|
||
- 邮箱验证
|
||
- 手机号验证
|
||
- URL 验证
|
||
- 密码强度验证
|
||
- 自定义验证规则
|
||
|
||
#### 实现代码
|
||
|
||
```rust
|
||
use regex::Regex;
|
||
use serde::{Deserialize, Serialize};
|
||
use std::collections::HashMap;
|
||
use thiserror::Error;
|
||
|
||
/// 验证错误类型
|
||
#[derive(Error, Debug, Clone, Serialize, Deserialize)]
|
||
pub enum ValidationError {
|
||
#[error("字段 '{field}' 是必需的")]
|
||
Required { field: String },
|
||
#[error("字段 '{field}' 长度必须在 {min} 到 {max} 之间")]
|
||
Length { field: String, min: usize, max: usize },
|
||
#[error("字段 '{field}' 格式无效: {message}")]
|
||
Format { field: String, message: String },
|
||
#[error("字段 '{field}' 值无效: {message}")]
|
||
Invalid { field: String, message: String },
|
||
#[error("自定义验证失败: {message}")]
|
||
Custom { message: String },
|
||
}
|
||
|
||
/// 验证结果
|
||
pub type ValidationResult<T> = Result<T, ValidationError>;
|
||
|
||
/// 验证器特征
|
||
pub trait Validator<T> {
|
||
fn validate(&self, value: &T) -> ValidationResult<()>;
|
||
}
|
||
|
||
/// 字符串验证器
|
||
#[derive(Debug, Clone)]
|
||
pub struct StringValidator {
|
||
pub field_name: String,
|
||
pub required: bool,
|
||
pub min_length: Option<usize>,
|
||
pub max_length: Option<usize>,
|
||
pub pattern: Option<Regex>,
|
||
pub custom_validators: Vec<Box<dyn Fn(&str) -> ValidationResult<()> + Send + Sync>>,
|
||
}
|
||
|
||
impl StringValidator {
|
||
pub fn new(field_name: &str) -> Self {
|
||
Self {
|
||
field_name: field_name.to_string(),
|
||
required: false,
|
||
min_length: None,
|
||
max_length: None,
|
||
pattern: None,
|
||
custom_validators: Vec::new(),
|
||
}
|
||
}
|
||
|
||
pub fn required(mut self) -> Self {
|
||
self.required = true;
|
||
self
|
||
}
|
||
|
||
pub fn min_length(mut self, min: usize) -> Self {
|
||
self.min_length = Some(min);
|
||
self
|
||
}
|
||
|
||
pub fn max_length(mut self, max: usize) -> Self {
|
||
self.max_length = Some(max);
|
||
self
|
||
}
|
||
|
||
pub fn length_range(mut self, min: usize, max: usize) -> Self {
|
||
self.min_length = Some(min);
|
||
self.max_length = Some(max);
|
||
self
|
||
}
|
||
|
||
pub fn pattern(mut self, pattern: &str) -> Result<Self, regex::Error> {
|
||
self.pattern = Some(Regex::new(pattern)?);
|
||
Ok(self)
|
||
}
|
||
|
||
pub fn email(self) -> Result<Self, regex::Error> {
|
||
self.pattern(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$")
|
||
}
|
||
|
||
pub fn phone(self) -> Result<Self, regex::Error> {
|
||
self.pattern(r"^1[3-9]\d{9}$")
|
||
}
|
||
|
||
pub fn url(self) -> Result<Self, regex::Error> {
|
||
self.pattern(r"^https?://[^\s/$.?#].[^\s]*$")
|
||
}
|
||
}
|
||
|
||
impl Validator<Option<String>> for StringValidator {
|
||
fn validate(&self, value: &Option<String>) -> ValidationResult<()> {
|
||
match value {
|
||
None => {
|
||
if self.required {
|
||
Err(ValidationError::Required {
|
||
field: self.field_name.clone(),
|
||
})
|
||
} else {
|
||
Ok(())
|
||
}
|
||
}
|
||
Some(s) => self.validate(s),
|
||
}
|
||
}
|
||
}
|
||
|
||
impl Validator<String> for StringValidator {
|
||
fn validate(&self, value: &String) -> ValidationResult<()> {
|
||
// 检查是否为空且必需
|
||
if value.is_empty() && self.required {
|
||
return Err(ValidationError::Required {
|
||
field: self.field_name.clone(),
|
||
});
|
||
}
|
||
|
||
// 如果为空且不是必需的,跳过其他验证
|
||
if value.is_empty() && !self.required {
|
||
return Ok(());
|
||
}
|
||
|
||
// 检查长度
|
||
if let Some(min) = self.min_length {
|
||
if value.len() < min {
|
||
return Err(ValidationError::Length {
|
||
field: self.field_name.clone(),
|
||
min,
|
||
max: self.max_length.unwrap_or(usize::MAX),
|
||
});
|
||
}
|
||
}
|
||
|
||
if let Some(max) = self.max_length {
|
||
if value.len() > max {
|
||
return Err(ValidationError::Length {
|
||
field: self.field_name.clone(),
|
||
min: self.min_length.unwrap_or(0),
|
||
max,
|
||
});
|
||
}
|
||
}
|
||
|
||
// 检查正则表达式
|
||
if let Some(pattern) = &self.pattern {
|
||
if !pattern.is_match(value) {
|
||
return Err(ValidationError::Format {
|
||
field: self.field_name.clone(),
|
||
message: "格式不匹配".to_string(),
|
||
});
|
||
}
|
||
}
|
||
|
||
// 执行自定义验证
|
||
for validator in &self.custom_validators {
|
||
validator(value)?;
|
||
}
|
||
|
||
Ok(())
|
||
}
|
||
}
|
||
|
||
/// 数字验证器
|
||
#[derive(Debug, Clone)]
|
||
pub struct NumberValidator<T> {
|
||
pub field_name: String,
|
||
pub required: bool,
|
||
pub min_value: Option<T>,
|
||
pub max_value: Option<T>,
|
||
}
|
||
|
||
impl<T> NumberValidator<T>
|
||
where
|
||
T: PartialOrd + Copy,
|
||
{
|
||
pub fn new(field_name: &str) -> Self {
|
||
Self {
|
||
field_name: field_name.to_string(),
|
||
required: false,
|
||
min_value: None,
|
||
max_value: None,
|
||
}
|
||
}
|
||
|
||
pub fn required(mut self) -> Self {
|
||
self.required = true;
|
||
self
|
||
}
|
||
|
||
pub fn min_value(mut self, min: T) -> Self {
|
||
self.min_value = Some(min);
|
||
self
|
||
}
|
||
|
||
pub fn max_value(mut self, max: T) -> Self {
|
||
self.max_value = Some(max);
|
||
self
|
||
}
|
||
|
||
pub fn range(mut self, min: T, max: T) -> Self {
|
||
self.min_value = Some(min);
|
||
self.max_value = Some(max);
|
||
self
|
||
}
|
||
}
|
||
|
||
impl<T> Validator<Option<T>> for NumberValidator<T>
|
||
where
|
||
T: PartialOrd + Copy + std::fmt::Display,
|
||
{
|
||
fn validate(&self, value: &Option<T>) -> ValidationResult<()> {
|
||
match value {
|
||
None => {
|
||
if self.required {
|
||
Err(ValidationError::Required {
|
||
field: self.field_name.clone(),
|
||
})
|
||
} else {
|
||
Ok(())
|
||
}
|
||
}
|
||
Some(v) => self.validate(v),
|
||
}
|
||
}
|
||
}
|
||
|
||
impl<T> Validator<T> for NumberValidator<T>
|
||
where
|
||
T: PartialOrd + Copy + std::fmt::Display,
|
||
{
|
||
fn validate(&self, value: &T) -> ValidationResult<()> {
|
||
if let Some(min) = self.min_value {
|
||
if *value < min {
|
||
return Err(ValidationError::Invalid {
|
||
field: self.field_name.clone(),
|
||
message: format!("值必须大于等于 {}", min),
|
||
});
|
||
}
|
||
}
|
||
|
||
if let Some(max) = self.max_value {
|
||
if *value > max {
|
||
return Err(ValidationError::Invalid {
|
||
field: self.field_name.clone(),
|
||
message: format!("值必须小于等于 {}", max),
|
||
});
|
||
}
|
||
}
|
||
|
||
Ok(())
|
||
}
|
||
}
|
||
|
||
/// 密码强度等级
|
||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||
pub enum PasswordStrength {
|
||
Weak,
|
||
Medium,
|
||
Strong,
|
||
VeryStrong,
|
||
}
|
||
|
||
/// 密码验证结果
|
||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||
pub struct PasswordValidation {
|
||
pub is_valid: bool,
|
||
pub strength: PasswordStrength,
|
||
pub score: u8,
|
||
pub feedback: Vec<String>,
|
||
}
|
||
|
||
/// 验证邮箱地址
|
||
pub fn validate_email(email: &str) -> bool {
|
||
let email_regex = Regex::new(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$").unwrap();
|
||
email_regex.is_match(email)
|
||
}
|
||
|
||
/// 验证中国手机号
|
||
pub fn validate_phone(phone: &str) -> bool {
|
||
let phone_regex = Regex::new(r"^1[3-9]\d{9}$").unwrap();
|
||
phone_regex.is_match(phone)
|
||
}
|
||
|
||
/// 验证 URL
|
||
pub fn validate_url(url: &str) -> bool {
|
||
let url_regex = Regex::new(r"^https?://[^\s/$.?#].[^\s]*$").unwrap();
|
||
url_regex.is_match(url)
|
||
}
|
||
|
||
/// 验证 IP 地址
|
||
pub fn validate_ip(ip: &str) -> bool {
|
||
ip.parse::<std::net::IpAddr>().is_ok()
|
||
}
|
||
|
||
/// 验证密码强度
|
||
pub fn validate_password(password: &str) -> PasswordValidation {
|
||
let mut score = 0u8;
|
||
let mut feedback = Vec::new();
|
||
|
||
// 长度检查
|
||
if password.len() >= 8 {
|
||
score += 1;
|
||
} else {
|
||
feedback.push("密码长度至少需要8位".to_string());
|
||
}
|
||
|
||
if password.len() >= 12 {
|
||
score += 1;
|
||
}
|
||
|
||
// 包含小写字母
|
||
if password.chars().any(|c| c.is_ascii_lowercase()) {
|
||
score += 1;
|
||
} else {
|
||
feedback.push("密码应包含小写字母".to_string());
|
||
}
|
||
|
||
// 包含大写字母
|
||
if password.chars().any(|c| c.is_ascii_uppercase()) {
|
||
score += 1;
|
||
} else {
|
||
feedback.push("密码应包含大写字母".to_string());
|
||
}
|
||
|
||
// 包含数字
|
||
if password.chars().any(|c| c.is_ascii_digit()) {
|
||
score += 1;
|
||
} else {
|
||
feedback.push("密码应包含数字".to_string());
|
||
}
|
||
|
||
// 包含特殊字符
|
||
if password.chars().any(|c| "!@#$%^&*()_+-=[]{}|;:,.<>?".contains(c)) {
|
||
score += 1;
|
||
} else {
|
||
feedback.push("密码应包含特殊字符".to_string());
|
||
}
|
||
|
||
// 不包含常见弱密码
|
||
let weak_passwords = [
|
||
"password", "123456", "123456789", "qwerty", "abc123",
|
||
"password123", "admin", "root", "user", "guest",
|
||
];
|
||
|
||
if weak_passwords.iter().any(|&weak| password.to_lowercase().contains(weak)) {
|
||
score = score.saturating_sub(2);
|
||
feedback.push("密码不能包含常见弱密码".to_string());
|
||
}
|
||
|
||
let strength = match score {
|
||
0..=2 => PasswordStrength::Weak,
|
||
3..=4 => PasswordStrength::Medium,
|
||
5..=6 => PasswordStrength::Strong,
|
||
_ => PasswordStrength::VeryStrong,
|
||
};
|
||
|
||
let is_valid = score >= 3 && password.len() >= 8;
|
||
|
||
if is_valid && feedback.is_empty() {
|
||
feedback.push("密码强度良好".to_string());
|
||
}
|
||
|
||
PasswordValidation {
|
||
is_valid,
|
||
strength,
|
||
score,
|
||
feedback,
|
||
}
|
||
}
|
||
|
||
/// 验证 JSON 格式
|
||
pub fn validate_json(json_str: &str) -> bool {
|
||
serde_json::from_str::<serde_json::Value>(json_str).is_ok()
|
||
}
|
||
|
||
/// 验证 UUID 格式
|
||
pub fn validate_uuid(uuid_str: &str) -> bool {
|
||
uuid::Uuid::parse_str(uuid_str).is_ok()
|
||
}
|
||
|
||
/// 批量验证器
|
||
#[derive(Debug)]
|
||
pub struct BatchValidator {
|
||
errors: Vec<ValidationError>,
|
||
}
|
||
|
||
impl BatchValidator {
|
||
pub fn new() -> Self {
|
||
Self {
|
||
errors: Vec::new(),
|
||
}
|
||
}
|
||
|
||
pub fn validate<T, V>(&mut self, validator: &V, value: &T) -> &mut Self
|
||
where
|
||
V: Validator<T>,
|
||
{
|
||
if let Err(error) = validator.validate(value) {
|
||
self.errors.push(error);
|
||
}
|
||
self
|
||
}
|
||
|
||
pub fn is_valid(&self) -> bool {
|
||
self.errors.is_empty()
|
||
}
|
||
|
||
pub fn errors(&self) -> &[ValidationError] {
|
||
&self.errors
|
||
}
|
||
|
||
pub fn into_result(self) -> ValidationResult<()> {
|
||
if self.errors.is_empty() {
|
||
Ok(())
|
||
} else {
|
||
// 返回第一个错误,实际使用中可能需要返回所有错误
|
||
Err(self.errors.into_iter().next().unwrap())
|
||
}
|
||
}
|
||
}
|
||
|
||
impl Default for BatchValidator {
|
||
fn default() -> Self {
|
||
Self::new()
|
||
}
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
|
||
#[test]
|
||
fn test_email_validation() {
|
||
assert!(validate_email("test@example.com"));
|
||
assert!(validate_email("user.name+tag@domain.co.uk"));
|
||
assert!(!validate_email("invalid.email"));
|
||
assert!(!validate_email("@domain.com"));
|
||
}
|
||
|
||
#[test]
|
||
fn test_phone_validation() {
|
||
assert!(validate_phone("13812345678"));
|
||
assert!(validate_phone("15987654321"));
|
||
assert!(!validate_phone("12345678901"));
|
||
assert!(!validate_phone("1381234567"));
|
||
}
|
||
|
||
#[test]
|
||
fn test_password_validation() {
|
||
let weak = validate_password("123");
|
||
assert_eq!(weak.strength, PasswordStrength::Weak);
|
||
assert!(!weak.is_valid);
|
||
|
||
let strong = validate_password("MyStr0ng!Pass");
|
||
assert!(strong.is_valid);
|
||
assert!(matches!(strong.strength, PasswordStrength::Strong | PasswordStrength::VeryStrong));
|
||
}
|
||
|
||
#[test]
|
||
fn test_string_validator() {
|
||
let validator = StringValidator::new("username")
|
||
.required()
|
||
.length_range(3, 20)
|
||
.pattern(r"^[a-zA-Z0-9_]+$")
|
||
.unwrap();
|
||
|
||
assert!(validator.validate(&"valid_user123".to_string()).is_ok());
|
||
assert!(validator.validate(&"ab".to_string()).is_err()); // 太短
|
||
assert!(validator.validate(&"invalid-user".to_string()).is_err()); // 包含非法字符
|
||
}
|
||
|
||
#[test]
|
||
fn test_number_validator() {
|
||
let validator = NumberValidator::new("age")
|
||
.required()
|
||
.range(0, 150);
|
||
|
||
assert!(validator.validate(&25).is_ok());
|
||
assert!(validator.validate(&-1).is_err()); // 小于最小值
|
||
assert!(validator.validate(&200).is_err()); // 大于最大值
|
||
}
|
||
|
||
#[test]
|
||
fn test_batch_validator() {
|
||
let mut batch = BatchValidator::new();
|
||
|
||
let email_validator = StringValidator::new("email").required().email().unwrap();
|
||
let age_validator = NumberValidator::new("age").required().range(0, 150);
|
||
|
||
batch
|
||
.validate(&email_validator, &"invalid.email".to_string())
|
||
.validate(&age_validator, &200);
|
||
|
||
assert!(!batch.is_valid());
|
||
assert_eq!(batch.errors().len(), 2);
|
||
}
|
||
}
|
||
``` |