feat: add redis

This commit is contained in:
2025-08-29 21:42:29 +08:00
parent af68d94efa
commit dc60a0a4bd
10 changed files with 875 additions and 52 deletions

View File

@ -1,4 +1,5 @@
mod db;
mod redis;
mod response;
mod error;
pub mod middlewares;
@ -6,6 +7,7 @@ pub mod models;
pub mod services;
pub mod routes;
pub mod utils;
pub mod workflow;
use axum::Router;
use axum::http::{HeaderValue, Method};
@ -37,6 +39,10 @@ async fn main() -> anyhow::Result<()> {
let db = db::init_db().await?;
// initialize Redis connection
let redis_pool = redis::init_redis().await?;
redis::set_redis_pool(redis_pool)?;
// run migrations
migration::Migrator::up(&db, None).await.expect("migration up");

View File

@ -2,7 +2,7 @@ use axum::{http::HeaderMap, http::header::AUTHORIZATION};
use chrono::{Utc, Duration as ChronoDuration};
use jsonwebtoken::{EncodingKey, DecodingKey, Header, Validation};
use serde::{Serialize, Deserialize};
use crate::error::AppError;
use crate::{error::AppError, redis::TokenRedis};
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Claims {
@ -37,6 +37,21 @@ impl<S> axum::extract::FromRequestParts<S> for AuthUser where S: Send + Sync + '
let secret = std::env::var("JWT_SECRET").map_err(|_| AppError::Unauthorized)?;
let claims = decode_token(token, &secret)?;
if claims.typ != "access" { return Err(AppError::Unauthorized); }
// 验证token是否在Redis中存在可选添加环境变量控制是否启用Redis验证
let redis_validation_enabled = std::env::var("REDIS_TOKEN_VALIDATION")
.unwrap_or_else(|_| "true".to_string())
.parse::<bool>()
.unwrap_or(true);
if redis_validation_enabled {
let is_valid = TokenRedis::validate_access_token(token, claims.uid).await
.unwrap_or(false);
if !is_valid {
return Err(AppError::Unauthorized);
}
}
Ok(AuthUser { uid: claims.uid, username: claims.sub })
}
}

154
backend/src/redis.rs Normal file
View File

