194 lines
5.5 KiB
Rust
194 lines
5.5 KiB
Rust
use axum::{
|
|
extract::{Request, State},
|
|
http::{header::HeaderMap, StatusCode},
|
|
middleware::Next,
|
|
response::Response,
|
|
};
|
|
use sha2::{Digest, Sha256};
|
|
use std::sync::Arc;
|
|
|
|
use crate::core::db::postgres_db::ApiKeyRecord;
|
|
use crate::core::db::PostgresDb;
|
|
|
|
#[derive(Clone)]
|
|
pub struct ApiKeyAuth {
|
|
pub key_id: String,
|
|
pub record: ApiKeyRecord,
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub struct ApiState {
|
|
pub db: Arc<PostgresDb>,
|
|
}
|
|
|
|
const PUBLIC_PATHS: &[&str] = &[
|
|
"/api/v1/faces/", // Thumbnail paths (partial match)
|
|
];
|
|
|
|
fn is_public_path(path: &str) -> bool {
|
|
PUBLIC_PATHS.iter().any(|prefix| path.starts_with(prefix)) && path.ends_with("/thumbnail")
|
|
}
|
|
|
|
pub async fn api_key_validation(
|
|
State(state): State<ApiState>,
|
|
request: Request,
|
|
next: Next,
|
|
) -> Response {
|
|
let path = request.uri().path();
|
|
tracing::info!("[MIDDLEWARE] Starting API key validation");
|
|
tracing::info!("[MIDDLEWARE] Path: {:?}", path);
|
|
|
|
if is_public_path(path) {
|
|
tracing::info!("[MIDDLEWARE] Public path, skipping auth: {}", path);
|
|
return next.run(request).await;
|
|
}
|
|
|
|
let headers = request.headers();
|
|
tracing::info!("[MIDDLEWARE] All headers: {:?}", headers);
|
|
|
|
let uri = request.uri().clone();
|
|
let api_key = match extract_api_key(headers, &uri) {
|
|
Ok(key) => {
|
|
tracing::info!("[MIDDLEWARE] API key extracted, length: {}", key.len());
|
|
if key.len() > 8 {
|
|
tracing::info!(
|
|
"[MIDDLEWARE] Key value: {}...{}",
|
|
&key[..4],
|
|
&key[key.len() - 4..]
|
|
);
|
|
} else {
|
|
tracing::info!("[MIDDLEWARE] Key value: ****");
|
|
}
|
|
key
|
|
}
|
|
Err(status) => {
|
|
tracing::warn!("[MIDDLEWARE] API key extraction failed: {:?}", status);
|
|
return Response::builder()
|
|
.status(status)
|
|
.body(axum::body::Body::empty())
|
|
.unwrap();
|
|
}
|
|
};
|
|
|
|
let key_hash = hash_key(&api_key);
|
|
tracing::info!("[MIDDLEWARE] Key hash: {}", &key_hash[..16]);
|
|
|
|
tracing::info!("[MIDDLEWARE] Querying database for key...");
|
|
let record = match state.db.get_api_key_by_hash(&key_hash).await {
|
|
Ok(Some(r)) => {
|
|
tracing::info!("[MIDDLEWARE] API key found: {}", r.key_id);
|
|
r
|
|
}
|
|
Ok(None) => {
|
|
tracing::warn!(
|
|
"[MIDDLEWARE] API key NOT FOUND in database for hash: {}",
|
|
&key_hash[..16]
|
|
);
|
|
return Response::builder()
|
|
.status(StatusCode::UNAUTHORIZED)
|
|
.body(axum::body::Body::empty())
|
|
.unwrap();
|
|
}
|
|
Err(e) => {
|
|
tracing::error!("[MIDDLEWARE] DB error: {}", e);
|
|
return Response::builder()
|
|
.status(StatusCode::INTERNAL_SERVER_ERROR)
|
|
.body(axum::body::Body::empty())
|
|
.unwrap();
|
|
}
|
|
};
|
|
|
|
if record.status != "active" {
|
|
tracing::warn!("[MIDDLEWARE] API key not active: {}", record.status);
|
|
return Response::builder()
|
|
.status(StatusCode::UNAUTHORIZED)
|
|
.body(axum::body::Body::empty())
|
|
.unwrap();
|
|
}
|
|
|
|
tracing::info!(
|
|
"[MIDDLEWARE] API key validated successfully: {}",
|
|
record.key_id
|
|
);
|
|
|
|
let auth = ApiKeyAuth {
|
|
key_id: record.key_id.clone(),
|
|
record,
|
|
};
|
|
|
|
if let Err(e) = state.db.update_api_key_usage(&auth.key_id, None).await {
|
|
tracing::warn!("[MIDDLEWARE] Failed to update API key usage: {}", e);
|
|
}
|
|
|
|
let mut request = request;
|
|
request.extensions_mut().insert(auth);
|
|
|
|
tracing::info!("[MIDDLEWARE] Passing request to handler");
|
|
let response = next.run(request).await;
|
|
tracing::info!("[MIDDLEWARE] Handler returned response");
|
|
response
|
|
}
|
|
|
|
fn extract_api_key(headers: &HeaderMap, uri: &axum::http::Uri) -> Result<String, StatusCode> {
|
|
// 1. X-API-Key header
|
|
if let Some(key) = headers
|
|
.get("X-API-Key")
|
|
.and_then(|v| v.to_str().ok())
|
|
{
|
|
return Ok(key.to_string());
|
|
}
|
|
// 2. Authorization: Bearer <key>
|
|
if let Some(auth) = headers
|
|
.get("Authorization")
|
|
.and_then(|v| v.to_str().ok())
|
|
{
|
|
if let Some(key) = auth.strip_prefix("Bearer ") {
|
|
return Ok(key.to_string());
|
|
}
|
|
}
|
|
// 3. ?api_key=<key> query parameter
|
|
if let Some(query) = uri.query() {
|
|
for pair in query.split('&') {
|
|
let mut parts = pair.splitn(2, '=');
|
|
if let (Some(k), Some(v)) = (parts.next(), parts.next()) {
|
|
if k == "api_key" {
|
|
return Ok(percent_decode(v));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
Err(StatusCode::UNAUTHORIZED)
|
|
}
|
|
|
|
fn percent_decode(s: &str) -> String {
|
|
let mut result = String::new();
|
|
let mut chars = s.bytes();
|
|
while let Some(b) = chars.next() {
|
|
match b {
|
|
b'%' => {
|
|
let hi = chars.next().and_then(|c| hex_val(c)).unwrap_or(0);
|
|
let lo = chars.next().and_then(|c| hex_val(c)).unwrap_or(0);
|
|
result.push((hi << 4 | lo) as char);
|
|
}
|
|
b'+' => result.push(' '),
|
|
_ => result.push(b as char),
|
|
}
|
|
}
|
|
result
|
|
}
|
|
|
|
fn hex_val(c: u8) -> Option<u8> {
|
|
match c {
|
|
b'0'..=b'9' => Some(c - b'0'),
|
|
b'a'..=b'f' => Some(c - b'a' + 10),
|
|
b'A'..=b'F' => Some(c - b'A' + 10),
|
|
_ => None,
|
|
}
|
|
}
|
|
|
|
fn hash_key(key: &str) -> String {
|
|
let mut hasher = Sha256::new();
|
|
hasher.update(key.as_bytes());
|
|
format!("{:x}", hasher.finalize())
|
|
}
|