use axum::{ extract::{Request, State}, http::{header::HeaderMap, StatusCode}, middleware::Next, response::Response, }; use sha2::{Digest, Sha256}; use std::sync::Arc; use crate::core::auth::jwt; use crate::core::db::postgres_db::ApiKeyRecord; use crate::core::db::PostgresDb; #[derive(Debug, Clone)] pub enum AuthSource { Session, Jwt, ApiKey, } #[derive(Debug, Clone)] pub struct UserAuth { pub user_id: i32, pub role: String, pub source: AuthSource, pub key_id: String, pub jwt_jti: Option, pub jwt_exp: Option>, } #[derive(Clone)] pub struct ApiState { pub db: Arc, } pub fn extract_cookies(headers: &HeaderMap) -> Vec<(String, String)> { let cookie_header = match headers.get("cookie").and_then(|v| v.to_str().ok()) { Some(c) => c, None => return Vec::new(), }; cookie_header .split(';') .filter_map(|pair| { let mut parts = pair.trim().splitn(2, '='); match (parts.next(), parts.next()) { (Some(k), Some(v)) => Some((k.to_lowercase(), v.to_string())), _ => None, } }) .collect() } fn hash_key(key: &str) -> String { let mut hasher = Sha256::new(); hasher.update(key.as_bytes()); format!("{:x}", hasher.finalize()) } fn percent_decode(s: &str) -> String { let mut result = String::new(); let mut chars = s.bytes(); while let Some(b) = chars.next() { match b { b'%' => { let hi = chars.next().and_then(|c| hex_val(c)).unwrap_or(0); let lo = chars.next().and_then(|c| hex_val(c)).unwrap_or(0); result.push((hi << 4 | lo) as char); } b'+' => result.push(' '), _ => result.push(b as char), } } result } fn hex_val(c: u8) -> Option { match c { b'0'..=b'9' => Some(c - b'0'), b'a'..=b'f' => Some(c - b'a' + 10), b'A'..=b'F' => Some(c - b'A' + 10), _ => None, } } fn extract_api_key(headers: &HeaderMap, uri: &axum::http::Uri) -> Result { if let Some(key) = headers .get("X-API-Key") .and_then(|v| v.to_str().ok()) { return Ok(key.to_string()); } if let Some(auth) = headers .get("Authorization") .and_then(|v| v.to_str().ok()) { // Check if it's a JWT (starts with eyJ) let trimmed = auth.strip_prefix("Bearer ").unwrap_or(auth); if !jwt::is_jwt(trimmed) { return Ok(trimmed.to_string()); } // If it IS a JWT, return it as-is — JWT branch handles it return Ok(trimmed.to_string()); } if let Some(query) = uri.query() { for pair in query.split('&') { let mut parts = pair.splitn(2, '='); if let (Some(k), Some(v)) = (parts.next(), parts.next()) { if k == "api_key" { return Ok(percent_decode(v)); } } } } Err(StatusCode::UNAUTHORIZED) } pub async fn unified_auth( State(state): State, mut request: Request, next: Next, ) -> Response { let headers = request.headers(); let uri = request.uri().clone(); // Priority 1: Cookie session (Portal) let cookies = extract_cookies(headers); if let Some(sid) = cookies.iter().find(|(k, _)| k == "session_id").map(|(_, v)| v.clone()) { match state.db.get_session_by_id(&sid).await { Ok(Some((_id, user_id, api_key_id, _expires_at))) => { let key_hash = hash_key(&api_key_id); match state.db.get_api_key_by_hash(&key_hash).await { Ok(Some(record)) if record.status == "active" => { let auth = UserAuth { user_id: user_id, role: record.key_type.clone(), source: AuthSource::Session, key_id: record.key_id.clone(), jwt_jti: None, jwt_exp: None, }; if let Err(e) = state.db.update_api_key_usage(&record.key_id, None).await { tracing::warn!("[AUTH] Failed to update key usage: {}", e); } request.extensions_mut().insert(auth); return next.run(request).await; } Ok(Some(_)) => { tracing::warn!("[AUTH] Session API key not active, removing session"); state.db.delete_session(&sid).await.ok(); } _ => {} } } Err(e) => tracing::error!("[AUTH] Session lookup error: {}", e), _ => {} } } // Priority 2: JWT (Authorization: Bearer ) if let Some(auth_header) = headers .get("Authorization") .and_then(|v| v.to_str().ok()) { if let Some(token) = auth_header.strip_prefix("Bearer ") { if jwt::is_jwt(token) { match jwt::verify_jwt(token) { Ok(claims) => { if !state.db.is_jwt_blacklisted(&claims.jti).await.unwrap_or(false) { let exp = chrono::DateTime::from_timestamp(claims.exp as i64, 0); let user_id: i32 = claims.sub.parse().unwrap_or(0); let auth = UserAuth { user_id, role: claims.role, source: AuthSource::Jwt, key_id: String::new(), jwt_jti: Some(claims.jti), jwt_exp: exp, }; request.extensions_mut().insert(auth); return next.run(request).await; } } Err(e) => { tracing::debug!("[AUTH] JWT verification failed: {}", e); } } } } } // Priority 3: API Key header / query param let api_key = match extract_api_key(headers, &uri) { Ok(key) => key, Err(status) => { return Response::builder() .status(status) .body(axum::body::Body::empty()) .unwrap(); } }; let key_hash = hash_key(&api_key); let record = match state.db.get_api_key_by_hash(&key_hash).await { Ok(Some(r)) => r, Ok(None) => { return Response::builder() .status(StatusCode::UNAUTHORIZED) .body(axum::body::Body::empty()) .unwrap(); } Err(e) => { tracing::error!("[AUTH] DB error: {}", e); return Response::builder() .status(StatusCode::INTERNAL_SERVER_ERROR) .body(axum::body::Body::empty()) .unwrap(); } }; if record.status != "active" { return Response::builder() .status(StatusCode::UNAUTHORIZED) .body(axum::body::Body::empty()) .unwrap(); } let auth = UserAuth { user_id: record.user_id.unwrap_or(0) as i32, role: record.key_type.clone(), source: AuthSource::ApiKey, key_id: record.key_id.clone(), jwt_jti: None, jwt_exp: None, }; if let Err(e) = state.db.update_api_key_usage(&record.key_id, None).await { tracing::warn!("[AUTH] Failed to update key usage: {}", e); } request.extensions_mut().insert(auth); next.run(request).await }