@ -0,0 +1,154 @@
use redis::{Client, AsyncCommands};
use redis::aio::ConnectionManager;
use once_cell::sync::OnceCell;
use crate::error::AppError;
pub type RedisPool = ConnectionManager;
static REDIS_POOL: OnceCell<RedisPool> = OnceCell::new();
/// 初始化Redis连接池
pub async fn init_redis() -> Result<RedisPool, AppError> {
let redis_url = std::env::var("REDIS_URL")
.unwrap_or_else(|_| "redis://:123456@127.0.0.1:6379/9".into());
tracing::info!("Connecting to Redis at: {}", redis_url.replace(":123456", ":***"));
let client = Client::open(redis_url)
.map_err(|e| AppError::Anyhow(anyhow::anyhow!("Failed to create Redis client: {}", e)))?;
let manager = ConnectionManager::new(client).await
.map_err(|e| AppError::Anyhow(anyhow::anyhow!("Failed to create Redis connection manager: {}", e)))?;
tracing::info!("Redis connection established successfully");
Ok(manager)
}
/// 获取Redis连接池
pub fn get_redis() -> Result<&'static RedisPool, AppError> {
REDIS_POOL.get().ok_or_else(|| AppError::Anyhow(anyhow::anyhow!("Redis pool not initialized")))
}
/// 设置Redis连接池
pub fn set_redis_pool(pool: RedisPool) -> Result<(), AppError> {
REDIS_POOL.set(pool)
.map_err(|_| AppError::Anyhow(anyhow::anyhow!("Redis pool already initialized")))
}
/// Redis工具函数
pub struct RedisHelper;
impl RedisHelper {
/// 设置带过期时间的键值对
pub async fn set_ex(key: &str, value: &str, expire_seconds: u64) -> Result<(), AppError> {
let mut conn = get_redis()?.clone();
tracing::debug!("Redis SET_EX: key={}, value_len={}, expire_seconds={}", key, value.len(), expire_seconds);
let _: String = conn.set_ex(key, value, expire_seconds).await
.map_err(|e| {
tracing::error!("Redis set_ex failed for key {}: {}", key, e);
AppError::Anyhow(anyhow::anyhow!("Redis set_ex failed: {}", e))
})?;
tracing::debug!("Redis SET_EX successful for key: {}", key);
Ok(())
}
/// 获取键值
pub async fn get(key: &str) -> Result<Option<String>, AppError> {
let mut conn = get_redis()?.clone();
let result: Option<String> = conn.get(key).await
.map_err(|e| AppError::Anyhow(anyhow::anyhow!("Redis get failed: {}", e)))?;
Ok(result)
}
/// 删除键
pub async fn del(key: &str) -> Result<(), AppError> {
let mut conn = get_redis()?.clone();
let _: i32 = conn.del(key).await
.map_err(|e| AppError::Anyhow(anyhow::anyhow!("Redis del failed: {}", e)))?;
Ok(())
}
/// 检查键是否存在
pub async fn exists(key: &str) -> Result<bool, AppError> {
let mut conn = get_redis()?.clone();
let result: bool = conn.exists(key).await
.map_err(|e| AppError::Anyhow(anyhow::anyhow!("Redis exists failed: {}", e)))?;
Ok(result)
}
/// 设置键的过期时间
pub async fn expire(key: &str, seconds: u64) -> Result<(), AppError> {
let mut conn = get_redis()?.clone();
let _: bool = conn.expire(key, seconds as i64).await
.map_err(|e| AppError::Anyhow(anyhow::anyhow!("Redis expire failed: {}", e)))?;
Ok(())
}
/// 删除用户相关的所有token
pub async fn del_user_tokens(user_id: i64) -> Result<(), AppError> {
let pattern = format!("token:*:user:{}", user_id);
let mut conn = get_redis()?.clone();
// 获取匹配的键
let keys: Vec<String> = conn.keys(&pattern).await
.map_err(|e| AppError::Anyhow(anyhow::anyhow!("Redis keys failed: {}", e)))?;
// 删除所有匹配的键
if !keys.is_empty() {
let _: i32 = conn.del(&keys).await
.map_err(|e| AppError::Anyhow(anyhow::anyhow!("Redis del multiple failed: {}", e)))?;
}
Ok(())
}
}
/// Token相关的Redis操作
pub struct TokenRedis;
impl TokenRedis {
/// 存储访问token
pub async fn store_access_token(token: &str, user_id: i64, expire_seconds: u64) -> Result<(), AppError> {
let key = format!("token:access:user:{}", user_id);
tracing::info!("Storing access token for user {} with key: {}, expires in {} seconds", user_id, key, expire_seconds);
RedisHelper::set_ex(&key, token, expire_seconds).await
}
/// 存储刷新token
pub async fn store_refresh_token(token: &str, user_id: i64, expire_seconds: u64) -> Result<(), AppError> {
let key = format!("token:refresh:user:{}", user_id);
tracing::info!("Storing refresh token for user {} with key: {}, expires in {} seconds", user_id, key, expire_seconds);
RedisHelper::set_ex(&key, token, expire_seconds).await
}
/// 验证访问token
pub async fn validate_access_token(token: &str, user_id: i64) -> Result<bool, AppError> {
let key = format!("token:access:user:{}", user_id);
let stored_token = RedisHelper::get(&key).await?;
Ok(stored_token.as_deref() == Some(token))
}
/// 验证刷新token
pub async fn validate_refresh_token(token: &str, user_id: i64) -> Result<bool, AppError> {
let key = format!("token:refresh:user:{}", user_id);
let stored_token = RedisHelper::get(&key).await?;
Ok(stored_token.as_deref() == Some(token))
}
/// 删除用户的访问token
pub async fn revoke_access_token(user_id: i64) -> Result<(), AppError> {
let key = format!("token:access:user:{}", user_id);
RedisHelper::del(&key).await
}
/// 删除用户的刷新token
pub async fn revoke_refresh_token(user_id: i64) -> Result<(), AppError> {
let key = format!("token:refresh:user:{}", user_id);
RedisHelper::del(&key).await
}
/// 删除用户的所有token
pub async fn revoke_all_tokens(user_id: i64) -> Result<(), AppError> {
RedisHelper::del_user_tokens(user_id).await
}
}

View File

