243 lines
7.7 KiB
Rust
243 lines
7.7 KiB
Rust
use axum::{
|
|
extract::{Request, State},
|
|
http::{header::HeaderMap, StatusCode},
|
|
middleware::Next,
|
|
response::Response,
|
|
};
|
|
use sha2::{Digest, Sha256};
|
|
use std::sync::Arc;
|
|
|
|
use crate::core::auth::jwt;
|
|
use crate::core::db::postgres_db::ApiKeyRecord;
|
|
use crate::core::db::PostgresDb;
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub enum AuthSource {
|
|
Session,
|
|
Jwt,
|
|
ApiKey,
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct UserAuth {
|
|
pub user_id: i32,
|
|
pub role: String,
|
|
pub source: AuthSource,
|
|
pub key_id: String,
|
|
pub jwt_jti: Option<String>,
|
|
pub jwt_exp: Option<chrono::DateTime<chrono::Utc>>,
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub struct ApiState {
|
|
pub db: Arc<PostgresDb>,
|
|
}
|
|
|
|
pub fn extract_cookies(headers: &HeaderMap) -> Vec<(String, String)> {
|
|
let cookie_header = match headers.get("cookie").and_then(|v| v.to_str().ok()) {
|
|
Some(c) => c,
|
|
None => return Vec::new(),
|
|
};
|
|
cookie_header
|
|
.split(';')
|
|
.filter_map(|pair| {
|
|
let mut parts = pair.trim().splitn(2, '=');
|
|
match (parts.next(), parts.next()) {
|
|
(Some(k), Some(v)) => Some((k.to_lowercase(), v.to_string())),
|
|
_ => None,
|
|
}
|
|
})
|
|
.collect()
|
|
}
|
|
|
|
fn hash_key(key: &str) -> String {
|
|
let mut hasher = Sha256::new();
|
|
hasher.update(key.as_bytes());
|
|
format!("{:x}", hasher.finalize())
|
|
}
|
|
|
|
fn percent_decode(s: &str) -> String {
|
|
let mut result = String::new();
|
|
let mut chars = s.bytes();
|
|
while let Some(b) = chars.next() {
|
|
match b {
|
|
b'%' => {
|
|
let hi = chars.next().and_then(|c| hex_val(c)).unwrap_or(0);
|
|
let lo = chars.next().and_then(|c| hex_val(c)).unwrap_or(0);
|
|
result.push((hi << 4 | lo) as char);
|
|
}
|
|
b'+' => result.push(' '),
|
|
_ => result.push(b as char),
|
|
}
|
|
}
|
|
result
|
|
}
|
|
|
|
fn hex_val(c: u8) -> Option<u8> {
|
|
match c {
|
|
b'0'..=b'9' => Some(c - b'0'),
|
|
b'a'..=b'f' => Some(c - b'a' + 10),
|
|
b'A'..=b'F' => Some(c - b'A' + 10),
|
|
_ => None,
|
|
}
|
|
}
|
|
|
|
fn extract_api_key(headers: &HeaderMap, uri: &axum::http::Uri) -> Result<String, StatusCode> {
|
|
if let Some(key) = headers
|
|
.get("X-API-Key")
|
|
.and_then(|v| v.to_str().ok())
|
|
{
|
|
return Ok(key.to_string());
|
|
}
|
|
if let Some(auth) = headers
|
|
.get("Authorization")
|
|
.and_then(|v| v.to_str().ok())
|
|
{
|
|
// Check if it's a JWT (starts with eyJ)
|
|
let trimmed = auth.strip_prefix("Bearer ").unwrap_or(auth);
|
|
if !jwt::is_jwt(trimmed) {
|
|
return Ok(trimmed.to_string());
|
|
}
|
|
// If it IS a JWT, return it as-is — JWT branch handles it
|
|
return Ok(trimmed.to_string());
|
|
}
|
|
if let Some(query) = uri.query() {
|
|
for pair in query.split('&') {
|
|
let mut parts = pair.splitn(2, '=');
|
|
if let (Some(k), Some(v)) = (parts.next(), parts.next()) {
|
|
if k == "api_key" {
|
|
return Ok(percent_decode(v));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
Err(StatusCode::UNAUTHORIZED)
|
|
}
|
|
|
|
pub async fn unified_auth(
|
|
State(state): State<ApiState>,
|
|
mut request: Request,
|
|
next: Next,
|
|
) -> Response {
|
|
let headers = request.headers();
|
|
let uri = request.uri().clone();
|
|
|
|
// Priority 1: Cookie session (Portal)
|
|
let cookies = extract_cookies(headers);
|
|
if let Some(sid) = cookies.iter().find(|(k, _)| k == "session_id").map(|(_, v)| v.clone()) {
|
|
match state.db.get_session_by_id(&sid).await {
|
|
Ok(Some((_id, user_id, api_key_id, _expires_at))) => {
|
|
let key_hash = hash_key(&api_key_id);
|
|
match state.db.get_api_key_by_hash(&key_hash).await {
|
|
Ok(Some(record)) if record.status == "active" => {
|
|
let auth = UserAuth {
|
|
user_id: user_id,
|
|
role: record.key_type.clone(),
|
|
source: AuthSource::Session,
|
|
key_id: record.key_id.clone(),
|
|
jwt_jti: None,
|
|
jwt_exp: None,
|
|
};
|
|
if let Err(e) = state.db.update_api_key_usage(&record.key_id, None).await {
|
|
tracing::warn!("[AUTH] Failed to update key usage: {}", e);
|
|
}
|
|
request.extensions_mut().insert(auth);
|
|
return next.run(request).await;
|
|
}
|
|
Ok(Some(_)) => {
|
|
tracing::warn!("[AUTH] Session API key not active, removing session");
|
|
state.db.delete_session(&sid).await.ok();
|
|
}
|
|
_ => {}
|
|
}
|
|
}
|
|
Err(e) => tracing::error!("[AUTH] Session lookup error: {}", e),
|
|
_ => {}
|
|
}
|
|
}
|
|
|
|
// Priority 2: JWT (Authorization: Bearer <eyJ...>)
|
|
if let Some(auth_header) = headers
|
|
.get("Authorization")
|
|
.and_then(|v| v.to_str().ok())
|
|
{
|
|
if let Some(token) = auth_header.strip_prefix("Bearer ") {
|
|
if jwt::is_jwt(token) {
|
|
match jwt::verify_jwt(token) {
|
|
Ok(claims) => {
|
|
if !state.db.is_jwt_blacklisted(&claims.jti).await.unwrap_or(false) {
|
|
let exp = chrono::DateTime::from_timestamp(claims.exp as i64, 0);
|
|
let user_id: i32 = claims.sub.parse().unwrap_or(0);
|
|
let auth = UserAuth {
|
|
user_id,
|
|
role: claims.role,
|
|
source: AuthSource::Jwt,
|
|
key_id: String::new(),
|
|
jwt_jti: Some(claims.jti),
|
|
jwt_exp: exp,
|
|
};
|
|
request.extensions_mut().insert(auth);
|
|
return next.run(request).await;
|
|
}
|
|
}
|
|
Err(e) => {
|
|
tracing::debug!("[AUTH] JWT verification failed: {}", e);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Priority 3: API Key header / query param
|
|
let api_key = match extract_api_key(headers, &uri) {
|
|
Ok(key) => key,
|
|
Err(status) => {
|
|
return Response::builder()
|
|
.status(status)
|
|
.body(axum::body::Body::empty())
|
|
.unwrap();
|
|
}
|
|
};
|
|
|
|
let key_hash = hash_key(&api_key);
|
|
let record = match state.db.get_api_key_by_hash(&key_hash).await {
|
|
Ok(Some(r)) => r,
|
|
Ok(None) => {
|
|
return Response::builder()
|
|
.status(StatusCode::UNAUTHORIZED)
|
|
.body(axum::body::Body::empty())
|
|
.unwrap();
|
|
}
|
|
Err(e) => {
|
|
tracing::error!("[AUTH] DB error: {}", e);
|
|
return Response::builder()
|
|
.status(StatusCode::INTERNAL_SERVER_ERROR)
|
|
.body(axum::body::Body::empty())
|
|
.unwrap();
|
|
}
|
|
};
|
|
|
|
if record.status != "active" {
|
|
return Response::builder()
|
|
.status(StatusCode::UNAUTHORIZED)
|
|
.body(axum::body::Body::empty())
|
|
.unwrap();
|
|
}
|
|
|
|
let auth = UserAuth {
|
|
user_id: record.user_id.unwrap_or(0) as i32,
|
|
role: record.key_type.clone(),
|
|
source: AuthSource::ApiKey,
|
|
key_id: record.key_id.clone(),
|
|
jwt_jti: None,
|
|
jwt_exp: None,
|
|
};
|
|
|
|
if let Err(e) = state.db.update_api_key_usage(&record.key_id, None).await {
|
|
tracing::warn!("[AUTH] Failed to update key usage: {}", e);
|
|
}
|
|
|
|
request.extensions_mut().insert(auth);
|
|
next.run(request).await
|
|
}
|