Files
udmin/backend/src/main.rs
ayou 30716686ed feat(ws): 新增WebSocket实时通信支持与SSE独立服务
重构中间件结构,新增ws模块实现WebSocket流程执行实时推送
将SSE服务拆分为独立端口监听,默认8866
优化前端流式模式切换,支持WS/SSE协议选择
统一流式事件处理逻辑,完善错误处理与取消机制
更新Cargo.toml依赖,添加WebSocket相关库
调整代码组织结构,规范导入分组与注释
2025-09-21 22:15:33 +08:00

135 lines
4.7 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

mod db;
mod redis;
mod response;
mod error;
pub mod middlewares;
pub mod models;
pub mod services;
pub mod routes;
pub mod utils;
pub mod flow;
use axum::Router;
use axum::http::{HeaderValue, Method};
use tower_http::cors::{CorsLayer, Any, AllowOrigin};
use migration::MigratorTrait;
use axum::middleware;
// 自定义日志时间格式YYYY-MM-DD HH:MM:SS.ssssss不带 T 和 Z
struct LocalTimeFmt;
impl tracing_subscriber::fmt::time::FormatTime for LocalTimeFmt {
fn format_time(&self, w: &mut tracing_subscriber::fmt::format::Writer) -> std::fmt::Result {
let now = chrono::Local::now();
w.write_str(&now.format("%Y-%m-%d %H:%M:%S%.6f").to_string())
}
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
// 增强:支持通过 ENV_FILE 指定要加载的环境文件,并记录实际加载的文件
// - ENV_FILE=prod 或 production => .env.prod
// - ENV_FILE=dev 或 development => .env
// - ENV_FILE=staging => .env.staging
// - ENV_FILE=任意字符串 => 视为显式文件名或路径
let env_file_used: Option<String> = if let Ok(v) = std::env::var("ENV_FILE") {
let filename = match v.trim() {
"" => ".env".to_string(),
"prod" | "production" => ".env.prod".to_string(),
"dev" | "development" => ".env".to_string(),
"staging" | "pre" | "preprod" | "pre-production" => ".env.staging".to_string(),
other => other.to_string(),
};
match dotenvy::from_filename_override(&filename) {
Ok(_) => Some(filename),
Err(_) => Some(format!("{} (not found)", filename)),
}
} else {
match dotenvy::dotenv_override() {
Ok(path) => Some(path.to_string_lossy().to_string()),
Err(_) => None,
}
};
tracing_subscriber::fmt()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.with_timer(LocalTimeFmt)
.init();
let db = db::init_db().await?;
// set global DB for tasks
db::set_db(db.clone()).expect("db set failure");
// initialize Redis connection
let redis_pool = redis::init_redis().await?;
redis::set_redis_pool(redis_pool)?;
// run migrations
migration::Migrator::up(&db, None).await.expect("migration up");
let allow_origins = std::env::var("CORS_ALLOW_ORIGINS").unwrap_or_else(|_| "http://localhost:5173".into());
let origin_values: Vec<HeaderValue> = allow_origins
.split(',')
.filter_map(|s| HeaderValue::from_str(s.trim()).ok())
.collect();
let allowed_methods = [
Method::GET,
Method::POST,
Method::PUT,
Method::PATCH,
Method::DELETE,
Method::OPTIONS,
];
let cors = if origin_values.is_empty() {
// 当允许任意来源时,不能与 allow_credentials(true) 同时使用
CorsLayer::new()
.allow_origin(Any)
.allow_methods(allowed_methods.clone())
.allow_headers([
axum::http::header::ACCEPT,
axum::http::header::CONTENT_TYPE,
axum::http::header::AUTHORIZATION,
])
.allow_credentials(false)
} else {
CorsLayer::new()
.allow_origin(AllowOrigin::list(origin_values))
.allow_methods(allowed_methods)
.allow_headers([
axum::http::header::ACCEPT,
axum::http::header::CONTENT_TYPE,
axum::http::header::AUTHORIZATION,
])
.allow_credentials(true)
};
let api = routes::api_router().with_state(db.clone());
let app = Router::new()
.nest("/api", api)
.layer(cors)
.layer(middleware::from_fn_with_state(db.clone(), middlewares::logging::request_logger));
// 读取并记录最终使用的主机与端口(默认端口改为 9898
let app_host = std::env::var("APP_HOST").unwrap_or("0.0.0.0".into());
let app_port = std::env::var("APP_PORT").unwrap_or("9898".into());
if let Some(f) = &env_file_used { tracing::info!("env file loaded: {}", f); } else { tracing::info!("env file loaded: <none>"); }
tracing::info!("resolved APP_HOST={} APP_PORT={}", app_host, app_port);
let http_addr = format!("{}:{}", app_host, app_port);
tracing::info!("listening on {}", http_addr);
// HTTP 服务监听
let http_listener = tokio::net::TcpListener::bind(http_addr.clone()).await?;
let http_server = axum::serve(http_listener, app);
// WS 服务下沉到中间件
let ws_server = middlewares::ws::serve(db.clone());
// 新增SSE 服务独立端口监听(默认 8866可配 SSE_HOST/SSE_PORT
let sse_server = middlewares::sse::serve(db.clone());
tokio::try_join!(http_server, ws_server, sse_server)?;
Ok(())
}