@ -1,5 +1,5 @@
use sea_orm::{EntityTrait, ColumnTrait, QueryFilter, ActiveModelTrait, Set};
use crate::{db::Db, models::{user, refresh_token}, utils::password, error::AppError};
use crate::{db::Db, models::{user, refresh_token}, utils::password, error::AppError, redis::TokenRedis};
use chrono::{Utc, Duration, FixedOffset};
use sha2::{Sha256, Digest};
use sea_orm::ActiveValue::NotSet;
@ -11,13 +11,38 @@ pub async fn login(db: &Db, username: String, password_plain: String) -> Result<
if u.status != 1 { return Err(AppError::Forbidden); }
let ok = password::verify_password(&password_plain, &u.password_hash).map_err(|_| AppError::Unauthorized)?;
if !ok { return Err(AppError::Unauthorized); }
let access_claims = crate::middlewares::jwt::new_access_claims(u.id, &u.username);
let refresh_claims = crate::middlewares::jwt::new_refresh_claims(u.id, &u.username);
let secret = std::env::var("JWT_SECRET").unwrap();
let access = crate::middlewares::jwt::encode_token(&access_claims, &secret)?;
let refresh = crate::middlewares::jwt::encode_token(&refresh_claims, &secret)?;
// persist refresh token hash
// 获取过期时间(秒)
let access_exp_secs = std::env::var("JWT_ACCESS_EXP_SECS").ok().and_then(|v| v.parse().ok()).unwrap_or(1800);
let refresh_exp_secs = std::env::var("JWT_REFRESH_EXP_SECS").ok().and_then(|v| v.parse().ok()).unwrap_or(1209600);
// 先删除用户的所有旧token防止多重登录
if let Err(e) = TokenRedis::revoke_all_tokens(u.id).await {
tracing::warn!("Failed to revoke old tokens for user {}: {}", u.id, e);
}
// 存储新token到Redis
if let Err(e) = TokenRedis::store_access_token(&access, u.id, access_exp_secs as u64).await {
tracing::error!("Failed to store access token to Redis for user {}: {}", u.id, e);
// 不返回错误降级到仅使用JWT模式
} else {
tracing::info!("Successfully stored access token to Redis for user {}", u.id);
}
if let Err(e) = TokenRedis::store_refresh_token(&refresh, u.id, refresh_exp_secs as u64).await {
tracing::error!("Failed to store refresh token to Redis for user {}: {}", u.id, e);
// 不返回错误降级到仅使用JWT模式
} else {
tracing::info!("Successfully stored refresh token to Redis for user {}", u.id);
}
// persist refresh token hash to database (backup)
let mut hasher = Sha256::new(); hasher.update(refresh.as_bytes());
let token_hash = format!("{:x}", hasher.finalize());
let exp_secs = std::env::var("JWT_REFRESH_EXP_SECS").ok().and_then(|v| v.parse().ok()).unwrap_or(1209600);
@ -36,15 +61,26 @@ pub async fn login(db: &Db, username: String, password_plain: String) -> Result<
}
pub async fn logout(db: &Db, uid: i64) -> Result<(), AppError> {
// 从 Redis 中删除所有 token
let _ = TokenRedis::revoke_all_tokens(uid).await;
// 从数据库中删除 refresh token
let _ = refresh_token::Entity::delete_many().filter(refresh_token::Column::UserId.eq(uid)).exec(db).await?;
Ok(())
}
pub async fn rotate_refresh(db: &Db, uid: i64, old_refresh: String) -> Result<(String, String), AppError> {
// 验证Redis中的refresh token
let is_valid_redis = TokenRedis::validate_refresh_token(&old_refresh, uid).await.unwrap_or(false);
// 同时验证数据库中的token hash备用验证
let mut hasher = Sha256::new(); hasher.update(old_refresh.as_bytes());
let token_hash = format!("{:x}", hasher.finalize());
let existing = refresh_token::Entity::find().filter(refresh_token::Column::UserId.eq(uid)).filter(refresh_token::Column::TokenHash.eq(token_hash.clone())).one(db).await?;
if existing.is_none() { return Err(AppError::Unauthorized); }
if !is_valid_redis && existing.is_none() {
return Err(AppError::Unauthorized);
}
let u = user::Entity::find_by_id(uid).one(db).await?.ok_or(AppError::Unauthorized)?;
let access_claims = crate::middlewares::jwt::new_access_claims(u.id, &u.username);
@ -52,7 +88,16 @@ pub async fn rotate_refresh(db: &Db, uid: i64, old_refresh: String) -> Result<(S
let secret = std::env::var("JWT_SECRET").unwrap();
let access = crate::middlewares::jwt::encode_token(&access_claims, &secret)?;
let refresh = crate::middlewares::jwt::encode_token(&refresh_claims, &secret)?;
// 获取过期时间
let access_exp_secs = std::env::var("JWT_ACCESS_EXP_SECS").ok().and_then(|v| v.parse().ok()).unwrap_or(1800);
let refresh_exp_secs = std::env::var("JWT_REFRESH_EXP_SECS").ok().and_then(|v| v.parse().ok()).unwrap_or(1209600);
// 更新Redis中的token
TokenRedis::store_access_token(&access, u.id, access_exp_secs as u64).await?;
TokenRedis::store_refresh_token(&refresh, u.id, refresh_exp_secs as u64).await?;
// 更新数据库中的refresh token
let _ = refresh_token::Entity::delete_many().filter(refresh_token::Column::UserId.eq(uid)).filter(refresh_token::Column::TokenHash.eq(token_hash)).exec(db).await?;
let mut hasher2 = Sha256::new(); hasher2.update(refresh.as_bytes());
let token_hash2 = format!("{:x}", hasher2.finalize());