Files
momentry_core/src/api/middleware.rs

194 lines
5.5 KiB
Rust

use axum::{
extract::{Request, State},
http::{header::HeaderMap, StatusCode},
middleware::Next,
response::Response,
};
use sha2::{Digest, Sha256};
use std::sync::Arc;
use crate::core::db::postgres_db::ApiKeyRecord;
use crate::core::db::PostgresDb;
#[derive(Clone)]
pub struct ApiKeyAuth {
pub key_id: String,
pub record: ApiKeyRecord,
}
#[derive(Clone)]
pub struct ApiState {
pub db: Arc<PostgresDb>,
}
const PUBLIC_PATHS: &[&str] = &[
"/api/v1/faces/", // Thumbnail paths (partial match)
];
fn is_public_path(path: &str) -> bool {
PUBLIC_PATHS.iter().any(|prefix| path.starts_with(prefix)) && path.ends_with("/thumbnail")
}
pub async fn api_key_validation(
State(state): State<ApiState>,
request: Request,
next: Next,
) -> Response {
let path = request.uri().path();
tracing::info!("[MIDDLEWARE] Starting API key validation");
tracing::info!("[MIDDLEWARE] Path: {:?}", path);
if is_public_path(path) {
tracing::info!("[MIDDLEWARE] Public path, skipping auth: {}", path);
return next.run(request).await;
}
let headers = request.headers();
tracing::info!("[MIDDLEWARE] All headers: {:?}", headers);
let uri = request.uri().clone();
let api_key = match extract_api_key(headers, &uri) {
Ok(key) => {
tracing::info!("[MIDDLEWARE] API key extracted, length: {}", key.len());
if key.len() > 8 {
tracing::info!(
"[MIDDLEWARE] Key value: {}...{}",
&key[..4],
&key[key.len() - 4..]
);
} else {
tracing::info!("[MIDDLEWARE] Key value: ****");
}
key
}
Err(status) => {
tracing::warn!("[MIDDLEWARE] API key extraction failed: {:?}", status);
return Response::builder()
.status(status)
.body(axum::body::Body::empty())
.unwrap();
}
};
let key_hash = hash_key(&api_key);
tracing::info!("[MIDDLEWARE] Key hash: {}", &key_hash[..16]);
tracing::info!("[MIDDLEWARE] Querying database for key...");
let record = match state.db.get_api_key_by_hash(&key_hash).await {
Ok(Some(r)) => {
tracing::info!("[MIDDLEWARE] API key found: {}", r.key_id);
r
}
Ok(None) => {
tracing::warn!(
"[MIDDLEWARE] API key NOT FOUND in database for hash: {}",
&key_hash[..16]
);
return Response::builder()
.status(StatusCode::UNAUTHORIZED)
.body(axum::body::Body::empty())
.unwrap();
}
Err(e) => {
tracing::error!("[MIDDLEWARE] DB error: {}", e);
return Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(axum::body::Body::empty())
.unwrap();
}
};
if record.status != "active" {
tracing::warn!("[MIDDLEWARE] API key not active: {}", record.status);
return Response::builder()
.status(StatusCode::UNAUTHORIZED)
.body(axum::body::Body::empty())
.unwrap();
}
tracing::info!(
"[MIDDLEWARE] API key validated successfully: {}",
record.key_id
);
let auth = ApiKeyAuth {
key_id: record.key_id.clone(),
record,
};
if let Err(e) = state.db.update_api_key_usage(&auth.key_id, None).await {
tracing::warn!("[MIDDLEWARE] Failed to update API key usage: {}", e);
}
let mut request = request;
request.extensions_mut().insert(auth);
tracing::info!("[MIDDLEWARE] Passing request to handler");
let response = next.run(request).await;
tracing::info!("[MIDDLEWARE] Handler returned response");
response
}
fn extract_api_key(headers: &HeaderMap, uri: &axum::http::Uri) -> Result<String, StatusCode> {
// 1. X-API-Key header
if let Some(key) = headers
.get("X-API-Key")
.and_then(|v| v.to_str().ok())
{
return Ok(key.to_string());
}
// 2. Authorization: Bearer <key>
if let Some(auth) = headers
.get("Authorization")
.and_then(|v| v.to_str().ok())
{
if let Some(key) = auth.strip_prefix("Bearer ") {
return Ok(key.to_string());
}
}
// 3. ?api_key=<key> query parameter
if let Some(query) = uri.query() {
for pair in query.split('&') {
let mut parts = pair.splitn(2, '=');
if let (Some(k), Some(v)) = (parts.next(), parts.next()) {
if k == "api_key" {
return Ok(percent_decode(v));
}
}
}
}
Err(StatusCode::UNAUTHORIZED)
}
fn percent_decode(s: &str) -> String {
let mut result = String::new();
let mut chars = s.bytes();
while let Some(b) = chars.next() {
match b {
b'%' => {
let hi = chars.next().and_then(|c| hex_val(c)).unwrap_or(0);
let lo = chars.next().and_then(|c| hex_val(c)).unwrap_or(0);
result.push((hi << 4 | lo) as char);
}
b'+' => result.push(' '),
_ => result.push(b as char),
}
}
result
}
fn hex_val(c: u8) -> Option<u8> {
match c {
b'0'..=b'9' => Some(c - b'0'),
b'a'..=b'f' => Some(c - b'a' + 10),
b'A'..=b'F' => Some(c - b'A' + 10),
_ => None,
}
}
fn hash_key(key: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(key.as_bytes());
format!("{:x}", hasher.finalize())
}