update: pipeline, search, clip, embedding fixes
This commit is contained in:
@@ -7,13 +7,25 @@ use axum::{
|
||||
use sha2::{Digest, Sha256};
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::core::auth::jwt;
|
||||
use crate::core::db::postgres_db::ApiKeyRecord;
|
||||
use crate::core::db::PostgresDb;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ApiKeyAuth {
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum AuthSource {
|
||||
Session,
|
||||
Jwt,
|
||||
ApiKey,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct UserAuth {
|
||||
pub user_id: i32,
|
||||
pub role: String,
|
||||
pub source: AuthSource,
|
||||
pub key_id: String,
|
||||
pub record: ApiKeyRecord,
|
||||
pub jwt_jti: Option<String>,
|
||||
pub jwt_exp: Option<chrono::DateTime<chrono::Utc>>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -21,143 +33,27 @@ pub struct ApiState {
|
||||
pub db: Arc<PostgresDb>,
|
||||
}
|
||||
|
||||
const PUBLIC_PATHS: &[&str] = &[
|
||||
"/api/v1/faces/", // Thumbnail paths (partial match)
|
||||
];
|
||||
|
||||
fn is_public_path(path: &str) -> bool {
|
||||
PUBLIC_PATHS.iter().any(|prefix| path.starts_with(prefix)) && path.ends_with("/thumbnail")
|
||||
pub fn extract_cookies(headers: &HeaderMap) -> Vec<(String, String)> {
|
||||
let cookie_header = match headers.get("cookie").and_then(|v| v.to_str().ok()) {
|
||||
Some(c) => c,
|
||||
None => return Vec::new(),
|
||||
};
|
||||
cookie_header
|
||||
.split(';')
|
||||
.filter_map(|pair| {
|
||||
let mut parts = pair.trim().splitn(2, '=');
|
||||
match (parts.next(), parts.next()) {
|
||||
(Some(k), Some(v)) => Some((k.to_lowercase(), v.to_string())),
|
||||
_ => None,
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub async fn api_key_validation(
|
||||
State(state): State<ApiState>,
|
||||
request: Request,
|
||||
next: Next,
|
||||
) -> Response {
|
||||
let path = request.uri().path();
|
||||
tracing::info!("[MIDDLEWARE] Starting API key validation");
|
||||
tracing::info!("[MIDDLEWARE] Path: {:?}", path);
|
||||
|
||||
if is_public_path(path) {
|
||||
tracing::info!("[MIDDLEWARE] Public path, skipping auth: {}", path);
|
||||
return next.run(request).await;
|
||||
}
|
||||
|
||||
let headers = request.headers();
|
||||
tracing::info!("[MIDDLEWARE] All headers: {:?}", headers);
|
||||
|
||||
let uri = request.uri().clone();
|
||||
let api_key = match extract_api_key(headers, &uri) {
|
||||
Ok(key) => {
|
||||
tracing::info!("[MIDDLEWARE] API key extracted, length: {}", key.len());
|
||||
if key.len() > 8 {
|
||||
tracing::info!(
|
||||
"[MIDDLEWARE] Key value: {}...{}",
|
||||
&key[..4],
|
||||
&key[key.len() - 4..]
|
||||
);
|
||||
} else {
|
||||
tracing::info!("[MIDDLEWARE] Key value: ****");
|
||||
}
|
||||
key
|
||||
}
|
||||
Err(status) => {
|
||||
tracing::warn!("[MIDDLEWARE] API key extraction failed: {:?}", status);
|
||||
return Response::builder()
|
||||
.status(status)
|
||||
.body(axum::body::Body::empty())
|
||||
.unwrap();
|
||||
}
|
||||
};
|
||||
|
||||
let key_hash = hash_key(&api_key);
|
||||
tracing::info!("[MIDDLEWARE] Key hash: {}", &key_hash[..16]);
|
||||
|
||||
tracing::info!("[MIDDLEWARE] Querying database for key...");
|
||||
let record = match state.db.get_api_key_by_hash(&key_hash).await {
|
||||
Ok(Some(r)) => {
|
||||
tracing::info!("[MIDDLEWARE] API key found: {}", r.key_id);
|
||||
r
|
||||
}
|
||||
Ok(None) => {
|
||||
tracing::warn!(
|
||||
"[MIDDLEWARE] API key NOT FOUND in database for hash: {}",
|
||||
&key_hash[..16]
|
||||
);
|
||||
return Response::builder()
|
||||
.status(StatusCode::UNAUTHORIZED)
|
||||
.body(axum::body::Body::empty())
|
||||
.unwrap();
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("[MIDDLEWARE] DB error: {}", e);
|
||||
return Response::builder()
|
||||
.status(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
.body(axum::body::Body::empty())
|
||||
.unwrap();
|
||||
}
|
||||
};
|
||||
|
||||
if record.status != "active" {
|
||||
tracing::warn!("[MIDDLEWARE] API key not active: {}", record.status);
|
||||
return Response::builder()
|
||||
.status(StatusCode::UNAUTHORIZED)
|
||||
.body(axum::body::Body::empty())
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
"[MIDDLEWARE] API key validated successfully: {}",
|
||||
record.key_id
|
||||
);
|
||||
|
||||
let auth = ApiKeyAuth {
|
||||
key_id: record.key_id.clone(),
|
||||
record,
|
||||
};
|
||||
|
||||
if let Err(e) = state.db.update_api_key_usage(&auth.key_id, None).await {
|
||||
tracing::warn!("[MIDDLEWARE] Failed to update API key usage: {}", e);
|
||||
}
|
||||
|
||||
let mut request = request;
|
||||
request.extensions_mut().insert(auth);
|
||||
|
||||
tracing::info!("[MIDDLEWARE] Passing request to handler");
|
||||
let response = next.run(request).await;
|
||||
tracing::info!("[MIDDLEWARE] Handler returned response");
|
||||
response
|
||||
}
|
||||
|
||||
fn extract_api_key(headers: &HeaderMap, uri: &axum::http::Uri) -> Result<String, StatusCode> {
|
||||
// 1. X-API-Key header
|
||||
if let Some(key) = headers
|
||||
.get("X-API-Key")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
{
|
||||
return Ok(key.to_string());
|
||||
}
|
||||
// 2. Authorization: Bearer <key>
|
||||
if let Some(auth) = headers
|
||||
.get("Authorization")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
{
|
||||
if let Some(key) = auth.strip_prefix("Bearer ") {
|
||||
return Ok(key.to_string());
|
||||
}
|
||||
}
|
||||
// 3. ?api_key=<key> query parameter
|
||||
if let Some(query) = uri.query() {
|
||||
for pair in query.split('&') {
|
||||
let mut parts = pair.splitn(2, '=');
|
||||
if let (Some(k), Some(v)) = (parts.next(), parts.next()) {
|
||||
if k == "api_key" {
|
||||
return Ok(percent_decode(v));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(StatusCode::UNAUTHORIZED)
|
||||
fn hash_key(key: &str) -> String {
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(key.as_bytes());
|
||||
format!("{:x}", hasher.finalize())
|
||||
}
|
||||
|
||||
fn percent_decode(s: &str) -> String {
|
||||
@@ -186,8 +82,161 @@ fn hex_val(c: u8) -> Option<u8> {
|
||||
}
|
||||
}
|
||||
|
||||
fn hash_key(key: &str) -> String {
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(key.as_bytes());
|
||||
format!("{:x}", hasher.finalize())
|
||||
fn extract_api_key(headers: &HeaderMap, uri: &axum::http::Uri) -> Result<String, StatusCode> {
|
||||
if let Some(key) = headers
|
||||
.get("X-API-Key")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
{
|
||||
return Ok(key.to_string());
|
||||
}
|
||||
if let Some(auth) = headers
|
||||
.get("Authorization")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
{
|
||||
// Check if it's a JWT (starts with eyJ)
|
||||
let trimmed = auth.strip_prefix("Bearer ").unwrap_or(auth);
|
||||
if !jwt::is_jwt(trimmed) {
|
||||
return Ok(trimmed.to_string());
|
||||
}
|
||||
// If it IS a JWT, return it as-is — JWT branch handles it
|
||||
return Ok(trimmed.to_string());
|
||||
}
|
||||
if let Some(query) = uri.query() {
|
||||
for pair in query.split('&') {
|
||||
let mut parts = pair.splitn(2, '=');
|
||||
if let (Some(k), Some(v)) = (parts.next(), parts.next()) {
|
||||
if k == "api_key" {
|
||||
return Ok(percent_decode(v));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(StatusCode::UNAUTHORIZED)
|
||||
}
|
||||
|
||||
pub async fn unified_auth(
|
||||
State(state): State<ApiState>,
|
||||
mut request: Request,
|
||||
next: Next,
|
||||
) -> Response {
|
||||
let headers = request.headers();
|
||||
let uri = request.uri().clone();
|
||||
|
||||
// Priority 1: Cookie session (Portal)
|
||||
let cookies = extract_cookies(headers);
|
||||
if let Some(sid) = cookies.iter().find(|(k, _)| k == "session_id").map(|(_, v)| v.clone()) {
|
||||
match state.db.get_session_by_id(&sid).await {
|
||||
Ok(Some((_id, user_id, api_key_id, _expires_at))) => {
|
||||
let key_hash = hash_key(&api_key_id);
|
||||
match state.db.get_api_key_by_hash(&key_hash).await {
|
||||
Ok(Some(record)) if record.status == "active" => {
|
||||
let auth = UserAuth {
|
||||
user_id: user_id,
|
||||
role: record.key_type.clone(),
|
||||
source: AuthSource::Session,
|
||||
key_id: record.key_id.clone(),
|
||||
jwt_jti: None,
|
||||
jwt_exp: None,
|
||||
};
|
||||
if let Err(e) = state.db.update_api_key_usage(&record.key_id, None).await {
|
||||
tracing::warn!("[AUTH] Failed to update key usage: {}", e);
|
||||
}
|
||||
request.extensions_mut().insert(auth);
|
||||
return next.run(request).await;
|
||||
}
|
||||
Ok(Some(_)) => {
|
||||
tracing::warn!("[AUTH] Session API key not active, removing session");
|
||||
state.db.delete_session(&sid).await.ok();
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
Err(e) => tracing::error!("[AUTH] Session lookup error: {}", e),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
// Priority 2: JWT (Authorization: Bearer <eyJ...>)
|
||||
if let Some(auth_header) = headers
|
||||
.get("Authorization")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
{
|
||||
if let Some(token) = auth_header.strip_prefix("Bearer ") {
|
||||
if jwt::is_jwt(token) {
|
||||
match jwt::verify_jwt(token) {
|
||||
Ok(claims) => {
|
||||
if !state.db.is_jwt_blacklisted(&claims.jti).await.unwrap_or(false) {
|
||||
let exp = chrono::DateTime::from_timestamp(claims.exp as i64, 0);
|
||||
let user_id: i32 = claims.sub.parse().unwrap_or(0);
|
||||
let auth = UserAuth {
|
||||
user_id,
|
||||
role: claims.role,
|
||||
source: AuthSource::Jwt,
|
||||
key_id: String::new(),
|
||||
jwt_jti: Some(claims.jti),
|
||||
jwt_exp: exp,
|
||||
};
|
||||
request.extensions_mut().insert(auth);
|
||||
return next.run(request).await;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::debug!("[AUTH] JWT verification failed: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Priority 3: API Key header / query param
|
||||
let api_key = match extract_api_key(headers, &uri) {
|
||||
Ok(key) => key,
|
||||
Err(status) => {
|
||||
return Response::builder()
|
||||
.status(status)
|
||||
.body(axum::body::Body::empty())
|
||||
.unwrap();
|
||||
}
|
||||
};
|
||||
|
||||
let key_hash = hash_key(&api_key);
|
||||
let record = match state.db.get_api_key_by_hash(&key_hash).await {
|
||||
Ok(Some(r)) => r,
|
||||
Ok(None) => {
|
||||
return Response::builder()
|
||||
.status(StatusCode::UNAUTHORIZED)
|
||||
.body(axum::body::Body::empty())
|
||||
.unwrap();
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("[AUTH] DB error: {}", e);
|
||||
return Response::builder()
|
||||
.status(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
.body(axum::body::Body::empty())
|
||||
.unwrap();
|
||||
}
|
||||
};
|
||||
|
||||
if record.status != "active" {
|
||||
return Response::builder()
|
||||
.status(StatusCode::UNAUTHORIZED)
|
||||
.body(axum::body::Body::empty())
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let auth = UserAuth {
|
||||
user_id: record.user_id.unwrap_or(0) as i32,
|
||||
role: record.key_type.clone(),
|
||||
source: AuthSource::ApiKey,
|
||||
key_id: record.key_id.clone(),
|
||||
jwt_jti: None,
|
||||
jwt_exp: None,
|
||||
};
|
||||
|
||||
if let Err(e) = state.db.update_api_key_usage(&record.key_id, None).await {
|
||||
tracing::warn!("[AUTH] Failed to update key usage: {}", e);
|
||||
}
|
||||
|
||||
request.extensions_mut().insert(auth);
|
||||
next.run(request).await
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user