重构中间件结构,新增ws模块实现WebSocket流程执行实时推送 将SSE服务拆分为独立端口监听,默认8866 优化前端流式模式切换,支持WS/SSE协议选择 统一流式事件处理逻辑,完善错误处理与取消机制 更新Cargo.toml依赖,添加WebSocket相关库 调整代码组织结构,规范导入分组与注释
135 lines
4.7 KiB
Rust
135 lines
4.7 KiB
Rust
mod db;
|
||
mod redis;
|
||
mod response;
|
||
mod error;
|
||
pub mod middlewares;
|
||
pub mod models;
|
||
pub mod services;
|
||
pub mod routes;
|
||
pub mod utils;
|
||
pub mod flow;
|
||
|
||
use axum::Router;
|
||
use axum::http::{HeaderValue, Method};
|
||
use tower_http::cors::{CorsLayer, Any, AllowOrigin};
|
||
use migration::MigratorTrait;
|
||
use axum::middleware;
|
||
|
||
// 自定义日志时间格式:YYYY-MM-DD HH:MM:SS.ssssss(不带 T 和 Z)
|
||
struct LocalTimeFmt;
|
||
|
||
impl tracing_subscriber::fmt::time::FormatTime for LocalTimeFmt {
|
||
fn format_time(&self, w: &mut tracing_subscriber::fmt::format::Writer) -> std::fmt::Result {
|
||
let now = chrono::Local::now();
|
||
w.write_str(&now.format("%Y-%m-%d %H:%M:%S%.6f").to_string())
|
||
}
|
||
}
|
||
|
||
#[tokio::main]
|
||
async fn main() -> anyhow::Result<()> {
|
||
// 增强:支持通过 ENV_FILE 指定要加载的环境文件,并记录实际加载的文件
|
||
// - ENV_FILE=prod 或 production => .env.prod
|
||
// - ENV_FILE=dev 或 development => .env
|
||
// - ENV_FILE=staging => .env.staging
|
||
// - ENV_FILE=任意字符串 => 视为显式文件名或路径
|
||
let env_file_used: Option<String> = if let Ok(v) = std::env::var("ENV_FILE") {
|
||
let filename = match v.trim() {
|
||
"" => ".env".to_string(),
|
||
"prod" | "production" => ".env.prod".to_string(),
|
||
"dev" | "development" => ".env".to_string(),
|
||
"staging" | "pre" | "preprod" | "pre-production" => ".env.staging".to_string(),
|
||
other => other.to_string(),
|
||
};
|
||
match dotenvy::from_filename_override(&filename) {
|
||
Ok(_) => Some(filename),
|
||
Err(_) => Some(format!("{} (not found)", filename)),
|
||
}
|
||
} else {
|
||
match dotenvy::dotenv_override() {
|
||
Ok(path) => Some(path.to_string_lossy().to_string()),
|
||
Err(_) => None,
|
||
}
|
||
};
|
||
|
||
tracing_subscriber::fmt()
|
||
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
|
||
.with_timer(LocalTimeFmt)
|
||
.init();
|
||
|
||
let db = db::init_db().await?;
|
||
// set global DB for tasks
|
||
db::set_db(db.clone()).expect("db set failure");
|
||
|
||
// initialize Redis connection
|
||
let redis_pool = redis::init_redis().await?;
|
||
redis::set_redis_pool(redis_pool)?;
|
||
|
||
// run migrations
|
||
migration::Migrator::up(&db, None).await.expect("migration up");
|
||
|
||
let allow_origins = std::env::var("CORS_ALLOW_ORIGINS").unwrap_or_else(|_| "http://localhost:5173".into());
|
||
let origin_values: Vec<HeaderValue> = allow_origins
|
||
.split(',')
|
||
.filter_map(|s| HeaderValue::from_str(s.trim()).ok())
|
||
.collect();
|
||
|
||
let allowed_methods = [
|
||
Method::GET,
|
||
Method::POST,
|
||
Method::PUT,
|
||
Method::PATCH,
|
||
Method::DELETE,
|
||
Method::OPTIONS,
|
||
];
|
||
|
||
let cors = if origin_values.is_empty() {
|
||
// 当允许任意来源时,不能与 allow_credentials(true) 同时使用
|
||
CorsLayer::new()
|
||
.allow_origin(Any)
|
||
.allow_methods(allowed_methods.clone())
|
||
.allow_headers([
|
||
axum::http::header::ACCEPT,
|
||
axum::http::header::CONTENT_TYPE,
|
||
axum::http::header::AUTHORIZATION,
|
||
])
|
||
.allow_credentials(false)
|
||
} else {
|
||
CorsLayer::new()
|
||
.allow_origin(AllowOrigin::list(origin_values))
|
||
.allow_methods(allowed_methods)
|
||
.allow_headers([
|
||
axum::http::header::ACCEPT,
|
||
axum::http::header::CONTENT_TYPE,
|
||
axum::http::header::AUTHORIZATION,
|
||
])
|
||
.allow_credentials(true)
|
||
};
|
||
|
||
let api = routes::api_router().with_state(db.clone());
|
||
|
||
let app = Router::new()
|
||
.nest("/api", api)
|
||
.layer(cors)
|
||
.layer(middleware::from_fn_with_state(db.clone(), middlewares::logging::request_logger));
|
||
|
||
// 读取并记录最终使用的主机与端口(默认端口改为 9898)
|
||
let app_host = std::env::var("APP_HOST").unwrap_or("0.0.0.0".into());
|
||
let app_port = std::env::var("APP_PORT").unwrap_or("9898".into());
|
||
if let Some(f) = &env_file_used { tracing::info!("env file loaded: {}", f); } else { tracing::info!("env file loaded: <none>"); }
|
||
tracing::info!("resolved APP_HOST={} APP_PORT={}", app_host, app_port);
|
||
|
||
let http_addr = format!("{}:{}", app_host, app_port);
|
||
tracing::info!("listening on {}", http_addr);
|
||
|
||
// HTTP 服务监听
|
||
let http_listener = tokio::net::TcpListener::bind(http_addr.clone()).await?;
|
||
let http_server = axum::serve(http_listener, app);
|
||
|
||
// WS 服务下沉到中间件
|
||
let ws_server = middlewares::ws::serve(db.clone());
|
||
|
||
// 新增:SSE 服务独立端口监听(默认 8866,可配 SSE_HOST/SSE_PORT)
|
||
let sse_server = middlewares::sse::serve(db.clone());
|
||
tokio::try_join!(http_server, ws_server, sse_server)?;
|
||
Ok(())
|
||
} |