use axum::{ extract::{Request, State}, http::{header::HeaderMap, StatusCode}, middleware::Next, response::Response, }; use sha2::{Digest, Sha256}; use std::sync::Arc; use crate::core::db::postgres_db::ApiKeyRecord; use crate::core::db::PostgresDb; #[derive(Clone)] pub struct ApiKeyAuth { pub key_id: String, pub record: ApiKeyRecord, } #[derive(Clone)] pub struct ApiState { pub db: Arc, } pub async fn api_key_validation( State(state): State, request: Request, next: Next, ) -> Response { tracing::info!("[MIDDLEWARE] Starting API key validation"); tracing::info!("[MIDDLEWARE] Path: {:?}", request.uri().path()); let headers = request.headers(); tracing::info!("[MIDDLEWARE] All headers: {:?}", headers); let api_key = match extract_api_key(headers) { Ok(key) => { tracing::info!("[MIDDLEWARE] API key extracted, length: {}", key.len()); if key.len() > 8 { tracing::info!( "[MIDDLEWARE] Key value: {}...{}", &key[..4], &key[key.len() - 4..] ); } else { tracing::info!("[MIDDLEWARE] Key value: ****"); } key } Err(status) => { tracing::warn!("[MIDDLEWARE] API key extraction failed: {:?}", status); return Response::builder() .status(status) .body(axum::body::Body::empty()) .unwrap(); } }; let key_hash = hash_key(&api_key); tracing::info!("[MIDDLEWARE] Key hash: {}", &key_hash[..16]); tracing::info!("[MIDDLEWARE] Querying database for key..."); let record = match state.db.get_api_key_by_hash(&key_hash).await { Ok(Some(r)) => { tracing::info!("[MIDDLEWARE] API key found: {}", r.key_id); r } Ok(None) => { tracing::warn!( "[MIDDLEWARE] API key NOT FOUND in database for hash: {}", &key_hash[..16] ); return Response::builder() .status(StatusCode::UNAUTHORIZED) .body(axum::body::Body::empty()) .unwrap(); } Err(e) => { tracing::error!("[MIDDLEWARE] DB error: {}", e); return Response::builder() .status(StatusCode::INTERNAL_SERVER_ERROR) .body(axum::body::Body::empty()) .unwrap(); } }; if record.status != "active" { tracing::warn!("[MIDDLEWARE] API key not active: {}", record.status); return Response::builder() .status(StatusCode::UNAUTHORIZED) .body(axum::body::Body::empty()) .unwrap(); } tracing::info!( "[MIDDLEWARE] API key validated successfully: {}", record.key_id ); let auth = ApiKeyAuth { key_id: record.key_id.clone(), record, }; if let Err(e) = state.db.update_api_key_usage(&auth.key_id, None).await { tracing::warn!("[MIDDLEWARE] Failed to update API key usage: {}", e); } let mut request = request; request.extensions_mut().insert(auth); tracing::info!("[MIDDLEWARE] Passing request to handler"); let response = next.run(request).await; tracing::info!("[MIDDLEWARE] Handler returned response"); response } fn extract_api_key(headers: &HeaderMap) -> Result { headers .get("X-API-Key") .and_then(|v| v.to_str().ok()) .map(|s| s.to_string()) .ok_or(StatusCode::UNAUTHORIZED) } fn hash_key(key: &str) -> String { let mut hasher = Sha256::new(); hasher.update(key.as_bytes()); format!("{:x}", hasher.finalize()) }