feat: add redis
This commit is contained in:
@ -1,4 +1,5 @@
|
||||
mod db;
|
||||
mod redis;
|
||||
mod response;
|
||||
mod error;
|
||||
pub mod middlewares;
|
||||
@ -6,6 +7,7 @@ pub mod models;
|
||||
pub mod services;
|
||||
pub mod routes;
|
||||
pub mod utils;
|
||||
pub mod workflow;
|
||||
|
||||
use axum::Router;
|
||||
use axum::http::{HeaderValue, Method};
|
||||
@ -37,6 +39,10 @@ async fn main() -> anyhow::Result<()> {
|
||||
|
||||
let db = db::init_db().await?;
|
||||
|
||||
// initialize Redis connection
|
||||
let redis_pool = redis::init_redis().await?;
|
||||
redis::set_redis_pool(redis_pool)?;
|
||||
|
||||
// run migrations
|
||||
migration::Migrator::up(&db, None).await.expect("migration up");
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@ use axum::{http::HeaderMap, http::header::AUTHORIZATION};
|
||||
use chrono::{Utc, Duration as ChronoDuration};
|
||||
use jsonwebtoken::{EncodingKey, DecodingKey, Header, Validation};
|
||||
use serde::{Serialize, Deserialize};
|
||||
use crate::error::AppError;
|
||||
use crate::{error::AppError, redis::TokenRedis};
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct Claims {
|
||||
@ -37,6 +37,21 @@ impl<S> axum::extract::FromRequestParts<S> for AuthUser where S: Send + Sync + '
|
||||
let secret = std::env::var("JWT_SECRET").map_err(|_| AppError::Unauthorized)?;
|
||||
let claims = decode_token(token, &secret)?;
|
||||
if claims.typ != "access" { return Err(AppError::Unauthorized); }
|
||||
|
||||
// 验证token是否在Redis中存在(可选:添加环境变量控制是否启用Redis验证)
|
||||
let redis_validation_enabled = std::env::var("REDIS_TOKEN_VALIDATION")
|
||||
.unwrap_or_else(|_| "true".to_string())
|
||||
.parse::<bool>()
|
||||
.unwrap_or(true);
|
||||
|
||||
if redis_validation_enabled {
|
||||
let is_valid = TokenRedis::validate_access_token(token, claims.uid).await
|
||||
.unwrap_or(false);
|
||||
if !is_valid {
|
||||
return Err(AppError::Unauthorized);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(AuthUser { uid: claims.uid, username: claims.sub })
|
||||
}
|
||||
}
|
||||
|
||||
154
backend/src/redis.rs
Normal file
154
backend/src/redis.rs
Normal file
@ -0,0 +1,154 @@
|
||||
use redis::{Client, AsyncCommands};
|
||||
use redis::aio::ConnectionManager;
|
||||
use once_cell::sync::OnceCell;
|
||||
use crate::error::AppError;
|
||||
|
||||
pub type RedisPool = ConnectionManager;
|
||||
|
||||
static REDIS_POOL: OnceCell<RedisPool> = OnceCell::new();
|
||||
|
||||
/// 初始化Redis连接池
|
||||
pub async fn init_redis() -> Result<RedisPool, AppError> {
|
||||
let redis_url = std::env::var("REDIS_URL")
|
||||
.unwrap_or_else(|_| "redis://:123456@127.0.0.1:6379/9".into());
|
||||
|
||||
tracing::info!("Connecting to Redis at: {}", redis_url.replace(":123456", ":***"));
|
||||
|
||||
let client = Client::open(redis_url)
|
||||
.map_err(|e| AppError::Anyhow(anyhow::anyhow!("Failed to create Redis client: {}", e)))?;
|
||||
|
||||
let manager = ConnectionManager::new(client).await
|
||||
.map_err(|e| AppError::Anyhow(anyhow::anyhow!("Failed to create Redis connection manager: {}", e)))?;
|
||||
|
||||
tracing::info!("Redis connection established successfully");
|
||||
Ok(manager)
|
||||
}
|
||||
|
||||
/// 获取Redis连接池
|
||||
pub fn get_redis() -> Result<&'static RedisPool, AppError> {
|
||||
REDIS_POOL.get().ok_or_else(|| AppError::Anyhow(anyhow::anyhow!("Redis pool not initialized")))
|
||||
}
|
||||
|
||||
/// 设置Redis连接池
|
||||
pub fn set_redis_pool(pool: RedisPool) -> Result<(), AppError> {
|
||||
REDIS_POOL.set(pool)
|
||||
.map_err(|_| AppError::Anyhow(anyhow::anyhow!("Redis pool already initialized")))
|
||||
}
|
||||
|
||||
/// Redis工具函数
|
||||
pub struct RedisHelper;
|
||||
|
||||
impl RedisHelper {
|
||||
/// 设置带过期时间的键值对
|
||||
pub async fn set_ex(key: &str, value: &str, expire_seconds: u64) -> Result<(), AppError> {
|
||||
let mut conn = get_redis()?.clone();
|
||||
tracing::debug!("Redis SET_EX: key={}, value_len={}, expire_seconds={}", key, value.len(), expire_seconds);
|
||||
let _: String = conn.set_ex(key, value, expire_seconds).await
|
||||
.map_err(|e| {
|
||||
tracing::error!("Redis set_ex failed for key {}: {}", key, e);
|
||||
AppError::Anyhow(anyhow::anyhow!("Redis set_ex failed: {}", e))
|
||||
})?;
|
||||
tracing::debug!("Redis SET_EX successful for key: {}", key);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 获取键值
|
||||
pub async fn get(key: &str) -> Result<Option<String>, AppError> {
|
||||
let mut conn = get_redis()?.clone();
|
||||
let result: Option<String> = conn.get(key).await
|
||||
.map_err(|e| AppError::Anyhow(anyhow::anyhow!("Redis get failed: {}", e)))?;
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// 删除键
|
||||
pub async fn del(key: &str) -> Result<(), AppError> {
|
||||
let mut conn = get_redis()?.clone();
|
||||
let _: i32 = conn.del(key).await
|
||||
.map_err(|e| AppError::Anyhow(anyhow::anyhow!("Redis del failed: {}", e)))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 检查键是否存在
|
||||
pub async fn exists(key: &str) -> Result<bool, AppError> {
|
||||
let mut conn = get_redis()?.clone();
|
||||
let result: bool = conn.exists(key).await
|
||||
.map_err(|e| AppError::Anyhow(anyhow::anyhow!("Redis exists failed: {}", e)))?;
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// 设置键的过期时间
|
||||
pub async fn expire(key: &str, seconds: u64) -> Result<(), AppError> {
|
||||
let mut conn = get_redis()?.clone();
|
||||
let _: bool = conn.expire(key, seconds as i64).await
|
||||
.map_err(|e| AppError::Anyhow(anyhow::anyhow!("Redis expire failed: {}", e)))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 删除用户相关的所有token
|
||||
pub async fn del_user_tokens(user_id: i64) -> Result<(), AppError> {
|
||||
let pattern = format!("token:*:user:{}", user_id);
|
||||
let mut conn = get_redis()?.clone();
|
||||
|
||||
// 获取匹配的键
|
||||
let keys: Vec<String> = conn.keys(&pattern).await
|
||||
.map_err(|e| AppError::Anyhow(anyhow::anyhow!("Redis keys failed: {}", e)))?;
|
||||
|
||||
// 删除所有匹配的键
|
||||
if !keys.is_empty() {
|
||||
let _: i32 = conn.del(&keys).await
|
||||
.map_err(|e| AppError::Anyhow(anyhow::anyhow!("Redis del multiple failed: {}", e)))?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Token相关的Redis操作
|
||||
pub struct TokenRedis;
|
||||
|
||||
impl TokenRedis {
|
||||
/// 存储访问token
|
||||
pub async fn store_access_token(token: &str, user_id: i64, expire_seconds: u64) -> Result<(), AppError> {
|
||||
let key = format!("token:access:user:{}", user_id);
|
||||
tracing::info!("Storing access token for user {} with key: {}, expires in {} seconds", user_id, key, expire_seconds);
|
||||
RedisHelper::set_ex(&key, token, expire_seconds).await
|
||||
}
|
||||
|
||||
/// 存储刷新token
|
||||
pub async fn store_refresh_token(token: &str, user_id: i64, expire_seconds: u64) -> Result<(), AppError> {
|
||||
let key = format!("token:refresh:user:{}", user_id);
|
||||
tracing::info!("Storing refresh token for user {} with key: {}, expires in {} seconds", user_id, key, expire_seconds);
|
||||
RedisHelper::set_ex(&key, token, expire_seconds).await
|
||||
}
|
||||
|
||||
/// 验证访问token
|
||||
pub async fn validate_access_token(token: &str, user_id: i64) -> Result<bool, AppError> {
|
||||
let key = format!("token:access:user:{}", user_id);
|
||||
let stored_token = RedisHelper::get(&key).await?;
|
||||
Ok(stored_token.as_deref() == Some(token))
|
||||
}
|
||||
|
||||
/// 验证刷新token
|
||||
pub async fn validate_refresh_token(token: &str, user_id: i64) -> Result<bool, AppError> {
|
||||
let key = format!("token:refresh:user:{}", user_id);
|
||||
let stored_token = RedisHelper::get(&key).await?;
|
||||
Ok(stored_token.as_deref() == Some(token))
|
||||
}
|
||||
|
||||
/// 删除用户的访问token
|
||||
pub async fn revoke_access_token(user_id: i64) -> Result<(), AppError> {
|
||||
let key = format!("token:access:user:{}", user_id);
|
||||
RedisHelper::del(&key).await
|
||||
}
|
||||
|
||||
/// 删除用户的刷新token
|
||||
pub async fn revoke_refresh_token(user_id: i64) -> Result<(), AppError> {
|
||||
let key = format!("token:refresh:user:{}", user_id);
|
||||
RedisHelper::del(&key).await
|
||||
}
|
||||
|
||||
/// 删除用户的所有token
|
||||
pub async fn revoke_all_tokens(user_id: i64) -> Result<(), AppError> {
|
||||
RedisHelper::del_user_tokens(user_id).await
|
||||
}
|
||||
}
|
||||
@ -1,5 +1,5 @@
|
||||
use sea_orm::{EntityTrait, ColumnTrait, QueryFilter, ActiveModelTrait, Set};
|
||||
use crate::{db::Db, models::{user, refresh_token}, utils::password, error::AppError};
|
||||
use crate::{db::Db, models::{user, refresh_token}, utils::password, error::AppError, redis::TokenRedis};
|
||||
use chrono::{Utc, Duration, FixedOffset};
|
||||
use sha2::{Sha256, Digest};
|
||||
use sea_orm::ActiveValue::NotSet;
|
||||
@ -11,13 +11,38 @@ pub async fn login(db: &Db, username: String, password_plain: String) -> Result<
|
||||
if u.status != 1 { return Err(AppError::Forbidden); }
|
||||
let ok = password::verify_password(&password_plain, &u.password_hash).map_err(|_| AppError::Unauthorized)?;
|
||||
if !ok { return Err(AppError::Unauthorized); }
|
||||
|
||||
let access_claims = crate::middlewares::jwt::new_access_claims(u.id, &u.username);
|
||||
let refresh_claims = crate::middlewares::jwt::new_refresh_claims(u.id, &u.username);
|
||||
let secret = std::env::var("JWT_SECRET").unwrap();
|
||||
let access = crate::middlewares::jwt::encode_token(&access_claims, &secret)?;
|
||||
let refresh = crate::middlewares::jwt::encode_token(&refresh_claims, &secret)?;
|
||||
|
||||
// persist refresh token hash
|
||||
// 获取过期时间(秒)
|
||||
let access_exp_secs = std::env::var("JWT_ACCESS_EXP_SECS").ok().and_then(|v| v.parse().ok()).unwrap_or(1800);
|
||||
let refresh_exp_secs = std::env::var("JWT_REFRESH_EXP_SECS").ok().and_then(|v| v.parse().ok()).unwrap_or(1209600);
|
||||
|
||||
// 先删除用户的所有旧token(防止多重登录)
|
||||
if let Err(e) = TokenRedis::revoke_all_tokens(u.id).await {
|
||||
tracing::warn!("Failed to revoke old tokens for user {}: {}", u.id, e);
|
||||
}
|
||||
|
||||
// 存储新token到Redis
|
||||
if let Err(e) = TokenRedis::store_access_token(&access, u.id, access_exp_secs as u64).await {
|
||||
tracing::error!("Failed to store access token to Redis for user {}: {}", u.id, e);
|
||||
// 不返回错误,降级到仅使用JWT模式
|
||||
} else {
|
||||
tracing::info!("Successfully stored access token to Redis for user {}", u.id);
|
||||
}
|
||||
|
||||
if let Err(e) = TokenRedis::store_refresh_token(&refresh, u.id, refresh_exp_secs as u64).await {
|
||||
tracing::error!("Failed to store refresh token to Redis for user {}: {}", u.id, e);
|
||||
// 不返回错误,降级到仅使用JWT模式
|
||||
} else {
|
||||
tracing::info!("Successfully stored refresh token to Redis for user {}", u.id);
|
||||
}
|
||||
|
||||
// persist refresh token hash to database (backup)
|
||||
let mut hasher = Sha256::new(); hasher.update(refresh.as_bytes());
|
||||
let token_hash = format!("{:x}", hasher.finalize());
|
||||
let exp_secs = std::env::var("JWT_REFRESH_EXP_SECS").ok().and_then(|v| v.parse().ok()).unwrap_or(1209600);
|
||||
@ -36,15 +61,26 @@ pub async fn login(db: &Db, username: String, password_plain: String) -> Result<
|
||||
}
|
||||
|
||||
pub async fn logout(db: &Db, uid: i64) -> Result<(), AppError> {
|
||||
// 从 Redis 中删除所有 token
|
||||
let _ = TokenRedis::revoke_all_tokens(uid).await;
|
||||
|
||||
// 从数据库中删除 refresh token
|
||||
let _ = refresh_token::Entity::delete_many().filter(refresh_token::Column::UserId.eq(uid)).exec(db).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn rotate_refresh(db: &Db, uid: i64, old_refresh: String) -> Result<(String, String), AppError> {
|
||||
// 验证Redis中的refresh token
|
||||
let is_valid_redis = TokenRedis::validate_refresh_token(&old_refresh, uid).await.unwrap_or(false);
|
||||
|
||||
// 同时验证数据库中的token hash(备用验证)
|
||||
let mut hasher = Sha256::new(); hasher.update(old_refresh.as_bytes());
|
||||
let token_hash = format!("{:x}", hasher.finalize());
|
||||
let existing = refresh_token::Entity::find().filter(refresh_token::Column::UserId.eq(uid)).filter(refresh_token::Column::TokenHash.eq(token_hash.clone())).one(db).await?;
|
||||
if existing.is_none() { return Err(AppError::Unauthorized); }
|
||||
|
||||
if !is_valid_redis && existing.is_none() {
|
||||
return Err(AppError::Unauthorized);
|
||||
}
|
||||
|
||||
let u = user::Entity::find_by_id(uid).one(db).await?.ok_or(AppError::Unauthorized)?;
|
||||
let access_claims = crate::middlewares::jwt::new_access_claims(u.id, &u.username);
|
||||
@ -52,7 +88,16 @@ pub async fn rotate_refresh(db: &Db, uid: i64, old_refresh: String) -> Result<(S
|
||||
let secret = std::env::var("JWT_SECRET").unwrap();
|
||||
let access = crate::middlewares::jwt::encode_token(&access_claims, &secret)?;
|
||||
let refresh = crate::middlewares::jwt::encode_token(&refresh_claims, &secret)?;
|
||||
|
||||
// 获取过期时间
|
||||
let access_exp_secs = std::env::var("JWT_ACCESS_EXP_SECS").ok().and_then(|v| v.parse().ok()).unwrap_or(1800);
|
||||
let refresh_exp_secs = std::env::var("JWT_REFRESH_EXP_SECS").ok().and_then(|v| v.parse().ok()).unwrap_or(1209600);
|
||||
|
||||
// 更新Redis中的token
|
||||
TokenRedis::store_access_token(&access, u.id, access_exp_secs as u64).await?;
|
||||
TokenRedis::store_refresh_token(&refresh, u.id, refresh_exp_secs as u64).await?;
|
||||
|
||||
// 更新数据库中的refresh token
|
||||
let _ = refresh_token::Entity::delete_many().filter(refresh_token::Column::UserId.eq(uid)).filter(refresh_token::Column::TokenHash.eq(token_hash)).exec(db).await?;
|
||||
let mut hasher2 = Sha256::new(); hasher2.update(refresh.as_bytes());
|
||||
let token_hash2 = format!("{:x}", hasher2.finalize());
|
||||
|
||||
Reference in New Issue
Block a user