Files
momentry_core/src/api/middleware.rs
2026-05-17 19:46:35 +08:00

243 lines
7.7 KiB
Rust

use axum::{
extract::{Request, State},
http::{header::HeaderMap, StatusCode},
middleware::Next,
response::Response,
};
use sha2::{Digest, Sha256};
use std::sync::Arc;
use crate::core::auth::jwt;
use crate::core::db::postgres_db::ApiKeyRecord;
use crate::core::db::PostgresDb;
#[derive(Debug, Clone)]
pub enum AuthSource {
Session,
Jwt,
ApiKey,
}
#[derive(Debug, Clone)]
pub struct UserAuth {
pub user_id: i32,
pub role: String,
pub source: AuthSource,
pub key_id: String,
pub jwt_jti: Option<String>,
pub jwt_exp: Option<chrono::DateTime<chrono::Utc>>,
}
#[derive(Clone)]
pub struct ApiState {
pub db: Arc<PostgresDb>,
}
pub fn extract_cookies(headers: &HeaderMap) -> Vec<(String, String)> {
let cookie_header = match headers.get("cookie").and_then(|v| v.to_str().ok()) {
Some(c) => c,
None => return Vec::new(),
};
cookie_header
.split(';')
.filter_map(|pair| {
let mut parts = pair.trim().splitn(2, '=');
match (parts.next(), parts.next()) {
(Some(k), Some(v)) => Some((k.to_lowercase(), v.to_string())),
_ => None,
}
})
.collect()
}
fn hash_key(key: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(key.as_bytes());
format!("{:x}", hasher.finalize())
}
fn percent_decode(s: &str) -> String {
let mut result = String::new();
let mut chars = s.bytes();
while let Some(b) = chars.next() {
match b {
b'%' => {
let hi = chars.next().and_then(|c| hex_val(c)).unwrap_or(0);
let lo = chars.next().and_then(|c| hex_val(c)).unwrap_or(0);
result.push((hi << 4 | lo) as char);
}
b'+' => result.push(' '),
_ => result.push(b as char),
}
}
result
}
fn hex_val(c: u8) -> Option<u8> {
match c {
b'0'..=b'9' => Some(c - b'0'),
b'a'..=b'f' => Some(c - b'a' + 10),
b'A'..=b'F' => Some(c - b'A' + 10),
_ => None,
}
}
fn extract_api_key(headers: &HeaderMap, uri: &axum::http::Uri) -> Result<String, StatusCode> {
if let Some(key) = headers
.get("X-API-Key")
.and_then(|v| v.to_str().ok())
{
return Ok(key.to_string());
}
if let Some(auth) = headers
.get("Authorization")
.and_then(|v| v.to_str().ok())
{
// Check if it's a JWT (starts with eyJ)
let trimmed = auth.strip_prefix("Bearer ").unwrap_or(auth);
if !jwt::is_jwt(trimmed) {
return Ok(trimmed.to_string());
}
// If it IS a JWT, return it as-is — JWT branch handles it
return Ok(trimmed.to_string());
}
if let Some(query) = uri.query() {
for pair in query.split('&') {
let mut parts = pair.splitn(2, '=');
if let (Some(k), Some(v)) = (parts.next(), parts.next()) {
if k == "api_key" {
return Ok(percent_decode(v));
}
}
}
}
Err(StatusCode::UNAUTHORIZED)
}
pub async fn unified_auth(
State(state): State<ApiState>,
mut request: Request,
next: Next,
) -> Response {
let headers = request.headers();
let uri = request.uri().clone();
// Priority 1: Cookie session (Portal)
let cookies = extract_cookies(headers);
if let Some(sid) = cookies.iter().find(|(k, _)| k == "session_id").map(|(_, v)| v.clone()) {
match state.db.get_session_by_id(&sid).await {
Ok(Some((_id, user_id, api_key_id, _expires_at))) => {
let key_hash = hash_key(&api_key_id);
match state.db.get_api_key_by_hash(&key_hash).await {
Ok(Some(record)) if record.status == "active" => {
let auth = UserAuth {
user_id: user_id,
role: record.key_type.clone(),
source: AuthSource::Session,
key_id: record.key_id.clone(),
jwt_jti: None,
jwt_exp: None,
};
if let Err(e) = state.db.update_api_key_usage(&record.key_id, None).await {
tracing::warn!("[AUTH] Failed to update key usage: {}", e);
}
request.extensions_mut().insert(auth);
return next.run(request).await;
}
Ok(Some(_)) => {
tracing::warn!("[AUTH] Session API key not active, removing session");
state.db.delete_session(&sid).await.ok();
}
_ => {}
}
}
Err(e) => tracing::error!("[AUTH] Session lookup error: {}", e),
_ => {}
}
}
// Priority 2: JWT (Authorization: Bearer <eyJ...>)
if let Some(auth_header) = headers
.get("Authorization")
.and_then(|v| v.to_str().ok())
{
if let Some(token) = auth_header.strip_prefix("Bearer ") {
if jwt::is_jwt(token) {
match jwt::verify_jwt(token) {
Ok(claims) => {
if !state.db.is_jwt_blacklisted(&claims.jti).await.unwrap_or(false) {
let exp = chrono::DateTime::from_timestamp(claims.exp as i64, 0);
let user_id: i32 = claims.sub.parse().unwrap_or(0);
let auth = UserAuth {
user_id,
role: claims.role,
source: AuthSource::Jwt,
key_id: String::new(),
jwt_jti: Some(claims.jti),
jwt_exp: exp,
};
request.extensions_mut().insert(auth);
return next.run(request).await;
}
}
Err(e) => {
tracing::debug!("[AUTH] JWT verification failed: {}", e);
}
}
}
}
}
// Priority 3: API Key header / query param
let api_key = match extract_api_key(headers, &uri) {
Ok(key) => key,
Err(status) => {
return Response::builder()
.status(status)
.body(axum::body::Body::empty())
.unwrap();
}
};
let key_hash = hash_key(&api_key);
let record = match state.db.get_api_key_by_hash(&key_hash).await {
Ok(Some(r)) => r,
Ok(None) => {
return Response::builder()
.status(StatusCode::UNAUTHORIZED)
.body(axum::body::Body::empty())
.unwrap();
}
Err(e) => {
tracing::error!("[AUTH] DB error: {}", e);
return Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(axum::body::Body::empty())
.unwrap();
}
};
if record.status != "active" {
return Response::builder()
.status(StatusCode::UNAUTHORIZED)
.body(axum::body::Body::empty())
.unwrap();
}
let auth = UserAuth {
user_id: record.user_id.unwrap_or(0) as i32,
role: record.key_type.clone(),
source: AuthSource::ApiKey,
key_id: record.key_id.clone(),
jwt_jti: None,
jwt_exp: None,
};
if let Err(e) = state.db.update_api_key_usage(&record.key_id, None).await {
tracing::warn!("[AUTH] Failed to update key usage: {}", e);
}
request.extensions_mut().insert(auth);
next.run(request).await
}