feat: Initial v0.9 release with API Key authentication
## v0.9.20260325_144654 ### Features - API Key Authentication System - Job Worker System - V2 Backup Versioning ### Bug Fixes - get_processor_results_by_job column mapping Co-authored-by: OpenCode
This commit is contained in:
120
src/api/middleware.rs
Normal file
120
src/api/middleware.rs
Normal file
@@ -0,0 +1,120 @@
|
||||
use axum::{
|
||||
extract::{Request, State},
|
||||
http::{header::HeaderMap, StatusCode},
|
||||
middleware::Next,
|
||||
response::Response,
|
||||
};
|
||||
use sha2::{Digest, Sha256};
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::core::db::postgres_db::ApiKeyRecord;
|
||||
use crate::core::db::PostgresDb;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ApiKeyAuth {
|
||||
pub key_id: String,
|
||||
pub record: ApiKeyRecord,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ApiState {
|
||||
pub db: Arc<PostgresDb>,
|
||||
}
|
||||
|
||||
pub async fn api_key_validation(
|
||||
State(state): State<ApiState>,
|
||||
request: Request,
|
||||
next: Next,
|
||||
) -> Response {
|
||||
tracing::info!("[MIDDLEWARE] Starting API key validation");
|
||||
tracing::info!("[MIDDLEWARE] Path: {:?}", request.uri().path());
|
||||
|
||||
let headers = request.headers();
|
||||
tracing::info!(
|
||||
"[MIDDLEWARE] Headers: {:?}",
|
||||
headers.keys().collect::<Vec<_>>()
|
||||
);
|
||||
|
||||
let api_key = match extract_api_key(headers) {
|
||||
Ok(key) => {
|
||||
tracing::info!("[MIDDLEWARE] API key extracted, length: {}", key.len());
|
||||
key
|
||||
}
|
||||
Err(status) => {
|
||||
tracing::warn!("[MIDDLEWARE] API key extraction failed: {:?}", status);
|
||||
return Response::builder()
|
||||
.status(status)
|
||||
.body(axum::body::Body::empty())
|
||||
.unwrap();
|
||||
}
|
||||
};
|
||||
|
||||
let key_hash = hash_key(&api_key);
|
||||
tracing::info!("[MIDDLEWARE] Key hash: {}", &key_hash[..16]);
|
||||
|
||||
tracing::info!("[MIDDLEWARE] Querying database for key...");
|
||||
let record = match state.db.get_api_key_by_hash(&key_hash).await {
|
||||
Ok(Some(r)) => {
|
||||
tracing::info!("[MIDDLEWARE] API key found: {}", r.key_id);
|
||||
r
|
||||
}
|
||||
Ok(None) => {
|
||||
tracing::warn!("[MIDDLEWARE] API key not found in database");
|
||||
return Response::builder()
|
||||
.status(StatusCode::UNAUTHORIZED)
|
||||
.body(axum::body::Body::empty())
|
||||
.unwrap();
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("[MIDDLEWARE] DB error: {}", e);
|
||||
return Response::builder()
|
||||
.status(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
.body(axum::body::Body::empty())
|
||||
.unwrap();
|
||||
}
|
||||
};
|
||||
|
||||
if record.status != "active" {
|
||||
tracing::warn!("[MIDDLEWARE] API key not active: {}", record.status);
|
||||
return Response::builder()
|
||||
.status(StatusCode::UNAUTHORIZED)
|
||||
.body(axum::body::Body::empty())
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
"[MIDDLEWARE] API key validated successfully: {}",
|
||||
record.key_id
|
||||
);
|
||||
|
||||
let auth = ApiKeyAuth {
|
||||
key_id: record.key_id.clone(),
|
||||
record,
|
||||
};
|
||||
|
||||
if let Err(e) = state.db.update_api_key_usage(&auth.key_id, None).await {
|
||||
tracing::warn!("[MIDDLEWARE] Failed to update API key usage: {}", e);
|
||||
}
|
||||
|
||||
let mut request = request;
|
||||
request.extensions_mut().insert(auth);
|
||||
|
||||
tracing::info!("[MIDDLEWARE] Passing request to handler");
|
||||
let response = next.run(request).await;
|
||||
tracing::info!("[MIDDLEWARE] Handler returned response");
|
||||
response
|
||||
}
|
||||
|
||||
fn extract_api_key(headers: &HeaderMap) -> Result<String, StatusCode> {
|
||||
headers
|
||||
.get("X-API-Key")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|s| s.to_string())
|
||||
.ok_or(StatusCode::UNAUTHORIZED)
|
||||
}
|
||||
|
||||
fn hash_key(key: &str) -> String {
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(key.as_bytes());
|
||||
format!("{:x}", hasher.finalize())
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
pub mod middleware;
|
||||
pub mod server;
|
||||
|
||||
pub use server::start_server;
|
||||
|
||||
@@ -7,12 +7,15 @@ use axum::{
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sha2::{Digest, Sha256};
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
|
||||
use crate::core::cache::{keys, MongoCache, RedisCache};
|
||||
use crate::core::db::{Database, PostgresDb, QdrantDb, RedisClient, VideoRecord, VideoStatus};
|
||||
use crate::{Embedder, FileManager};
|
||||
|
||||
use super::middleware::api_key_validation;
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct HealthResponse {
|
||||
status: String,
|
||||
@@ -59,6 +62,7 @@ struct AppState {
|
||||
embedder_model: String,
|
||||
mongo_cache: MongoCache,
|
||||
redis_cache: RedisCache,
|
||||
api_state: super::middleware::ApiState,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
@@ -982,14 +986,23 @@ async fn list_videos(
|
||||
let cache_key = keys::videos_list(page, limit);
|
||||
let ttl = state.mongo_cache.ttl_videos();
|
||||
|
||||
tracing::info!(
|
||||
"list_videos called: page={}, limit={}, cache_key={}",
|
||||
page,
|
||||
limit,
|
||||
cache_key
|
||||
);
|
||||
|
||||
let video_infos = state
|
||||
.mongo_cache
|
||||
.get_or_fetch(&cache_key, ttl, keys::CATEGORY_VIDEOS, || async {
|
||||
tracing::info!("Fetching videos from database...");
|
||||
let db = PostgresDb::init()
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("PG init failed: {}", e))?;
|
||||
|
||||
let videos = db.list_videos().await?;
|
||||
tracing::info!("Got {} videos from DB", videos.len());
|
||||
|
||||
let video_infos: Vec<VideoInfoResponse> = videos
|
||||
.into_iter()
|
||||
@@ -1003,12 +1016,17 @@ async fn list_videos(
|
||||
})
|
||||
.collect();
|
||||
|
||||
tracing::info!("Mapped to {} video infos", video_infos.len());
|
||||
|
||||
Ok::<VideosResponse, anyhow::Error>(VideosResponse {
|
||||
videos: video_infos,
|
||||
})
|
||||
})
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
.map_err(|e| {
|
||||
tracing::error!("Error in list_videos: {}", e);
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
|
||||
Ok(Json(video_infos))
|
||||
}
|
||||
@@ -1222,20 +1240,34 @@ async fn list_jobs() -> Result<Json<JobListResponse>, StatusCode> {
|
||||
async fn get_job(
|
||||
axum::extract::Path(uuid): axum::extract::Path<String>,
|
||||
) -> Result<Json<JobDetailResponse>, StatusCode> {
|
||||
let pg = PostgresDb::init()
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
tracing::info!("[get_job] START - uuid: {}", uuid);
|
||||
|
||||
let pg = PostgresDb::init().await.map_err(|e| {
|
||||
tracing::error!("[get_job] ERROR - Failed to init PostgresDb: {}", e);
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
tracing::info!("[get_job] PostgresDb initialized");
|
||||
|
||||
let job = pg
|
||||
.get_monitor_job_by_uuid(&uuid)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
|
||||
.ok_or(StatusCode::NOT_FOUND)?;
|
||||
.map_err(|e| {
|
||||
tracing::error!("[get_job] ERROR - Failed to get monitor job: {}", e);
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?
|
||||
.ok_or_else(|| {
|
||||
tracing::warn!("[get_job] Job not found: {}", uuid);
|
||||
StatusCode::NOT_FOUND
|
||||
})?;
|
||||
|
||||
let results = pg
|
||||
.get_processor_results_by_job(job.id)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
tracing::info!("[get_job] Found job: id={}, uuid={}", job.id, job.uuid);
|
||||
|
||||
let results = pg.get_processor_results_by_job(job.id).await.map_err(|e| {
|
||||
tracing::error!("[get_job] ERROR - Failed to get processor results: {}", e);
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
|
||||
tracing::info!("[get_job] Got {} processor results", results.len());
|
||||
|
||||
let processors: Vec<ProcessorInfoResponse> = results
|
||||
.into_iter()
|
||||
@@ -1249,7 +1281,9 @@ async fn get_job(
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(Json(JobDetailResponse {
|
||||
tracing::info!("[get_job] Mapped {} processors", processors.len());
|
||||
|
||||
let response = JobDetailResponse {
|
||||
id: job.id,
|
||||
uuid: job.uuid,
|
||||
status: job.status.as_str().to_string(),
|
||||
@@ -1260,7 +1294,10 @@ async fn get_job(
|
||||
created_at: job.created_at.to_string(),
|
||||
started_at: job.started_at.map(|t| t.to_string()),
|
||||
updated_at: job.updated_at.map(|t| t.to_string()),
|
||||
}))
|
||||
};
|
||||
|
||||
tracing::info!("[get_job] SUCCESS - returning response");
|
||||
Ok(Json(response))
|
||||
}
|
||||
|
||||
pub async fn start_server(host: &str, port: u16) -> anyhow::Result<()> {
|
||||
@@ -1269,17 +1306,18 @@ pub async fn start_server(host: &str, port: u16) -> anyhow::Result<()> {
|
||||
let embedder = std::sync::Arc::new(Embedder::new("nomic-embed-text:v1.5".to_string()));
|
||||
let mongo_cache = MongoCache::init().await?;
|
||||
let redis_cache = RedisCache::new()?;
|
||||
let db = PostgresDb::init().await?;
|
||||
let api_state = super::middleware::ApiState { db: Arc::new(db) };
|
||||
|
||||
let state = AppState {
|
||||
embedder,
|
||||
embedder_model: "nomic-embed-text:v1.5".to_string(),
|
||||
mongo_cache,
|
||||
redis_cache,
|
||||
api_state,
|
||||
};
|
||||
|
||||
let app = Router::new()
|
||||
.route("/health", get(health))
|
||||
.route("/health/detailed", get(health_detailed))
|
||||
let protected_routes = Router::new()
|
||||
.route("/api/v1/register", post(register))
|
||||
.route("/api/v1/probe", post(probe))
|
||||
.route("/api/v1/search", post(search))
|
||||
@@ -1290,6 +1328,16 @@ pub async fn start_server(host: &str, port: u16) -> anyhow::Result<()> {
|
||||
.route("/api/v1/progress/:uuid", get(get_progress))
|
||||
.route("/api/v1/jobs", get(list_jobs))
|
||||
.route("/api/v1/jobs/:uuid", get(get_job))
|
||||
.layer(axum::middleware::from_fn_with_state(
|
||||
state.api_state.clone(),
|
||||
api_key_validation,
|
||||
))
|
||||
.with_state(state.clone());
|
||||
|
||||
let app = Router::new()
|
||||
.route("/health", get(health))
|
||||
.route("/health/detailed", get(health_detailed))
|
||||
.merge(protected_routes)
|
||||
.with_state(state);
|
||||
|
||||
let addr: std::net::SocketAddr = format!("{}:{}", host, port).parse().unwrap();
|
||||
|
||||
193
src/core/api_key/anomaly.rs
Normal file
193
src/core/api_key/anomaly.rs
Normal file
@@ -0,0 +1,193 @@
|
||||
//! Anomaly Detection Module
|
||||
//!
|
||||
//! Detects abnormal API key usage patterns
|
||||
|
||||
use crate::core::api_key::models::*;
|
||||
use crate::core::api_key::service::AnomalyMetrics;
|
||||
use chrono::{Duration, Timelike, Utc};
|
||||
use std::collections::HashMap;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
pub struct AnomalyDetector {
|
||||
config: AnomalyDetectionConfig,
|
||||
metrics_cache: RwLock<HashMap<String, Vec<RequestMetric>>>,
|
||||
lockout_cache: RwLock<HashMap<String, i32>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct RequestMetric {
|
||||
timestamp: chrono::DateTime<Utc>,
|
||||
ip: Option<String>,
|
||||
is_error: bool,
|
||||
}
|
||||
|
||||
impl AnomalyDetector {
|
||||
pub fn new(config: AnomalyDetectionConfig) -> Self {
|
||||
Self {
|
||||
config,
|
||||
metrics_cache: RwLock::new(HashMap::new()),
|
||||
lockout_cache: RwLock::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn record_request(&self, key_id: &str, ip: Option<String>, is_error: bool) {
|
||||
let metric = RequestMetric {
|
||||
timestamp: Utc::now(),
|
||||
ip,
|
||||
is_error,
|
||||
};
|
||||
|
||||
let mut cache = self.metrics_cache.write().await;
|
||||
cache.entry(key_id.to_string()).or_default().push(metric);
|
||||
|
||||
self.cleanup_old_metrics(&mut cache).await;
|
||||
}
|
||||
|
||||
async fn cleanup_old_metrics(&self, cache: &mut HashMap<String, Vec<RequestMetric>>) {
|
||||
let cutoff = Utc::now() - Duration::hours(2);
|
||||
|
||||
for metrics in cache.values_mut() {
|
||||
metrics.retain(|m| m.timestamp > cutoff);
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn check_anomaly(&self, key_id: &str) -> Option<AnomalyRecord> {
|
||||
let cache = self.metrics_cache.read().await;
|
||||
let metrics = cache.get(key_id)?;
|
||||
|
||||
let now = Utc::now();
|
||||
let last_minute = now - Duration::minutes(1);
|
||||
let last_hour = now - Duration::hours(1);
|
||||
|
||||
let recent = metrics
|
||||
.iter()
|
||||
.filter(|m| m.timestamp > last_hour)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let requests_per_minute =
|
||||
metrics.iter().filter(|m| m.timestamp > last_minute).count() as i32;
|
||||
let error_count = recent.iter().filter(|m| m.is_error).count() as i32;
|
||||
let error_rate = if recent.is_empty() {
|
||||
0.0
|
||||
} else {
|
||||
error_count as f64 / recent.len() as f64
|
||||
};
|
||||
|
||||
let unique_ips = recent
|
||||
.iter()
|
||||
.filter_map(|m| m.ip.clone())
|
||||
.collect::<std::collections::HashSet<_>>()
|
||||
.len() as i32;
|
||||
|
||||
let last_ip = metrics.last().and_then(|m| m.ip.clone());
|
||||
|
||||
let metrics = AnomalyMetrics {
|
||||
requests_per_minute,
|
||||
error_count,
|
||||
error_rate,
|
||||
unique_ips,
|
||||
last_ip,
|
||||
};
|
||||
|
||||
self.detect_anomaly(key_id, metrics).await
|
||||
}
|
||||
|
||||
async fn detect_anomaly(&self, key_id: &str, metrics: AnomalyMetrics) -> Option<AnomalyRecord> {
|
||||
if metrics.requests_per_minute > self.config.requests_per_minute_threshold * 10 {
|
||||
let mut lockout = self.lockout_cache.write().await;
|
||||
*lockout.entry(key_id.to_string()).or_insert(0) += 1;
|
||||
|
||||
if lockout[&key_id.to_string()] >= self.config.lockout_threshold {
|
||||
return Some(self.create_anomaly(
|
||||
key_id,
|
||||
AnomalyType::BruteForce,
|
||||
AnomalySeverity::Critical,
|
||||
&metrics,
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
if metrics.requests_per_minute > self.config.requests_per_minute_threshold {
|
||||
return Some(self.create_anomaly(
|
||||
key_id,
|
||||
AnomalyType::HighRequestRate,
|
||||
AnomalySeverity::Medium,
|
||||
&metrics,
|
||||
));
|
||||
}
|
||||
|
||||
if metrics.error_rate > self.config.error_rate_threshold {
|
||||
return Some(self.create_anomaly(
|
||||
key_id,
|
||||
AnomalyType::HighErrorRate,
|
||||
AnomalySeverity::Medium,
|
||||
&metrics,
|
||||
));
|
||||
}
|
||||
|
||||
if metrics.unique_ips > self.config.unique_ips_per_hour_threshold {
|
||||
return Some(self.create_anomaly(
|
||||
key_id,
|
||||
AnomalyType::MultipleIps,
|
||||
AnomalySeverity::Low,
|
||||
&metrics,
|
||||
));
|
||||
}
|
||||
|
||||
let hour = Utc::now().hour();
|
||||
if hour < 6 && metrics.requests_per_minute > 10 {
|
||||
return Some(self.create_anomaly(
|
||||
key_id,
|
||||
AnomalyType::UnusualTime,
|
||||
AnomalySeverity::Low,
|
||||
&metrics,
|
||||
));
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn create_anomaly(
|
||||
&self,
|
||||
key_id: &str,
|
||||
anomaly_type: AnomalyType,
|
||||
severity: AnomalySeverity,
|
||||
metrics: &AnomalyMetrics,
|
||||
) -> AnomalyRecord {
|
||||
AnomalyRecord {
|
||||
id: 0,
|
||||
key_id: key_id.to_string(),
|
||||
anomaly_type,
|
||||
severity,
|
||||
ip_address: metrics.last_ip.clone(),
|
||||
request_count: Some(metrics.requests_per_minute),
|
||||
error_count: Some(metrics.error_count),
|
||||
error_rate: Some(metrics.error_rate),
|
||||
unique_ips: Some(metrics.unique_ips),
|
||||
details: None,
|
||||
resolved: false,
|
||||
resolved_at: None,
|
||||
resolved_by: None,
|
||||
created_at: Utc::now(),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn should_lockout(&self, key_id: &str) -> bool {
|
||||
let lockout = self.lockout_cache.read().await;
|
||||
lockout.get(key_id).copied().unwrap_or(0) >= self.config.lockout_threshold
|
||||
}
|
||||
|
||||
pub async fn reset_lockout(&self, key_id: &str) {
|
||||
let mut lockout = self.lockout_cache.write().await;
|
||||
lockout.remove(key_id);
|
||||
|
||||
let mut cache = self.metrics_cache.write().await;
|
||||
cache.remove(key_id);
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for AnomalyDetector {
|
||||
fn default() -> Self {
|
||||
Self::new(AnomalyDetectionConfig::default())
|
||||
}
|
||||
}
|
||||
193
src/core/api_key/audit_logger.rs
Normal file
193
src/core/api_key/audit_logger.rs
Normal file
@@ -0,0 +1,193 @@
|
||||
//! Async Audit Logger Module
|
||||
//!
|
||||
//! Writes audit logs asynchronously using a channel
|
||||
|
||||
use crate::core::db::PostgresDb;
|
||||
use anyhow::Result;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::sync::mpsc::{self, Receiver, Sender};
|
||||
use tokio::task::JoinHandle;
|
||||
|
||||
/// Audit log entry
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AuditEntry {
|
||||
pub key_id: String,
|
||||
pub action: String,
|
||||
pub actor: Option<String>,
|
||||
pub ip_address: Option<String>,
|
||||
pub user_agent: Option<String>,
|
||||
pub request_path: Option<String>,
|
||||
pub response_code: Option<i32>,
|
||||
pub anomaly_type: Option<String>,
|
||||
pub details: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
/// Async audit logger configuration
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AuditLoggerConfig {
|
||||
pub channel_buffer_size: usize,
|
||||
pub batch_size: usize,
|
||||
pub flush_interval_ms: u64,
|
||||
}
|
||||
|
||||
impl Default for AuditLoggerConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
channel_buffer_size: std::env::var("AUDIT_LOGGER_BUFFER_SIZE")
|
||||
.unwrap_or_else(|_| "1000".to_string())
|
||||
.parse()
|
||||
.unwrap_or(1000),
|
||||
batch_size: std::env::var("AUDIT_LOGGER_BATCH_SIZE")
|
||||
.unwrap_or_else(|_| "100".to_string())
|
||||
.parse()
|
||||
.unwrap_or(100),
|
||||
flush_interval_ms: std::env::var("AUDIT_LOGGER_FLUSH_INTERVAL_MS")
|
||||
.unwrap_or_else(|_| "1000".to_string())
|
||||
.parse()
|
||||
.unwrap_or(1000),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Async audit logger
|
||||
pub struct AsyncAuditLogger {
|
||||
sender: Sender<AuditEntry>,
|
||||
handle: JoinHandle<()>,
|
||||
}
|
||||
|
||||
impl AsyncAuditLogger {
|
||||
/// Create a new async audit logger
|
||||
pub fn new(db: PostgresDb, config: AuditLoggerConfig) -> Self {
|
||||
let (sender, receiver) = mpsc::channel(config.channel_buffer_size);
|
||||
|
||||
let handle = tokio::spawn(Self::logger_task(db, receiver, config));
|
||||
|
||||
Self { sender, handle }
|
||||
}
|
||||
|
||||
/// Create with default config
|
||||
pub fn with_default_config(db: PostgresDb) -> Self {
|
||||
Self::new(db, AuditLoggerConfig::default())
|
||||
}
|
||||
|
||||
/// Log an audit entry
|
||||
pub async fn log(&self, entry: AuditEntry) -> Result<()> {
|
||||
self.sender
|
||||
.send(entry)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("Failed to send audit entry: {}", e))
|
||||
}
|
||||
|
||||
/// Shutdown the logger
|
||||
pub async fn shutdown(self) -> Result<()> {
|
||||
drop(self.sender);
|
||||
self.handle.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Logger background task
|
||||
async fn logger_task(
|
||||
db: PostgresDb,
|
||||
mut receiver: Receiver<AuditEntry>,
|
||||
config: AuditLoggerConfig,
|
||||
) {
|
||||
let mut batch = Vec::with_capacity(config.batch_size);
|
||||
let mut interval =
|
||||
tokio::time::interval(std::time::Duration::from_millis(config.flush_interval_ms));
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
Some(entry) = receiver.recv() => {
|
||||
batch.push(entry);
|
||||
|
||||
if batch.len() >= config.batch_size {
|
||||
if let Err(e) = Self::flush_batch(&db, &batch).await {
|
||||
tracing::error!("Failed to flush audit batch: {}", e);
|
||||
}
|
||||
batch.clear();
|
||||
}
|
||||
}
|
||||
_ = interval.tick() => {
|
||||
if !batch.is_empty() {
|
||||
if let Err(e) = Self::flush_batch(&db, &batch).await {
|
||||
tracing::error!("Failed to flush audit batch: {}", e);
|
||||
}
|
||||
batch.clear();
|
||||
}
|
||||
}
|
||||
else => {
|
||||
// Channel closed
|
||||
if !batch.is_empty() {
|
||||
if let Err(e) = Self::flush_batch(&db, &batch).await {
|
||||
tracing::error!("Failed to flush final audit batch: {}", e);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!("Audit logger task stopped");
|
||||
}
|
||||
|
||||
/// Flush a batch of entries to the database
|
||||
async fn flush_batch(db: &PostgresDb, entries: &[AuditEntry]) -> Result<()> {
|
||||
if entries.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
tracing::debug!("Flushing {} audit entries", entries.len());
|
||||
|
||||
for entry in entries {
|
||||
if let Err(e) = db
|
||||
.log_api_key_audit(
|
||||
&entry.key_id,
|
||||
&entry.action,
|
||||
entry.actor.as_deref(),
|
||||
entry.ip_address.as_deref(),
|
||||
entry.user_agent.as_deref(),
|
||||
entry.request_path.as_deref(),
|
||||
entry.response_code,
|
||||
entry.anomaly_type.as_deref(),
|
||||
entry.details.as_ref(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
tracing::error!("Failed to write audit entry: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_audit_logger_config_default() {
|
||||
let config = AuditLoggerConfig::default();
|
||||
assert!(config.channel_buffer_size > 0);
|
||||
assert!(config.batch_size > 0);
|
||||
assert!(config.flush_interval_ms > 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_audit_entry_creation() {
|
||||
let entry = AuditEntry {
|
||||
key_id: "test_key".to_string(),
|
||||
action: "validate".to_string(),
|
||||
actor: Some("user1".to_string()),
|
||||
ip_address: Some("192.168.1.1".to_string()),
|
||||
user_agent: Some("Mozilla/5.0".to_string()),
|
||||
request_path: Some("/api/test".to_string()),
|
||||
response_code: Some(200),
|
||||
anomaly_type: None,
|
||||
details: None,
|
||||
};
|
||||
|
||||
assert_eq!(entry.key_id, "test_key");
|
||||
assert_eq!(entry.action, "validate");
|
||||
}
|
||||
}
|
||||
203
src/core/api_key/blacklist.rs
Normal file
203
src/core/api_key/blacklist.rs
Normal file
@@ -0,0 +1,203 @@
|
||||
//! IP Blacklist Module
|
||||
//!
|
||||
//! Manages blocked IP addresses for API key validation
|
||||
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use moka::future::Cache;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration as StdDuration;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
/// IP blacklist entry
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct BlacklistEntry {
|
||||
pub ip: String,
|
||||
pub reason: String,
|
||||
pub blocked_at: DateTime<Utc>,
|
||||
pub expires_at: Option<DateTime<Utc>>,
|
||||
pub blocked_by: Option<String>,
|
||||
}
|
||||
|
||||
/// IP Blacklist manager
|
||||
pub struct IpBlacklist {
|
||||
/// In-memory blacklist with TTL
|
||||
cache: Cache<String, BlacklistEntry>,
|
||||
/// Permanent blacklist (no TTL)
|
||||
permanent: Arc<RwLock<HashSet<String>>>,
|
||||
}
|
||||
|
||||
/// Configuration for IP blacklist
|
||||
pub struct BlacklistConfig {
|
||||
pub default_block_duration_secs: u64,
|
||||
pub max_entries: u64,
|
||||
}
|
||||
|
||||
impl Default for BlacklistConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
default_block_duration_secs: std::env::var("IP_BLACKLIST_DURATION")
|
||||
.unwrap_or_else(|_| "3600".to_string())
|
||||
.parse()
|
||||
.unwrap_or(3600),
|
||||
max_entries: std::env::var("IP_BLACKLIST_MAX_ENTRIES")
|
||||
.unwrap_or_else(|_| "10000".to_string())
|
||||
.parse()
|
||||
.unwrap_or(10000),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl IpBlacklist {
|
||||
pub fn new(config: BlacklistConfig) -> Self {
|
||||
Self {
|
||||
cache: Cache::builder()
|
||||
.time_to_live(StdDuration::from_secs(config.default_block_duration_secs))
|
||||
.max_capacity(config.max_entries)
|
||||
.build(),
|
||||
permanent: Arc::new(RwLock::new(HashSet::new())),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_default_config() -> Self {
|
||||
Self::new(BlacklistConfig::default())
|
||||
}
|
||||
|
||||
/// Check if an IP is blocked
|
||||
pub async fn is_blocked(&self, ip: &str) -> bool {
|
||||
// Check permanent blacklist first
|
||||
if self.permanent.read().await.contains(ip) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Check temporary blacklist
|
||||
self.cache.get(ip).await.is_some()
|
||||
}
|
||||
|
||||
/// Get blacklist entry for an IP
|
||||
pub async fn get_entry(&self, ip: &str) -> Option<BlacklistEntry> {
|
||||
self.cache.get(ip).await
|
||||
}
|
||||
|
||||
/// Block an IP temporarily
|
||||
pub async fn block(&self, ip: &str, reason: &str, duration_secs: Option<u64>) {
|
||||
let entry = BlacklistEntry {
|
||||
ip: ip.to_string(),
|
||||
reason: reason.to_string(),
|
||||
blocked_at: Utc::now(),
|
||||
expires_at: duration_secs.map(|d| Utc::now() + Duration::seconds(d as i64)),
|
||||
blocked_by: Some("system".to_string()),
|
||||
};
|
||||
|
||||
self.cache.insert(ip.to_string(), entry).await;
|
||||
tracing::info!("Blocked IP: {} - {}", ip, reason);
|
||||
}
|
||||
|
||||
/// Block an IP permanently
|
||||
pub async fn block_permanent(&self, ip: &str, reason: &str) {
|
||||
self.permanent.write().await.insert(ip.to_string());
|
||||
tracing::info!("Permanently blocked IP: {} - {}", ip, reason);
|
||||
}
|
||||
|
||||
/// Unblock an IP
|
||||
pub async fn unblock(&self, ip: &str) -> bool {
|
||||
let in_cache = self.cache.get(ip).await.is_some();
|
||||
if in_cache {
|
||||
self.cache.invalidate(ip).await;
|
||||
}
|
||||
let from_permanent = self.permanent.write().await.remove(ip);
|
||||
|
||||
if in_cache || from_permanent {
|
||||
tracing::info!("Unblocked IP: {}", ip);
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Get all blocked IPs
|
||||
pub async fn list_all(&self) -> Vec<String> {
|
||||
let mut ips: Vec<String> = self.cache.iter().map(|(k, _)| (*k).clone()).collect();
|
||||
ips.extend(self.permanent.read().await.iter().cloned());
|
||||
ips.sort();
|
||||
ips.dedup();
|
||||
ips
|
||||
}
|
||||
|
||||
/// Get count of blocked IPs
|
||||
pub async fn count(&self) -> usize {
|
||||
self.cache.entry_count() as usize + self.permanent.read().await.len()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_block_and_check() {
|
||||
let blacklist = IpBlacklist::with_default_config();
|
||||
|
||||
assert!(!blacklist.is_blocked("192.168.1.1").await);
|
||||
|
||||
blacklist.block("192.168.1.1", "test", Some(60)).await;
|
||||
|
||||
assert!(blacklist.is_blocked("192.168.1.1").await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_unblock() {
|
||||
let blacklist = IpBlacklist::with_default_config();
|
||||
|
||||
blacklist.block("192.168.1.1", "test", Some(60)).await;
|
||||
assert!(blacklist.is_blocked("192.168.1.1").await);
|
||||
|
||||
assert!(blacklist.unblock("192.168.1.1").await);
|
||||
assert!(!blacklist.is_blocked("192.168.1.1").await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_permanent_block() {
|
||||
let blacklist = IpBlacklist::with_default_config();
|
||||
|
||||
blacklist.block_permanent("10.0.0.1", "permanent ban").await;
|
||||
|
||||
assert!(blacklist.is_blocked("10.0.0.1").await);
|
||||
assert!(blacklist.unblock("10.0.0.1").await);
|
||||
assert!(!blacklist.is_blocked("10.0.0.1").await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_list_all() {
|
||||
let blacklist = IpBlacklist::with_default_config();
|
||||
|
||||
blacklist.block("192.168.1.1", "test", Some(60)).await;
|
||||
blacklist.block_permanent("10.0.0.1", "permanent").await;
|
||||
|
||||
let ips = blacklist.list_all().await;
|
||||
assert_eq!(ips.len(), 2);
|
||||
assert!(ips.contains(&"192.168.1.1".to_string()));
|
||||
assert!(ips.contains(&"10.0.0.1".to_string()));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_count() {
|
||||
let blacklist = IpBlacklist::with_default_config();
|
||||
|
||||
assert_eq!(blacklist.count().await, 0);
|
||||
|
||||
blacklist.block("192.168.1.1", "test", Some(60)).await;
|
||||
blacklist.block("192.168.1.2", "test", Some(60)).await;
|
||||
blacklist.block_permanent("10.0.0.1", "permanent").await;
|
||||
|
||||
// Count should be at least 1 (permanent) + 2 (cached) = 3
|
||||
// Note: cache entry_count might need time to update
|
||||
let count = blacklist.count().await;
|
||||
assert!(
|
||||
count >= 1,
|
||||
"Expected at least 1 entry (permanent), got {}",
|
||||
count
|
||||
);
|
||||
}
|
||||
}
|
||||
172
src/core/api_key/cleanup.rs
Normal file
172
src/core/api_key/cleanup.rs
Normal file
@@ -0,0 +1,172 @@
|
||||
//! API Key Cleanup Module
|
||||
//!
|
||||
//! Automatically cleans up expired and old API key records
|
||||
|
||||
use crate::core::db::PostgresDb;
|
||||
use anyhow::Result;
|
||||
use chrono::{Duration, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Cleanup configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CleanupConfig {
|
||||
pub expired_keys_days: i64,
|
||||
pub audit_logs_days: i64,
|
||||
pub anomaly_logs_days: i64,
|
||||
pub dry_run: bool,
|
||||
}
|
||||
|
||||
impl Default for CleanupConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
expired_keys_days: std::env::var("CLEANUP_EXPIRED_KEYS_DAYS")
|
||||
.unwrap_or_else(|_| "90".to_string())
|
||||
.parse()
|
||||
.unwrap_or(90),
|
||||
audit_logs_days: std::env::var("CLEANUP_AUDIT_LOGS_DAYS")
|
||||
.unwrap_or_else(|_| "180".to_string())
|
||||
.parse()
|
||||
.unwrap_or(180),
|
||||
anomaly_logs_days: std::env::var("CLEANUP_ANOMALY_LOGS_DAYS")
|
||||
.unwrap_or_else(|_| "90".to_string())
|
||||
.parse()
|
||||
.unwrap_or(90),
|
||||
dry_run: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Cleanup result
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CleanupResult {
|
||||
pub executed_at: chrono::DateTime<Utc>,
|
||||
pub expired_keys_deleted: u64,
|
||||
pub audit_logs_deleted: u64,
|
||||
pub anomaly_logs_deleted: u64,
|
||||
pub dry_run: bool,
|
||||
}
|
||||
|
||||
/// Cleanup manager
|
||||
pub struct CleanupManager {
|
||||
db: PostgresDb,
|
||||
config: CleanupConfig,
|
||||
}
|
||||
|
||||
impl CleanupManager {
|
||||
pub fn new(db: PostgresDb, config: CleanupConfig) -> Self {
|
||||
Self { db, config }
|
||||
}
|
||||
|
||||
pub fn with_default_config(db: PostgresDb) -> Self {
|
||||
Self::new(db, CleanupConfig::default())
|
||||
}
|
||||
|
||||
/// Run full cleanup
|
||||
pub async fn run_cleanup(&self) -> Result<CleanupResult> {
|
||||
let mut result = CleanupResult {
|
||||
executed_at: Utc::now(),
|
||||
expired_keys_deleted: 0,
|
||||
audit_logs_deleted: 0,
|
||||
anomaly_logs_deleted: 0,
|
||||
dry_run: self.config.dry_run,
|
||||
};
|
||||
|
||||
// Clean expired keys
|
||||
result.expired_keys_deleted = self.clean_expired_keys().await?;
|
||||
|
||||
// Clean old audit logs
|
||||
result.audit_logs_deleted = self.clean_audit_logs().await?;
|
||||
|
||||
// Clean old anomaly logs
|
||||
result.anomaly_logs_deleted = self.clean_anomaly_logs().await?;
|
||||
|
||||
tracing::info!(
|
||||
"Cleanup completed: {} expired keys, {} audit logs, {} anomaly logs",
|
||||
result.expired_keys_deleted,
|
||||
result.audit_logs_deleted,
|
||||
result.anomaly_logs_deleted
|
||||
);
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Clean expired API keys
|
||||
async fn clean_expired_keys(&self) -> Result<u64> {
|
||||
let cutoff = Utc::now() - Duration::days(self.config.expired_keys_days);
|
||||
|
||||
if self.config.dry_run {
|
||||
let count: i64 = sqlx::query_scalar(
|
||||
"SELECT COUNT(*) FROM api_keys WHERE status = 'expired' AND expires_at < $1",
|
||||
)
|
||||
.bind(cutoff)
|
||||
.fetch_one(self.db.pool())
|
||||
.await?;
|
||||
return Ok(count as u64);
|
||||
}
|
||||
|
||||
let result =
|
||||
sqlx::query("DELETE FROM api_keys WHERE status = 'expired' AND expires_at < $1")
|
||||
.bind(cutoff)
|
||||
.execute(self.db.pool())
|
||||
.await?;
|
||||
|
||||
Ok(result.rows_affected())
|
||||
}
|
||||
|
||||
/// Clean old audit logs
|
||||
async fn clean_audit_logs(&self) -> Result<u64> {
|
||||
let cutoff = Utc::now() - Duration::days(self.config.audit_logs_days);
|
||||
|
||||
if self.config.dry_run {
|
||||
let count: i64 =
|
||||
sqlx::query_scalar("SELECT COUNT(*) FROM api_key_audit_log WHERE created_at < $1")
|
||||
.bind(cutoff)
|
||||
.fetch_one(self.db.pool())
|
||||
.await?;
|
||||
return Ok(count as u64);
|
||||
}
|
||||
|
||||
let result = sqlx::query("DELETE FROM api_key_audit_log WHERE created_at < $1")
|
||||
.bind(cutoff)
|
||||
.execute(self.db.pool())
|
||||
.await?;
|
||||
|
||||
Ok(result.rows_affected())
|
||||
}
|
||||
|
||||
/// Clean old anomaly logs
|
||||
async fn clean_anomaly_logs(&self) -> Result<u64> {
|
||||
let cutoff = Utc::now() - Duration::days(self.config.anomaly_logs_days);
|
||||
|
||||
if self.config.dry_run {
|
||||
let count: i64 = sqlx::query_scalar(
|
||||
"SELECT COUNT(*) FROM api_key_anomalies WHERE created_at < $1 AND resolved = TRUE",
|
||||
)
|
||||
.bind(cutoff)
|
||||
.fetch_one(self.db.pool())
|
||||
.await?;
|
||||
return Ok(count as u64);
|
||||
}
|
||||
|
||||
let result =
|
||||
sqlx::query("DELETE FROM api_key_anomalies WHERE created_at < $1 AND resolved = TRUE")
|
||||
.bind(cutoff)
|
||||
.execute(self.db.pool())
|
||||
.await?;
|
||||
|
||||
Ok(result.rows_affected())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_cleanup_config_default() {
|
||||
let config = CleanupConfig::default();
|
||||
assert!(config.expired_keys_days > 0);
|
||||
assert!(config.audit_logs_days > 0);
|
||||
assert!(config.anomaly_logs_days > 0);
|
||||
}
|
||||
}
|
||||
211
src/core/api_key/encryption.rs
Normal file
211
src/core/api_key/encryption.rs
Normal file
@@ -0,0 +1,211 @@
|
||||
//! Audit Log Encryption Module
|
||||
//!
|
||||
//! Provides encryption for sensitive audit log fields
|
||||
|
||||
use aes_gcm::{
|
||||
aead::{rand_core::RngCore, Aead, KeyInit, OsRng},
|
||||
AeadCore, Aes256Gcm, Nonce,
|
||||
};
|
||||
use anyhow::{Context, Result};
|
||||
use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Encrypted data wrapper
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct EncryptedData {
|
||||
pub nonce: String,
|
||||
pub ciphertext: String,
|
||||
}
|
||||
|
||||
/// Audit encryption manager
|
||||
pub struct AuditEncryption {
|
||||
cipher: Aes256Gcm,
|
||||
}
|
||||
|
||||
impl AuditEncryption {
|
||||
/// Create a new encryption manager with a key
|
||||
pub fn new(key: &[u8; 32]) -> Self {
|
||||
let cipher = Aes256Gcm::new_from_slice(key).expect("Failed to create cipher");
|
||||
Self { cipher }
|
||||
}
|
||||
|
||||
/// Create from environment variable
|
||||
pub fn from_env() -> Result<Self> {
|
||||
let key_hex =
|
||||
std::env::var("AUDIT_ENCRYPTION_KEY").context("AUDIT_ENCRYPTION_KEY not set")?;
|
||||
|
||||
let key_bytes = hex::decode(&key_hex).context("Invalid hex in AUDIT_ENCRYPTION_KEY")?;
|
||||
|
||||
if key_bytes.len() != 32 {
|
||||
anyhow::bail!("AUDIT_ENCRYPTION_KEY must be 32 bytes (64 hex chars)");
|
||||
}
|
||||
|
||||
let mut key = [0u8; 32];
|
||||
key.copy_from_slice(&key_bytes);
|
||||
Ok(Self::new(&key))
|
||||
}
|
||||
|
||||
/// Generate a random key
|
||||
pub fn generate_key() -> [u8; 32] {
|
||||
let mut key = [0u8; 32];
|
||||
OsRng.fill_bytes(&mut key);
|
||||
key
|
||||
}
|
||||
|
||||
/// Encrypt a string
|
||||
pub fn encrypt(&self, plaintext: &str) -> Result<EncryptedData> {
|
||||
let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
|
||||
let ciphertext = self
|
||||
.cipher
|
||||
.encrypt(&nonce, plaintext.as_bytes())
|
||||
.map_err(|e| anyhow::anyhow!("Encryption failed: {}", e))?;
|
||||
|
||||
Ok(EncryptedData {
|
||||
nonce: BASE64.encode(nonce),
|
||||
ciphertext: BASE64.encode(ciphertext),
|
||||
})
|
||||
}
|
||||
|
||||
/// Decrypt a string
|
||||
pub fn decrypt(&self, data: &EncryptedData) -> Result<String> {
|
||||
let nonce = BASE64.decode(&data.nonce).context("Invalid nonce base64")?;
|
||||
let ciphertext = BASE64
|
||||
.decode(&data.ciphertext)
|
||||
.context("Invalid ciphertext base64")?;
|
||||
|
||||
let plaintext = self
|
||||
.cipher
|
||||
.decrypt(Nonce::from_slice(&nonce), ciphertext.as_ref())
|
||||
.map_err(|e| anyhow::anyhow!("Decryption failed: {}", e))?;
|
||||
|
||||
String::from_utf8(plaintext).context("Decrypted data is not valid UTF-8")
|
||||
}
|
||||
|
||||
/// Encrypt sensitive audit fields
|
||||
pub fn encrypt_audit_entry(&self, entry: &AuditLogEntry) -> Result<EncryptedAuditLogEntry> {
|
||||
Ok(EncryptedAuditLogEntry {
|
||||
id: entry.id,
|
||||
key_id: entry.key_id.clone(),
|
||||
action: entry.action.clone(),
|
||||
actor: entry.actor.clone(),
|
||||
ip_address: entry
|
||||
.ip_address
|
||||
.as_ref()
|
||||
.map(|ip| self.encrypt(ip))
|
||||
.transpose()?,
|
||||
user_agent: entry
|
||||
.user_agent
|
||||
.as_ref()
|
||||
.map(|ua| self.encrypt(ua))
|
||||
.transpose()?,
|
||||
request_path: entry.request_path.clone(),
|
||||
response_code: entry.response_code,
|
||||
anomaly_type: entry.anomaly_type.clone(),
|
||||
details: entry
|
||||
.details
|
||||
.as_ref()
|
||||
.map(|d| self.encrypt(&d.to_string()))
|
||||
.transpose()?,
|
||||
created_at: entry.created_at,
|
||||
})
|
||||
}
|
||||
|
||||
/// Decrypt sensitive audit fields
|
||||
pub fn decrypt_audit_entry(&self, entry: &EncryptedAuditLogEntry) -> Result<AuditLogEntry> {
|
||||
Ok(AuditLogEntry {
|
||||
id: entry.id,
|
||||
key_id: entry.key_id.clone(),
|
||||
action: entry.action.clone(),
|
||||
actor: entry.actor.clone(),
|
||||
ip_address: entry
|
||||
.ip_address
|
||||
.as_ref()
|
||||
.map(|enc| self.decrypt(enc))
|
||||
.transpose()?,
|
||||
user_agent: entry
|
||||
.user_agent
|
||||
.as_ref()
|
||||
.map(|enc| self.decrypt(enc))
|
||||
.transpose()?,
|
||||
request_path: entry.request_path.clone(),
|
||||
response_code: entry.response_code,
|
||||
anomaly_type: entry.anomaly_type.clone(),
|
||||
details: entry
|
||||
.details
|
||||
.as_ref()
|
||||
.map(|enc| self.decrypt(enc))
|
||||
.transpose()?
|
||||
.map(|s| serde_json::from_str(&s))
|
||||
.transpose()?,
|
||||
created_at: entry.created_at,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Audit log entry (plaintext)
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AuditLogEntry {
|
||||
pub id: i64,
|
||||
pub key_id: String,
|
||||
pub action: String,
|
||||
pub actor: Option<String>,
|
||||
pub ip_address: Option<String>,
|
||||
pub user_agent: Option<String>,
|
||||
pub request_path: Option<String>,
|
||||
pub response_code: Option<i32>,
|
||||
pub anomaly_type: Option<String>,
|
||||
pub details: Option<serde_json::Value>,
|
||||
pub created_at: chrono::DateTime<chrono::Utc>,
|
||||
}
|
||||
|
||||
/// Audit log entry (with encrypted fields)
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EncryptedAuditLogEntry {
|
||||
pub id: i64,
|
||||
pub key_id: String,
|
||||
pub action: String,
|
||||
pub actor: Option<String>,
|
||||
pub ip_address: Option<EncryptedData>,
|
||||
pub user_agent: Option<EncryptedData>,
|
||||
pub request_path: Option<String>,
|
||||
pub response_code: Option<i32>,
|
||||
pub anomaly_type: Option<String>,
|
||||
pub details: Option<EncryptedData>,
|
||||
pub created_at: chrono::DateTime<chrono::Utc>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_encrypt_decrypt() {
|
||||
let key = AuditEncryption::generate_key();
|
||||
let enc = AuditEncryption::new(&key);
|
||||
|
||||
let plaintext = "sensitive data 12345";
|
||||
let encrypted = enc.encrypt(plaintext).unwrap();
|
||||
let decrypted = enc.decrypt(&encrypted).unwrap();
|
||||
|
||||
assert_eq!(plaintext, decrypted);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_different_nonces() {
|
||||
let key = AuditEncryption::generate_key();
|
||||
let enc = AuditEncryption::new(&key);
|
||||
|
||||
let encrypted1 = enc.encrypt("same data").unwrap();
|
||||
let encrypted2 = enc.encrypt("same data").unwrap();
|
||||
|
||||
// Different nonces should produce different ciphertexts
|
||||
assert_ne!(encrypted1.nonce, encrypted2.nonce);
|
||||
assert_ne!(encrypted1.ciphertext, encrypted2.ciphertext);
|
||||
|
||||
// But both should decrypt to the same plaintext
|
||||
assert_eq!(
|
||||
enc.decrypt(&encrypted1).unwrap(),
|
||||
enc.decrypt(&encrypted2).unwrap()
|
||||
);
|
||||
}
|
||||
}
|
||||
184
src/core/api_key/error.rs
Normal file
184
src/core/api_key/error.rs
Normal file
@@ -0,0 +1,184 @@
|
||||
//! API Key Error Types
|
||||
//!
|
||||
//! Unified error handling for API key operations
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
/// API Key related errors
|
||||
#[derive(Error, Debug)]
|
||||
pub enum ApiKeyError {
|
||||
#[error("API key not found: {key_id}")]
|
||||
NotFound { key_id: String },
|
||||
|
||||
#[error("API key expired: {key_id}")]
|
||||
Expired { key_id: String },
|
||||
|
||||
#[error("API key revoked: {key_id}")]
|
||||
Revoked { key_id: String },
|
||||
|
||||
#[error("API key suspended: {key_id}")]
|
||||
Suspended { key_id: String },
|
||||
|
||||
#[error("Invalid API key format")]
|
||||
InvalidFormat,
|
||||
|
||||
#[error("Insufficient permissions: required {required}, have {actual}")]
|
||||
InsufficientPermissions { required: String, actual: String },
|
||||
|
||||
#[error("Rate limit exceeded: retry after {retry_after_secs} seconds")]
|
||||
RateLimited { retry_after_secs: u64 },
|
||||
|
||||
#[error("IP blocked: {ip}")]
|
||||
IpBlocked { ip: String },
|
||||
|
||||
#[error("Rotation required: {reason}")]
|
||||
RotationRequired { reason: String },
|
||||
|
||||
#[error("Anomaly detected: {anomaly_type}")]
|
||||
AnomalyDetected { anomaly_type: String },
|
||||
|
||||
#[error("Database error: {0}")]
|
||||
Database(#[from] sqlx::Error),
|
||||
|
||||
#[error("Cache error: {message}")]
|
||||
Cache { message: String },
|
||||
|
||||
#[error("External service error: {service} - {message}")]
|
||||
ExternalService { service: String, message: String },
|
||||
|
||||
#[error("Configuration error: {0}")]
|
||||
Config(String),
|
||||
|
||||
#[error("Internal error: {0}")]
|
||||
Internal(#[from] anyhow::Error),
|
||||
}
|
||||
|
||||
/// Result type for API key operations
|
||||
pub type ApiKeyResult<T> = Result<T, ApiKeyError>;
|
||||
|
||||
impl ApiKeyError {
|
||||
/// Check if the error is retryable
|
||||
pub fn is_retryable(&self) -> bool {
|
||||
matches!(
|
||||
self,
|
||||
ApiKeyError::RateLimited { .. } | ApiKeyError::Database(_) | ApiKeyError::Internal(_)
|
||||
)
|
||||
}
|
||||
|
||||
/// Check if the error is a client error (4xx)
|
||||
pub fn is_client_error(&self) -> bool {
|
||||
matches!(
|
||||
self,
|
||||
ApiKeyError::NotFound { .. }
|
||||
| ApiKeyError::Expired { .. }
|
||||
| ApiKeyError::Revoked { .. }
|
||||
| ApiKeyError::Suspended { .. }
|
||||
| ApiKeyError::InvalidFormat
|
||||
| ApiKeyError::InsufficientPermissions { .. }
|
||||
| ApiKeyError::RateLimited { .. }
|
||||
| ApiKeyError::IpBlocked { .. }
|
||||
| ApiKeyError::RotationRequired { .. }
|
||||
)
|
||||
}
|
||||
|
||||
/// Get HTTP status code for the error
|
||||
pub fn status_code(&self) -> u16 {
|
||||
match self {
|
||||
ApiKeyError::NotFound { .. } => 404,
|
||||
ApiKeyError::Expired { .. } => 401,
|
||||
ApiKeyError::Revoked { .. } => 401,
|
||||
ApiKeyError::Suspended { .. } => 403,
|
||||
ApiKeyError::InvalidFormat => 400,
|
||||
ApiKeyError::InsufficientPermissions { .. } => 403,
|
||||
ApiKeyError::RateLimited { .. } => 429,
|
||||
ApiKeyError::IpBlocked { .. } => 403,
|
||||
ApiKeyError::RotationRequired { .. } => 401,
|
||||
ApiKeyError::AnomalyDetected { .. } => 403,
|
||||
ApiKeyError::Database(_) => 500,
|
||||
ApiKeyError::Cache { .. } => 500,
|
||||
ApiKeyError::ExternalService { .. } => 502,
|
||||
ApiKeyError::Config(_) => 500,
|
||||
ApiKeyError::Internal(_) => 500,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get error code for API responses
|
||||
pub fn error_code(&self) -> &'static str {
|
||||
match self {
|
||||
ApiKeyError::NotFound { .. } => "api_key.not_found",
|
||||
ApiKeyError::Expired { .. } => "api_key.expired",
|
||||
ApiKeyError::Revoked { .. } => "api_key.revoked",
|
||||
ApiKeyError::Suspended { .. } => "api_key.suspended",
|
||||
ApiKeyError::InvalidFormat => "api_key.invalid_format",
|
||||
ApiKeyError::InsufficientPermissions { .. } => "api_key.insufficient_permissions",
|
||||
ApiKeyError::RateLimited { .. } => "api_key.rate_limited",
|
||||
ApiKeyError::IpBlocked { .. } => "api_key.ip_blocked",
|
||||
ApiKeyError::RotationRequired { .. } => "api_key.rotation_required",
|
||||
ApiKeyError::AnomalyDetected { .. } => "api_key.anomaly_detected",
|
||||
ApiKeyError::Database(_) => "internal.database_error",
|
||||
ApiKeyError::Cache { .. } => "internal.cache_error",
|
||||
ApiKeyError::ExternalService { .. } => "external.service_error",
|
||||
ApiKeyError::Config(_) => "internal.config_error",
|
||||
ApiKeyError::Internal(_) => "internal.error",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_error_status_codes() {
|
||||
assert_eq!(
|
||||
ApiKeyError::NotFound {
|
||||
key_id: "test".into()
|
||||
}
|
||||
.status_code(),
|
||||
404
|
||||
);
|
||||
assert_eq!(
|
||||
ApiKeyError::Expired {
|
||||
key_id: "test".into()
|
||||
}
|
||||
.status_code(),
|
||||
401
|
||||
);
|
||||
assert_eq!(ApiKeyError::InvalidFormat.status_code(), 400);
|
||||
assert_eq!(
|
||||
ApiKeyError::RateLimited {
|
||||
retry_after_secs: 60
|
||||
}
|
||||
.status_code(),
|
||||
429
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_error_codes() {
|
||||
assert_eq!(
|
||||
ApiKeyError::NotFound {
|
||||
key_id: "test".into()
|
||||
}
|
||||
.error_code(),
|
||||
"api_key.not_found"
|
||||
);
|
||||
assert_eq!(
|
||||
ApiKeyError::RateLimited {
|
||||
retry_after_secs: 60
|
||||
}
|
||||
.error_code(),
|
||||
"api_key.rate_limited"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_client_error() {
|
||||
assert!(ApiKeyError::NotFound {
|
||||
key_id: "test".into()
|
||||
}
|
||||
.is_client_error());
|
||||
assert!(ApiKeyError::InvalidFormat.is_client_error());
|
||||
assert!(!ApiKeyError::Database(sqlx::Error::RowNotFound).is_client_error());
|
||||
}
|
||||
}
|
||||
226
src/core/api_key/export.rs
Normal file
226
src/core/api_key/export.rs
Normal file
@@ -0,0 +1,226 @@
|
||||
//! API Key Export/Import Module
|
||||
//!
|
||||
//! Supports exporting and importing API key records
|
||||
|
||||
use crate::core::db::postgres_db::PostgresDb;
|
||||
use anyhow::Result;
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::Path;
|
||||
|
||||
/// Export format
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ExportFormat {
|
||||
Json,
|
||||
Csv,
|
||||
}
|
||||
|
||||
impl ExportFormat {
|
||||
pub fn extension(&self) -> &'static str {
|
||||
match self {
|
||||
ExportFormat::Json => "json",
|
||||
ExportFormat::Csv => "csv",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Exported API key record (without sensitive data)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ExportedApiKey {
|
||||
pub key_id: String,
|
||||
pub name: String,
|
||||
pub key_type: String,
|
||||
pub status: String,
|
||||
pub permissions: serde_json::Value,
|
||||
pub expires_at: Option<DateTime<Utc>>,
|
||||
pub usage_count: i64,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub rotation_required: bool,
|
||||
}
|
||||
|
||||
/// Export container
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ExportContainer {
|
||||
pub exported_at: DateTime<Utc>,
|
||||
pub version: String,
|
||||
pub count: usize,
|
||||
pub keys: Vec<ExportedApiKey>,
|
||||
}
|
||||
|
||||
/// Import result
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ImportResult {
|
||||
pub imported: u32,
|
||||
pub skipped: u32,
|
||||
pub errors: Vec<String>,
|
||||
}
|
||||
|
||||
/// Export manager
|
||||
pub struct ExportManager {
|
||||
db: PostgresDb,
|
||||
}
|
||||
|
||||
impl ExportManager {
|
||||
pub fn new(db: PostgresDb) -> Self {
|
||||
Self { db }
|
||||
}
|
||||
|
||||
/// Export all API keys
|
||||
pub async fn export_all(&self, format: ExportFormat) -> Result<String> {
|
||||
let keys = self.db.list_api_keys().await?;
|
||||
|
||||
let exported: Vec<ExportedApiKey> = keys
|
||||
.into_iter()
|
||||
.map(|k| ExportedApiKey {
|
||||
key_id: k.key_id,
|
||||
name: k.name,
|
||||
key_type: k.key_type,
|
||||
status: k.status,
|
||||
permissions: k.permissions,
|
||||
expires_at: k.expires_at,
|
||||
usage_count: k.usage_count,
|
||||
created_at: k.created_at,
|
||||
rotation_required: k.rotation_required,
|
||||
})
|
||||
.collect();
|
||||
|
||||
let container = ExportContainer {
|
||||
exported_at: Utc::now(),
|
||||
version: "1.0".to_string(),
|
||||
count: exported.len(),
|
||||
keys: exported,
|
||||
};
|
||||
|
||||
match format {
|
||||
ExportFormat::Json => Ok(serde_json::to_string_pretty(&container)?),
|
||||
ExportFormat::Csv => self.to_csv(&container),
|
||||
}
|
||||
}
|
||||
|
||||
/// Export to file
|
||||
pub async fn export_to_file(&self, path: &Path, format: ExportFormat) -> Result<usize> {
|
||||
let content = self.export_all(format).await?;
|
||||
let count = serde_json::from_str::<ExportContainer>(&content)
|
||||
.map(|c| c.count)
|
||||
.unwrap_or(0);
|
||||
|
||||
tokio::fs::write(path, content).await?;
|
||||
Ok(count)
|
||||
}
|
||||
|
||||
/// Convert to CSV
|
||||
fn to_csv(&self, container: &ExportContainer) -> Result<String> {
|
||||
let mut csv = String::new();
|
||||
csv.push_str("key_id,name,key_type,status,usage_count,created_at,rotation_required\n");
|
||||
|
||||
for key in &container.keys {
|
||||
csv.push_str(&format!(
|
||||
"{},{},{},{},{},{},{}\n",
|
||||
key.key_id,
|
||||
key.name,
|
||||
key.key_type,
|
||||
key.status,
|
||||
key.usage_count,
|
||||
key.created_at.format("%Y-%m-%d %H:%M:%S"),
|
||||
key.rotation_required,
|
||||
));
|
||||
}
|
||||
|
||||
Ok(csv)
|
||||
}
|
||||
}
|
||||
|
||||
/// Import manager
|
||||
pub struct ImportManager {
|
||||
db: PostgresDb,
|
||||
}
|
||||
|
||||
impl ImportManager {
|
||||
pub fn new(db: PostgresDb) -> Self {
|
||||
Self { db }
|
||||
}
|
||||
|
||||
/// Import from JSON string
|
||||
pub async fn import_from_json(&self, json: &str, overwrite: bool) -> Result<ImportResult> {
|
||||
let container: ExportContainer = serde_json::from_str(json)?;
|
||||
let mut result = ImportResult {
|
||||
imported: 0,
|
||||
skipped: 0,
|
||||
errors: vec![],
|
||||
};
|
||||
|
||||
for key in container.keys {
|
||||
match self.import_key(&key, overwrite).await {
|
||||
Ok(true) => result.imported += 1,
|
||||
Ok(false) => result.skipped += 1,
|
||||
Err(e) => {
|
||||
result.errors.push(format!("{}: {}", key.key_id, e));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Import from file
|
||||
pub async fn import_from_file(&self, path: &Path, overwrite: bool) -> Result<ImportResult> {
|
||||
let content = tokio::fs::read_to_string(path).await?;
|
||||
|
||||
if path.extension().map(|e| e == "json").unwrap_or(false) {
|
||||
self.import_from_json(&content, overwrite).await
|
||||
} else {
|
||||
anyhow::bail!("Unsupported file format")
|
||||
}
|
||||
}
|
||||
|
||||
/// Import a single key
|
||||
async fn import_key(&self, key: &ExportedApiKey, overwrite: bool) -> Result<bool> {
|
||||
// Check if key already exists
|
||||
let existing = self.db.get_api_key_by_key_id(&key.key_id).await?;
|
||||
|
||||
if existing.is_some() && !overwrite {
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
// Note: Import only creates metadata, not the actual key hash
|
||||
// The actual key needs to be regenerated
|
||||
tracing::info!("Imported key metadata: {} ({})", key.key_id, key.name);
|
||||
|
||||
Ok(true)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_export_format_extension() {
|
||||
assert_eq!(ExportFormat::Json.extension(), "json");
|
||||
assert_eq!(ExportFormat::Csv.extension(), "csv");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_export_container_serialization() {
|
||||
let container = ExportContainer {
|
||||
exported_at: Utc::now(),
|
||||
version: "1.0".to_string(),
|
||||
count: 1,
|
||||
keys: vec![ExportedApiKey {
|
||||
key_id: "test_123".to_string(),
|
||||
name: "test".to_string(),
|
||||
key_type: "service".to_string(),
|
||||
status: "active".to_string(),
|
||||
permissions: serde_json::json!(["read"]),
|
||||
expires_at: None,
|
||||
usage_count: 0,
|
||||
created_at: Utc::now(),
|
||||
rotation_required: false,
|
||||
}],
|
||||
};
|
||||
|
||||
let json = serde_json::to_string_pretty(&container).unwrap();
|
||||
assert!(json.contains("\"key_id\": \"test_123\""));
|
||||
}
|
||||
}
|
||||
304
src/core/api_key/gitea.rs
Normal file
304
src/core/api_key/gitea.rs
Normal file
@@ -0,0 +1,304 @@
|
||||
//! Gitea API Token Integration
|
||||
//!
|
||||
//! Manages Gitea Personal Access Tokens through the API Key system
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use chrono::{DateTime, Utc};
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Gitea token scope
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum GiteaScope {
|
||||
ReadRepository,
|
||||
WriteRepository,
|
||||
ReadIssue,
|
||||
WriteIssue,
|
||||
ReadUser,
|
||||
WriteUser,
|
||||
ReadAdmin,
|
||||
WriteAdmin,
|
||||
ReadOrganization,
|
||||
WriteOrganization,
|
||||
ReadPackage,
|
||||
WritePackage,
|
||||
ReadNotification,
|
||||
WriteNotification,
|
||||
ReadActivitypub,
|
||||
WriteActivitypub,
|
||||
ReadMisc,
|
||||
WriteMisc,
|
||||
}
|
||||
|
||||
impl GiteaScope {
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
GiteaScope::ReadRepository => "read:repository",
|
||||
GiteaScope::WriteRepository => "write:repository",
|
||||
GiteaScope::ReadIssue => "read:issue",
|
||||
GiteaScope::WriteIssue => "write:issue",
|
||||
GiteaScope::ReadUser => "read:user",
|
||||
GiteaScope::WriteUser => "write:user",
|
||||
GiteaScope::ReadAdmin => "read:admin",
|
||||
GiteaScope::WriteAdmin => "write:admin",
|
||||
GiteaScope::ReadOrganization => "read:organization",
|
||||
GiteaScope::WriteOrganization => "write:organization",
|
||||
GiteaScope::ReadPackage => "read:package",
|
||||
GiteaScope::WritePackage => "write:package",
|
||||
GiteaScope::ReadNotification => "read:notification",
|
||||
GiteaScope::WriteNotification => "write:notification",
|
||||
GiteaScope::ReadActivitypub => "read:activitypub",
|
||||
GiteaScope::WriteActivitypub => "write:activitypub",
|
||||
GiteaScope::ReadMisc => "read:misc",
|
||||
GiteaScope::WriteMisc => "write:misc",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::str::FromStr for GiteaScope {
|
||||
type Err = String;
|
||||
|
||||
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
|
||||
match s {
|
||||
"read:repository" => Ok(GiteaScope::ReadRepository),
|
||||
"write:repository" => Ok(GiteaScope::WriteRepository),
|
||||
"read:issue" => Ok(GiteaScope::ReadIssue),
|
||||
"write:issue" => Ok(GiteaScope::WriteIssue),
|
||||
"read:user" => Ok(GiteaScope::ReadUser),
|
||||
"write:user" => Ok(GiteaScope::WriteUser),
|
||||
"read:admin" => Ok(GiteaScope::ReadAdmin),
|
||||
"write:admin" => Ok(GiteaScope::WriteAdmin),
|
||||
"read:organization" => Ok(GiteaScope::ReadOrganization),
|
||||
"write:organization" => Ok(GiteaScope::WriteOrganization),
|
||||
"read:package" => Ok(GiteaScope::ReadPackage),
|
||||
"write:package" => Ok(GiteaScope::WritePackage),
|
||||
"read:notification" => Ok(GiteaScope::ReadNotification),
|
||||
"write:notification" => Ok(GiteaScope::WriteNotification),
|
||||
"read:activitypub" => Ok(GiteaScope::ReadActivitypub),
|
||||
"write:activitypub" => Ok(GiteaScope::WriteActivitypub),
|
||||
"read:misc" => Ok(GiteaScope::ReadMisc),
|
||||
"write:misc" => Ok(GiteaScope::WriteMisc),
|
||||
_ => Err(format!("Invalid Gitea scope: {}", s)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Request to create a Gitea token
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CreateGiteaTokenRequest {
|
||||
pub username: String,
|
||||
pub password: String,
|
||||
pub token_name: String,
|
||||
pub scopes: Vec<GiteaScope>,
|
||||
}
|
||||
|
||||
/// Response from creating a Gitea token
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GiteaTokenResponse {
|
||||
pub id: i64,
|
||||
pub name: String,
|
||||
pub sha1: String,
|
||||
pub token_last_eight: String,
|
||||
}
|
||||
|
||||
/// List token response (without SHA1)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GiteaTokenInfo {
|
||||
pub id: i64,
|
||||
pub name: String,
|
||||
pub token_last_eight: String,
|
||||
}
|
||||
|
||||
/// Gitea API client
|
||||
pub struct GiteaClient {
|
||||
client: Client,
|
||||
base_url: String,
|
||||
}
|
||||
|
||||
impl GiteaClient {
|
||||
pub fn new() -> Result<Self> {
|
||||
let base_url =
|
||||
std::env::var("GITEA_URL").unwrap_or_else(|_| "http://localhost:3001".to_string());
|
||||
|
||||
Ok(Self {
|
||||
client: Client::new(),
|
||||
base_url,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn with_url(base_url: String) -> Self {
|
||||
Self {
|
||||
client: Client::new(),
|
||||
base_url,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new access token for a user
|
||||
pub async fn create_token(
|
||||
&self,
|
||||
request: &CreateGiteaTokenRequest,
|
||||
) -> Result<GiteaTokenResponse> {
|
||||
let url = format!("{}/api/v1/users/{}/tokens", self.base_url, request.username);
|
||||
|
||||
let scopes: Vec<String> = request
|
||||
.scopes
|
||||
.iter()
|
||||
.map(|s| s.as_str().to_string())
|
||||
.collect();
|
||||
|
||||
let body = serde_json::json!({
|
||||
"name": request.token_name,
|
||||
"scopes": scopes,
|
||||
});
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(&url)
|
||||
.basic_auth(&request.username, Some(&request.password))
|
||||
.json(&body)
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to send create token request")?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let text = response.text().await.unwrap_or_default();
|
||||
anyhow::bail!("Failed to create Gitea token: {} - {}", status, text);
|
||||
}
|
||||
|
||||
let token: GiteaTokenResponse = response
|
||||
.json()
|
||||
.await
|
||||
.context("Failed to parse token response")?;
|
||||
|
||||
Ok(token)
|
||||
}
|
||||
|
||||
/// List all tokens for a user
|
||||
pub async fn list_tokens(&self, username: &str, password: &str) -> Result<Vec<GiteaTokenInfo>> {
|
||||
let url = format!("{}/api/v1/users/{}/tokens", self.base_url, username);
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.get(&url)
|
||||
.basic_auth(username, Some(password))
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to send list tokens request")?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let text = response.text().await.unwrap_or_default();
|
||||
anyhow::bail!("Failed to list Gitea tokens: {} - {}", status, text);
|
||||
}
|
||||
|
||||
let tokens: Vec<GiteaTokenInfo> = response
|
||||
.json()
|
||||
.await
|
||||
.context("Failed to parse tokens response")?;
|
||||
|
||||
Ok(tokens)
|
||||
}
|
||||
|
||||
/// Delete a token by name
|
||||
pub async fn delete_token(
|
||||
&self,
|
||||
username: &str,
|
||||
password: &str,
|
||||
token_name: &str,
|
||||
) -> Result<()> {
|
||||
let url = format!(
|
||||
"{}/api/v1/users/{}/tokens/{}",
|
||||
self.base_url, username, token_name
|
||||
);
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.delete(&url)
|
||||
.basic_auth(username, Some(password))
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to send delete token request")?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let text = response.text().await.unwrap_or_default();
|
||||
anyhow::bail!("Failed to delete Gitea token: {} - {}", status, text);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Verify a token is valid by making a test API call
|
||||
pub async fn verify_token(&self, token: &str) -> Result<bool> {
|
||||
let url = format!("{}/api/v1/user", self.base_url);
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.get(&url)
|
||||
.header("Authorization", format!("token {}", token))
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to send verify token request")?;
|
||||
|
||||
Ok(response.status().is_success())
|
||||
}
|
||||
|
||||
/// Get base URL
|
||||
pub fn base_url(&self) -> &str {
|
||||
&self.base_url
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for GiteaClient {
|
||||
fn default() -> Self {
|
||||
Self::new().unwrap_or_else(|_| Self::with_url("http://localhost:3001".to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
/// Integrated token record
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GiteaTokenRecord {
|
||||
pub gitea_token_id: i64,
|
||||
pub gitea_user: String,
|
||||
pub token_name: String,
|
||||
pub token_last_eight: String,
|
||||
pub scopes: Vec<String>,
|
||||
pub api_key_id: Option<String>,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub last_verified: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_gitea_scope_as_str() {
|
||||
assert_eq!(GiteaScope::ReadRepository.as_str(), "read:repository");
|
||||
assert_eq!(GiteaScope::WriteIssue.as_str(), "write:issue");
|
||||
assert_eq!(GiteaScope::ReadAdmin.as_str(), "read:admin");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gitea_scope_from_str() {
|
||||
use std::str::FromStr;
|
||||
|
||||
assert!(matches!(
|
||||
GiteaScope::from_str("read:repository").ok(),
|
||||
Some(GiteaScope::ReadRepository)
|
||||
));
|
||||
assert!(matches!(
|
||||
GiteaScope::from_str("write:issue").ok(),
|
||||
Some(GiteaScope::WriteIssue)
|
||||
));
|
||||
assert!(GiteaScope::from_str("invalid").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gitea_client_default() {
|
||||
let client = GiteaClient::default();
|
||||
assert_eq!(client.base_url(), "http://localhost:3001");
|
||||
}
|
||||
}
|
||||
45
src/core/api_key/mod.rs
Normal file
45
src/core/api_key/mod.rs
Normal file
@@ -0,0 +1,45 @@
|
||||
//! API Key Management Module
|
||||
//!
|
||||
//! Features:
|
||||
//! - API Key generation with secure random
|
||||
//! - Key hashing (SHA256)
|
||||
//! - Anomaly detection
|
||||
//! - Forced rotation mechanism
|
||||
//! - Audit logging
|
||||
//! - Gitea token integration
|
||||
//! - n8n API key integration
|
||||
//! - Cached validation with rate limiting
|
||||
|
||||
pub mod anomaly;
|
||||
pub mod audit_logger;
|
||||
pub mod blacklist;
|
||||
pub mod cleanup;
|
||||
pub mod encryption;
|
||||
pub mod error;
|
||||
pub mod export;
|
||||
pub mod gitea;
|
||||
pub mod models;
|
||||
pub mod n8n;
|
||||
pub mod report;
|
||||
pub mod rotation;
|
||||
pub mod service;
|
||||
pub mod strength;
|
||||
pub mod validator;
|
||||
pub mod webhook;
|
||||
|
||||
pub use audit_logger::{AsyncAuditLogger, AuditEntry, AuditLoggerConfig};
|
||||
pub use blacklist::{BlacklistConfig, BlacklistEntry, IpBlacklist};
|
||||
pub use cleanup::{CleanupConfig, CleanupManager, CleanupResult};
|
||||
pub use encryption::{AuditEncryption, EncryptedAuditLogEntry, EncryptedData};
|
||||
pub use error::{ApiKeyError, ApiKeyResult};
|
||||
pub use export::{ExportFormat, ExportManager, ImportManager};
|
||||
pub use gitea::*;
|
||||
pub use models::*;
|
||||
pub use n8n::*;
|
||||
pub use report::{ApiKeyReport, ReportGenerator, ReportSummary};
|
||||
pub use service::ApiKeyService;
|
||||
pub use strength::{KeyStrength, KeyStrengthValidator, StrengthResult};
|
||||
pub use validator::{
|
||||
ApiKeyValidator, CacheStats, RateLimitResult, RateLimitStats, ValidatorConfig,
|
||||
};
|
||||
pub use webhook::{WebhookConfig, WebhookEvent, WebhookNotifier, WebhookPayload};
|
||||
228
src/core/api_key/models.rs
Normal file
228
src/core/api_key/models.rs
Normal file
@@ -0,0 +1,228 @@
|
||||
//! API Key Models
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashSet;
|
||||
use utoipa::ToSchema;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, ToSchema)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ApiKeyType {
|
||||
System,
|
||||
User,
|
||||
Service,
|
||||
Integration,
|
||||
Emergency,
|
||||
}
|
||||
|
||||
impl ApiKeyType {
|
||||
pub fn prefix(&self) -> &'static str {
|
||||
match self {
|
||||
ApiKeyType::System => "msys_",
|
||||
ApiKeyType::User => "muser_",
|
||||
ApiKeyType::Service => "msvc_",
|
||||
ApiKeyType::Integration => "mint_",
|
||||
ApiKeyType::Emergency => "memg_",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn default_ttl_days(&self) -> i64 {
|
||||
match self {
|
||||
ApiKeyType::System => 365,
|
||||
ApiKeyType::User => 90,
|
||||
ApiKeyType::Service => 180,
|
||||
ApiKeyType::Integration => 30,
|
||||
ApiKeyType::Emergency => 1,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn grace_period_hours(&self) -> i64 {
|
||||
match self {
|
||||
ApiKeyType::System => 72,
|
||||
ApiKeyType::User => 24,
|
||||
ApiKeyType::Service => 48,
|
||||
ApiKeyType::Integration => 24,
|
||||
ApiKeyType::Emergency => 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, ToSchema)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ApiKeyStatus {
|
||||
Active,
|
||||
Suspended,
|
||||
Expired,
|
||||
Revoked,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum RotationType {
|
||||
Scheduled,
|
||||
Manual,
|
||||
Forced,
|
||||
Emergency,
|
||||
AnomalyTriggered,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum AnomalyType {
|
||||
HighRequestRate,
|
||||
HighErrorRate,
|
||||
MultipleIps,
|
||||
UnusualTime,
|
||||
BruteForce,
|
||||
DataExfiltration,
|
||||
}
|
||||
|
||||
impl AnomalyType {
|
||||
pub fn severity(&self) -> AnomalySeverity {
|
||||
match self {
|
||||
AnomalyType::HighRequestRate => AnomalySeverity::Medium,
|
||||
AnomalyType::HighErrorRate => AnomalySeverity::Medium,
|
||||
AnomalyType::MultipleIps => AnomalySeverity::Low,
|
||||
AnomalyType::UnusualTime => AnomalySeverity::Low,
|
||||
AnomalyType::BruteForce => AnomalySeverity::Critical,
|
||||
AnomalyType::DataExfiltration => AnomalySeverity::Critical,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, ToSchema)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum AnomalySeverity {
|
||||
Low,
|
||||
Medium,
|
||||
High,
|
||||
Critical,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ApiKey {
|
||||
pub id: i64,
|
||||
pub key_id: String,
|
||||
pub key_hash: String,
|
||||
pub key_prefix: String,
|
||||
pub name: String,
|
||||
pub key_type: ApiKeyType,
|
||||
pub user_id: Option<i64>,
|
||||
pub service_name: Option<String>,
|
||||
pub permissions: HashSet<String>,
|
||||
pub expires_at: Option<DateTime<Utc>>,
|
||||
pub last_used_at: Option<DateTime<Utc>>,
|
||||
pub last_used_ip: Option<String>,
|
||||
pub usage_count: i64,
|
||||
pub status: ApiKeyStatus,
|
||||
pub rotation_required: bool,
|
||||
pub rotation_reason: Option<String>,
|
||||
pub grace_period_end: Option<DateTime<Utc>>,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
|
||||
pub struct CreateApiKeyRequest {
|
||||
pub name: String,
|
||||
pub key_type: ApiKeyType,
|
||||
pub user_id: Option<i64>,
|
||||
pub service_name: Option<String>,
|
||||
pub permissions: Vec<String>,
|
||||
pub ttl_days: Option<i64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
|
||||
pub struct CreateApiKeyResponse {
|
||||
pub key: String,
|
||||
pub key_id: String,
|
||||
pub expires_at: DateTime<Utc>,
|
||||
pub warning: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
|
||||
pub struct ValidateApiKeyRequest {
|
||||
pub key: String,
|
||||
pub ip_address: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
|
||||
pub struct ValidateApiKeyResponse {
|
||||
pub valid: bool,
|
||||
pub key_id: Option<String>,
|
||||
pub permissions: Option<Vec<String>>,
|
||||
pub error: Option<String>,
|
||||
pub requires_rotation: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AuditLogEntry {
|
||||
pub id: i64,
|
||||
pub key_id: String,
|
||||
pub action: String,
|
||||
pub actor: Option<String>,
|
||||
pub ip_address: Option<String>,
|
||||
pub user_agent: Option<String>,
|
||||
pub request_path: Option<String>,
|
||||
pub response_code: Option<i32>,
|
||||
pub anomaly_type: Option<AnomalyType>,
|
||||
pub details: Option<serde_json::Value>,
|
||||
pub created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AnomalyRecord {
|
||||
pub id: i64,
|
||||
pub key_id: String,
|
||||
pub anomaly_type: AnomalyType,
|
||||
pub severity: AnomalySeverity,
|
||||
pub ip_address: Option<String>,
|
||||
pub request_count: Option<i32>,
|
||||
pub error_count: Option<i32>,
|
||||
pub error_rate: Option<f64>,
|
||||
pub unique_ips: Option<i32>,
|
||||
pub details: Option<serde_json::Value>,
|
||||
pub resolved: bool,
|
||||
pub resolved_at: Option<DateTime<Utc>>,
|
||||
pub resolved_by: Option<String>,
|
||||
pub created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RotationStatus {
|
||||
pub key_id: String,
|
||||
pub requires_rotation: bool,
|
||||
pub reason: Option<String>,
|
||||
pub grace_period_end: Option<DateTime<Utc>>,
|
||||
pub in_grace_period: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AnomalyDetectionConfig {
|
||||
pub requests_per_minute_threshold: i32,
|
||||
pub requests_per_hour_threshold: i32,
|
||||
pub error_rate_threshold: f64,
|
||||
pub unique_ips_per_hour_threshold: i32,
|
||||
pub lockout_threshold: i32,
|
||||
}
|
||||
|
||||
impl Default for AnomalyDetectionConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
requests_per_minute_threshold: 1000,
|
||||
requests_per_hour_threshold: 10000,
|
||||
error_rate_threshold: 0.5,
|
||||
unique_ips_per_hour_threshold: 5,
|
||||
lockout_threshold: 3,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
|
||||
pub struct ApiKeyStats {
|
||||
pub total_keys: i64,
|
||||
pub active_keys: i64,
|
||||
pub expired_keys: i64,
|
||||
pub rotation_required: i64,
|
||||
pub anomalies_last_24h: i64,
|
||||
}
|
||||
211
src/core/api_key/n8n.rs
Normal file
211
src/core/api_key/n8n.rs
Normal file
@@ -0,0 +1,211 @@
|
||||
//! n8n API Key Integration
|
||||
//!
|
||||
//! Manages n8n API Keys through the API Key system
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use chrono::{DateTime, Utc};
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Request to create an n8n API key
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CreateN8nApiKeyRequest {
|
||||
pub label: String,
|
||||
pub expires_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
/// Response from creating an n8n API key
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct N8nApiKeyResponse {
|
||||
pub id: String,
|
||||
pub api_key: String,
|
||||
pub label: String,
|
||||
pub created_at: Option<DateTime<Utc>>,
|
||||
pub expires_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
/// n8n API key info (without raw key)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct N8nApiKeyInfo {
|
||||
pub id: String,
|
||||
pub label: String,
|
||||
pub created_at: Option<DateTime<Utc>>,
|
||||
pub updated_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
/// n8n API client
|
||||
pub struct N8nClient {
|
||||
client: Client,
|
||||
base_url: String,
|
||||
api_key: String,
|
||||
}
|
||||
|
||||
impl N8nClient {
|
||||
pub fn new(api_key: String) -> Result<Self> {
|
||||
let base_url = std::env::var("N8N_URL")
|
||||
.unwrap_or_else(|_| "https://n8n.momentry.ddns.net".to_string());
|
||||
|
||||
Ok(Self {
|
||||
client: Client::new(),
|
||||
base_url,
|
||||
api_key,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn with_url(base_url: String, api_key: String) -> Self {
|
||||
Self {
|
||||
client: Client::new(),
|
||||
base_url,
|
||||
api_key,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new API key
|
||||
pub async fn create_api_key(
|
||||
&self,
|
||||
request: &CreateN8nApiKeyRequest,
|
||||
) -> Result<N8nApiKeyResponse> {
|
||||
let url = format!("{}/api/v1/me/api-keys", self.base_url);
|
||||
|
||||
let mut body = serde_json::json!({
|
||||
"label": request.label,
|
||||
});
|
||||
|
||||
if let Some(expires_at) = request.expires_at {
|
||||
body["expiresAt"] = serde_json::json!(expires_at.to_rfc3339());
|
||||
}
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(&url)
|
||||
.header("X-N8N-API-KEY", &self.api_key)
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&body)
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to send create API key request")?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let text = response.text().await.unwrap_or_default();
|
||||
anyhow::bail!("Failed to create n8n API key: {} - {}", status, text);
|
||||
}
|
||||
|
||||
let api_key: N8nApiKeyResponse = response
|
||||
.json()
|
||||
.await
|
||||
.context("Failed to parse API key response")?;
|
||||
|
||||
Ok(api_key)
|
||||
}
|
||||
|
||||
/// List all API keys for the authenticated user
|
||||
pub async fn list_api_keys(&self) -> Result<Vec<N8nApiKeyInfo>> {
|
||||
let url = format!("{}/api/v1/me/api-keys", self.base_url);
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.get(&url)
|
||||
.header("X-N8N-API-KEY", &self.api_key)
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to send list API keys request")?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let text = response.text().await.unwrap_or_default();
|
||||
anyhow::bail!("Failed to list n8n API keys: {} - {}", status, text);
|
||||
}
|
||||
|
||||
let data: serde_json::Value = response
|
||||
.json()
|
||||
.await
|
||||
.context("Failed to parse API keys response")?;
|
||||
|
||||
// n8n returns { data: [...] } format
|
||||
let keys: Vec<N8nApiKeyInfo> =
|
||||
serde_json::from_value(data["data"].clone()).unwrap_or_default();
|
||||
|
||||
Ok(keys)
|
||||
}
|
||||
|
||||
/// Delete an API key by ID
|
||||
pub async fn delete_api_key(&self, key_id: &str) -> Result<()> {
|
||||
let url = format!("{}/api/v1/me/api-keys/{}", self.base_url, key_id);
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.delete(&url)
|
||||
.header("X-N8N-API-KEY", &self.api_key)
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to send delete API key request")?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let text = response.text().await.unwrap_or_default();
|
||||
anyhow::bail!("Failed to delete n8n API key: {} - {}", status, text);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Verify an API key is valid by making a test API call
|
||||
pub async fn verify_api_key(&self, api_key: &str) -> Result<bool> {
|
||||
let url = format!("{}/api/v1/workflows", self.base_url);
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.get(&url)
|
||||
.header("X-N8N-API-KEY", api_key)
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to send verify API key request")?;
|
||||
|
||||
Ok(response.status().is_success())
|
||||
}
|
||||
|
||||
/// Get base URL
|
||||
pub fn base_url(&self) -> &str {
|
||||
&self.base_url
|
||||
}
|
||||
}
|
||||
|
||||
/// Integrated n8n API key record
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct N8nApiKeyRecord {
|
||||
pub n8n_key_id: String,
|
||||
pub label: String,
|
||||
pub api_key_last_eight: String,
|
||||
pub momentry_api_key_id: Option<String>,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub last_verified: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
/// Extract last 8 characters from API key for display
|
||||
pub fn extract_last_eight(api_key: &str) -> String {
|
||||
if api_key.len() <= 8 {
|
||||
api_key.to_string()
|
||||
} else {
|
||||
api_key[api_key.len() - 8..].to_string()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_extract_last_eight() {
|
||||
assert_eq!(extract_last_eight("n8n_api_1234567890abcdef"), "90abcdef");
|
||||
assert_eq!(extract_last_eight("short"), "short");
|
||||
assert_eq!(extract_last_eight("12345678"), "12345678");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_n8n_client_with_url() {
|
||||
let client =
|
||||
N8nClient::with_url("http://localhost:5678".to_string(), "test_key".to_string());
|
||||
assert_eq!(client.base_url(), "http://localhost:5678");
|
||||
}
|
||||
}
|
||||
233
src/core/api_key/report.rs
Normal file
233
src/core/api_key/report.rs
Normal file
@@ -0,0 +1,233 @@
|
||||
//! API Key Statistics Report Module
|
||||
//!
|
||||
//! Generates usage statistics and reports for API keys
|
||||
|
||||
use crate::core::db::postgres_db::PostgresDb;
|
||||
use anyhow::Result;
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Detailed statistics report
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ApiKeyReport {
|
||||
pub generated_at: DateTime<Utc>,
|
||||
pub period: ReportPeriod,
|
||||
pub summary: ReportSummary,
|
||||
pub by_type: Vec<TypeStats>,
|
||||
pub by_status: Vec<StatusStats>,
|
||||
pub top_usage: Vec<UsageStats>,
|
||||
pub anomalies: AnomalyStats,
|
||||
}
|
||||
|
||||
/// Report period
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ReportPeriod {
|
||||
pub start: DateTime<Utc>,
|
||||
pub end: DateTime<Utc>,
|
||||
pub days: i64,
|
||||
}
|
||||
|
||||
/// Summary statistics
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ReportSummary {
|
||||
pub total_keys: i64,
|
||||
pub active_keys: i64,
|
||||
pub expired_keys: i64,
|
||||
pub revoked_keys: i64,
|
||||
pub keys_needing_rotation: i64,
|
||||
pub total_usage: i64,
|
||||
}
|
||||
|
||||
/// Statistics by key type
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TypeStats {
|
||||
pub key_type: String,
|
||||
pub count: i64,
|
||||
pub active: i64,
|
||||
pub expired: i64,
|
||||
}
|
||||
|
||||
/// Statistics by status
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct StatusStats {
|
||||
pub status: String,
|
||||
pub count: i64,
|
||||
pub percentage: f64,
|
||||
}
|
||||
|
||||
/// Top usage statistics
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct UsageStats {
|
||||
pub key_id: String,
|
||||
pub name: String,
|
||||
pub usage_count: i64,
|
||||
pub last_used: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
/// Anomaly statistics
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AnomalyStats {
|
||||
pub total: i64,
|
||||
pub last_24h: i64,
|
||||
pub last_7d: i64,
|
||||
pub by_severity: Vec<SeverityStats>,
|
||||
}
|
||||
|
||||
/// Severity statistics
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SeverityStats {
|
||||
pub severity: String,
|
||||
pub count: i64,
|
||||
}
|
||||
|
||||
/// Report generator
|
||||
pub struct ReportGenerator {
|
||||
db: PostgresDb,
|
||||
}
|
||||
|
||||
impl ReportGenerator {
|
||||
pub fn new(db: PostgresDb) -> Self {
|
||||
Self { db }
|
||||
}
|
||||
|
||||
/// Generate a full report
|
||||
pub async fn generate_report(&self, days: i64) -> Result<ApiKeyReport> {
|
||||
let end = Utc::now();
|
||||
let start = end - Duration::days(days);
|
||||
|
||||
let stats = self.db.get_api_key_stats().await?;
|
||||
|
||||
Ok(ApiKeyReport {
|
||||
generated_at: Utc::now(),
|
||||
period: ReportPeriod { start, end, days },
|
||||
summary: ReportSummary {
|
||||
total_keys: stats.total_keys,
|
||||
active_keys: stats.active_keys,
|
||||
expired_keys: stats.expired_keys,
|
||||
revoked_keys: 0,
|
||||
keys_needing_rotation: stats.rotation_required,
|
||||
total_usage: 0,
|
||||
},
|
||||
by_type: vec![],
|
||||
by_status: vec![
|
||||
StatusStats {
|
||||
status: "active".to_string(),
|
||||
count: stats.active_keys,
|
||||
percentage: if stats.total_keys > 0 {
|
||||
(stats.active_keys as f64 / stats.total_keys as f64) * 100.0
|
||||
} else {
|
||||
0.0
|
||||
},
|
||||
},
|
||||
StatusStats {
|
||||
status: "expired".to_string(),
|
||||
count: stats.expired_keys,
|
||||
percentage: if stats.total_keys > 0 {
|
||||
(stats.expired_keys as f64 / stats.total_keys as f64) * 100.0
|
||||
} else {
|
||||
0.0
|
||||
},
|
||||
},
|
||||
],
|
||||
top_usage: vec![],
|
||||
anomalies: AnomalyStats {
|
||||
total: 0,
|
||||
last_24h: stats.anomalies_last_24h,
|
||||
last_7d: 0,
|
||||
by_severity: vec![],
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
/// Generate text report
|
||||
pub async fn generate_text_report(&self, days: i64) -> Result<String> {
|
||||
let report = self.generate_report(days).await?;
|
||||
|
||||
let mut output = String::new();
|
||||
output.push_str("=== API Key Statistics Report ===\n");
|
||||
output.push_str(&format!(
|
||||
"Generated: {}\n",
|
||||
report.generated_at.format("%Y-%m-%d %H:%M:%S")
|
||||
));
|
||||
output.push_str(&format!(
|
||||
"Period: {} to {} ({} days)\n\n",
|
||||
report.period.start.format("%Y-%m-%d"),
|
||||
report.period.end.format("%Y-%m-%d"),
|
||||
report.period.days
|
||||
));
|
||||
|
||||
output.push_str("--- Summary ---\n");
|
||||
output.push_str(&format!(
|
||||
"Total Keys: {}\n",
|
||||
report.summary.total_keys
|
||||
));
|
||||
output.push_str(&format!(
|
||||
"Active Keys: {}\n",
|
||||
report.summary.active_keys
|
||||
));
|
||||
output.push_str(&format!(
|
||||
"Expired Keys: {}\n",
|
||||
report.summary.expired_keys
|
||||
));
|
||||
output.push_str(&format!(
|
||||
"Rotation Required: {}\n\n",
|
||||
report.summary.keys_needing_rotation
|
||||
));
|
||||
|
||||
output.push_str("--- Status Distribution ---\n");
|
||||
for status in &report.by_status {
|
||||
output.push_str(&format!(
|
||||
"{:12}: {} ({:.1}%)\n",
|
||||
status.status, status.count, status.percentage
|
||||
));
|
||||
}
|
||||
|
||||
output.push_str(&format!(
|
||||
"\n--- Anomalies (Last 24h) ---\n{}\n",
|
||||
report.anomalies.last_24h
|
||||
));
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_report_serialization() {
|
||||
let report = ApiKeyReport {
|
||||
generated_at: Utc::now(),
|
||||
period: ReportPeriod {
|
||||
start: Utc::now() - Duration::days(30),
|
||||
end: Utc::now(),
|
||||
days: 30,
|
||||
},
|
||||
summary: ReportSummary {
|
||||
total_keys: 10,
|
||||
active_keys: 8,
|
||||
expired_keys: 2,
|
||||
revoked_keys: 0,
|
||||
keys_needing_rotation: 1,
|
||||
total_usage: 1000,
|
||||
},
|
||||
by_type: vec![],
|
||||
by_status: vec![StatusStats {
|
||||
status: "active".to_string(),
|
||||
count: 8,
|
||||
percentage: 80.0,
|
||||
}],
|
||||
top_usage: vec![],
|
||||
anomalies: AnomalyStats {
|
||||
total: 5,
|
||||
last_24h: 1,
|
||||
last_7d: 3,
|
||||
by_severity: vec![],
|
||||
},
|
||||
};
|
||||
|
||||
let json = serde_json::to_string_pretty(&report).unwrap();
|
||||
assert!(json.contains("\"total_keys\": 10"));
|
||||
}
|
||||
}
|
||||
319
src/core/api_key/rotation.rs
Normal file
319
src/core/api_key/rotation.rs
Normal file
@@ -0,0 +1,319 @@
|
||||
//! API Key Rotation Module
|
||||
//!
|
||||
//! Implements forced rotation mechanism with grace periods
|
||||
|
||||
use crate::core::api_key::models::*;
|
||||
use chrono::{Duration, Utc};
|
||||
use std::collections::HashMap;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
pub struct RotationManager {
|
||||
grace_periods: HashMap<ApiKeyType, i64>,
|
||||
rotation_queue: RwLock<Vec<RotationTask>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct RotationTask {
|
||||
key_id: String,
|
||||
key_type: ApiKeyType,
|
||||
reason: RotationReason,
|
||||
created_at: chrono::DateTime<Utc>,
|
||||
scheduled_at: chrono::DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum RotationReason {
|
||||
Expired,
|
||||
Manual,
|
||||
Forced,
|
||||
AnomalyDetected,
|
||||
SecurityBreach,
|
||||
PolicyChange,
|
||||
}
|
||||
|
||||
impl RotationReason {
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
RotationReason::Expired => "expired",
|
||||
RotationReason::Manual => "manual",
|
||||
RotationReason::Forced => "forced",
|
||||
RotationReason::AnomalyDetected => "anomaly_detected",
|
||||
RotationReason::SecurityBreach => "security_breach",
|
||||
RotationReason::PolicyChange => "policy_change",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn requires_immediate_rotation(&self) -> bool {
|
||||
matches!(
|
||||
self,
|
||||
RotationReason::AnomalyDetected | RotationReason::SecurityBreach
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl RotationManager {
|
||||
pub fn new() -> Self {
|
||||
let mut grace_periods = HashMap::new();
|
||||
grace_periods.insert(ApiKeyType::System, 72);
|
||||
grace_periods.insert(ApiKeyType::User, 24);
|
||||
grace_periods.insert(ApiKeyType::Service, 48);
|
||||
grace_periods.insert(ApiKeyType::Integration, 24);
|
||||
grace_periods.insert(ApiKeyType::Emergency, 0);
|
||||
|
||||
Self {
|
||||
grace_periods,
|
||||
rotation_queue: RwLock::new(Vec::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_grace_period(&self, key_type: ApiKeyType) -> Duration {
|
||||
let hours = self.grace_periods.get(&key_type).copied().unwrap_or(24);
|
||||
Duration::hours(hours)
|
||||
}
|
||||
|
||||
pub fn calculate_grace_period_end(
|
||||
&self,
|
||||
key_type: ApiKeyType,
|
||||
triggered_at: chrono::DateTime<Utc>,
|
||||
) -> chrono::DateTime<Utc> {
|
||||
triggered_at + self.get_grace_period(key_type)
|
||||
}
|
||||
|
||||
pub fn is_in_grace_period(&self, grace_period_end: Option<chrono::DateTime<Utc>>) -> bool {
|
||||
match grace_period_end {
|
||||
Some(end) => Utc::now() < end,
|
||||
None => false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_grace_period_expired(&self, grace_period_end: Option<chrono::DateTime<Utc>>) -> bool {
|
||||
match grace_period_end {
|
||||
Some(end) => Utc::now() >= end,
|
||||
None => true,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn queue_rotation(
|
||||
&self,
|
||||
key_id: String,
|
||||
key_type: ApiKeyType,
|
||||
reason: RotationReason,
|
||||
) {
|
||||
let grace_period_end = self.calculate_grace_period_end(key_type, Utc::now());
|
||||
|
||||
let scheduled_at = if reason.requires_immediate_rotation() {
|
||||
Utc::now()
|
||||
} else {
|
||||
grace_period_end
|
||||
};
|
||||
|
||||
let task = RotationTask {
|
||||
key_id,
|
||||
key_type,
|
||||
reason,
|
||||
created_at: Utc::now(),
|
||||
scheduled_at,
|
||||
};
|
||||
|
||||
let mut queue = self.rotation_queue.write().await;
|
||||
queue.push(task);
|
||||
}
|
||||
|
||||
pub async fn get_pending_rotations(&self) -> Vec<RotationStatus> {
|
||||
let queue = self.rotation_queue.read().await;
|
||||
|
||||
queue
|
||||
.iter()
|
||||
.map(|task| {
|
||||
let grace_period_end =
|
||||
self.calculate_grace_period_end(task.key_type, task.created_at);
|
||||
RotationStatus {
|
||||
key_id: task.key_id.clone(),
|
||||
requires_rotation: true,
|
||||
reason: Some(task.reason.as_str().to_string()),
|
||||
grace_period_end: Some(grace_period_end),
|
||||
in_grace_period: self.is_in_grace_period(Some(grace_period_end)),
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub async fn get_overdue_rotations(&self) -> Vec<RotationStatus> {
|
||||
let queue = self.rotation_queue.read().await;
|
||||
let now = Utc::now();
|
||||
|
||||
queue
|
||||
.iter()
|
||||
.filter(|task| task.scheduled_at <= now)
|
||||
.map(|task| {
|
||||
let grace_period_end =
|
||||
self.calculate_grace_period_end(task.key_type, task.created_at);
|
||||
RotationStatus {
|
||||
key_id: task.key_id.clone(),
|
||||
requires_rotation: true,
|
||||
reason: Some(task.reason.as_str().to_string()),
|
||||
grace_period_end: Some(grace_period_end),
|
||||
in_grace_period: self.is_in_grace_period(Some(grace_period_end)),
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub async fn remove_from_queue(&self, key_id: &str) {
|
||||
let mut queue = self.rotation_queue.write().await;
|
||||
queue.retain(|task| task.key_id != key_id);
|
||||
}
|
||||
|
||||
pub fn check_rotation_required(&self, key: &ApiKey) -> RotationStatus {
|
||||
let now = Utc::now();
|
||||
|
||||
if key.status == ApiKeyStatus::Revoked {
|
||||
return RotationStatus {
|
||||
key_id: key.key_id.clone(),
|
||||
requires_rotation: false,
|
||||
reason: Some("key_revoked".to_string()),
|
||||
grace_period_end: None,
|
||||
in_grace_period: false,
|
||||
};
|
||||
}
|
||||
|
||||
if let Some(expires_at) = key.expires_at {
|
||||
if now > expires_at {
|
||||
return RotationStatus {
|
||||
key_id: key.key_id.clone(),
|
||||
requires_rotation: true,
|
||||
reason: Some("expired".to_string()),
|
||||
grace_period_end: Some(
|
||||
self.calculate_grace_period_end(key.key_type, expires_at),
|
||||
),
|
||||
in_grace_period: self.is_in_grace_period(Some(
|
||||
self.calculate_grace_period_end(key.key_type, expires_at),
|
||||
)),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
if key.rotation_required {
|
||||
return RotationStatus {
|
||||
key_id: key.key_id.clone(),
|
||||
requires_rotation: true,
|
||||
reason: key.rotation_reason.clone(),
|
||||
grace_period_end: key.grace_period_end,
|
||||
in_grace_period: self.is_in_grace_period(key.grace_period_end),
|
||||
};
|
||||
}
|
||||
|
||||
RotationStatus {
|
||||
key_id: key.key_id.clone(),
|
||||
requires_rotation: false,
|
||||
reason: None,
|
||||
grace_period_end: None,
|
||||
in_grace_period: false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn should_auto_expire(&self, key: &ApiKey) -> bool {
|
||||
if key.status != ApiKeyStatus::Active {
|
||||
return false;
|
||||
}
|
||||
|
||||
if key.key_type == ApiKeyType::Emergency {
|
||||
if let Some(expires_at) = key.expires_at {
|
||||
return Utc::now() >= expires_at;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(grace_period_end) = key.grace_period_end {
|
||||
return self.is_grace_period_expired(Some(grace_period_end)) && key.rotation_required;
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for RotationManager {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RotationScheduler {
|
||||
check_interval_seconds: u64,
|
||||
}
|
||||
|
||||
impl RotationScheduler {
|
||||
pub fn new(check_interval_seconds: u64) -> Self {
|
||||
Self {
|
||||
check_interval_seconds,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn check_interval(&self) -> Duration {
|
||||
Duration::seconds(self.check_interval_seconds as i64)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for RotationScheduler {
|
||||
fn default() -> Self {
|
||||
Self::new(3600)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_grace_period_calculation() {
|
||||
let manager = RotationManager::new();
|
||||
|
||||
let user_grace = manager.get_grace_period(ApiKeyType::User);
|
||||
assert_eq!(user_grace, Duration::hours(24));
|
||||
|
||||
let system_grace = manager.get_grace_period(ApiKeyType::System);
|
||||
assert_eq!(system_grace, Duration::hours(72));
|
||||
|
||||
let emergency_grace = manager.get_grace_period(ApiKeyType::Emergency);
|
||||
assert_eq!(emergency_grace, Duration::hours(0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rotation_reason_requires_immediate() {
|
||||
assert!(RotationReason::AnomalyDetected.requires_immediate_rotation());
|
||||
assert!(RotationReason::SecurityBreach.requires_immediate_rotation());
|
||||
assert!(!RotationReason::Expired.requires_immediate_rotation());
|
||||
assert!(!RotationReason::Manual.requires_immediate_rotation());
|
||||
assert!(!RotationReason::Forced.requires_immediate_rotation());
|
||||
assert!(!RotationReason::PolicyChange.requires_immediate_rotation());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_rotation() {
|
||||
let manager = RotationManager::new();
|
||||
|
||||
manager
|
||||
.queue_rotation(
|
||||
"test_key_123".to_string(),
|
||||
ApiKeyType::User,
|
||||
RotationReason::Manual,
|
||||
)
|
||||
.await;
|
||||
|
||||
let pending = manager.get_pending_rotations().await;
|
||||
assert_eq!(pending.len(), 1);
|
||||
assert_eq!(pending[0].key_id, "test_key_123");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_in_grace_period() {
|
||||
let manager = RotationManager::new();
|
||||
|
||||
let future_end = Utc::now() + Duration::hours(12);
|
||||
assert!(manager.is_in_grace_period(Some(future_end)));
|
||||
|
||||
let past_end = Utc::now() - Duration::hours(1);
|
||||
assert!(!manager.is_in_grace_period(Some(past_end)));
|
||||
|
||||
assert!(!manager.is_in_grace_period(None));
|
||||
}
|
||||
}
|
||||
276
src/core/api_key/service.rs
Normal file
276
src/core/api_key/service.rs
Normal file
@@ -0,0 +1,276 @@
|
||||
//! API Key Service
|
||||
//!
|
||||
//! Core functionality for API key management
|
||||
|
||||
use crate::core::api_key::models::*;
|
||||
use anyhow::Result;
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use sha2::{Digest, Sha256};
|
||||
use subtle::ConstantTimeEq;
|
||||
use uuid::Uuid;
|
||||
|
||||
pub struct ApiKeyService {
|
||||
_db_url: String,
|
||||
config: AnomalyDetectionConfig,
|
||||
}
|
||||
|
||||
impl ApiKeyService {
|
||||
pub fn new(db_url: String) -> Self {
|
||||
Self {
|
||||
_db_url: db_url,
|
||||
config: AnomalyDetectionConfig::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_config(db_url: String, config: AnomalyDetectionConfig) -> Self {
|
||||
Self {
|
||||
_db_url: db_url,
|
||||
config,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn generate_key(&self, key_type: ApiKeyType) -> (String, String, String) {
|
||||
let uuid = Uuid::new_v4().to_string().replace("-", "");
|
||||
let timestamp = Utc::now().timestamp();
|
||||
let random_part = Uuid::new_v4().to_string().replace("-", "")[..8].to_string();
|
||||
|
||||
let key = format!(
|
||||
"{}{}_{}_{}",
|
||||
key_type.prefix(),
|
||||
uuid,
|
||||
timestamp,
|
||||
random_part
|
||||
);
|
||||
let hash = self.hash_key(&key);
|
||||
|
||||
(key, hash, uuid)
|
||||
}
|
||||
|
||||
pub fn hash_key(&self, key: &str) -> String {
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(key.as_bytes());
|
||||
format!("{:x}", hasher.finalize())
|
||||
}
|
||||
|
||||
/// Constant-time comparison of two hash strings
|
||||
///
|
||||
/// This prevents timing attacks when comparing sensitive data like
|
||||
/// API key hashes. Use this instead of `==` for security-critical comparisons.
|
||||
pub fn constant_time_compare(a: &str, b: &str) -> bool {
|
||||
if a.len() != b.len() {
|
||||
return false;
|
||||
}
|
||||
a.as_bytes().ct_eq(b.as_bytes()).into()
|
||||
}
|
||||
|
||||
pub fn create_key(&self, request: CreateApiKeyRequest) -> Result<CreateApiKeyResponse> {
|
||||
let ttl_days = request
|
||||
.ttl_days
|
||||
.unwrap_or(request.key_type.default_ttl_days());
|
||||
let expires_at = Utc::now() + Duration::days(ttl_days);
|
||||
|
||||
let (key, _key_hash, _) = self.generate_key(request.key_type);
|
||||
|
||||
let warning = if request.key_type == ApiKeyType::Emergency {
|
||||
"警告:緊急 Key 將在 24 小時後自動過期,請及時更新".to_string()
|
||||
} else if ttl_days < 30 {
|
||||
format!("警告:Key 有效期僅 {} 天,建議使用更長的有效期", ttl_days)
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
Ok(CreateApiKeyResponse {
|
||||
key_id: self.extract_key_id(&key),
|
||||
key,
|
||||
expires_at,
|
||||
warning,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn extract_key_id(&self, key: &str) -> String {
|
||||
let parts: Vec<&str> = key.split('_').collect();
|
||||
if parts.len() >= 2 {
|
||||
format!("{}_{}", parts[0], parts[1])
|
||||
} else {
|
||||
key[..16.min(key.len())].to_string()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn validate_key(&self, request: ValidateApiKeyRequest) -> Result<ValidateApiKeyResponse> {
|
||||
let _key_hash = self.hash_key(&request.key);
|
||||
|
||||
Ok(ValidateApiKeyResponse {
|
||||
valid: true,
|
||||
key_id: Some(self.extract_key_id(&request.key)),
|
||||
permissions: Some(vec!["read".to_string(), "write".to_string()]),
|
||||
error: None,
|
||||
requires_rotation: false,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn require_rotation(&self, key_id: &str, reason: &str) -> Result<()> {
|
||||
tracing::info!("API Key {} requires rotation: {}", key_id, reason);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn check_anomaly(&self, key_id: &str, metrics: &AnomalyMetrics) -> Option<AnomalyRecord> {
|
||||
if metrics.requests_per_minute > self.config.requests_per_minute_threshold {
|
||||
return Some(AnomalyRecord {
|
||||
id: 0,
|
||||
key_id: key_id.to_string(),
|
||||
anomaly_type: AnomalyType::HighRequestRate,
|
||||
severity: AnomalySeverity::Medium,
|
||||
ip_address: metrics.last_ip.clone(),
|
||||
request_count: Some(metrics.requests_per_minute),
|
||||
error_count: Some(metrics.error_count),
|
||||
error_rate: Some(metrics.error_rate),
|
||||
unique_ips: Some(metrics.unique_ips),
|
||||
details: None,
|
||||
resolved: false,
|
||||
resolved_at: None,
|
||||
resolved_by: None,
|
||||
created_at: Utc::now(),
|
||||
});
|
||||
}
|
||||
|
||||
if metrics.error_rate > self.config.error_rate_threshold {
|
||||
return Some(AnomalyRecord {
|
||||
id: 0,
|
||||
key_id: key_id.to_string(),
|
||||
anomaly_type: AnomalyType::HighErrorRate,
|
||||
severity: AnomalySeverity::Medium,
|
||||
ip_address: metrics.last_ip.clone(),
|
||||
request_count: Some(metrics.requests_per_minute),
|
||||
error_count: Some(metrics.error_count),
|
||||
error_rate: Some(metrics.error_rate),
|
||||
unique_ips: Some(metrics.unique_ips),
|
||||
details: None,
|
||||
resolved: false,
|
||||
resolved_at: None,
|
||||
resolved_by: None,
|
||||
created_at: Utc::now(),
|
||||
});
|
||||
}
|
||||
|
||||
if metrics.unique_ips > self.config.unique_ips_per_hour_threshold {
|
||||
return Some(AnomalyRecord {
|
||||
id: 0,
|
||||
key_id: key_id.to_string(),
|
||||
anomaly_type: AnomalyType::MultipleIps,
|
||||
severity: AnomalySeverity::Low,
|
||||
ip_address: None,
|
||||
request_count: None,
|
||||
error_count: None,
|
||||
error_rate: None,
|
||||
unique_ips: Some(metrics.unique_ips),
|
||||
details: None,
|
||||
resolved: false,
|
||||
resolved_at: None,
|
||||
resolved_by: None,
|
||||
created_at: Utc::now(),
|
||||
});
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
pub fn calculate_grace_period_end(&self, key_type: ApiKeyType) -> DateTime<Utc> {
|
||||
Utc::now() + Duration::hours(key_type.grace_period_hours())
|
||||
}
|
||||
|
||||
pub fn is_in_grace_period(&self, grace_period_end: Option<DateTime<Utc>>) -> bool {
|
||||
match grace_period_end {
|
||||
Some(end) => Utc::now() < end,
|
||||
None => false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_key_prefix(&self, key_type: ApiKeyType) -> &'static str {
|
||||
key_type.prefix()
|
||||
}
|
||||
|
||||
pub fn get_config(&self) -> &AnomalyDetectionConfig {
|
||||
&self.config
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AnomalyMetrics {
|
||||
pub requests_per_minute: i32,
|
||||
pub error_count: i32,
|
||||
pub error_rate: f64,
|
||||
pub unique_ips: i32,
|
||||
pub last_ip: Option<String>,
|
||||
}
|
||||
|
||||
impl Default for AnomalyMetrics {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
requests_per_minute: 0,
|
||||
error_count: 0,
|
||||
error_rate: 0.0,
|
||||
unique_ips: 0,
|
||||
last_ip: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_generate_key() {
|
||||
let service = ApiKeyService::new("postgres://localhost".to_string());
|
||||
|
||||
let (key, hash, _) = service.generate_key(ApiKeyType::User);
|
||||
|
||||
assert!(key.starts_with("muser_"));
|
||||
assert_eq!(hash.len(), 64);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hash_key() {
|
||||
let service = ApiKeyService::new("postgres://localhost".to_string());
|
||||
|
||||
let hash = service.hash_key("test_key");
|
||||
|
||||
assert_eq!(hash.len(), 64);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_key_id() {
|
||||
let service = ApiKeyService::new("postgres://localhost".to_string());
|
||||
|
||||
let key_id = service.extract_key_id("muser_a1b2c3d4_1710998400_abc12345");
|
||||
|
||||
assert_eq!(key_id, "muser_a1b2c3d4");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_grace_period() {
|
||||
let service = ApiKeyService::new("postgres://localhost".to_string());
|
||||
|
||||
let end = service.calculate_grace_period_end(ApiKeyType::User);
|
||||
|
||||
let hours = (end - Utc::now()).num_hours();
|
||||
assert!(
|
||||
hours >= 23 && hours <= 24,
|
||||
"expected 23-24 hours, got {}",
|
||||
hours
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_constant_time_compare() {
|
||||
let a = "abcdef1234567890";
|
||||
let b = "abcdef1234567890";
|
||||
let c = "abcdef1234567891";
|
||||
let d = "short";
|
||||
|
||||
assert!(ApiKeyService::constant_time_compare(a, b));
|
||||
assert!(!ApiKeyService::constant_time_compare(a, c));
|
||||
assert!(!ApiKeyService::constant_time_compare(a, d));
|
||||
assert!(!ApiKeyService::constant_time_compare(d, a));
|
||||
}
|
||||
}
|
||||
209
src/core/api_key/strength.rs
Normal file
209
src/core/api_key/strength.rs
Normal file
@@ -0,0 +1,209 @@
|
||||
//! API Key Strength Validation
|
||||
//!
|
||||
//! Validates that API keys meet security requirements
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Key strength level
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum KeyStrength {
|
||||
Weak,
|
||||
Medium,
|
||||
Strong,
|
||||
VeryStrong,
|
||||
}
|
||||
|
||||
impl KeyStrength {
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
KeyStrength::Weak => "weak",
|
||||
KeyStrength::Medium => "medium",
|
||||
KeyStrength::Strong => "strong",
|
||||
KeyStrength::VeryStrong => "very_strong",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_acceptable(&self) -> bool {
|
||||
!matches!(self, KeyStrength::Weak)
|
||||
}
|
||||
}
|
||||
|
||||
/// Key strength validation result
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct StrengthResult {
|
||||
pub strength: KeyStrength,
|
||||
pub score: u32,
|
||||
pub max_score: u32,
|
||||
pub issues: Vec<String>,
|
||||
pub suggestions: Vec<String>,
|
||||
}
|
||||
|
||||
/// Key strength validator
|
||||
pub struct KeyStrengthValidator {
|
||||
min_length: usize,
|
||||
require_prefix: bool,
|
||||
}
|
||||
|
||||
impl Default for KeyStrengthValidator {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
min_length: 32,
|
||||
require_prefix: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl KeyStrengthValidator {
|
||||
pub fn new(min_length: usize, require_prefix: bool) -> Self {
|
||||
Self {
|
||||
min_length,
|
||||
require_prefix,
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate key strength
|
||||
pub fn validate(&self, key: &str) -> StrengthResult {
|
||||
let mut score: u32 = 0;
|
||||
let mut issues = Vec::new();
|
||||
let mut suggestions = Vec::new();
|
||||
|
||||
// Check length
|
||||
if key.len() >= self.min_length {
|
||||
score += 25;
|
||||
} else {
|
||||
issues.push(format!(
|
||||
"Key length {} is less than minimum {}",
|
||||
key.len(),
|
||||
self.min_length
|
||||
));
|
||||
suggestions.push(format!("Use at least {} characters", self.min_length));
|
||||
}
|
||||
|
||||
// Check for valid prefix
|
||||
let valid_prefixes = ["msys_", "muser_", "msvc_", "mint_", "memg_"];
|
||||
let has_valid_prefix = valid_prefixes.iter().any(|p| key.starts_with(p));
|
||||
|
||||
if has_valid_prefix {
|
||||
score += 25;
|
||||
} else if self.require_prefix {
|
||||
issues.push("Key does not have a valid prefix".to_string());
|
||||
suggestions.push("Use a valid prefix: msys_, muser_, msvc_, mint_, memg_".to_string());
|
||||
}
|
||||
|
||||
// Check entropy (character variety)
|
||||
let has_lowercase = key.chars().any(|c| c.is_ascii_lowercase());
|
||||
let has_uppercase = key.chars().any(|c| c.is_ascii_uppercase());
|
||||
let has_digit = key.chars().any(|c| c.is_ascii_digit());
|
||||
let has_special = key.chars().any(|c| !c.is_ascii_alphanumeric());
|
||||
|
||||
let entropy_count = [has_lowercase, has_uppercase, has_digit, has_special]
|
||||
.iter()
|
||||
.filter(|&&x| x)
|
||||
.count();
|
||||
|
||||
score += (entropy_count as u32) * 12;
|
||||
|
||||
if entropy_count < 2 {
|
||||
issues.push("Low character variety".to_string());
|
||||
suggestions
|
||||
.push("Include lowercase, uppercase, digits, and special characters".to_string());
|
||||
}
|
||||
|
||||
// Check for sequential characters
|
||||
let has_sequential = key
|
||||
.as_bytes()
|
||||
.windows(3)
|
||||
.any(|w| w[1] == w[0] + 1 && w[2] == w[1] + 1);
|
||||
|
||||
if has_sequential {
|
||||
score = score.saturating_sub(10);
|
||||
issues.push("Contains sequential characters".to_string());
|
||||
}
|
||||
|
||||
// Check for repeated characters
|
||||
let has_repeated = key
|
||||
.as_bytes()
|
||||
.windows(3)
|
||||
.any(|w| w[0] == w[1] && w[1] == w[2]);
|
||||
|
||||
if has_repeated {
|
||||
score = score.saturating_sub(10);
|
||||
issues.push("Contains repeated characters".to_string());
|
||||
}
|
||||
|
||||
// Determine strength level
|
||||
let strength = if score >= 80 {
|
||||
KeyStrength::VeryStrong
|
||||
} else if score >= 60 {
|
||||
KeyStrength::Strong
|
||||
} else if score >= 40 {
|
||||
KeyStrength::Medium
|
||||
} else {
|
||||
KeyStrength::Weak
|
||||
};
|
||||
|
||||
StrengthResult {
|
||||
strength,
|
||||
score: score.min(100),
|
||||
max_score: 100,
|
||||
issues,
|
||||
suggestions,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if key is acceptable
|
||||
pub fn is_acceptable(&self, key: &str) -> bool {
|
||||
self.validate(key).strength.is_acceptable()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_strong_key() {
|
||||
let validator = KeyStrengthValidator::default();
|
||||
let result = validator.validate("msvc_a1B2c3D4e5F6g7H8i9J0k1L2m3N4o5P6");
|
||||
|
||||
assert!(result.strength.is_acceptable());
|
||||
assert!(result.score >= 60);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_weak_key() {
|
||||
let validator = KeyStrengthValidator::default();
|
||||
let result = validator.validate("short");
|
||||
|
||||
assert_eq!(result.strength, KeyStrength::Weak);
|
||||
assert!(!result.issues.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sequential_penalty() {
|
||||
let validator = KeyStrengthValidator::default();
|
||||
let result_with = validator.validate("msvc_abc123def456ghi789jkl012mno345");
|
||||
let result_without = validator.validate("msvc_a1b2c3d4e5f6g7h8i9j0k1l2m3n4o5");
|
||||
|
||||
// Sequential characters should reduce score
|
||||
assert!(
|
||||
result_without.score >= result_with.score
|
||||
|| result_with.issues.len() <= result_without.issues.len()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_key_strength_serialization() {
|
||||
let result = StrengthResult {
|
||||
strength: KeyStrength::Strong,
|
||||
score: 75,
|
||||
max_score: 100,
|
||||
issues: vec![],
|
||||
suggestions: vec![],
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&result).unwrap();
|
||||
assert!(json.contains("\"strong\""));
|
||||
}
|
||||
}
|
||||
310
src/core/api_key/validator.rs
Normal file
310
src/core/api_key/validator.rs
Normal file
@@ -0,0 +1,310 @@
|
||||
//! API Key Validation with Cache and Rate Limiting
|
||||
//!
|
||||
//! Provides cached validation and rate limiting for API keys
|
||||
|
||||
use crate::core::db::postgres_db::ApiKeyRecord;
|
||||
use crate::core::db::PostgresDb;
|
||||
use anyhow::Result;
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use moka::future::Cache;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration as StdDuration;
|
||||
|
||||
/// Cached API key record
|
||||
#[derive(Clone)]
|
||||
pub struct CachedApiKey {
|
||||
pub record: ApiKeyRecord,
|
||||
pub cached_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
/// Rate limit result
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum RateLimitResult {
|
||||
/// Request is allowed
|
||||
Allowed,
|
||||
/// Request is allowed but with warning
|
||||
AllowedWithWarning { remaining_attempts: u32 },
|
||||
/// Request is locked
|
||||
Locked {
|
||||
remaining_seconds: i64,
|
||||
attempts: u32,
|
||||
},
|
||||
}
|
||||
|
||||
/// Attempt tracking info
|
||||
#[derive(Clone)]
|
||||
struct AttemptInfo {
|
||||
count: u32,
|
||||
first_attempt: DateTime<Utc>,
|
||||
last_attempt: DateTime<Utc>,
|
||||
locked_until: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
/// API Key Validator with caching and rate limiting
|
||||
pub struct ApiKeyValidator {
|
||||
db: Arc<PostgresDb>,
|
||||
cache: Cache<String, CachedApiKey>,
|
||||
rate_limiter: Cache<String, AttemptInfo>,
|
||||
max_attempts: u32,
|
||||
lockout_duration: Duration,
|
||||
}
|
||||
|
||||
/// Configuration for ApiKeyValidator
|
||||
pub struct ValidatorConfig {
|
||||
pub cache_ttl_secs: u64,
|
||||
pub cache_max_capacity: u64,
|
||||
pub max_attempts: u32,
|
||||
pub lockout_duration_secs: i64,
|
||||
}
|
||||
|
||||
impl Default for ValidatorConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
cache_ttl_secs: std::env::var("CACHE_TTL_SECONDS")
|
||||
.unwrap_or_else(|_| "300".to_string())
|
||||
.parse()
|
||||
.unwrap_or(300),
|
||||
cache_max_capacity: std::env::var("CACHE_MAX_CAPACITY")
|
||||
.unwrap_or_else(|_| "10000".to_string())
|
||||
.parse()
|
||||
.unwrap_or(10000),
|
||||
max_attempts: std::env::var("RATE_LIMIT_MAX_ATTEMPTS")
|
||||
.unwrap_or_else(|_| "5".to_string())
|
||||
.parse()
|
||||
.unwrap_or(5),
|
||||
lockout_duration_secs: std::env::var("RATE_LIMIT_WINDOW_SECONDS")
|
||||
.unwrap_or_else(|_| "900".to_string())
|
||||
.parse()
|
||||
.unwrap_or(900),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ApiKeyValidator {
|
||||
pub fn new(db: PostgresDb, config: ValidatorConfig) -> Self {
|
||||
Self {
|
||||
db: Arc::new(db),
|
||||
cache: Cache::builder()
|
||||
.time_to_live(StdDuration::from_secs(config.cache_ttl_secs))
|
||||
.time_to_idle(StdDuration::from_secs(config.cache_ttl_secs * 2))
|
||||
.max_capacity(config.cache_max_capacity)
|
||||
.build(),
|
||||
rate_limiter: Cache::builder()
|
||||
.time_to_live(StdDuration::from_secs(
|
||||
config.lockout_duration_secs as u64 * 2,
|
||||
))
|
||||
.max_capacity(10000)
|
||||
.build(),
|
||||
max_attempts: config.max_attempts,
|
||||
lockout_duration: Duration::seconds(config.lockout_duration_secs),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_default_config(db: PostgresDb) -> Self {
|
||||
Self::new(db, ValidatorConfig::default())
|
||||
}
|
||||
|
||||
/// Check rate limit for an IP
|
||||
pub async fn check_rate_limit(&self, ip: &str) -> RateLimitResult {
|
||||
match self.rate_limiter.get(ip).await {
|
||||
None => RateLimitResult::Allowed,
|
||||
Some(info) => {
|
||||
if let Some(locked_until) = info.locked_until {
|
||||
let remaining = locked_until - Utc::now();
|
||||
if remaining.num_seconds() > 0 {
|
||||
return RateLimitResult::Locked {
|
||||
remaining_seconds: remaining.num_seconds(),
|
||||
attempts: info.count,
|
||||
};
|
||||
}
|
||||
}
|
||||
if info.count >= self.max_attempts / 2 {
|
||||
RateLimitResult::AllowedWithWarning {
|
||||
remaining_attempts: self.max_attempts - info.count,
|
||||
}
|
||||
} else {
|
||||
RateLimitResult::Allowed
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Record a failed attempt
|
||||
pub async fn record_failure(&self, ip: &str) -> RateLimitResult {
|
||||
let mut info = self.rate_limiter.get(ip).await.unwrap_or(AttemptInfo {
|
||||
count: 0,
|
||||
first_attempt: Utc::now(),
|
||||
last_attempt: Utc::now(),
|
||||
locked_until: None,
|
||||
});
|
||||
|
||||
info.count += 1;
|
||||
info.last_attempt = Utc::now();
|
||||
|
||||
if info.count >= self.max_attempts {
|
||||
info.locked_until = Some(Utc::now() + self.lockout_duration);
|
||||
self.rate_limiter.insert(ip.to_string(), info.clone()).await;
|
||||
|
||||
tracing::warn!("IP {} locked due to {} failed attempts", ip, info.count);
|
||||
|
||||
RateLimitResult::Locked {
|
||||
remaining_seconds: self.lockout_duration.num_seconds(),
|
||||
attempts: info.count,
|
||||
}
|
||||
} else {
|
||||
let remaining = self.max_attempts - info.count;
|
||||
self.rate_limiter.insert(ip.to_string(), info).await;
|
||||
|
||||
RateLimitResult::AllowedWithWarning {
|
||||
remaining_attempts: remaining,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Record a successful validation (clear rate limit)
|
||||
pub async fn record_success(&self, ip: &str) {
|
||||
self.rate_limiter.invalidate(ip).await;
|
||||
}
|
||||
|
||||
/// Manually unlock an IP
|
||||
pub async fn unlock_ip(&self, ip: &str) {
|
||||
self.rate_limiter.invalidate(ip).await;
|
||||
tracing::info!("Manually unlocked IP: {}", ip);
|
||||
}
|
||||
|
||||
/// Validate an API key with caching
|
||||
pub async fn validate(&self, key_hash: &str) -> Result<Option<ApiKeyRecord>> {
|
||||
// 1. Check cache
|
||||
if let Some(cached) = self.cache.get(key_hash).await {
|
||||
// Check if expired
|
||||
if let Some(expires_at) = cached.record.expires_at {
|
||||
if Utc::now() > expires_at {
|
||||
self.cache.invalidate(key_hash).await;
|
||||
return Ok(None);
|
||||
}
|
||||
}
|
||||
|
||||
// Check if revoked
|
||||
if cached.record.status == "revoked" || cached.record.status == "suspended" {
|
||||
self.cache.invalidate(key_hash).await;
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
return Ok(Some(cached.record));
|
||||
}
|
||||
|
||||
// 2. Query database
|
||||
let record = self.db.get_api_key_by_hash(key_hash).await?;
|
||||
|
||||
// 3. Cache if valid
|
||||
if let Some(ref r) = record {
|
||||
// Only cache active keys
|
||||
if r.status == "active" {
|
||||
self.cache
|
||||
.insert(
|
||||
key_hash.to_string(),
|
||||
CachedApiKey {
|
||||
record: r.clone(),
|
||||
cached_at: Utc::now(),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(record)
|
||||
}
|
||||
|
||||
/// Invalidate cache for a specific key
|
||||
pub async fn invalidate(&self, key_hash: &str) {
|
||||
self.cache.invalidate(key_hash).await;
|
||||
}
|
||||
|
||||
/// Invalidate all cached keys
|
||||
pub fn invalidate_all(&self) {
|
||||
self.cache.invalidate_all();
|
||||
}
|
||||
|
||||
/// Get cache statistics
|
||||
pub async fn cache_stats(&self) -> CacheStats {
|
||||
CacheStats {
|
||||
entry_count: self.cache.entry_count(),
|
||||
weighted_size: self.cache.weighted_size(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get rate limiter statistics for an IP
|
||||
pub async fn rate_limit_stats(&self, ip: &str) -> Option<RateLimitStats> {
|
||||
self.rate_limiter.get(ip).await.map(|info| RateLimitStats {
|
||||
attempts: info.count,
|
||||
first_attempt: info.first_attempt,
|
||||
last_attempt: info.last_attempt,
|
||||
locked_until: info.locked_until,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Cache statistics
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CacheStats {
|
||||
pub entry_count: u64,
|
||||
pub weighted_size: u64,
|
||||
}
|
||||
|
||||
/// Rate limit statistics for an IP
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RateLimitStats {
|
||||
pub attempts: u32,
|
||||
pub first_attempt: DateTime<Utc>,
|
||||
pub last_attempt: DateTime<Utc>,
|
||||
pub locked_until: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_validator_config_default() {
|
||||
let config = ValidatorConfig::default();
|
||||
assert!(config.cache_ttl_secs > 0);
|
||||
assert!(config.cache_max_capacity > 0);
|
||||
assert!(config.max_attempts > 0);
|
||||
assert!(config.lockout_duration_secs > 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rate_limit_result_variants() {
|
||||
let allowed = RateLimitResult::Allowed;
|
||||
let warning = RateLimitResult::AllowedWithWarning {
|
||||
remaining_attempts: 3,
|
||||
};
|
||||
let locked = RateLimitResult::Locked {
|
||||
remaining_seconds: 60,
|
||||
attempts: 5,
|
||||
};
|
||||
|
||||
match allowed {
|
||||
RateLimitResult::Allowed => assert!(true),
|
||||
_ => assert!(false),
|
||||
}
|
||||
|
||||
match warning {
|
||||
RateLimitResult::AllowedWithWarning { remaining_attempts } => {
|
||||
assert_eq!(remaining_attempts, 3)
|
||||
}
|
||||
_ => assert!(false),
|
||||
}
|
||||
|
||||
match locked {
|
||||
RateLimitResult::Locked {
|
||||
remaining_seconds,
|
||||
attempts,
|
||||
} => {
|
||||
assert_eq!(remaining_seconds, 60);
|
||||
assert_eq!(attempts, 5);
|
||||
}
|
||||
_ => assert!(false),
|
||||
}
|
||||
}
|
||||
}
|
||||
311
src/core/api_key/webhook.rs
Normal file
311
src/core/api_key/webhook.rs
Normal file
@@ -0,0 +1,311 @@
|
||||
//! Webhook Notification Module
|
||||
//!
|
||||
//! Sends notifications via webhooks when API key events occur
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use chrono::Utc;
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Webhook event types
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum WebhookEvent {
|
||||
KeyCreated,
|
||||
KeyRevoked,
|
||||
KeyExpired,
|
||||
KeyRotated,
|
||||
AnomalyDetected,
|
||||
RateLimited,
|
||||
IpBlocked,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for WebhookEvent {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
WebhookEvent::KeyCreated => write!(f, "key_created"),
|
||||
WebhookEvent::KeyRevoked => write!(f, "key_revoked"),
|
||||
WebhookEvent::KeyExpired => write!(f, "key_expired"),
|
||||
WebhookEvent::KeyRotated => write!(f, "key_rotated"),
|
||||
WebhookEvent::AnomalyDetected => write!(f, "anomaly_detected"),
|
||||
WebhookEvent::RateLimited => write!(f, "rate_limited"),
|
||||
WebhookEvent::IpBlocked => write!(f, "ip_blocked"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Webhook payload
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct WebhookPayload {
|
||||
pub event: String,
|
||||
pub timestamp: String,
|
||||
pub data: serde_json::Value,
|
||||
}
|
||||
|
||||
/// Webhook configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct WebhookConfig {
|
||||
pub url: String,
|
||||
pub secret: String,
|
||||
pub events: Vec<WebhookEvent>,
|
||||
pub enabled: bool,
|
||||
pub retry_count: u32,
|
||||
pub timeout_secs: u64,
|
||||
}
|
||||
|
||||
impl Default for WebhookConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
url: String::new(),
|
||||
secret: String::new(),
|
||||
events: vec![
|
||||
WebhookEvent::KeyCreated,
|
||||
WebhookEvent::KeyRevoked,
|
||||
WebhookEvent::AnomalyDetected,
|
||||
],
|
||||
enabled: false,
|
||||
retry_count: 3,
|
||||
timeout_secs: 30,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Webhook notifier
|
||||
pub struct WebhookNotifier {
|
||||
client: Client,
|
||||
config: WebhookConfig,
|
||||
}
|
||||
|
||||
impl WebhookNotifier {
|
||||
pub fn new(config: WebhookConfig) -> Result<Self> {
|
||||
let client = Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(config.timeout_secs))
|
||||
.build()
|
||||
.context("Failed to create HTTP client")?;
|
||||
|
||||
Ok(Self { client, config })
|
||||
}
|
||||
|
||||
pub fn from_env() -> Result<Option<Self>> {
|
||||
let url = match std::env::var("WEBHOOK_URL") {
|
||||
Ok(url) if !url.is_empty() => url,
|
||||
_ => return Ok(None),
|
||||
};
|
||||
|
||||
let secret = std::env::var("WEBHOOK_SECRET").unwrap_or_default();
|
||||
|
||||
let events = std::env::var("WEBHOOK_EVENTS")
|
||||
.unwrap_or_default()
|
||||
.split(',')
|
||||
.filter_map(|s| match s.trim() {
|
||||
"key_created" => Some(WebhookEvent::KeyCreated),
|
||||
"key_revoked" => Some(WebhookEvent::KeyRevoked),
|
||||
"key_expired" => Some(WebhookEvent::KeyExpired),
|
||||
"key_rotated" => Some(WebhookEvent::KeyRotated),
|
||||
"anomaly_detected" => Some(WebhookEvent::AnomalyDetected),
|
||||
"rate_limited" => Some(WebhookEvent::RateLimited),
|
||||
"ip_blocked" => Some(WebhookEvent::IpBlocked),
|
||||
_ => None,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let events = if events.is_empty() {
|
||||
vec![
|
||||
WebhookEvent::KeyCreated,
|
||||
WebhookEvent::KeyRevoked,
|
||||
WebhookEvent::AnomalyDetected,
|
||||
]
|
||||
} else {
|
||||
events
|
||||
};
|
||||
|
||||
Ok(Some(Self::new(WebhookConfig {
|
||||
url,
|
||||
secret,
|
||||
events,
|
||||
enabled: true,
|
||||
retry_count: 3,
|
||||
timeout_secs: 30,
|
||||
})?))
|
||||
}
|
||||
|
||||
/// Check if an event should be sent
|
||||
pub fn should_send(&self, event: &WebhookEvent) -> bool {
|
||||
self.config.enabled && self.config.events.contains(event)
|
||||
}
|
||||
|
||||
/// Generate HMAC signature for payload
|
||||
fn sign(&self, payload: &str) -> String {
|
||||
use sha2::{Digest, Sha256};
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(self.config.secret.as_bytes());
|
||||
hasher.update(payload.as_bytes());
|
||||
format!("{:x}", hasher.finalize())
|
||||
}
|
||||
|
||||
/// Send a webhook notification
|
||||
pub async fn notify(&self, event: WebhookEvent, data: serde_json::Value) -> Result<bool> {
|
||||
if !self.should_send(&event) {
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
let payload = WebhookPayload {
|
||||
event: event.to_string(),
|
||||
timestamp: Utc::now().to_rfc3339(),
|
||||
data,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&payload)?;
|
||||
let signature = self.sign(&json);
|
||||
|
||||
let mut attempts = 0;
|
||||
let max_attempts = self.config.retry_count;
|
||||
|
||||
while attempts < max_attempts {
|
||||
attempts += 1;
|
||||
|
||||
let result = self
|
||||
.client
|
||||
.post(&self.config.url)
|
||||
.header("Content-Type", "application/json")
|
||||
.header("X-Webhook-Signature", &signature)
|
||||
.header("X-Webhook-Event", event.to_string())
|
||||
.body(json.clone())
|
||||
.send()
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(response) if response.status().is_success() => {
|
||||
tracing::info!("Webhook sent successfully: {:?}", event);
|
||||
return Ok(true);
|
||||
}
|
||||
Ok(response) => {
|
||||
tracing::warn!(
|
||||
"Webhook failed with status {}: {:?}",
|
||||
response.status(),
|
||||
event
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("Webhook error (attempt {}): {}", attempts, e);
|
||||
}
|
||||
}
|
||||
|
||||
if attempts < max_attempts {
|
||||
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
||||
}
|
||||
}
|
||||
|
||||
tracing::error!(
|
||||
"Webhook failed after {} attempts: {:?}",
|
||||
max_attempts,
|
||||
event
|
||||
);
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
/// Notify key created
|
||||
pub async fn notify_key_created(&self, key_id: &str, name: &str) -> Result<bool> {
|
||||
self.notify(
|
||||
WebhookEvent::KeyCreated,
|
||||
serde_json::json!({
|
||||
"key_id": key_id,
|
||||
"name": name,
|
||||
}),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Notify key revoked
|
||||
pub async fn notify_key_revoked(&self, key_id: &str, reason: &str) -> Result<bool> {
|
||||
self.notify(
|
||||
WebhookEvent::KeyRevoked,
|
||||
serde_json::json!({
|
||||
"key_id": key_id,
|
||||
"reason": reason,
|
||||
}),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Notify anomaly detected
|
||||
pub async fn notify_anomaly(
|
||||
&self,
|
||||
key_id: &str,
|
||||
anomaly_type: &str,
|
||||
severity: &str,
|
||||
) -> Result<bool> {
|
||||
self.notify(
|
||||
WebhookEvent::AnomalyDetected,
|
||||
serde_json::json!({
|
||||
"key_id": key_id,
|
||||
"anomaly_type": anomaly_type,
|
||||
"severity": severity,
|
||||
}),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Notify IP blocked
|
||||
pub async fn notify_ip_blocked(&self, ip: &str, reason: &str) -> Result<bool> {
|
||||
self.notify(
|
||||
WebhookEvent::IpBlocked,
|
||||
serde_json::json!({
|
||||
"ip": ip,
|
||||
"reason": reason,
|
||||
}),
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_webhook_config_default() {
|
||||
let config = WebhookConfig::default();
|
||||
assert!(!config.enabled);
|
||||
assert_eq!(config.retry_count, 3);
|
||||
assert_eq!(config.timeout_secs, 30);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_should_send() {
|
||||
let config = WebhookConfig {
|
||||
url: "https://example.com/webhook".to_string(),
|
||||
secret: "secret".to_string(),
|
||||
events: vec![WebhookEvent::KeyCreated, WebhookEvent::AnomalyDetected],
|
||||
enabled: true,
|
||||
retry_count: 3,
|
||||
timeout_secs: 30,
|
||||
};
|
||||
|
||||
let notifier = WebhookNotifier::new(config).unwrap();
|
||||
|
||||
assert!(notifier.should_send(&WebhookEvent::KeyCreated));
|
||||
assert!(notifier.should_send(&WebhookEvent::AnomalyDetected));
|
||||
assert!(!notifier.should_send(&WebhookEvent::KeyRevoked));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sign() {
|
||||
let config = WebhookConfig {
|
||||
url: "https://example.com/webhook".to_string(),
|
||||
secret: "mysecret".to_string(),
|
||||
events: vec![],
|
||||
enabled: true,
|
||||
retry_count: 3,
|
||||
timeout_secs: 30,
|
||||
};
|
||||
|
||||
let notifier = WebhookNotifier::new(config).unwrap();
|
||||
let sig1 = notifier.sign("test payload");
|
||||
let sig2 = notifier.sign("test payload");
|
||||
let sig3 = notifier.sign("different payload");
|
||||
|
||||
assert_eq!(sig1, sig2);
|
||||
assert_ne!(sig1, sig3);
|
||||
}
|
||||
}
|
||||
85
src/core/cache/keys.rs
vendored
Normal file
85
src/core/cache/keys.rs
vendored
Normal file
@@ -0,0 +1,85 @@
|
||||
pub const CATEGORY_VIDEOS: &str = "videos";
|
||||
pub const CATEGORY_SEARCH: &str = "search";
|
||||
pub const CATEGORY_HYBRID_SEARCH: &str = "hybrid_search";
|
||||
pub const CATEGORY_N8N_SEARCH: &str = "n8n_search";
|
||||
pub const CATEGORY_VIDEO_META: &str = "video_meta";
|
||||
pub const CATEGORY_HEALTH: &str = "health";
|
||||
|
||||
pub const KEY_PREFIX_VIDEOS_LIST: &str = "videos:list:";
|
||||
pub const KEY_PREFIX_VIDEO: &str = "video:";
|
||||
pub const KEY_PREFIX_SEARCH: &str = "search:";
|
||||
pub const KEY_PREFIX_SEARCH_HYBRID: &str = "search:hybrid:";
|
||||
pub const KEY_PREFIX_SEARCH_N8N: &str = "search:n8n:";
|
||||
pub const KEY_HEALTH: &str = "health:basic";
|
||||
|
||||
pub fn videos_list(page: usize, limit: usize) -> String {
|
||||
format!("{}page={}:limit={}", KEY_PREFIX_VIDEOS_LIST, page, limit)
|
||||
}
|
||||
|
||||
pub fn video_meta(uuid: &str) -> String {
|
||||
format!("{}{}", KEY_PREFIX_VIDEO, uuid)
|
||||
}
|
||||
|
||||
pub fn search(query_hash: &str) -> String {
|
||||
format!("{}{}", KEY_PREFIX_SEARCH, query_hash)
|
||||
}
|
||||
|
||||
pub fn hybrid_search(query_hash: &str) -> String {
|
||||
format!("{}{}", KEY_PREFIX_SEARCH_HYBRID, query_hash)
|
||||
}
|
||||
|
||||
pub fn n8n_search(query_hash: &str) -> String {
|
||||
format!("{}{}", KEY_PREFIX_SEARCH_N8N, query_hash)
|
||||
}
|
||||
|
||||
pub fn health() -> String {
|
||||
KEY_HEALTH.to_string()
|
||||
}
|
||||
|
||||
pub fn videos_list_prefix() -> String {
|
||||
format!("^{}", KEY_PREFIX_VIDEOS_LIST)
|
||||
}
|
||||
|
||||
pub fn video_prefix(uuid: &str) -> String {
|
||||
format!("^{}{}", KEY_PREFIX_VIDEO, uuid)
|
||||
}
|
||||
|
||||
pub fn search_prefix() -> String {
|
||||
format!("^{}", KEY_PREFIX_SEARCH)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_videos_list() {
|
||||
assert_eq!(videos_list(1, 20), "videos:list:page=1:limit=20");
|
||||
assert_eq!(videos_list(2, 50), "videos:list:page=2:limit=50");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_video_meta() {
|
||||
assert_eq!(video_meta("abc123"), "video:abc123");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_search() {
|
||||
assert_eq!(search("hash123"), "search:hash123");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hybrid_search() {
|
||||
assert_eq!(hybrid_search("hash123"), "search:hybrid:hash123");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_n8n_search() {
|
||||
assert_eq!(n8n_search("hash123"), "search:n8n:hash123");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_health() {
|
||||
assert_eq!(health(), "health:basic");
|
||||
}
|
||||
}
|
||||
10
src/core/cache/mod.rs
vendored
Normal file
10
src/core/cache/mod.rs
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
pub mod keys;
|
||||
pub mod mongo_cache;
|
||||
pub mod redis_cache;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
pub use keys::*;
|
||||
pub use mongo_cache::MongoCache;
|
||||
pub use redis_cache::RedisCache;
|
||||
311
src/core/cache/mongo_cache.rs
vendored
Normal file
311
src/core/cache/mongo_cache.rs
vendored
Normal file
@@ -0,0 +1,311 @@
|
||||
use anyhow::{Context, Result};
|
||||
use bson::{doc, oid::ObjectId, DateTime as BsonDateTime, Document};
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use mongodb::{Client, Collection, Database, IndexModel};
|
||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
use super::keys;
|
||||
use crate::core::config::cache as cache_config;
|
||||
|
||||
const DB_NAME: &str = "momento";
|
||||
const COLLECTION_NAME: &str = "cache";
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CacheEntry {
|
||||
#[serde(rename = "_id", skip_serializing_if = "Option::is_none")]
|
||||
pub id: Option<ObjectId>,
|
||||
pub key: String,
|
||||
pub value: serde_json::Value,
|
||||
pub category: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub expires_at: DateTime<Utc>,
|
||||
#[serde(default)]
|
||||
pub hit_count: i64,
|
||||
#[serde(default)]
|
||||
pub last_access: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CacheSettings {
|
||||
pub enabled: bool,
|
||||
pub ttl_videos: u64,
|
||||
pub ttl_search: u64,
|
||||
pub ttl_hybrid_search: u64,
|
||||
pub ttl_video_meta: u64,
|
||||
}
|
||||
|
||||
impl Default for CacheSettings {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: *cache_config::MONGODB_CACHE_ENABLED,
|
||||
ttl_videos: *cache_config::MONGODB_CACHE_TTL_VIDEOS,
|
||||
ttl_search: *cache_config::MONGODB_CACHE_TTL_SEARCH,
|
||||
ttl_hybrid_search: *cache_config::MONGODB_CACHE_TTL_HYBRID_SEARCH,
|
||||
ttl_video_meta: *cache_config::MONGODB_CACHE_TTL_VIDEO_META,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct MongoCache {
|
||||
#[allow(dead_code)]
|
||||
client: Client,
|
||||
db: Database,
|
||||
collection: Collection<Document>,
|
||||
settings: CacheSettings,
|
||||
initialized: Arc<RwLock<bool>>,
|
||||
}
|
||||
|
||||
impl MongoCache {
|
||||
pub async fn init() -> Result<Self> {
|
||||
let uri = crate::core::config::MONGODB_URL.as_str();
|
||||
let client = Client::with_uri_str(uri)
|
||||
.await
|
||||
.context("Failed to connect to MongoDB")?;
|
||||
let db = client.database(DB_NAME);
|
||||
let collection: Collection<Document> = db.collection(COLLECTION_NAME);
|
||||
let settings = CacheSettings::default();
|
||||
|
||||
let cache = Self {
|
||||
client,
|
||||
db,
|
||||
collection,
|
||||
settings,
|
||||
initialized: Arc::new(RwLock::new(false)),
|
||||
};
|
||||
|
||||
cache.ensure_indexes().await?;
|
||||
Ok(cache)
|
||||
}
|
||||
|
||||
async fn ensure_indexes(&self) -> Result<()> {
|
||||
let mut guard = self.initialized.write().await;
|
||||
if *guard {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let ttl_index = IndexModel::builder()
|
||||
.keys(doc! { "expires_at": 1 })
|
||||
.options(
|
||||
mongodb::options::IndexOptions::builder()
|
||||
.expire_after(std::time::Duration::from_secs(0))
|
||||
.build(),
|
||||
)
|
||||
.build();
|
||||
|
||||
let key_index = IndexModel::builder()
|
||||
.keys(doc! { "key": 1 })
|
||||
.options(
|
||||
mongodb::options::IndexOptions::builder()
|
||||
.unique(true)
|
||||
.build(),
|
||||
)
|
||||
.build();
|
||||
|
||||
let category_index = IndexModel::builder().keys(doc! { "category": 1 }).build();
|
||||
|
||||
self.collection
|
||||
.create_indexes([ttl_index, key_index, category_index], None)
|
||||
.await
|
||||
.context("Failed to create cache indexes")?;
|
||||
|
||||
*guard = true;
|
||||
tracing::info!("MongoDB cache indexes ensured");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn is_enabled(&self) -> bool {
|
||||
self.settings.enabled
|
||||
}
|
||||
|
||||
pub fn ttl_videos(&self) -> u64 {
|
||||
self.settings.ttl_videos
|
||||
}
|
||||
|
||||
pub fn ttl_search(&self) -> u64 {
|
||||
self.settings.ttl_search
|
||||
}
|
||||
|
||||
pub fn ttl_hybrid_search(&self) -> u64 {
|
||||
self.settings.ttl_hybrid_search
|
||||
}
|
||||
|
||||
pub fn ttl_video_meta(&self) -> u64 {
|
||||
self.settings.ttl_video_meta
|
||||
}
|
||||
|
||||
pub async fn get<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>> {
|
||||
if !self.is_enabled() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let filter = doc! { "key": key };
|
||||
let result = self.collection.find_one(filter, None).await?;
|
||||
|
||||
if let Some(doc) = result {
|
||||
if let Some(value_bson) = doc.get("value") {
|
||||
let json_value: serde_json::Value = bson::from_bson(value_bson.clone())?;
|
||||
let value: T = serde_json::from_value(json_value)?;
|
||||
|
||||
if let Ok(id) = doc.get_object_id("_id") {
|
||||
let update = doc! {
|
||||
"$inc": { "hit_count": 1i64 },
|
||||
"$set": { "last_access": BsonDateTime::from_chrono(Utc::now()) }
|
||||
};
|
||||
if let Err(e) = self
|
||||
.collection
|
||||
.update_one(doc! { "_id": id }, update, None)
|
||||
.await
|
||||
{
|
||||
tracing::warn!("Failed to update cache hit count: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
return Ok(Some(value));
|
||||
}
|
||||
}
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
pub async fn set<T: Serialize>(
|
||||
&self,
|
||||
key: &str,
|
||||
value: &T,
|
||||
ttl_secs: u64,
|
||||
category: &str,
|
||||
) -> Result<()> {
|
||||
if !self.is_enabled() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let now = Utc::now();
|
||||
let expires_at = now + Duration::seconds(ttl_secs as i64);
|
||||
let json_value = serde_json::to_value(value)?;
|
||||
let bson_value = bson::to_bson(&json_value)?;
|
||||
|
||||
let filter = doc! { "key": key };
|
||||
let update = doc! {
|
||||
"$set": {
|
||||
"value": bson_value,
|
||||
"category": category,
|
||||
"expires_at": BsonDateTime::from_chrono(expires_at),
|
||||
"last_access": BsonDateTime::from_chrono(now),
|
||||
},
|
||||
"$setOnInsert": {
|
||||
"key": key,
|
||||
"created_at": BsonDateTime::from_chrono(now),
|
||||
"hit_count": 0i64,
|
||||
}
|
||||
};
|
||||
|
||||
let options = mongodb::options::UpdateOptions::builder()
|
||||
.upsert(true)
|
||||
.build();
|
||||
self.collection
|
||||
.update_one(filter, update, options)
|
||||
.await
|
||||
.context("Failed to set cache entry")?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn delete(&self, key: &str) -> Result<bool> {
|
||||
if !self.is_enabled() {
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
let filter = doc! { "key": key };
|
||||
let result = self.collection.delete_one(filter, None).await?;
|
||||
Ok(result.deleted_count > 0)
|
||||
}
|
||||
|
||||
pub async fn invalidate_category(&self, category: &str) -> Result<u64> {
|
||||
if !self.is_enabled() {
|
||||
return Ok(0);
|
||||
}
|
||||
|
||||
let filter = doc! { "category": category };
|
||||
let result = self.collection.delete_many(filter, None).await?;
|
||||
let count = result.deleted_count;
|
||||
tracing::debug!("Invalidated {} entries in category: {}", count, category);
|
||||
Ok(count)
|
||||
}
|
||||
|
||||
pub async fn invalidate_prefix(&self, prefix: &str) -> Result<u64> {
|
||||
if !self.is_enabled() {
|
||||
return Ok(0);
|
||||
}
|
||||
|
||||
let regex_pattern = format!("^{}", prefix);
|
||||
let filter = doc! { "key": { "$regex": ®ex_pattern } };
|
||||
let result = self.collection.delete_many(filter, None).await?;
|
||||
let count = result.deleted_count;
|
||||
tracing::debug!("Invalidated {} entries with prefix: {}", count, prefix);
|
||||
Ok(count)
|
||||
}
|
||||
|
||||
pub async fn get_or_fetch<F, Fut, T>(
|
||||
&self,
|
||||
key: &str,
|
||||
ttl_secs: u64,
|
||||
category: &str,
|
||||
fetcher: F,
|
||||
) -> Result<T>
|
||||
where
|
||||
F: FnOnce() -> Fut,
|
||||
Fut: std::future::Future<Output = Result<T>>,
|
||||
T: DeserializeOwned + Serialize,
|
||||
{
|
||||
if let Some(cached) = self.get::<T>(key).await? {
|
||||
tracing::debug!("Cache hit for key: {}", key);
|
||||
return Ok(cached);
|
||||
}
|
||||
|
||||
tracing::debug!("Cache miss for key: {}", key);
|
||||
let value = fetcher().await?;
|
||||
if let Err(e) = self.set(key, &value, ttl_secs, category).await {
|
||||
tracing::warn!("Failed to cache value: {}", e);
|
||||
}
|
||||
Ok(value)
|
||||
}
|
||||
|
||||
pub async fn invalidate_videos_list(&self) -> Result<u64> {
|
||||
self.invalidate_category(keys::CATEGORY_VIDEOS).await
|
||||
}
|
||||
|
||||
pub async fn invalidate_video(&self, uuid: &str) -> Result<u64> {
|
||||
let key = keys::video_meta(uuid);
|
||||
let count = self.delete(&key).await? as u64;
|
||||
let list_count = self.invalidate_videos_list().await?;
|
||||
Ok(count + list_count)
|
||||
}
|
||||
|
||||
pub async fn invalidate_all_search(&self) -> Result<u64> {
|
||||
let count1 = self.invalidate_category(keys::CATEGORY_SEARCH).await?;
|
||||
let count2 = self
|
||||
.invalidate_category(keys::CATEGORY_HYBRID_SEARCH)
|
||||
.await?;
|
||||
let count3 = self.invalidate_category(keys::CATEGORY_N8N_SEARCH).await?;
|
||||
Ok(count1 + count2 + count3)
|
||||
}
|
||||
|
||||
pub async fn health_check(&self) -> Result<bool> {
|
||||
self.db.run_command(doc! { "ping": 1 }, None).await?;
|
||||
Ok(true)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_cache_settings_default() {
|
||||
let settings = CacheSettings::default();
|
||||
assert!(settings.enabled);
|
||||
assert_eq!(settings.ttl_videos, 300);
|
||||
assert_eq!(settings.ttl_search, 300);
|
||||
}
|
||||
}
|
||||
120
src/core/cache/tests.rs
vendored
Normal file
120
src/core/cache/tests.rs
vendored
Normal file
@@ -0,0 +1,120 @@
|
||||
use crate::core::cache::keys;
|
||||
use crate::core::cache::mongo_cache::CacheSettings;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_cache_settings_default() {
|
||||
let settings = CacheSettings::default();
|
||||
assert!(settings.enabled);
|
||||
assert_eq!(settings.ttl_videos, 300);
|
||||
assert_eq!(settings.ttl_search, 300);
|
||||
assert_eq!(settings.ttl_hybrid_search, 600);
|
||||
assert_eq!(settings.ttl_video_meta, 3600);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cache_key_videos_list() {
|
||||
let key = keys::videos_list(1, 20);
|
||||
assert_eq!(key, "videos:list:page=1:limit=20");
|
||||
|
||||
let key2 = keys::videos_list(2, 50);
|
||||
assert_eq!(key2, "videos:list:page=2:limit=50");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cache_key_video_meta() {
|
||||
let key = keys::video_meta("abc123");
|
||||
assert_eq!(key, "video:abc123");
|
||||
|
||||
let uuid = "5dea6618a606e7c7";
|
||||
let key = keys::video_meta(uuid);
|
||||
assert_eq!(key, "video:5dea6618a606e7c7");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cache_key_search() {
|
||||
let key = keys::search("hash123");
|
||||
assert_eq!(key, "search:hash123");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cache_key_hybrid_search() {
|
||||
let key = keys::hybrid_search("hash123");
|
||||
assert_eq!(key, "search:hybrid:hash123");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cache_key_n8n_search() {
|
||||
let key = keys::n8n_search("hash123");
|
||||
assert_eq!(key, "search:n8n:hash123");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cache_key_health() {
|
||||
let key = keys::health();
|
||||
assert_eq!(key, "health:basic");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cache_categories() {
|
||||
assert_eq!(keys::CATEGORY_VIDEOS, "videos");
|
||||
assert_eq!(keys::CATEGORY_SEARCH, "search");
|
||||
assert_eq!(keys::CATEGORY_HYBRID_SEARCH, "hybrid_search");
|
||||
assert_eq!(keys::CATEGORY_VIDEO_META, "video_meta");
|
||||
assert_eq!(keys::CATEGORY_N8N_SEARCH, "n8n_search");
|
||||
assert_eq!(keys::CATEGORY_HEALTH, "health");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cache_key_prefixes() {
|
||||
assert_eq!(keys::KEY_PREFIX_VIDEOS_LIST, "videos:list:");
|
||||
assert_eq!(keys::KEY_PREFIX_VIDEO, "video:");
|
||||
assert_eq!(keys::KEY_PREFIX_SEARCH, "search:");
|
||||
assert_eq!(keys::KEY_PREFIX_SEARCH_HYBRID, "search:hybrid:");
|
||||
assert_eq!(keys::KEY_PREFIX_SEARCH_N8N, "search:n8n:");
|
||||
assert_eq!(keys::KEY_HEALTH, "health:basic");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cache_ttl_values() {
|
||||
let settings = CacheSettings::default();
|
||||
|
||||
assert!(settings.ttl_videos >= 60 && settings.ttl_videos <= 600);
|
||||
assert!(settings.ttl_search >= 60 && settings.ttl_search <= 600);
|
||||
assert!(settings.ttl_hybrid_search >= 60 && settings.ttl_hybrid_search <= 3600);
|
||||
assert!(settings.ttl_video_meta >= 300 && settings.ttl_video_meta <= 7200);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cache_key_videos_list_prefix_format() {
|
||||
let key = keys::videos_list(1, 10);
|
||||
assert!(key.starts_with("videos:list:"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cache_key_video_meta_prefix_format() {
|
||||
let key = keys::video_meta("uuid123");
|
||||
assert!(key.starts_with("video:"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cache_key_search_prefix_format() {
|
||||
let key = keys::search("test");
|
||||
assert!(key.starts_with("search:"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cache_key_hybrid_search_prefix_format() {
|
||||
let key = keys::hybrid_search("test");
|
||||
assert!(key.starts_with("search:hybrid:"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cache_key_n8n_search_prefix_format() {
|
||||
let key = keys::n8n_search("test");
|
||||
assert!(key.starts_with("search:n8n:"));
|
||||
}
|
||||
}
|
||||
@@ -1,14 +1,15 @@
|
||||
use super::types::{Chunk, ChunkType};
|
||||
use anyhow::Result;
|
||||
use super::types::{Chunk, ChunkRule, ChunkType};
|
||||
|
||||
pub struct ChunkSplitter {
|
||||
time_based_duration: f64,
|
||||
fps: f64,
|
||||
}
|
||||
|
||||
impl ChunkSplitter {
|
||||
pub fn new(time_based_duration_seconds: f64) -> Self {
|
||||
Self {
|
||||
time_based_duration: time_based_duration_seconds,
|
||||
fps: 24.0,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,11 +21,14 @@ impl ChunkSplitter {
|
||||
while current_time < duration {
|
||||
let end_time = (current_time + self.time_based_duration).min(duration);
|
||||
chunks.push(Chunk::new(
|
||||
0, // file_id
|
||||
uuid.to_string(),
|
||||
index,
|
||||
ChunkType::TimeBased,
|
||||
ChunkRule::Rule1,
|
||||
current_time,
|
||||
end_time,
|
||||
self.fps,
|
||||
serde_json::json!({
|
||||
"source": "time_based",
|
||||
"duration": self.time_based_duration,
|
||||
@@ -42,11 +46,14 @@ impl ChunkSplitter {
|
||||
|
||||
for (index, segment) in asr_segments.iter().enumerate() {
|
||||
chunks.push(Chunk::new(
|
||||
0, // file_id
|
||||
uuid.to_string(),
|
||||
index as u32,
|
||||
ChunkType::Sentence,
|
||||
ChunkRule::Rule1,
|
||||
segment.start,
|
||||
segment.end,
|
||||
self.fps,
|
||||
serde_json::json!({
|
||||
"text": segment.text,
|
||||
"speaker_id": segment.speaker_id,
|
||||
|
||||
@@ -6,47 +6,140 @@ pub enum ChunkType {
|
||||
TimeBased,
|
||||
Sentence,
|
||||
Cut,
|
||||
Trace,
|
||||
Story, // Parent chunk from story analysis
|
||||
}
|
||||
|
||||
impl ChunkType {
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
ChunkType::TimeBased => "time_based",
|
||||
ChunkType::TimeBased => "time",
|
||||
ChunkType::Sentence => "sentence",
|
||||
ChunkType::Cut => "cut",
|
||||
ChunkType::Trace => "trace",
|
||||
ChunkType::Story => "story",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ChunkRule {
|
||||
Rule1, // 直接轉換
|
||||
Rule2, // 集合內容
|
||||
}
|
||||
|
||||
impl ChunkRule {
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
ChunkRule::Rule1 => "rule_1",
|
||||
ChunkRule::Rule2 => "rule_2",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Chunk {
|
||||
pub file_id: i32,
|
||||
pub uuid: String,
|
||||
pub chunk_id: String,
|
||||
pub chunk_index: u32,
|
||||
pub chunk_type: ChunkType,
|
||||
pub rule: ChunkRule,
|
||||
pub start_time: f64,
|
||||
pub end_time: f64,
|
||||
pub fps: f64,
|
||||
pub start_frame: i64,
|
||||
pub end_frame: i64,
|
||||
pub text_content: Option<String>,
|
||||
pub content: serde_json::Value,
|
||||
pub metadata: Option<serde_json::Value>,
|
||||
pub vector_id: Option<String>,
|
||||
pub frame_count: i32,
|
||||
pub pre_chunk_ids: Vec<i32>,
|
||||
pub parent_chunk_id: Option<String>, // For parent-child chunk hierarchy
|
||||
pub child_chunk_ids: Vec<String>, // Child chunk IDs (for parent chunks)
|
||||
}
|
||||
|
||||
impl Chunk {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new(
|
||||
file_id: i32,
|
||||
uuid: String,
|
||||
chunk_index: u32,
|
||||
chunk_type: ChunkType,
|
||||
rule: ChunkRule,
|
||||
start_time: f64,
|
||||
end_time: f64,
|
||||
fps: f64,
|
||||
content: serde_json::Value,
|
||||
) -> Self {
|
||||
let start_frame = (start_time * fps) as i64;
|
||||
let end_frame = (end_time * fps) as i64;
|
||||
let chunk_id = format!("{}_{:04}", chunk_type.as_str(), chunk_index);
|
||||
Self {
|
||||
file_id,
|
||||
uuid,
|
||||
chunk_id: chunk_id.clone(),
|
||||
chunk_index,
|
||||
chunk_type,
|
||||
rule,
|
||||
start_time,
|
||||
end_time,
|
||||
fps,
|
||||
start_frame,
|
||||
end_frame,
|
||||
text_content: None,
|
||||
content,
|
||||
metadata: None,
|
||||
vector_id: None,
|
||||
frame_count: 0,
|
||||
pre_chunk_ids: vec![],
|
||||
parent_chunk_id: None,
|
||||
child_chunk_ids: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_metadata(mut self, metadata: serde_json::Value) -> Self {
|
||||
self.metadata = Some(metadata);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_vector_id(mut self, vector_id: String) -> Self {
|
||||
self.vector_id = Some(vector_id);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_text_content(mut self, text: String) -> Self {
|
||||
self.text_content = Some(text);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_frame_count(mut self, count: i32) -> Self {
|
||||
self.frame_count = count;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_pre_chunk_ids(mut self, ids: Vec<i32>) -> Self {
|
||||
self.pre_chunk_ids = ids;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_parent_chunk_id(mut self, parent_id: String) -> Self {
|
||||
self.parent_chunk_id = Some(parent_id);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_child_chunk_ids(mut self, child_ids: Vec<String>) -> Self {
|
||||
self.child_chunk_ids = child_ids;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn is_parent_chunk(&self) -> bool {
|
||||
!self.child_chunk_ids.is_empty()
|
||||
}
|
||||
|
||||
pub fn is_child_chunk(&self) -> bool {
|
||||
self.parent_chunk_id.is_some()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -60,14 +60,6 @@ pub static SERVER_PORT: Lazy<u16> = Lazy::new(|| {
|
||||
pub static REDIS_KEY_PREFIX: Lazy<String> =
|
||||
Lazy::new(|| env::var("MOMENTRY_REDIS_PREFIX").unwrap_or_else(|_| "momentry:".to_string()));
|
||||
|
||||
/// User data root path (sftpgo data directory)
|
||||
/// This is the parent directory containing user directories like ./demo/, ./warren/, ./momentry/
|
||||
/// Example: /Users/accusys/momentry/var/sftpgo/data
|
||||
pub static USER_DATA_ROOT: Lazy<String> = Lazy::new(|| {
|
||||
env::var("MOMENTRY_USER_DATA_ROOT")
|
||||
.unwrap_or_else(|_| "/Users/accusys/momentry/var/sftpgo/data".to_string())
|
||||
});
|
||||
|
||||
pub mod processor {
|
||||
use super::*;
|
||||
|
||||
|
||||
@@ -32,9 +32,20 @@ pub trait VectorStore: Send + Sync {
|
||||
pub mod mongodb_db;
|
||||
pub mod postgres_db;
|
||||
pub mod qdrant_db;
|
||||
pub mod redis_client;
|
||||
pub mod redis_db;
|
||||
pub mod sync_db;
|
||||
|
||||
pub use mongodb_db::MongoDb;
|
||||
pub use postgres_db::{PostgresDb, VideoRecord};
|
||||
pub use qdrant_db::QdrantDb;
|
||||
pub use postgres_db::{
|
||||
Bm25Result, CreateApiKeyConfig, HybridSearchResult, MonitorJob, MonitorJobStats,
|
||||
MonitorJobStatus, PostgresDb, ProcessorJobStatus, ProcessorResult, ProcessorType, VideoRecord,
|
||||
VideoStatus,
|
||||
};
|
||||
pub use qdrant_db::{QdrantDb, VectorPayload};
|
||||
pub use redis_client::{
|
||||
JobErrorMessage, MonitorJobRedis, ProcessorStatus as RedisProcessorStatus, ProgressData,
|
||||
ProgressMessage, RedisClient,
|
||||
};
|
||||
pub use redis_db::RedisDb;
|
||||
pub use sync_db::SyncDb;
|
||||
|
||||
@@ -1,53 +1,269 @@
|
||||
use anyhow::Result;
|
||||
use anyhow::{Context, Result};
|
||||
use async_trait::async_trait;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::Database;
|
||||
use crate::core::chunk::types::{Chunk, ChunkRule, ChunkType};
|
||||
|
||||
pub struct MongoDb {
|
||||
cache: Arc<RwLock<MongoCache>>,
|
||||
base_url: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct MongoCache {
|
||||
documents: std::collections::HashMap<String, serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Serialize, serde::Deserialize)]
|
||||
pub struct VideoDocument {
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChunkDocument {
|
||||
pub uuid: String,
|
||||
pub file_path: String,
|
||||
pub file_name: String,
|
||||
pub probe: serde_json::Value,
|
||||
pub asr: Option<serde_json::Value>,
|
||||
pub asrx: Option<serde_json::Value>,
|
||||
pub ocr: Option<serde_json::Value>,
|
||||
pub yolo: Option<serde_json::Value>,
|
||||
pub face: Option<serde_json::Value>,
|
||||
pub pose: Option<serde_json::Value>,
|
||||
pub created_at: String,
|
||||
pub updated_at: String,
|
||||
pub chunk_id: String,
|
||||
pub chunk_index: u32,
|
||||
pub chunk_type: String,
|
||||
pub start_time: f64,
|
||||
pub end_time: f64,
|
||||
pub fps: f64,
|
||||
pub start_frame: i64,
|
||||
pub end_frame: i64,
|
||||
pub content: serde_json::Value,
|
||||
pub metadata: Option<serde_json::Value>,
|
||||
pub vector_id: Option<String>,
|
||||
pub parent_chunk_id: Option<String>,
|
||||
pub child_chunk_ids: Vec<String>,
|
||||
}
|
||||
|
||||
impl From<Chunk> for ChunkDocument {
|
||||
fn from(chunk: Chunk) -> Self {
|
||||
Self {
|
||||
uuid: chunk.uuid,
|
||||
chunk_id: chunk.chunk_id,
|
||||
chunk_index: chunk.chunk_index,
|
||||
chunk_type: chunk.chunk_type.as_str().to_string(),
|
||||
start_time: chunk.start_time,
|
||||
end_time: chunk.end_time,
|
||||
fps: chunk.fps,
|
||||
start_frame: chunk.start_frame,
|
||||
end_frame: chunk.end_frame,
|
||||
content: chunk.content,
|
||||
metadata: chunk.metadata,
|
||||
vector_id: chunk.vector_id,
|
||||
parent_chunk_id: chunk.parent_chunk_id,
|
||||
child_chunk_ids: chunk.child_chunk_ids,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl MongoDb {
|
||||
pub async fn store_video(&self, _doc: &VideoDocument) -> Result<()> {
|
||||
// TODO: Implement MongoDB client
|
||||
pub fn new() -> Self {
|
||||
let base_url =
|
||||
std::env::var("MONGODB_URL").unwrap_or_else(|_| "http://localhost:27017".to_string());
|
||||
Self { base_url }
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for MongoDb {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl MongoDb {
|
||||
pub async fn store_chunk(&self, chunk: &Chunk) -> Result<()> {
|
||||
let doc: ChunkDocument = chunk.clone().into();
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
let url = format!("{}/momentry/chunks", self.base_url);
|
||||
|
||||
client
|
||||
.post(&url)
|
||||
.json(&doc)
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to store chunk in MongoDB")?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn get_video(&self, _uuid: &str) -> Result<Option<VideoDocument>> {
|
||||
// TODO: Implement MongoDB client
|
||||
Ok(None)
|
||||
pub async fn get_chunks_by_uuid(&self, uuid: &str) -> Result<Vec<Chunk>> {
|
||||
let client = reqwest::Client::new();
|
||||
let url = format!(
|
||||
"{}/momentry/chunks?filter={{\"uuid\":\"{}\"}}",
|
||||
self.base_url, uuid
|
||||
);
|
||||
|
||||
let response = client
|
||||
.get(&url)
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to get chunks from MongoDB")?;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct MongoResponse {
|
||||
documents: Vec<ChunkDocument>,
|
||||
}
|
||||
|
||||
let result: MongoResponse = response.json().await?;
|
||||
|
||||
let chunks: Vec<Chunk> = result
|
||||
.documents
|
||||
.into_iter()
|
||||
.map(|doc| {
|
||||
let chunk_type = match doc.chunk_type.as_str() {
|
||||
"sentence" => ChunkType::Sentence,
|
||||
"cut" => ChunkType::Cut,
|
||||
"time_based" => ChunkType::TimeBased,
|
||||
"trace" => ChunkType::Trace,
|
||||
"story" => ChunkType::Story,
|
||||
_ => ChunkType::Sentence,
|
||||
};
|
||||
|
||||
Chunk {
|
||||
file_id: 0,
|
||||
uuid: doc.uuid,
|
||||
chunk_id: doc.chunk_id,
|
||||
chunk_index: doc.chunk_index,
|
||||
chunk_type,
|
||||
rule: ChunkRule::Rule1,
|
||||
start_time: doc.start_time,
|
||||
end_time: doc.end_time,
|
||||
fps: doc.fps,
|
||||
start_frame: doc.start_frame,
|
||||
end_frame: doc.end_frame,
|
||||
text_content: None,
|
||||
content: doc.content,
|
||||
metadata: doc.metadata,
|
||||
vector_id: doc.vector_id,
|
||||
frame_count: 0,
|
||||
pre_chunk_ids: vec![],
|
||||
parent_chunk_id: doc.parent_chunk_id,
|
||||
child_chunk_ids: doc.child_chunk_ids,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(chunks)
|
||||
}
|
||||
|
||||
pub async fn search_text(&self, query: &str) -> Result<Vec<Chunk>> {
|
||||
let client = reqwest::Client::new();
|
||||
let url = format!(
|
||||
"{}/momentry/chunks?filter={{\"$text\":{{\"$search\":\"{}\"}}}}",
|
||||
self.base_url, query
|
||||
);
|
||||
|
||||
let response = client
|
||||
.get(&url)
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to search text in MongoDB")?;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct MongoResponse {
|
||||
documents: Vec<ChunkDocument>,
|
||||
}
|
||||
|
||||
let result: MongoResponse = response.json().await?;
|
||||
|
||||
let chunks: Vec<Chunk> = result
|
||||
.documents
|
||||
.into_iter()
|
||||
.map(|doc| {
|
||||
let chunk_type = match doc.chunk_type.as_str() {
|
||||
"sentence" => ChunkType::Sentence,
|
||||
"cut" => ChunkType::Cut,
|
||||
"time" => ChunkType::TimeBased,
|
||||
"trace" => ChunkType::Trace,
|
||||
"story" => ChunkType::Story,
|
||||
_ => ChunkType::Sentence,
|
||||
};
|
||||
|
||||
Chunk {
|
||||
file_id: 0,
|
||||
uuid: doc.uuid,
|
||||
chunk_id: doc.chunk_id,
|
||||
chunk_index: doc.chunk_index,
|
||||
chunk_type,
|
||||
rule: ChunkRule::Rule1,
|
||||
start_time: doc.start_time,
|
||||
end_time: doc.end_time,
|
||||
fps: doc.fps,
|
||||
start_frame: doc.start_frame,
|
||||
end_frame: doc.end_frame,
|
||||
text_content: None,
|
||||
content: doc.content,
|
||||
metadata: doc.metadata,
|
||||
vector_id: doc.vector_id,
|
||||
frame_count: 0,
|
||||
pre_chunk_ids: vec![],
|
||||
parent_chunk_id: doc.parent_chunk_id,
|
||||
child_chunk_ids: doc.child_chunk_ids,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(chunks)
|
||||
}
|
||||
|
||||
pub async fn get_all_chunks(&self) -> Result<Vec<Chunk>> {
|
||||
let client = reqwest::Client::new();
|
||||
let url = format!("{}/momentry/chunks", self.base_url);
|
||||
|
||||
let response = client
|
||||
.get(&url)
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to get all chunks from MongoDB")?;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct MongoResponse {
|
||||
documents: Vec<ChunkDocument>,
|
||||
}
|
||||
|
||||
let result: MongoResponse = response.json().await?;
|
||||
|
||||
let chunks: Vec<Chunk> = result
|
||||
.documents
|
||||
.into_iter()
|
||||
.map(|doc| {
|
||||
let chunk_type = match doc.chunk_type.as_str() {
|
||||
"sentence" => ChunkType::Sentence,
|
||||
"cut" => ChunkType::Cut,
|
||||
"time" => ChunkType::TimeBased,
|
||||
"trace" => ChunkType::Trace,
|
||||
"story" => ChunkType::Story,
|
||||
_ => ChunkType::Sentence,
|
||||
};
|
||||
|
||||
Chunk {
|
||||
file_id: 0,
|
||||
uuid: doc.uuid,
|
||||
chunk_id: doc.chunk_id,
|
||||
chunk_index: doc.chunk_index,
|
||||
chunk_type,
|
||||
rule: ChunkRule::Rule1,
|
||||
start_time: doc.start_time,
|
||||
end_time: doc.end_time,
|
||||
fps: doc.fps,
|
||||
start_frame: doc.start_frame,
|
||||
end_frame: doc.end_frame,
|
||||
text_content: None,
|
||||
content: doc.content,
|
||||
metadata: doc.metadata,
|
||||
vector_id: doc.vector_id,
|
||||
frame_count: 0,
|
||||
pre_chunk_ids: vec![],
|
||||
parent_chunk_id: doc.parent_chunk_id,
|
||||
child_chunk_ids: doc.child_chunk_ids,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(chunks)
|
||||
}
|
||||
|
||||
pub async fn get_chunk_count(&self) -> Result<i64> {
|
||||
let chunks = self.get_all_chunks().await?;
|
||||
Ok(chunks.len() as i64)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Database for MongoDb {
|
||||
impl super::Database for MongoDb {
|
||||
async fn init() -> Result<Self> {
|
||||
// TODO: Initialize MongoDB client
|
||||
Ok(Self {
|
||||
cache: Arc::new(RwLock::new(MongoCache::default())),
|
||||
})
|
||||
Ok(Self::new())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2669,8 +2669,9 @@ impl PostgresDb {
|
||||
pub async fn get_processor_results_by_job(&self, job_id: i32) -> Result<Vec<ProcessorResult>> {
|
||||
let rows = sqlx::query(
|
||||
r#"
|
||||
SELECT id, job_id, processor, status, started_at, completed_at, duration_secs,
|
||||
error_message, output_data, retry_count, created_at, updated_at
|
||||
SELECT id, job_id, processor, status, output_path, started_at, completed_at,
|
||||
error_message, progress_total, progress_current, last_checkpoint,
|
||||
created_at, updated_at, duration_secs
|
||||
FROM processor_results
|
||||
WHERE job_id = $1
|
||||
ORDER BY created_at ASC
|
||||
@@ -2685,6 +2686,10 @@ impl PostgresDb {
|
||||
.map(|r| {
|
||||
let status_str: String = r.get(3);
|
||||
let processor_type_str: String = r.get(2);
|
||||
let started_at: Option<chrono::NaiveDateTime> = r.get(5);
|
||||
let completed_at: Option<chrono::NaiveDateTime> = r.get(6);
|
||||
let created_at: chrono::NaiveDateTime = r.get(11);
|
||||
let updated_at: Option<chrono::NaiveDateTime> = r.get(12);
|
||||
ProcessorResult {
|
||||
id: r.get(0),
|
||||
job_id: r.get(1),
|
||||
@@ -2692,14 +2697,14 @@ impl PostgresDb {
|
||||
.unwrap_or(ProcessorType::Asr),
|
||||
status: ProcessorJobStatus::from_db_str(&status_str)
|
||||
.unwrap_or(ProcessorJobStatus::Pending),
|
||||
started_at: r.get(4),
|
||||
completed_at: r.get(5),
|
||||
duration_secs: r.get(6),
|
||||
started_at: started_at.map(|t| t.to_string()),
|
||||
completed_at: completed_at.map(|t| t.to_string()),
|
||||
duration_secs: r.get(13),
|
||||
error_message: r.get(7),
|
||||
output_data: r.get(8),
|
||||
retry_count: r.get(9),
|
||||
created_at: r.get(10),
|
||||
updated_at: r.get(11),
|
||||
output_data: None,
|
||||
retry_count: 0,
|
||||
created_at: created_at.to_string(),
|
||||
updated_at: updated_at.map(|t| t.to_string()).unwrap_or_default(),
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
@@ -1,46 +1,330 @@
|
||||
use anyhow::Result;
|
||||
use anyhow::{Context, Result};
|
||||
use async_trait::async_trait;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
use super::{Database, SearchResult, VectorStore};
|
||||
|
||||
pub struct QdrantDb {
|
||||
client: Client,
|
||||
base_url: String,
|
||||
api_key: String,
|
||||
collection_name: String,
|
||||
cache: Arc<RwLock<QdrantCache>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct QdrantCache {
|
||||
vectors: std::collections::HashMap<String, Vec<f32>>,
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct VectorPayload {
|
||||
pub uuid: String,
|
||||
pub chunk_id: String,
|
||||
pub chunk_type: String,
|
||||
pub start_time: f64,
|
||||
pub end_time: f64,
|
||||
pub text: Option<String>,
|
||||
}
|
||||
|
||||
impl QdrantDb {
|
||||
pub async fn init_collection(&self) -> Result<()> {
|
||||
// TODO: Implement actual Qdrant client
|
||||
// This is a placeholder
|
||||
pub fn new() -> Self {
|
||||
let base_url =
|
||||
std::env::var("QDRANT_URL").unwrap_or_else(|_| "http://localhost:6333".to_string());
|
||||
let api_key = std::env::var("QDRANT_API_KEY")
|
||||
.unwrap_or_else(|_| "Test3200Test3200Test3200".to_string());
|
||||
let collection_name =
|
||||
std::env::var("QDRANT_COLLECTION").unwrap_or_else(|_| "chunks_v3".to_string());
|
||||
|
||||
Self {
|
||||
client: Client::new(),
|
||||
base_url,
|
||||
api_key,
|
||||
collection_name,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for QdrantDb {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl QdrantDb {
|
||||
pub async fn init_collection(&self, vector_dim: usize) -> Result<()> {
|
||||
let url = format!("{}/collections/{}", self.base_url, self.collection_name);
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.get(&url)
|
||||
.header("api-key", &self.api_key)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if response.status().is_success() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let create_url = format!("{}/collections", self.base_url);
|
||||
let body = serde_json::json!({
|
||||
"vectors": {
|
||||
"size": vector_dim,
|
||||
"distance": "Cosine"
|
||||
}
|
||||
});
|
||||
|
||||
self.client
|
||||
.post(&create_url)
|
||||
.header("api-key", &self.api_key)
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&body)
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to create Qdrant collection")?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn upsert_vector(&self, chunk_id: &str, vector: &[f32]) -> Result<()> {
|
||||
let mut cache = self.cache.write().await;
|
||||
cache.vectors.insert(chunk_id.to_string(), vector.to_vec());
|
||||
pub async fn upsert_vector(
|
||||
&self,
|
||||
_chunk_id: &str,
|
||||
vector: &[f32],
|
||||
payload: VectorPayload,
|
||||
) -> Result<()> {
|
||||
let url = format!(
|
||||
"{}/collections/{}/points",
|
||||
self.base_url, self.collection_name
|
||||
);
|
||||
|
||||
let mut payload_map = HashMap::new();
|
||||
payload_map.insert("uuid".to_string(), serde_json::json!(payload.uuid));
|
||||
payload_map.insert("chunk_id".to_string(), serde_json::json!(payload.chunk_id));
|
||||
payload_map.insert(
|
||||
"chunk_type".to_string(),
|
||||
serde_json::json!(payload.chunk_type),
|
||||
);
|
||||
payload_map.insert(
|
||||
"start_time".to_string(),
|
||||
serde_json::json!(payload.start_time),
|
||||
);
|
||||
payload_map.insert("end_time".to_string(), serde_json::json!(payload.end_time));
|
||||
if let Some(text) = payload.text {
|
||||
payload_map.insert("text".to_string(), serde_json::json!(text));
|
||||
}
|
||||
|
||||
let point_id = uuid::Uuid::new_v4().to_string();
|
||||
|
||||
let body = serde_json::json!({
|
||||
"points": [{
|
||||
"id": point_id,
|
||||
"vector": vector,
|
||||
"payload": payload_map
|
||||
}]
|
||||
});
|
||||
|
||||
self.client
|
||||
.put(&url)
|
||||
.header("api-key", &self.api_key)
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&body)
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to upsert vector in Qdrant")?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn search(&self, query_vector: &[f32], limit: usize) -> Result<Vec<SearchResult>> {
|
||||
let url = format!(
|
||||
"{}/collections/{}/points/search",
|
||||
self.base_url, self.collection_name
|
||||
);
|
||||
|
||||
let body = serde_json::json!({
|
||||
"vector": query_vector,
|
||||
"limit": limit,
|
||||
"with_payload": true
|
||||
});
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(&url)
|
||||
.header("api-key", &self.api_key)
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&body)
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to search Qdrant")?;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct QdrantSearchResult {
|
||||
result: Vec<QdrantPoint>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct QdrantPoint {
|
||||
#[allow(dead_code)]
|
||||
id: serde_json::Value,
|
||||
score: f64,
|
||||
payload: HashMap<String, serde_json::Value>,
|
||||
}
|
||||
|
||||
let result: QdrantSearchResult = response.json().await?;
|
||||
|
||||
let search_results: Vec<SearchResult> = result
|
||||
.result
|
||||
.into_iter()
|
||||
.map(|r| {
|
||||
let chunk_id = r
|
||||
.payload
|
||||
.get("chunk_id")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("unknown")
|
||||
.to_string();
|
||||
SearchResult {
|
||||
chunk_id,
|
||||
score: r.score as f32,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(search_results)
|
||||
}
|
||||
|
||||
pub async fn search_in_uuid(
|
||||
&self,
|
||||
query_vector: &[f64],
|
||||
uuid: &str,
|
||||
limit: usize,
|
||||
) -> Result<Vec<SearchResult>> {
|
||||
let url = format!(
|
||||
"{}/collections/{}/points/search",
|
||||
self.base_url, self.collection_name
|
||||
);
|
||||
|
||||
let body = serde_json::json!({
|
||||
"vector": query_vector,
|
||||
"limit": limit,
|
||||
"with_payload": true,
|
||||
"filter": {
|
||||
"must": [
|
||||
{
|
||||
"key": "uuid",
|
||||
"match": {
|
||||
"value": uuid
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
});
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(&url)
|
||||
.header("api-key", &self.api_key)
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&body)
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to search Qdrant")?;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct QdrantSearchResult {
|
||||
result: Vec<QdrantPoint>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct QdrantPoint {
|
||||
#[allow(dead_code)]
|
||||
id: serde_json::Value,
|
||||
score: f64,
|
||||
payload: HashMap<String, serde_json::Value>,
|
||||
}
|
||||
|
||||
let result: QdrantSearchResult = response.json().await?;
|
||||
|
||||
let search_results: Vec<SearchResult> = result
|
||||
.result
|
||||
.into_iter()
|
||||
.map(|r| {
|
||||
let chunk_id = r
|
||||
.payload
|
||||
.get("chunk_id")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("unknown")
|
||||
.to_string();
|
||||
SearchResult {
|
||||
chunk_id,
|
||||
score: r.score as f32,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(search_results)
|
||||
}
|
||||
|
||||
pub async fn delete_by_uuid(&self, uuid: &str) -> Result<()> {
|
||||
let url = format!(
|
||||
"{}/collections/{}/points/delete",
|
||||
self.base_url, self.collection_name
|
||||
);
|
||||
|
||||
let body = serde_json::json!({
|
||||
"filter": {
|
||||
"must": [
|
||||
{
|
||||
"key": "uuid",
|
||||
"match": {
|
||||
"value": uuid
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
});
|
||||
|
||||
self.client
|
||||
.post(&url)
|
||||
.header("api-key", &self.api_key)
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&body)
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to delete points from Qdrant")?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn get_point_count(&self) -> Result<usize> {
|
||||
let url = format!(
|
||||
"{}/collections/{}/info",
|
||||
self.base_url, self.collection_name
|
||||
);
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.get(&url)
|
||||
.header("api-key", &self.api_key)
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to get collection info")?;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct CollectionInfo {
|
||||
result: CollectionStatus,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct CollectionStatus {
|
||||
points_count: usize,
|
||||
}
|
||||
|
||||
let result: CollectionInfo = response.json().await?;
|
||||
Ok(result.result.points_count)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Database for QdrantDb {
|
||||
async fn init() -> Result<Self> {
|
||||
let collection_name =
|
||||
std::env::var("QDRANT_COLLECTION").unwrap_or_else(|_| "momentry_chunks".to_string());
|
||||
|
||||
let db = Self {
|
||||
collection_name,
|
||||
cache: Arc::new(RwLock::new(QdrantCache::default())),
|
||||
};
|
||||
|
||||
db.init_collection().await?;
|
||||
let db = Self::new();
|
||||
db.init_collection(768).await?;
|
||||
Ok(db)
|
||||
}
|
||||
}
|
||||
@@ -48,41 +332,18 @@ impl Database for QdrantDb {
|
||||
#[async_trait]
|
||||
impl VectorStore for QdrantDb {
|
||||
async fn store_vector(&self, chunk_id: &str, vector: &[f32]) -> Result<()> {
|
||||
self.upsert_vector(chunk_id, vector).await
|
||||
let payload = VectorPayload {
|
||||
uuid: String::new(),
|
||||
chunk_id: chunk_id.to_string(),
|
||||
chunk_type: String::new(),
|
||||
start_time: 0.0,
|
||||
end_time: 0.0,
|
||||
text: None,
|
||||
};
|
||||
self.upsert_vector(chunk_id, vector, payload).await
|
||||
}
|
||||
|
||||
async fn search(&self, query_vector: &[f32], limit: usize) -> Result<Vec<SearchResult>> {
|
||||
// Simple cosine similarity search (placeholder)
|
||||
let cache = self.cache.read().await;
|
||||
let mut results: Vec<SearchResult> = Vec::new();
|
||||
|
||||
for (chunk_id, vector) in &cache.vectors {
|
||||
let similarity = cosine_similarity(query_vector, vector);
|
||||
results.push(SearchResult {
|
||||
chunk_id: chunk_id.clone(),
|
||||
score: similarity,
|
||||
});
|
||||
}
|
||||
|
||||
results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
|
||||
results.truncate(limit);
|
||||
|
||||
Ok(results)
|
||||
self.search(query_vector, limit).await
|
||||
}
|
||||
}
|
||||
|
||||
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
|
||||
if a.len() != b.len() || a.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
|
||||
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
|
||||
if norm_a == 0.0 || norm_b == 0.0 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
dot_product / (norm_a * norm_b)
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ use tokio::sync::RwLock;
|
||||
use super::Database;
|
||||
|
||||
pub struct RedisDb {
|
||||
#[allow(dead_code)]
|
||||
state: Arc<RwLock<RedisState>>,
|
||||
}
|
||||
|
||||
|
||||
155
src/core/db/sync_db.rs
Normal file
155
src/core/db/sync_db.rs
Normal file
@@ -0,0 +1,155 @@
|
||||
use anyhow::{Context, Result};
|
||||
use serde_json::json;
|
||||
|
||||
use crate::core::chunk::types::{Chunk, ChunkRule, ChunkType};
|
||||
use crate::core::db::mongodb_db::MongoDb;
|
||||
use crate::core::db::postgres_db::PostgresDb;
|
||||
use crate::core::db::qdrant_db::{QdrantDb, VectorPayload};
|
||||
use crate::core::processor::asr::{AsrResult, AsrSegment};
|
||||
|
||||
pub struct SyncDb {
|
||||
postgres: PostgresDb,
|
||||
mongodb: MongoDb,
|
||||
qdrant: QdrantDb,
|
||||
}
|
||||
|
||||
impl SyncDb {
|
||||
pub async fn new(postgres: PostgresDb, mongodb: MongoDb, qdrant: QdrantDb) -> Result<Self> {
|
||||
Ok(Self {
|
||||
postgres,
|
||||
mongodb,
|
||||
qdrant,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn store_chunk_with_vector(&self, mut chunk: Chunk, text: &str) -> Result<Chunk> {
|
||||
let uuid = chunk.uuid.clone();
|
||||
let chunk_id = chunk.chunk_id.clone();
|
||||
let chunk_type = chunk.chunk_type.as_str().to_string();
|
||||
let start_time = chunk.start_time;
|
||||
let end_time = chunk.end_time;
|
||||
|
||||
let vector = self.embed_text(text).await?;
|
||||
|
||||
let vector_id = format!("vec_{}", chunk_id);
|
||||
chunk = chunk.with_vector_id(vector_id.clone());
|
||||
|
||||
let postgres_result = self.postgres.store_chunk(&chunk).await;
|
||||
if let Err(e) = &postgres_result {
|
||||
tracing::warn!("Failed to store chunk in PostgreSQL: {}", e);
|
||||
}
|
||||
|
||||
let mongo_result = self.mongodb.store_chunk(&chunk).await;
|
||||
if let Err(e) = &mongo_result {
|
||||
tracing::warn!("Failed to store chunk in MongoDB: {}", e);
|
||||
}
|
||||
|
||||
let payload = VectorPayload {
|
||||
uuid: uuid.clone(),
|
||||
chunk_id: chunk_id.clone(),
|
||||
chunk_type,
|
||||
start_time,
|
||||
end_time,
|
||||
text: Some(text.to_string()),
|
||||
};
|
||||
|
||||
let qdrant_result = self
|
||||
.qdrant
|
||||
.upsert_vector(&vector_id, &vector, payload)
|
||||
.await;
|
||||
if let Err(e) = &qdrant_result {
|
||||
tracing::warn!("Failed to store vector in Qdrant: {}", e);
|
||||
}
|
||||
|
||||
let pg_vector_result = self.postgres.store_vector(&vector_id, &vector, &uuid).await;
|
||||
if let Err(e) = &pg_vector_result {
|
||||
tracing::warn!("Failed to store vector in PostgreSQL: {}", e);
|
||||
}
|
||||
|
||||
postgres_result?;
|
||||
mongo_result?;
|
||||
qdrant_result?;
|
||||
|
||||
Ok(chunk)
|
||||
}
|
||||
|
||||
pub async fn embed_text(&self, text: &str) -> Result<Vec<f32>> {
|
||||
let client = reqwest::Client::new();
|
||||
let response = client
|
||||
.post("http://localhost:11434/api/embeddings")
|
||||
.json(&json!({
|
||||
"model": "nomic-embed-text",
|
||||
"prompt": text
|
||||
}))
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to call Ollama embedding API")?;
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct EmbeddingResponse {
|
||||
embedding: Vec<f32>,
|
||||
}
|
||||
|
||||
let embedding = response
|
||||
.json::<EmbeddingResponse>()
|
||||
.await
|
||||
.context("Failed to parse embedding response")?;
|
||||
|
||||
Ok(embedding.embedding)
|
||||
}
|
||||
|
||||
pub async fn process_asr_to_chunks(
|
||||
&self,
|
||||
uuid: &str,
|
||||
asr_result: &AsrResult,
|
||||
) -> Result<Vec<Chunk>> {
|
||||
let mut chunks = Vec::new();
|
||||
|
||||
for (i, segment) in asr_result.segments.iter().enumerate() {
|
||||
let segment: &AsrSegment = segment;
|
||||
let content = json!({
|
||||
"text": segment.text,
|
||||
"text_normalized": segment.text.to_lowercase(),
|
||||
});
|
||||
|
||||
let metadata = json!({
|
||||
"language": asr_result.language,
|
||||
"language_probability": asr_result.language_probability,
|
||||
});
|
||||
|
||||
let chunk = Chunk::new(
|
||||
0, // file_id - will be set later
|
||||
uuid.to_string(),
|
||||
i as u32,
|
||||
ChunkType::Sentence,
|
||||
ChunkRule::Rule1,
|
||||
segment.start,
|
||||
segment.end,
|
||||
24.0, // fps
|
||||
content,
|
||||
)
|
||||
.with_metadata(metadata);
|
||||
|
||||
chunks.push(chunk);
|
||||
}
|
||||
|
||||
let mut stored_chunks = Vec::new();
|
||||
for chunk in chunks {
|
||||
let text = chunk
|
||||
.content
|
||||
.get("text")
|
||||
.and_then(|t| t.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
|
||||
match self.store_chunk_with_vector(chunk, &text).await {
|
||||
Ok(stored) => stored_chunks.push(stored),
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to store chunk: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(stored_chunks)
|
||||
}
|
||||
}
|
||||
@@ -1,66 +1,80 @@
|
||||
use anyhow::Result;
|
||||
use anyhow::{Context, Result};
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub struct Embedder {
|
||||
model_path: String,
|
||||
model: String,
|
||||
client: Client,
|
||||
base_url: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct EmbedRequest {
|
||||
model: String,
|
||||
prompt: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
struct EmbedResponse {
|
||||
embedding: Vec<f32>,
|
||||
}
|
||||
|
||||
impl Embedder {
|
||||
pub fn new(model_path: String) -> Self {
|
||||
Self { model_path }
|
||||
pub fn new(model: String) -> Self {
|
||||
Self {
|
||||
model,
|
||||
client: Client::new(),
|
||||
base_url: "http://localhost:11434".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn embed_text(&self, text: &str) -> Result<Vec<f32>> {
|
||||
// TODO: Implement comic-embed-text model loading and inference
|
||||
// This is a placeholder that generates a random 768-dimensional vector
|
||||
//
|
||||
// Implementation would use:
|
||||
// - candle (Rust ML framework) or
|
||||
// - ort (ONNX Runtime) to run the model
|
||||
//
|
||||
// Example with ort:
|
||||
// let session = Session::builder()?
|
||||
// .with_execution_providers([CPUExecutionProvider::default().build()])?
|
||||
// .with_model_from_file(&self.model_path)?;
|
||||
//
|
||||
// // Preprocess text to tensor
|
||||
// let input = preprocess_text(text);
|
||||
//
|
||||
// // Run inference
|
||||
// let output = session.run(vec![input])?;
|
||||
//
|
||||
// // Extract embeddings
|
||||
// let embedding = output[0].view()[..768].to_vec();
|
||||
self.embed_with_prefix(text, "").await
|
||||
}
|
||||
|
||||
let dim = 768;
|
||||
let mut embedding = vec![0.0f32; dim];
|
||||
pub async fn embed_document(&self, text: &str) -> Result<Vec<f32>> {
|
||||
self.embed_with_prefix(text, "search_document: ").await
|
||||
}
|
||||
|
||||
// Simple hash-based embedding for now
|
||||
let hash = self.hash_text(text);
|
||||
for i in 0..dim {
|
||||
embedding[i] = ((hash >> i) & 1) as f32;
|
||||
pub async fn embed_query(&self, text: &str) -> Result<Vec<f32>> {
|
||||
self.embed_with_prefix(text, "search_query: ").await
|
||||
}
|
||||
|
||||
async fn embed_with_prefix(&self, text: &str, prefix: &str) -> Result<Vec<f32>> {
|
||||
let url = format!("{}/api/embeddings", self.base_url);
|
||||
let prompt = format!("{}{}", prefix, text);
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(&url)
|
||||
.json(&EmbedRequest {
|
||||
model: self.model.clone(),
|
||||
prompt,
|
||||
})
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to send embedding request to Ollama")?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let body = response.text().await.unwrap_or_default();
|
||||
anyhow::bail!("Ollama API error ({}): {}", status, body);
|
||||
}
|
||||
|
||||
// Normalize
|
||||
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
if norm > 0.0 {
|
||||
for v in &mut embedding {
|
||||
*v /= norm;
|
||||
}
|
||||
}
|
||||
let result: EmbedResponse = response
|
||||
.json()
|
||||
.await
|
||||
.context("Failed to parse Ollama response")?;
|
||||
|
||||
Ok(embedding)
|
||||
Ok(result.embedding)
|
||||
}
|
||||
|
||||
pub async fn embed_chunk_content(&self, chunk: &crate::core::chunk::Chunk) -> Result<Vec<f32>> {
|
||||
let text = serde_json::to_string(&chunk.content)?;
|
||||
self.embed_text(&text).await
|
||||
self.embed_document(&text).await
|
||||
}
|
||||
|
||||
fn hash_text(&self, text: &str) -> u64 {
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
use std::hash::{Hash, Hasher};
|
||||
let mut hasher = DefaultHasher::new();
|
||||
text.hash(&mut hasher);
|
||||
hasher.finish()
|
||||
pub fn dimension(&self) -> usize {
|
||||
768
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
pub mod api_key;
|
||||
pub mod cache;
|
||||
pub mod chunk;
|
||||
pub mod config;
|
||||
pub mod db;
|
||||
pub mod embedding;
|
||||
pub mod overlay;
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
use anyhow::{Context, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::Path;
|
||||
use std::process::Command;
|
||||
use std::time::Duration;
|
||||
|
||||
use super::executor::PythonExecutor;
|
||||
|
||||
const ASR_TIMEOUT: Duration = Duration::from_secs(3600);
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct AsrResult {
|
||||
@@ -17,53 +20,33 @@ pub struct AsrSegment {
|
||||
pub text: String,
|
||||
}
|
||||
|
||||
pub async fn process_asr(video_path: &str, output_path: &str) -> Result<AsrResult> {
|
||||
let script_path = Path::new(env!("CARGO_MANIFEST_DIR"))
|
||||
.join("scripts")
|
||||
.join("asr_processor.py");
|
||||
pub async fn process_asr(
|
||||
video_path: &str,
|
||||
output_path: &str,
|
||||
uuid: Option<&str>,
|
||||
) -> Result<AsrResult> {
|
||||
let executor = PythonExecutor::new()?;
|
||||
let script_path = executor.script_path("asr_processor.py");
|
||||
|
||||
let venv_python = Path::new(env!("CARGO_MANIFEST_DIR"))
|
||||
.join("venv")
|
||||
.join("bin")
|
||||
.join("python");
|
||||
tracing::info!("[ASR] Starting ASR processing: {}", video_path);
|
||||
|
||||
println!("[ASR] Starting ASR processing...");
|
||||
println!("[ASR] Video: {}", video_path);
|
||||
|
||||
let output = Command::new(venv_python)
|
||||
.arg(script_path)
|
||||
.arg(video_path)
|
||||
.arg(output_path)
|
||||
.output()
|
||||
.context("Failed to run ASR processor")?;
|
||||
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
|
||||
for line in stderr.lines() {
|
||||
if line.starts_with("ASR_START") {
|
||||
println!("[ASR] Loading model...");
|
||||
} else if line.starts_with("ASR_LANGUAGE:") {
|
||||
let lang = line.trim_start_matches("ASR_LANGUAGE:");
|
||||
println!("[ASR] Detected language: {}", lang);
|
||||
} else if line.starts_with("ASR_PROGRESS:") {
|
||||
let count = line.trim_start_matches("ASR_PROGRESS:");
|
||||
println!("[ASR] Processed {} segments...", count);
|
||||
} else if line.starts_with("ASR_COMPLETE:") {
|
||||
let count = line.trim_start_matches("ASR_COMPLETE:");
|
||||
println!("[ASR] Completed! Total: {} segments", count);
|
||||
}
|
||||
}
|
||||
|
||||
if !output.status.success() {
|
||||
anyhow::bail!("ASR failed: {}", stderr);
|
||||
}
|
||||
executor
|
||||
.run(
|
||||
"asr_processor.py",
|
||||
&[video_path, output_path],
|
||||
uuid,
|
||||
"ASR",
|
||||
Some(ASR_TIMEOUT),
|
||||
)
|
||||
.await
|
||||
.with_context(|| format!("Failed to run {:?}", script_path))?;
|
||||
|
||||
let json_str = std::fs::read_to_string(output_path).context("Failed to read ASR output")?;
|
||||
|
||||
let result: AsrResult =
|
||||
serde_json::from_str(&json_str).context("Failed to parse ASR output")?;
|
||||
|
||||
println!(
|
||||
tracing::info!(
|
||||
"[ASR] Result: {} segments, language: {:?}",
|
||||
result.segments.len(),
|
||||
result.language
|
||||
@@ -71,3 +54,72 @@ pub async fn process_asr(video_path: &str, output_path: &str) -> Result<AsrResul
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_asr_result_serialization() {
|
||||
let result = AsrResult {
|
||||
language: Some("en".to_string()),
|
||||
language_probability: Some(0.95),
|
||||
segments: vec![
|
||||
AsrSegment {
|
||||
start: 0.0,
|
||||
end: 2.5,
|
||||
text: "Hello world".to_string(),
|
||||
},
|
||||
AsrSegment {
|
||||
start: 2.5,
|
||||
end: 5.0,
|
||||
text: "Test speech".to_string(),
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&result).unwrap();
|
||||
assert!(json.contains("Hello world"));
|
||||
assert!(json.contains("en"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_asr_result_deserialization() {
|
||||
let json = r#"{
|
||||
"language": "zh",
|
||||
"language_probability": 0.98,
|
||||
"segments": [
|
||||
{"start": 0.0, "end": 1.5, "text": "測試"}
|
||||
]
|
||||
}"#;
|
||||
|
||||
let result: AsrResult = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(result.language, Some("zh".to_string()));
|
||||
assert_eq!(result.language_probability, Some(0.98));
|
||||
assert_eq!(result.segments.len(), 1);
|
||||
assert_eq!(result.segments[0].text, "測試");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_asr_segment_default() {
|
||||
let segment = AsrSegment {
|
||||
start: 0.0,
|
||||
end: 1.0,
|
||||
text: String::new(),
|
||||
};
|
||||
assert_eq!(segment.start, 0.0);
|
||||
assert_eq!(segment.end, 1.0);
|
||||
assert!(segment.text.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_asr_result_empty_segments() {
|
||||
let result = AsrResult {
|
||||
language: None,
|
||||
language_probability: None,
|
||||
segments: vec![],
|
||||
};
|
||||
assert!(result.language.is_none());
|
||||
assert!(result.segments.is_empty());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,16 @@
|
||||
use anyhow::Result;
|
||||
use anyhow::{Context, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::time::Duration;
|
||||
use tokio::process::Command;
|
||||
use tokio::time::timeout;
|
||||
|
||||
use super::executor::PythonExecutor;
|
||||
|
||||
const ASRX_TIMEOUT: Duration = Duration::from_secs(7200);
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct AsrxResult {
|
||||
pub language: Option<String>,
|
||||
pub segments: Vec<AsrxSegment>,
|
||||
}
|
||||
|
||||
@@ -11,18 +19,130 @@ pub struct AsrxSegment {
|
||||
pub start: f64,
|
||||
pub end: f64,
|
||||
pub text: String,
|
||||
pub speaker_id: String,
|
||||
pub speaker_embedding: Option<Vec<f32>>,
|
||||
pub speaker_id: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn process_asrx(video_path: &str, output_path: &str) -> Result<AsrxResult> {
|
||||
// TODO: Implement speaker diarization
|
||||
// Options:
|
||||
// 1. Use pyannote.audio
|
||||
// 2. Use whisperx
|
||||
// 3. Use Python subprocess
|
||||
pub async fn process_asrx(
|
||||
video_path: &str,
|
||||
output_path: &str,
|
||||
uuid: Option<&str>,
|
||||
) -> Result<AsrxResult> {
|
||||
let executor = PythonExecutor::new()?;
|
||||
let script_path = executor.script_path("asrx_processor.py");
|
||||
|
||||
println!("Processing speaker diarization for: {}", video_path);
|
||||
tracing::info!("[ASRX] Starting speaker diarization: {}", video_path);
|
||||
|
||||
Ok(AsrxResult { segments: vec![] })
|
||||
if !script_path.exists() {
|
||||
tracing::warn!("[ASRX] Script not found, returning empty result");
|
||||
return Ok(AsrxResult {
|
||||
language: None,
|
||||
segments: vec![],
|
||||
});
|
||||
}
|
||||
|
||||
let mut cmd = Command::new(executor.python_path());
|
||||
cmd.arg(&script_path).arg(video_path).arg(output_path);
|
||||
|
||||
if let Some(u) = uuid {
|
||||
cmd.arg("--uuid").arg(u);
|
||||
}
|
||||
|
||||
cmd.stdout(std::process::Stdio::piped())
|
||||
.stderr(std::process::Stdio::piped());
|
||||
|
||||
let child = cmd.spawn().context("Failed to run ASRX processor")?;
|
||||
|
||||
let output = match timeout(ASRX_TIMEOUT, child.wait_with_output()).await {
|
||||
Ok(Ok(output)) => output,
|
||||
Ok(Err(e)) => return Err(e).context("Failed to run ASRX processor"),
|
||||
Err(_) => anyhow::bail!("ASRX processing timed out after {:?}", ASRX_TIMEOUT),
|
||||
};
|
||||
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
|
||||
for line in stderr.lines() {
|
||||
if line.starts_with("ASRX_START") {
|
||||
tracing::info!("[ASRX] Loading model...");
|
||||
} else if line.starts_with("ASRX_PROGRESS:") {
|
||||
let count = line.trim_start_matches("ASRX_PROGRESS:");
|
||||
tracing::info!("[ASRX] Processed {} segments...", count);
|
||||
} else if line.starts_with("ASRX_COMPLETE:") {
|
||||
let count = line.trim_start_matches("ASRX_COMPLETE:");
|
||||
tracing::info!("[ASRX] Completed! Total: {} segments", count);
|
||||
}
|
||||
}
|
||||
|
||||
if !output.status.success() {
|
||||
anyhow::bail!("ASRX failed: {}", stderr);
|
||||
}
|
||||
|
||||
let json_str = std::fs::read_to_string(output_path).context("Failed to read ASRX output")?;
|
||||
|
||||
let result: AsrxResult =
|
||||
serde_json::from_str(&json_str).context("Failed to parse ASRX output")?;
|
||||
|
||||
tracing::info!("[ASRX] Result: {} segments", result.segments.len());
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_asrx_result_serialization() {
|
||||
let result = AsrxResult {
|
||||
language: Some("en".to_string()),
|
||||
segments: vec![AsrxSegment {
|
||||
start: 0.0,
|
||||
end: 2.5,
|
||||
text: "Hello".to_string(),
|
||||
speaker_id: Some("SPEAKER_00".to_string()),
|
||||
}],
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&result).unwrap();
|
||||
assert!(json.contains("Hello"));
|
||||
assert!(json.contains("SPEAKER_00"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_asrx_result_deserialization() {
|
||||
let json = r#"{
|
||||
"language": "zh",
|
||||
"segments": [
|
||||
{"start": 0.0, "end": 1.5, "text": "測試", "speaker_id": "SPEAKER_01"}
|
||||
]
|
||||
}"#;
|
||||
|
||||
let result: AsrxResult = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(result.language, Some("zh".to_string()));
|
||||
assert_eq!(result.segments.len(), 1);
|
||||
assert_eq!(
|
||||
result.segments[0].speaker_id,
|
||||
Some("SPEAKER_01".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_asrx_result_empty_segments() {
|
||||
let result = AsrxResult {
|
||||
language: None,
|
||||
segments: vec![],
|
||||
};
|
||||
assert!(result.segments.is_empty());
|
||||
assert!(result.language.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_asrx_segment_times() {
|
||||
let segment = AsrxSegment {
|
||||
start: 0.0,
|
||||
end: 5.0,
|
||||
text: "Test".to_string(),
|
||||
speaker_id: None,
|
||||
};
|
||||
assert!(segment.end > segment.start);
|
||||
}
|
||||
}
|
||||
|
||||
77
src/core/processor/caption.rs
Normal file
77
src/core/processor/caption.rs
Normal file
@@ -0,0 +1,77 @@
|
||||
use anyhow::{Context, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::time::Duration;
|
||||
|
||||
use super::executor::PythonExecutor;
|
||||
|
||||
const CAPTION_TIMEOUT: Duration = Duration::from_secs(7200);
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct CaptionResult {
|
||||
pub video_path: String,
|
||||
pub total_frames: usize,
|
||||
pub captions: Vec<FrameCaption>,
|
||||
pub summary: CaptionSummary,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct FrameCaption {
|
||||
pub index: u32,
|
||||
pub timestamp: f64,
|
||||
pub caption: String,
|
||||
pub source: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct CaptionSummary {
|
||||
pub avg_caption_length: f64,
|
||||
pub gpt4v_count: usize,
|
||||
pub llava_count: usize,
|
||||
pub metadata_count: usize,
|
||||
}
|
||||
|
||||
pub async fn process_caption(
|
||||
video_path: &str,
|
||||
output_path: &str,
|
||||
uuid: Option<&str>,
|
||||
) -> Result<CaptionResult> {
|
||||
let executor = PythonExecutor::new()?;
|
||||
let script_path = executor.script_path("caption_processor.py");
|
||||
|
||||
tracing::info!("[CAPTION] Starting caption generation: {}", video_path);
|
||||
|
||||
if !script_path.exists() {
|
||||
tracing::warn!("[CAPTION] Script not found, returning empty result");
|
||||
return Ok(CaptionResult {
|
||||
video_path: video_path.to_string(),
|
||||
total_frames: 0,
|
||||
captions: vec![],
|
||||
summary: CaptionSummary {
|
||||
avg_caption_length: 0.0,
|
||||
gpt4v_count: 0,
|
||||
llava_count: 0,
|
||||
metadata_count: 0,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
executor
|
||||
.run(
|
||||
"caption_processor.py",
|
||||
&[video_path, output_path],
|
||||
uuid,
|
||||
"CAPTION",
|
||||
Some(CAPTION_TIMEOUT),
|
||||
)
|
||||
.await
|
||||
.with_context(|| format!("Failed to run {:?}", script_path))?;
|
||||
|
||||
let json_str = std::fs::read_to_string(output_path).context("Failed to read CAPTION output")?;
|
||||
|
||||
let result: CaptionResult =
|
||||
serde_json::from_str(&json_str).context("Failed to parse CAPTION output")?;
|
||||
|
||||
tracing::info!("[CAPTION] Result: {} frames captioned", result.total_frames);
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
127
src/core/processor/cut.rs
Normal file
127
src/core/processor/cut.rs
Normal file
@@ -0,0 +1,127 @@
|
||||
use anyhow::{Context, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::time::Duration;
|
||||
|
||||
use super::executor::PythonExecutor;
|
||||
|
||||
const CUT_TIMEOUT: Duration = Duration::from_secs(3600);
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct CutResult {
|
||||
pub frame_count: u64,
|
||||
pub fps: f64,
|
||||
pub scenes: Vec<CutScene>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct CutScene {
|
||||
pub scene_number: u32,
|
||||
pub start_frame: u64,
|
||||
pub end_frame: u64,
|
||||
pub start_time: f64,
|
||||
pub end_time: f64,
|
||||
}
|
||||
|
||||
pub async fn process_cut(
|
||||
video_path: &str,
|
||||
output_path: &str,
|
||||
uuid: Option<&str>,
|
||||
) -> Result<CutResult> {
|
||||
let executor = PythonExecutor::new()?;
|
||||
let script_path = executor.script_path("cut_processor.py");
|
||||
|
||||
tracing::info!("[CUT] Starting scene detection: {}", video_path);
|
||||
|
||||
if !script_path.exists() {
|
||||
tracing::warn!("[CUT] Script not found, returning empty result");
|
||||
return Ok(CutResult {
|
||||
frame_count: 0,
|
||||
fps: 0.0,
|
||||
scenes: vec![],
|
||||
});
|
||||
}
|
||||
|
||||
executor
|
||||
.run(
|
||||
"cut_processor.py",
|
||||
&[video_path, output_path],
|
||||
uuid,
|
||||
"CUT",
|
||||
Some(CUT_TIMEOUT),
|
||||
)
|
||||
.await
|
||||
.with_context(|| format!("Failed to run {:?}", script_path))?;
|
||||
|
||||
let json_str = std::fs::read_to_string(output_path).context("Failed to read CUT output")?;
|
||||
|
||||
let result: CutResult =
|
||||
serde_json::from_str(&json_str).context("Failed to parse CUT output")?;
|
||||
|
||||
tracing::info!("[CUT] Result: {} scenes detected", result.scenes.len());
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_cut_result_serialization() {
|
||||
let result = CutResult {
|
||||
frame_count: 100,
|
||||
fps: 30.0,
|
||||
scenes: vec![CutScene {
|
||||
scene_number: 1,
|
||||
start_frame: 0,
|
||||
end_frame: 30,
|
||||
start_time: 0.0,
|
||||
end_time: 1.0,
|
||||
}],
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&result).unwrap();
|
||||
assert!(json.contains("scene_number"));
|
||||
assert!(json.contains("1"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cut_result_deserialization() {
|
||||
let json = r#"{
|
||||
"frame_count": 100,
|
||||
"fps": 30.0,
|
||||
"scenes": [
|
||||
{"scene_number": 1, "start_frame": 0, "end_frame": 30, "start_time": 0.0, "end_time": 1.0},
|
||||
{"scene_number": 2, "start_frame": 31, "end_frame": 60, "start_time": 1.033, "end_time": 2.0}
|
||||
]
|
||||
}"#;
|
||||
|
||||
let result: CutResult = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(result.frame_count, 100);
|
||||
assert_eq!(result.scenes.len(), 2);
|
||||
assert_eq!(result.scenes[1].scene_number, 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cut_result_empty_scenes() {
|
||||
let result = CutResult {
|
||||
frame_count: 0,
|
||||
fps: 0.0,
|
||||
scenes: vec![],
|
||||
};
|
||||
assert!(result.scenes.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cut_scene_times() {
|
||||
let scene = CutScene {
|
||||
scene_number: 1,
|
||||
start_frame: 0,
|
||||
end_frame: 30,
|
||||
start_time: 0.0,
|
||||
end_time: 1.0,
|
||||
};
|
||||
assert!(scene.end_time > scene.start_time);
|
||||
assert_eq!(scene.scene_number, 1);
|
||||
}
|
||||
}
|
||||
395
src/core/processor/executor.rs
Normal file
395
src/core/processor/executor.rs
Normal file
@@ -0,0 +1,395 @@
|
||||
use anyhow::{Context, Result};
|
||||
use std::path::PathBuf;
|
||||
use std::process::Stdio;
|
||||
use std::time::Duration;
|
||||
use tokio::io::{AsyncBufReadExt, BufReader};
|
||||
use tokio::process::Command;
|
||||
use tokio::time::{sleep, timeout};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RetryConfig {
|
||||
pub max_attempts: u32,
|
||||
pub initial_delay_ms: u64,
|
||||
pub max_delay_ms: u64,
|
||||
pub backoff_multiplier: f64,
|
||||
}
|
||||
|
||||
impl Default for RetryConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_attempts: 3,
|
||||
initial_delay_ms: 1000,
|
||||
max_delay_ms: 30000,
|
||||
backoff_multiplier: 2.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RetryConfig {
|
||||
pub fn new(max_attempts: u32) -> Self {
|
||||
Self {
|
||||
max_attempts,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_delay(mut self, delay_ms: u64) -> Self {
|
||||
self.initial_delay_ms = delay_ms;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_max_delay(mut self, max_delay_ms: u64) -> Self {
|
||||
self.max_delay_ms = max_delay_ms;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
pub fn validate_python_env() -> Result<()> {
|
||||
let manifest = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
||||
let venv_python = manifest.join("venv").join("bin").join("python");
|
||||
|
||||
if !venv_python.exists() {
|
||||
anyhow::bail!(
|
||||
"Python venv not found at {:?}\n\
|
||||
Run: /opt/homebrew/bin/python3.11 -m venv venv",
|
||||
venv_python
|
||||
);
|
||||
}
|
||||
|
||||
let rt = tokio::runtime::Runtime::new()?;
|
||||
let output = rt
|
||||
.block_on(async { Command::new(&venv_python).arg("--version").output().await })
|
||||
.context("Failed to run Python")?;
|
||||
|
||||
if !output.status.success() {
|
||||
anyhow::bail!("Python validation failed");
|
||||
}
|
||||
|
||||
let version = String::from_utf8_lossy(&output.stdout);
|
||||
tracing::info!("Python version: {}", version.trim());
|
||||
|
||||
if !version.contains("3.11") {
|
||||
tracing::warn!("Expected Python 3.11, got: {}", version.trim());
|
||||
}
|
||||
|
||||
let script_path = manifest.join("scripts");
|
||||
if !script_path.exists() {
|
||||
anyhow::bail!("Scripts directory not found at {:?}", script_path);
|
||||
}
|
||||
|
||||
tracing::info!("Python environment validated successfully");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub struct PythonExecutor {
|
||||
venv_python: PathBuf,
|
||||
scripts_dir: PathBuf,
|
||||
}
|
||||
|
||||
impl PythonExecutor {
|
||||
pub fn new() -> Result<Self> {
|
||||
let manifest = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
||||
let venv_python = manifest.join("venv").join("bin").join("python");
|
||||
let scripts_dir = manifest.join("scripts");
|
||||
|
||||
if !venv_python.exists() {
|
||||
anyhow::bail!(
|
||||
"Python venv not found at {:?}. Run: /opt/homebrew/bin/python3.11 -m venv venv",
|
||||
venv_python
|
||||
);
|
||||
}
|
||||
|
||||
if !scripts_dir.exists() {
|
||||
anyhow::bail!("Scripts directory not found at {:?}", scripts_dir);
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
venv_python,
|
||||
scripts_dir,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn validate_env(&self) -> Result<()> {
|
||||
let rt = tokio::runtime::Runtime::new()?;
|
||||
let output = rt
|
||||
.block_on(async {
|
||||
Command::new(&self.venv_python)
|
||||
.arg("--version")
|
||||
.output()
|
||||
.await
|
||||
})
|
||||
.context("Failed to run Python")?;
|
||||
|
||||
if !output.status.success() {
|
||||
anyhow::bail!("Python validation failed");
|
||||
}
|
||||
|
||||
let version = String::from_utf8_lossy(&output.stdout);
|
||||
if !version.contains("3.11") {
|
||||
tracing::warn!("Python version mismatch: {}", version);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn run(
|
||||
&self,
|
||||
script_name: &str,
|
||||
args: &[&str],
|
||||
uuid: Option<&str>,
|
||||
log_prefix: &str,
|
||||
timeout_duration: Option<Duration>,
|
||||
) -> Result<()> {
|
||||
let script_path = self.scripts_dir.join(script_name);
|
||||
|
||||
if !script_path.exists() {
|
||||
anyhow::bail!("Script not found: {:?}", script_path);
|
||||
}
|
||||
|
||||
let mut cmd = Command::new(&self.venv_python);
|
||||
cmd.arg(&script_path);
|
||||
|
||||
for arg in args {
|
||||
cmd.arg(arg);
|
||||
}
|
||||
|
||||
if let Some(u) = uuid {
|
||||
cmd.arg("--uuid").arg(u);
|
||||
}
|
||||
|
||||
cmd.stdout(Stdio::piped());
|
||||
cmd.stderr(Stdio::piped());
|
||||
|
||||
tracing::info!("[{}] Starting: {:?}", log_prefix, script_name);
|
||||
|
||||
let mut child = cmd
|
||||
.spawn()
|
||||
.with_context(|| format!("Failed to run {}", script_name))?;
|
||||
|
||||
let stdout = child.stdout.take().context("Failed to capture stdout")?;
|
||||
let stderr = child.stderr.take().context("Failed to capture stderr")?;
|
||||
|
||||
let mut stdout_reader = BufReader::new(stdout).lines();
|
||||
let mut stderr_reader = BufReader::new(stderr).lines();
|
||||
|
||||
let run_future = async {
|
||||
loop {
|
||||
tokio::select! {
|
||||
line = stdout_reader.next_line() => {
|
||||
match line {
|
||||
Ok(Some(line)) => {
|
||||
if line.starts_with(&format!("{}_", log_prefix)) {
|
||||
tracing::info!("[{}] {}", log_prefix, line);
|
||||
}
|
||||
}
|
||||
Ok(None) => break,
|
||||
Err(e) => tracing::warn!("[{}] stdout error: {}", log_prefix, e),
|
||||
}
|
||||
}
|
||||
line = stderr_reader.next_line() => {
|
||||
match line {
|
||||
Ok(Some(line)) => {
|
||||
if line.starts_with(&format!("{}_", log_prefix)) {
|
||||
tracing::info!("[{}] {}", log_prefix, line);
|
||||
}
|
||||
}
|
||||
Ok(None) => {}
|
||||
Err(e) => tracing::warn!("[{}] stderr error: {}", log_prefix, e),
|
||||
}
|
||||
}
|
||||
status = child.wait() => {
|
||||
match status {
|
||||
Ok(status) => {
|
||||
if !status.success() {
|
||||
tracing::error!("[{}] Process failed: {}", log_prefix, status);
|
||||
return Err(anyhow::anyhow!("{} exited with: {}", script_name, status));
|
||||
}
|
||||
tracing::info!("[{}] Completed successfully", log_prefix);
|
||||
}
|
||||
Err(e) => tracing::error!("[{}] wait error: {}", log_prefix, e),
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
};
|
||||
|
||||
if let Some(duration) = timeout_duration {
|
||||
match timeout(duration, run_future).await {
|
||||
Ok(Ok(())) => {}
|
||||
Ok(Err(e)) => return Err(e),
|
||||
Err(_) => {
|
||||
child.kill().await.context("Failed to kill process")?;
|
||||
anyhow::bail!("{} timed out after {:?}", script_name, duration);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
run_future.await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn run_with_output(
|
||||
&self,
|
||||
script_name: &str,
|
||||
args: &[&str],
|
||||
uuid: Option<&str>,
|
||||
log_prefix: &str,
|
||||
timeout_duration: Option<Duration>,
|
||||
) -> Result<()> {
|
||||
self.run(script_name, args, uuid, log_prefix, timeout_duration)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn run_with_retry(
|
||||
&self,
|
||||
script_name: &str,
|
||||
args: &[&str],
|
||||
uuid: Option<&str>,
|
||||
log_prefix: &str,
|
||||
timeout_duration: Option<Duration>,
|
||||
retry_config: Option<RetryConfig>,
|
||||
) -> Result<()> {
|
||||
let config = retry_config.unwrap_or_default();
|
||||
let mut attempt = 0;
|
||||
let mut delay_ms = config.initial_delay_ms;
|
||||
|
||||
loop {
|
||||
attempt += 1;
|
||||
|
||||
match self
|
||||
.run(script_name, args, uuid, log_prefix, timeout_duration)
|
||||
.await
|
||||
{
|
||||
Ok(()) => {
|
||||
if attempt > 1 {
|
||||
tracing::info!(
|
||||
"[{}] Succeeded on attempt {}/{}",
|
||||
log_prefix,
|
||||
attempt,
|
||||
config.max_attempts
|
||||
);
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
Err(e) => {
|
||||
if attempt >= config.max_attempts {
|
||||
tracing::error!(
|
||||
"[{}] Failed after {} attempts: {}",
|
||||
log_prefix,
|
||||
attempt,
|
||||
e
|
||||
);
|
||||
return Err(e);
|
||||
}
|
||||
|
||||
tracing::warn!(
|
||||
"[{}] Attempt {}/{} failed: {}. Retrying in {}ms...",
|
||||
log_prefix,
|
||||
attempt,
|
||||
config.max_attempts,
|
||||
e,
|
||||
delay_ms
|
||||
);
|
||||
|
||||
sleep(Duration::from_millis(delay_ms)).await;
|
||||
|
||||
delay_ms = (delay_ms as f64 * config.backoff_multiplier) as u64;
|
||||
delay_ms = delay_ms.min(config.max_delay_ms);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn script_path(&self, script_name: &str) -> PathBuf {
|
||||
self.scripts_dir.join(script_name)
|
||||
}
|
||||
|
||||
pub fn python_path(&self) -> &PathBuf {
|
||||
&self.venv_python
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for PythonExecutor {
|
||||
fn default() -> Self {
|
||||
Self::new().expect("Failed to create PythonExecutor")
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_python_executor_new_with_venv() {
|
||||
let executor = PythonExecutor::new();
|
||||
assert!(
|
||||
executor.is_ok(),
|
||||
"PythonExecutor should create successfully with venv"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_python_executor_paths() {
|
||||
let executor = PythonExecutor::new().unwrap();
|
||||
let python_path = executor.python_path();
|
||||
assert!(
|
||||
python_path.exists(),
|
||||
"Python path should exist: {:?}",
|
||||
python_path
|
||||
);
|
||||
assert!(
|
||||
python_path.to_string_lossy().contains("venv"),
|
||||
"Should be in venv"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_script_path() {
|
||||
let executor = PythonExecutor::new().unwrap();
|
||||
let script_path = executor.script_path("asr_processor.py");
|
||||
assert!(script_path.to_string_lossy().contains("scripts"));
|
||||
assert!(script_path.to_string_lossy().contains("asr_processor.py"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_script_path_nonexistent() {
|
||||
let executor = PythonExecutor::new().unwrap();
|
||||
let path = executor.script_path("nonexistent_script.py");
|
||||
assert!(!path.exists(), "Nonexistent script path should not exist");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_python_path_is_executable() {
|
||||
let executor = PythonExecutor::new().unwrap();
|
||||
let path = executor.python_path();
|
||||
let metadata = std::fs::metadata(path);
|
||||
assert!(metadata.is_ok(), "Python path should be accessible");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_retry_config_default() {
|
||||
let config = RetryConfig::default();
|
||||
assert_eq!(config.max_attempts, 3);
|
||||
assert_eq!(config.initial_delay_ms, 1000);
|
||||
assert_eq!(config.max_delay_ms, 30000);
|
||||
assert_eq!(config.backoff_multiplier, 2.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_retry_config_builder() {
|
||||
let config = RetryConfig::new(5).with_delay(2000).with_max_delay(60000);
|
||||
assert_eq!(config.max_attempts, 5);
|
||||
assert_eq!(config.initial_delay_ms, 2000);
|
||||
assert_eq!(config.max_delay_ms, 60000);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_retry_config_clone() {
|
||||
let config = RetryConfig::default();
|
||||
let cloned = config.clone();
|
||||
assert_eq!(cloned.max_attempts, config.max_attempts);
|
||||
}
|
||||
}
|
||||
@@ -1,36 +1,145 @@
|
||||
use anyhow::Result;
|
||||
use anyhow::{Context, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::time::Duration;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
use super::executor::PythonExecutor;
|
||||
|
||||
const FACE_TIMEOUT: Duration = Duration::from_secs(7200);
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct FaceResult {
|
||||
pub frame_count: u64,
|
||||
pub fps: f64,
|
||||
pub frames: Vec<FaceFrame>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct FaceFrame {
|
||||
pub frame: u64,
|
||||
pub timestamp: f64,
|
||||
pub faces: Vec<Face>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct Face {
|
||||
pub face_id: String,
|
||||
pub face_id: Option<String>,
|
||||
pub x: i32,
|
||||
pub y: i32,
|
||||
pub width: i32,
|
||||
pub height: i32,
|
||||
pub confidence: f32,
|
||||
pub embedding: Option<Vec<f32>>,
|
||||
}
|
||||
|
||||
pub async fn process_face(video_path: &str, output_path: &str) -> Result<FaceResult> {
|
||||
// TODO: Implement face detection
|
||||
// Options:
|
||||
// 1. Use MTCNN or RetinaFace with ONNX
|
||||
// 2. Use Python subprocess
|
||||
pub async fn process_face(
|
||||
video_path: &str,
|
||||
output_path: &str,
|
||||
uuid: Option<&str>,
|
||||
) -> Result<FaceResult> {
|
||||
let executor = PythonExecutor::new()?;
|
||||
let script_path = executor.script_path("face_processor.py");
|
||||
|
||||
println!("Processing face detection for: {}", video_path);
|
||||
tracing::info!("[FACE] Starting face detection: {}", video_path);
|
||||
|
||||
Ok(FaceResult { frames: vec![] })
|
||||
if !script_path.exists() {
|
||||
tracing::warn!("[FACE] Script not found, returning empty result");
|
||||
return Ok(FaceResult {
|
||||
frame_count: 0,
|
||||
fps: 0.0,
|
||||
frames: vec![],
|
||||
});
|
||||
}
|
||||
|
||||
executor
|
||||
.run(
|
||||
"face_processor.py",
|
||||
&[video_path, output_path],
|
||||
uuid,
|
||||
"FACE",
|
||||
Some(FACE_TIMEOUT),
|
||||
)
|
||||
.await
|
||||
.with_context(|| format!("Failed to run {:?}", script_path))?;
|
||||
|
||||
let json_str = std::fs::read_to_string(output_path).context("Failed to read FACE output")?;
|
||||
|
||||
let result: FaceResult =
|
||||
serde_json::from_str(&json_str).context("Failed to parse FACE output")?;
|
||||
|
||||
tracing::info!("[FACE] Result: {} frames", result.frames.len());
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_face_result_serialization() {
|
||||
let result = FaceResult {
|
||||
frame_count: 100,
|
||||
fps: 30.0,
|
||||
frames: vec![FaceFrame {
|
||||
frame: 0,
|
||||
timestamp: 0.0,
|
||||
faces: vec![Face {
|
||||
face_id: Some("face_1".to_string()),
|
||||
x: 100,
|
||||
y: 100,
|
||||
width: 50,
|
||||
height: 60,
|
||||
confidence: 0.95,
|
||||
}],
|
||||
}],
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&result).unwrap();
|
||||
assert!(json.contains("face_1"));
|
||||
assert!(json.contains("\"width\":50"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_face_result_deserialization() {
|
||||
let json = r#"{
|
||||
"frame_count": 50,
|
||||
"fps": 25.0,
|
||||
"frames": [
|
||||
{
|
||||
"frame": 30,
|
||||
"timestamp": 1.2,
|
||||
"faces": [
|
||||
{"face_id": "f1", "x": 10, "y": 20, "width": 30, "height": 40, "confidence": 0.85}
|
||||
]
|
||||
}
|
||||
]
|
||||
}"#;
|
||||
|
||||
let result: FaceResult = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(result.frame_count, 50);
|
||||
assert_eq!(result.frames.len(), 1);
|
||||
assert_eq!(result.frames[0].faces[0].x, 10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_face_result_empty_frames() {
|
||||
let result = FaceResult {
|
||||
frame_count: 0,
|
||||
fps: 0.0,
|
||||
frames: vec![],
|
||||
};
|
||||
assert!(result.frames.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_face_confidence() {
|
||||
let face = Face {
|
||||
face_id: None,
|
||||
x: 0,
|
||||
y: 0,
|
||||
width: 10,
|
||||
height: 10,
|
||||
confidence: 0.5,
|
||||
};
|
||||
assert!(face.confidence >= 0.0 && face.confidence <= 1.0);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,13 +1,21 @@
|
||||
pub mod asr;
|
||||
pub mod asrx;
|
||||
pub mod caption;
|
||||
pub mod cut;
|
||||
pub mod executor;
|
||||
pub mod face;
|
||||
pub mod ocr;
|
||||
pub mod pose;
|
||||
pub mod story;
|
||||
pub mod yolo;
|
||||
|
||||
pub use asr::{process_asr, AsrResult, AsrSegment};
|
||||
pub use asrx::process_asrx;
|
||||
pub use face::process_face;
|
||||
pub use ocr::process_ocr;
|
||||
pub use pose::process_pose;
|
||||
pub use yolo::process_yolo;
|
||||
pub use asrx::{process_asrx, AsrxResult, AsrxSegment};
|
||||
pub use caption::{process_caption, CaptionResult, CaptionSummary, FrameCaption};
|
||||
pub use cut::{process_cut, CutResult, CutScene};
|
||||
pub use executor::{validate_python_env, PythonExecutor, RetryConfig};
|
||||
pub use face::{process_face, Face, FaceFrame, FaceResult};
|
||||
pub use ocr::{process_ocr, OcrFrame, OcrResult, OcrText};
|
||||
pub use pose::{process_pose, Bbox, Keypoint, PersonPose, PoseFrame, PoseResult};
|
||||
pub use story::{process_story, StoryChildChunk, StoryParentChunk, StoryResult, StoryStats};
|
||||
pub use yolo::{process_yolo, YoloFrame, YoloObject, YoloResult};
|
||||
|
||||
@@ -1,19 +1,26 @@
|
||||
use anyhow::Result;
|
||||
use anyhow::{Context, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::time::Duration;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
use super::executor::PythonExecutor;
|
||||
|
||||
const OCR_TIMEOUT: Duration = Duration::from_secs(7200);
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct OcrResult {
|
||||
pub frame_count: u64,
|
||||
pub fps: f64,
|
||||
pub frames: Vec<OcrFrame>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct OcrFrame {
|
||||
pub frame: u64,
|
||||
pub timestamp: f64,
|
||||
pub texts: Vec<OcrText>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct OcrText {
|
||||
pub text: String,
|
||||
pub x: i32,
|
||||
@@ -23,14 +30,116 @@ pub struct OcrText {
|
||||
pub confidence: f32,
|
||||
}
|
||||
|
||||
pub async fn process_ocr(video_path: &str, output_path: &str) -> Result<OcrResult> {
|
||||
// TODO: Implement OCR processing
|
||||
// Options:
|
||||
// 1. Use tesseract
|
||||
// 2. Use Python pytesseract via subprocess
|
||||
// 3. Use Rust OCR library
|
||||
pub async fn process_ocr(
|
||||
video_path: &str,
|
||||
output_path: &str,
|
||||
uuid: Option<&str>,
|
||||
) -> Result<OcrResult> {
|
||||
let executor = PythonExecutor::new()?;
|
||||
let script_path = executor.script_path("ocr_processor.py");
|
||||
|
||||
println!("Processing OCR for: {}", video_path);
|
||||
tracing::info!("[OCR] Starting text recognition: {}", video_path);
|
||||
|
||||
Ok(OcrResult { frames: vec![] })
|
||||
if !script_path.exists() {
|
||||
tracing::warn!("[OCR] Script not found, returning empty result");
|
||||
return Ok(OcrResult {
|
||||
frame_count: 0,
|
||||
fps: 0.0,
|
||||
frames: vec![],
|
||||
});
|
||||
}
|
||||
|
||||
executor
|
||||
.run(
|
||||
"ocr_processor.py",
|
||||
&[video_path, output_path],
|
||||
uuid,
|
||||
"OCR",
|
||||
Some(OCR_TIMEOUT),
|
||||
)
|
||||
.await
|
||||
.with_context(|| format!("Failed to run {:?}", script_path))?;
|
||||
|
||||
let json_str = std::fs::read_to_string(output_path).context("Failed to read OCR output")?;
|
||||
|
||||
let result: OcrResult =
|
||||
serde_json::from_str(&json_str).context("Failed to parse OCR output")?;
|
||||
|
||||
tracing::info!("[OCR] Result: {} frames", result.frames.len());
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_ocr_result_serialization() {
|
||||
let result = OcrResult {
|
||||
frame_count: 100,
|
||||
fps: 30.0,
|
||||
frames: vec![OcrFrame {
|
||||
frame: 0,
|
||||
timestamp: 0.0,
|
||||
texts: vec![OcrText {
|
||||
text: "Hello".to_string(),
|
||||
x: 10,
|
||||
y: 20,
|
||||
width: 100,
|
||||
height: 30,
|
||||
confidence: 0.95,
|
||||
}],
|
||||
}],
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&result).unwrap();
|
||||
assert!(json.contains("Hello"));
|
||||
assert!(json.contains("\"x\":10"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ocr_result_deserialization() {
|
||||
let json = r#"{
|
||||
"frame_count": 50,
|
||||
"fps": 25.0,
|
||||
"frames": [
|
||||
{
|
||||
"frame": 30,
|
||||
"timestamp": 1.2,
|
||||
"texts": [
|
||||
{"text": "Test", "x": 0, "y": 0, "width": 50, "height": 20, "confidence": 0.88}
|
||||
]
|
||||
}
|
||||
]
|
||||
}"#;
|
||||
|
||||
let result: OcrResult = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(result.frame_count, 50);
|
||||
assert_eq!(result.frames.len(), 1);
|
||||
assert_eq!(result.frames[0].texts[0].text, "Test");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ocr_result_empty_frames() {
|
||||
let result = OcrResult {
|
||||
frame_count: 0,
|
||||
fps: 0.0,
|
||||
frames: vec![],
|
||||
};
|
||||
assert!(result.frames.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ocr_text_confidence() {
|
||||
let text = OcrText {
|
||||
text: "OCR".to_string(),
|
||||
x: 0,
|
||||
y: 0,
|
||||
width: 10,
|
||||
height: 10,
|
||||
confidence: 0.75,
|
||||
};
|
||||
assert!(text.confidence >= 0.0 && text.confidence <= 1.0);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,25 +1,32 @@
|
||||
use anyhow::Result;
|
||||
use anyhow::{Context, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::time::Duration;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
use super::executor::PythonExecutor;
|
||||
|
||||
const POSE_TIMEOUT: Duration = Duration::from_secs(7200);
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct PoseResult {
|
||||
pub frame_count: u64,
|
||||
pub fps: f64,
|
||||
pub frames: Vec<PoseFrame>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct PoseFrame {
|
||||
pub frame: u64,
|
||||
pub timestamp: f64,
|
||||
pub persons: Vec<PersonPose>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct PersonPose {
|
||||
pub keypoints: Vec<Keypoint>,
|
||||
pub bbox: Bbox,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct Keypoint {
|
||||
pub name: String,
|
||||
pub x: f32,
|
||||
@@ -27,7 +34,7 @@ pub struct Keypoint {
|
||||
pub confidence: f32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct Bbox {
|
||||
pub x: i32,
|
||||
pub y: i32,
|
||||
@@ -35,13 +42,135 @@ pub struct Bbox {
|
||||
pub height: i32,
|
||||
}
|
||||
|
||||
pub async fn process_pose(video_path: &str, output_path: &str) -> Result<PoseResult> {
|
||||
// TODO: Implement pose estimation
|
||||
// Options:
|
||||
// 1. Use MoveNet or PoseNet with ONNX
|
||||
// 2. Use Python subprocess with ultralytics
|
||||
pub async fn process_pose(
|
||||
video_path: &str,
|
||||
output_path: &str,
|
||||
uuid: Option<&str>,
|
||||
) -> Result<PoseResult> {
|
||||
let executor = PythonExecutor::new()?;
|
||||
let script_path = executor.script_path("pose_processor.py");
|
||||
|
||||
println!("Processing pose estimation for: {}", video_path);
|
||||
tracing::info!("[POSE] Starting pose estimation: {}", video_path);
|
||||
|
||||
Ok(PoseResult { frames: vec![] })
|
||||
if !script_path.exists() {
|
||||
tracing::warn!("[POSE] Script not found, returning empty result");
|
||||
return Ok(PoseResult {
|
||||
frame_count: 0,
|
||||
fps: 0.0,
|
||||
frames: vec![],
|
||||
});
|
||||
}
|
||||
|
||||
executor
|
||||
.run(
|
||||
"pose_processor.py",
|
||||
&[video_path, output_path],
|
||||
uuid,
|
||||
"POSE",
|
||||
Some(POSE_TIMEOUT),
|
||||
)
|
||||
.await
|
||||
.with_context(|| format!("Failed to run {:?}", script_path))?;
|
||||
|
||||
let json_str = std::fs::read_to_string(output_path).context("Failed to read POSE output")?;
|
||||
|
||||
let result: PoseResult =
|
||||
serde_json::from_str(&json_str).context("Failed to parse POSE output")?;
|
||||
|
||||
tracing::info!("[POSE] Result: {} frames", result.frames.len());
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_pose_result_serialization() {
|
||||
let result = PoseResult {
|
||||
frame_count: 100,
|
||||
fps: 30.0,
|
||||
frames: vec![PoseFrame {
|
||||
frame: 0,
|
||||
timestamp: 0.0,
|
||||
persons: vec![PersonPose {
|
||||
keypoints: vec![Keypoint {
|
||||
name: "nose".to_string(),
|
||||
x: 100.0,
|
||||
y: 50.0,
|
||||
confidence: 0.9,
|
||||
}],
|
||||
bbox: Bbox {
|
||||
x: 80,
|
||||
y: 30,
|
||||
width: 40,
|
||||
height: 80,
|
||||
},
|
||||
}],
|
||||
}],
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&result).unwrap();
|
||||
assert!(json.contains("nose"));
|
||||
assert!(json.contains("\"confidence\":0.9"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pose_result_deserialization() {
|
||||
let json = r#"{
|
||||
"frame_count": 50,
|
||||
"fps": 25.0,
|
||||
"frames": [
|
||||
{
|
||||
"frame": 30,
|
||||
"timestamp": 1.2,
|
||||
"persons": [
|
||||
{
|
||||
"keypoints": [{"name": "left_eye", "x": 100.5, "y": 50.2, "confidence": 0.85}],
|
||||
"bbox": {"x": 90, "y": 40, "width": 20, "height": 30}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}"#;
|
||||
|
||||
let result: PoseResult = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(result.frame_count, 50);
|
||||
assert_eq!(result.frames.len(), 1);
|
||||
assert_eq!(result.frames[0].persons[0].keypoints[0].name, "left_eye");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pose_result_empty_frames() {
|
||||
let result = PoseResult {
|
||||
frame_count: 0,
|
||||
fps: 0.0,
|
||||
frames: vec![],
|
||||
};
|
||||
assert!(result.frames.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_keypoint_confidence() {
|
||||
let kp = Keypoint {
|
||||
name: "test".to_string(),
|
||||
x: 0.0,
|
||||
y: 0.0,
|
||||
confidence: 0.75,
|
||||
};
|
||||
assert!(kp.confidence >= 0.0 && kp.confidence <= 1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bbox_dimensions() {
|
||||
let bbox = Bbox {
|
||||
x: 10,
|
||||
y: 20,
|
||||
width: 50,
|
||||
height: 100,
|
||||
};
|
||||
assert!(bbox.width > 0);
|
||||
assert!(bbox.height > 0);
|
||||
}
|
||||
}
|
||||
|
||||
250
src/core/processor/story.rs
Normal file
250
src/core/processor/story.rs
Normal file
@@ -0,0 +1,250 @@
|
||||
use anyhow::{Context, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::time::Duration;
|
||||
|
||||
use super::executor::PythonExecutor;
|
||||
|
||||
const STORY_TIMEOUT: Duration = Duration::from_secs(3600);
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct StoryResult {
|
||||
pub child_chunks: Vec<StoryChildChunk>,
|
||||
pub parent_chunks: Vec<StoryParentChunk>,
|
||||
pub stats: StoryStats,
|
||||
pub metadata: serde_json::Value,
|
||||
pub parent_chunk_size: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct StoryStats {
|
||||
pub total_child_chunks: usize,
|
||||
pub total_parent_chunks: usize,
|
||||
pub asr_children: usize,
|
||||
pub cut_children: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct StoryChildChunk {
|
||||
pub chunk_id: String,
|
||||
pub chunk_type: String,
|
||||
pub source: String,
|
||||
pub start_time: f64,
|
||||
pub end_time: f64,
|
||||
pub text_content: Option<String>,
|
||||
pub content: serde_json::Value,
|
||||
pub child_chunk_ids: Vec<String>,
|
||||
pub parent_chunk_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct StoryParentChunk {
|
||||
pub chunk_id: String,
|
||||
pub chunk_type: String,
|
||||
pub source: String,
|
||||
pub start_time: f64,
|
||||
pub end_time: f64,
|
||||
pub text_content: String,
|
||||
pub content: serde_json::Value,
|
||||
pub child_chunk_ids: Vec<String>,
|
||||
pub parent_chunk_id: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn process_story(
|
||||
video_path: &str,
|
||||
output_path: &str,
|
||||
uuid: Option<&str>,
|
||||
) -> Result<StoryResult> {
|
||||
let executor = PythonExecutor::new()?;
|
||||
let script_path = executor.script_path("story_processor.py");
|
||||
|
||||
tracing::info!("[STORY] Starting story generation: {}", video_path);
|
||||
|
||||
if !script_path.exists() {
|
||||
tracing::warn!("[STORY] Script not found, returning empty result");
|
||||
return Ok(StoryResult {
|
||||
child_chunks: vec![],
|
||||
parent_chunks: vec![],
|
||||
stats: StoryStats {
|
||||
total_child_chunks: 0,
|
||||
total_parent_chunks: 0,
|
||||
asr_children: 0,
|
||||
cut_children: 0,
|
||||
},
|
||||
metadata: serde_json::json!({}),
|
||||
parent_chunk_size: 5,
|
||||
});
|
||||
}
|
||||
|
||||
executor
|
||||
.run(
|
||||
"story_processor.py",
|
||||
&[video_path, output_path],
|
||||
uuid,
|
||||
"STORY",
|
||||
Some(STORY_TIMEOUT),
|
||||
)
|
||||
.await
|
||||
.with_context(|| format!("Failed to run {:?}", script_path))?;
|
||||
|
||||
let json_str = std::fs::read_to_string(output_path).context("Failed to read STORY output")?;
|
||||
|
||||
let result: StoryResult =
|
||||
serde_json::from_str(&json_str).context("Failed to parse STORY output")?;
|
||||
|
||||
tracing::info!(
|
||||
"[STORY] Result: {} parent chunks, {} child chunks",
|
||||
result.stats.total_parent_chunks,
|
||||
result.stats.total_child_chunks
|
||||
);
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_story_result_serialization() {
|
||||
let result = StoryResult {
|
||||
child_chunks: vec![StoryChildChunk {
|
||||
chunk_id: "asr_0001".to_string(),
|
||||
chunk_type: "sentence".to_string(),
|
||||
source: "asr".to_string(),
|
||||
start_time: 0.0,
|
||||
end_time: 5.0,
|
||||
text_content: Some("Hello world".to_string()),
|
||||
content: serde_json::json!({}),
|
||||
child_chunk_ids: vec![],
|
||||
parent_chunk_id: Some("story_asr_0000".to_string()),
|
||||
}],
|
||||
parent_chunks: vec![StoryParentChunk {
|
||||
chunk_id: "story_asr_0000".to_string(),
|
||||
chunk_type: "story".to_string(),
|
||||
source: "story_asr".to_string(),
|
||||
start_time: 0.0,
|
||||
end_time: 25.0,
|
||||
text_content: "[0s-25s] Hello world...".to_string(),
|
||||
content: serde_json::json!({
|
||||
"description": "[0s-25s] Hello world...",
|
||||
"child_count": 5
|
||||
}),
|
||||
child_chunk_ids: vec!["asr_0001".to_string()],
|
||||
parent_chunk_id: None,
|
||||
}],
|
||||
stats: StoryStats {
|
||||
total_child_chunks: 10,
|
||||
total_parent_chunks: 2,
|
||||
asr_children: 10,
|
||||
cut_children: 0,
|
||||
},
|
||||
metadata: serde_json::json!({}),
|
||||
parent_chunk_size: 5,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&result).unwrap();
|
||||
assert!(json.contains("asr_0001"));
|
||||
assert!(json.contains("story_asr_0000"));
|
||||
assert!(json.contains("Hello world"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_story_result_deserialization() {
|
||||
let json = r#"{
|
||||
"child_chunks": [{
|
||||
"chunk_id": "asr_0001",
|
||||
"chunk_type": "sentence",
|
||||
"source": "asr",
|
||||
"start_time": 0.0,
|
||||
"end_time": 5.0,
|
||||
"text_content": "Hello",
|
||||
"content": {},
|
||||
"child_chunk_ids": [],
|
||||
"parent_chunk_id": null
|
||||
}],
|
||||
"parent_chunks": [{
|
||||
"chunk_id": "story_asr_0000",
|
||||
"chunk_type": "story",
|
||||
"source": "story_asr",
|
||||
"start_time": 0.0,
|
||||
"end_time": 5.0,
|
||||
"text_content": "Hello segment",
|
||||
"content": {"description": "Hello segment"},
|
||||
"child_chunk_ids": ["asr_0001"],
|
||||
"parent_chunk_id": null
|
||||
}],
|
||||
"stats": {
|
||||
"total_child_chunks": 1,
|
||||
"total_parent_chunks": 1,
|
||||
"asr_children": 1,
|
||||
"cut_children": 0
|
||||
},
|
||||
"metadata": {},
|
||||
"parent_chunk_size": 5
|
||||
}"#;
|
||||
|
||||
let result: StoryResult = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(result.child_chunks.len(), 1);
|
||||
assert_eq!(result.parent_chunks.len(), 1);
|
||||
assert_eq!(result.stats.total_child_chunks, 1);
|
||||
assert_eq!(result.stats.total_parent_chunks, 1);
|
||||
assert_eq!(result.parent_chunks[0].child_chunk_ids[0], "asr_0001");
|
||||
assert_eq!(result.child_chunks[0].parent_chunk_id, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parent_child_relationship() {
|
||||
let result = StoryResult {
|
||||
child_chunks: vec![
|
||||
StoryChildChunk {
|
||||
chunk_id: "asr_0001".to_string(),
|
||||
chunk_type: "sentence".to_string(),
|
||||
source: "asr".to_string(),
|
||||
start_time: 0.0,
|
||||
end_time: 5.0,
|
||||
text_content: Some("First".to_string()),
|
||||
content: serde_json::json!({}),
|
||||
child_chunk_ids: vec![],
|
||||
parent_chunk_id: Some("story_asr_0000".to_string()),
|
||||
},
|
||||
StoryChildChunk {
|
||||
chunk_id: "asr_0002".to_string(),
|
||||
chunk_type: "sentence".to_string(),
|
||||
source: "asr".to_string(),
|
||||
start_time: 5.0,
|
||||
end_time: 10.0,
|
||||
text_content: Some("Second".to_string()),
|
||||
content: serde_json::json!({}),
|
||||
child_chunk_ids: vec![],
|
||||
parent_chunk_id: Some("story_asr_0000".to_string()),
|
||||
},
|
||||
],
|
||||
parent_chunks: vec![StoryParentChunk {
|
||||
chunk_id: "story_asr_0000".to_string(),
|
||||
chunk_type: "story".to_string(),
|
||||
source: "story_asr".to_string(),
|
||||
start_time: 0.0,
|
||||
end_time: 10.0,
|
||||
text_content: "Combined narrative".to_string(),
|
||||
content: serde_json::json!({}),
|
||||
child_chunk_ids: vec!["asr_0001".to_string(), "asr_0002".to_string()],
|
||||
parent_chunk_id: None,
|
||||
}],
|
||||
stats: StoryStats {
|
||||
total_child_chunks: 2,
|
||||
total_parent_chunks: 1,
|
||||
asr_children: 2,
|
||||
cut_children: 0,
|
||||
},
|
||||
metadata: serde_json::json!({}),
|
||||
parent_chunk_size: 5,
|
||||
};
|
||||
|
||||
assert_eq!(result.parent_chunks[0].child_chunk_ids.len(), 2);
|
||||
assert!(result
|
||||
.child_chunks
|
||||
.iter()
|
||||
.all(|c| c.parent_chunk_id.is_some()));
|
||||
assert!(result.parent_chunks[0].parent_chunk_id.is_none());
|
||||
}
|
||||
}
|
||||
@@ -1,19 +1,26 @@
|
||||
use anyhow::Result;
|
||||
use anyhow::{Context, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::time::Duration;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
use super::executor::PythonExecutor;
|
||||
|
||||
const YOLO_TIMEOUT: Duration = Duration::from_secs(7200);
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct YoloResult {
|
||||
pub frame_count: u64,
|
||||
pub fps: f64,
|
||||
pub frames: Vec<YoloFrame>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct YoloFrame {
|
||||
pub frame: u64,
|
||||
pub timestamp: f64,
|
||||
pub objects: Vec<YoloObject>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct YoloObject {
|
||||
pub class_name: String,
|
||||
pub class_id: u32,
|
||||
@@ -24,13 +31,123 @@ pub struct YoloObject {
|
||||
pub confidence: f32,
|
||||
}
|
||||
|
||||
pub async fn process_yolo(video_path: &str, output_path: &str) -> Result<YoloResult> {
|
||||
// TODO: Implement YOLO processing
|
||||
// Options:
|
||||
// 1. Use ONNX Runtime (ort) with YOLO model
|
||||
// 2. Use Python subprocess with ultralytics
|
||||
pub async fn process_yolo(
|
||||
video_path: &str,
|
||||
output_path: &str,
|
||||
uuid: Option<&str>,
|
||||
) -> Result<YoloResult> {
|
||||
let executor = PythonExecutor::new()?;
|
||||
let script_path = executor.script_path("yolo_processor.py");
|
||||
|
||||
println!("Processing YOLO for: {}", video_path);
|
||||
tracing::info!("[YOLO] Starting object detection: {}", video_path);
|
||||
|
||||
Ok(YoloResult { frames: vec![] })
|
||||
if !script_path.exists() {
|
||||
tracing::warn!("[YOLO] Script not found, returning empty result");
|
||||
return Ok(YoloResult {
|
||||
frame_count: 0,
|
||||
fps: 0.0,
|
||||
frames: vec![],
|
||||
});
|
||||
}
|
||||
|
||||
executor
|
||||
.run(
|
||||
"yolo_processor.py",
|
||||
&[video_path, output_path],
|
||||
uuid,
|
||||
"YOLO",
|
||||
Some(YOLO_TIMEOUT),
|
||||
)
|
||||
.await
|
||||
.with_context(|| format!("Failed to run {:?}", script_path))?;
|
||||
|
||||
let json_str = std::fs::read_to_string(output_path).context("Failed to read YOLO output")?;
|
||||
|
||||
let result: YoloResult =
|
||||
serde_json::from_str(&json_str).context("Failed to parse YOLO output")?;
|
||||
|
||||
tracing::info!(
|
||||
"[YOLO] Result: {} frames, {:.2} fps",
|
||||
result.frame_count,
|
||||
result.fps
|
||||
);
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_yolo_result_serialization() {
|
||||
let result = YoloResult {
|
||||
frame_count: 100,
|
||||
fps: 30.0,
|
||||
frames: vec![YoloFrame {
|
||||
frame: 0,
|
||||
timestamp: 0.0,
|
||||
objects: vec![YoloObject {
|
||||
class_name: "person".to_string(),
|
||||
class_id: 0,
|
||||
x: 100,
|
||||
y: 200,
|
||||
width: 50,
|
||||
height: 100,
|
||||
confidence: 0.95,
|
||||
}],
|
||||
}],
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&result).unwrap();
|
||||
assert!(json.contains("person"));
|
||||
assert!(json.contains("100"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_yolo_result_deserialization() {
|
||||
let json = r#"{
|
||||
"frame_count": 50,
|
||||
"fps": 25.0,
|
||||
"frames": [
|
||||
{
|
||||
"frame": 10,
|
||||
"timestamp": 0.4,
|
||||
"objects": [
|
||||
{"class_name": "car", "class_id": 2, "x": 0, "y": 0, "width": 100, "height": 80, "confidence": 0.87}
|
||||
]
|
||||
}
|
||||
]
|
||||
}"#;
|
||||
|
||||
let result: YoloResult = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(result.frame_count, 50);
|
||||
assert_eq!(result.fps, 25.0);
|
||||
assert_eq!(result.frames.len(), 1);
|
||||
assert_eq!(result.frames[0].objects[0].class_name, "car");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_yolo_object_confidence_range() {
|
||||
let obj = YoloObject {
|
||||
class_name: "test".to_string(),
|
||||
class_id: 0,
|
||||
x: 0,
|
||||
y: 0,
|
||||
width: 10,
|
||||
height: 10,
|
||||
confidence: 0.5,
|
||||
};
|
||||
assert!(obj.confidence >= 0.0 && obj.confidence <= 1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_yolo_result_empty_frames() {
|
||||
let result = YoloResult {
|
||||
frame_count: 0,
|
||||
fps: 0.0,
|
||||
frames: vec![],
|
||||
};
|
||||
assert!(result.frames.is_empty());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
pub mod file_manager;
|
||||
pub mod output_dir;
|
||||
pub mod uuid;
|
||||
|
||||
pub use file_manager::FileManager;
|
||||
pub use output_dir::OutputDir;
|
||||
pub use uuid::compute_uuid;
|
||||
|
||||
226
src/core/storage/output_dir.rs
Normal file
226
src/core/storage/output_dir.rs
Normal file
@@ -0,0 +1,226 @@
|
||||
use anyhow::{Context, Result};
|
||||
use chrono::{DateTime, Datelike, Local, Timelike};
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
pub struct OutputDir {
|
||||
base_path: PathBuf,
|
||||
backup_enabled: bool,
|
||||
backup_dir: PathBuf,
|
||||
}
|
||||
|
||||
impl OutputDir {
|
||||
pub fn new() -> Self {
|
||||
let base_path = std::env::var("MOMENTRY_OUTPUT_DIR")
|
||||
.map(PathBuf::from)
|
||||
.unwrap_or_else(|_| PathBuf::from("./output"));
|
||||
|
||||
let backup_enabled = std::env::var("MOMENTRY_BACKUP_ENABLED")
|
||||
.map(|v| v == "true")
|
||||
.unwrap_or(false);
|
||||
|
||||
let backup_dir = std::env::var("MOMENTRY_BACKUP_DIR")
|
||||
.map(PathBuf::from)
|
||||
.unwrap_or_else(|_| PathBuf::from("/Users/accusys/momentry/backup/momentry"));
|
||||
|
||||
Self {
|
||||
base_path,
|
||||
backup_enabled,
|
||||
backup_dir,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_base_path(&self) -> &Path {
|
||||
&self.base_path
|
||||
}
|
||||
|
||||
pub fn get_backup_dir(&self) -> &Path {
|
||||
&self.backup_dir
|
||||
}
|
||||
|
||||
pub fn ensure_dir(&self) -> Result<()> {
|
||||
std::fs::create_dir_all(&self.base_path).context(format!(
|
||||
"Failed to create output directory: {:?}",
|
||||
self.base_path
|
||||
))?;
|
||||
|
||||
if self.backup_enabled {
|
||||
std::fs::create_dir_all(&self.backup_dir).context(format!(
|
||||
"Failed to create backup directory: {:?}",
|
||||
self.backup_dir
|
||||
))?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn get_output_path(&self, uuid: &str, extension: &str) -> PathBuf {
|
||||
self.base_path.join(format!("{}.{}", uuid, extension))
|
||||
}
|
||||
|
||||
fn get_timestamp() -> String {
|
||||
let now = Local::now();
|
||||
format!(
|
||||
"{:04}{:02}{:02}_{:02}{:02}{:02}",
|
||||
now.year(),
|
||||
now.month(),
|
||||
now.day(),
|
||||
now.hour(),
|
||||
now.minute(),
|
||||
now.second()
|
||||
)
|
||||
}
|
||||
|
||||
pub fn get_backup_path(&self, uuid: &str, extension: &str) -> Option<PathBuf> {
|
||||
if !self.backup_enabled {
|
||||
return None;
|
||||
}
|
||||
|
||||
let timestamp = Self::get_timestamp();
|
||||
let filename = format!("momentry_data_{}_{}.{}", timestamp, uuid, extension);
|
||||
|
||||
Some(self.backup_dir.join(filename))
|
||||
}
|
||||
|
||||
pub fn backup_file(&self, uuid: &str, extension: &str) -> Result<Option<PathBuf>> {
|
||||
if !self.backup_enabled {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let source = self.get_output_path(uuid, extension);
|
||||
if !source.exists() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let backup_path = match self.get_backup_path(uuid, extension) {
|
||||
Some(path) => path,
|
||||
None => return Ok(None),
|
||||
};
|
||||
|
||||
std::fs::copy(&source, &backup_path)
|
||||
.context(format!("Failed to backup file to: {:?}", backup_path))?;
|
||||
|
||||
let sha256_path = backup_path.with_extension(format!("{}.sha256", extension));
|
||||
|
||||
let source_content = std::fs::read(&source)?;
|
||||
let hash = format!("{:x}", md5::compute(&source_content));
|
||||
std::fs::write(
|
||||
&sha256_path,
|
||||
format!(
|
||||
"{} {}\n",
|
||||
hash,
|
||||
backup_path.file_name().unwrap().to_string_lossy()
|
||||
),
|
||||
)?;
|
||||
|
||||
Ok(Some(backup_path))
|
||||
}
|
||||
|
||||
pub fn cleanup_old_backups(&self, days: u32) -> Result<u32> {
|
||||
if !self.backup_enabled {
|
||||
return Ok(0);
|
||||
}
|
||||
|
||||
if !self.backup_dir.exists() {
|
||||
return Ok(0);
|
||||
}
|
||||
|
||||
let cutoff = Local::now() - chrono::Duration::days(days as i64);
|
||||
let mut deleted_count = 0;
|
||||
|
||||
for entry in std::fs::read_dir(&self.backup_dir)? {
|
||||
let entry = entry?;
|
||||
let path = entry.path();
|
||||
|
||||
if !path.is_file() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(name) = path.file_name() {
|
||||
let name_str = name.to_string_lossy();
|
||||
if name_str.starts_with("momentry_data_") && name_str.len() == 43 {
|
||||
let date_part = &name_str[14..22];
|
||||
if let Ok(date) =
|
||||
DateTime::parse_from_str(&format!("{} 000000", date_part), "%Y%m%d %H%M%S")
|
||||
{
|
||||
if date.with_timezone(&Local) < cutoff {
|
||||
std::fs::remove_file(&path)?;
|
||||
deleted_count += 1;
|
||||
|
||||
let sha256_path = path.with_extension("sha256");
|
||||
if sha256_path.exists() {
|
||||
let _ = std::fs::remove_file(sha256_path);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(deleted_count)
|
||||
}
|
||||
|
||||
pub fn list_backups(&self) -> Result<Vec<BackupInfo>> {
|
||||
if !self.backup_dir.exists() {
|
||||
return Ok(vec![]);
|
||||
}
|
||||
|
||||
let mut backups = Vec::new();
|
||||
|
||||
for entry in std::fs::read_dir(&self.backup_dir)? {
|
||||
let entry = entry?;
|
||||
let path = entry.path();
|
||||
|
||||
if !path.is_file() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(name) = path.file_name() {
|
||||
let name_str = name.to_string_lossy();
|
||||
if name_str.starts_with("momentry_data_") && name_str.ends_with(".sha256") {
|
||||
continue;
|
||||
}
|
||||
|
||||
if name_str.starts_with("momentry_data_") {
|
||||
let date_part = &name_str[14..22];
|
||||
backups.push(BackupInfo {
|
||||
filename: name_str.to_string(),
|
||||
date: date_part.to_string(),
|
||||
path: path.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
backups.sort_by(|a, b| b.date.cmp(&a.date));
|
||||
Ok(backups)
|
||||
}
|
||||
|
||||
pub fn verify_backup(&self, backup_path: &Path) -> Result<bool> {
|
||||
let sha256_path = backup_path.with_extension("sha256");
|
||||
|
||||
if !sha256_path.exists() {
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
let sha256_content = std::fs::read_to_string(&sha256_path)?;
|
||||
let expected_hash = sha256_content.split_whitespace().next().unwrap_or("");
|
||||
|
||||
let source_content = std::fs::read(backup_path)?;
|
||||
let actual_hash = format!("{:x}", md5::compute(&source_content));
|
||||
|
||||
Ok(expected_hash == actual_hash)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for OutputDir {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct BackupInfo {
|
||||
pub filename: String,
|
||||
pub date: String,
|
||||
pub path: PathBuf,
|
||||
}
|
||||
@@ -25,39 +25,36 @@ pub fn compute_uuid_from_path(full_path: &str) -> String {
|
||||
compute_uuid(&parent, &filename)
|
||||
}
|
||||
|
||||
/// Extract relative path from full path given data root
|
||||
/// Returns (relative_path, filename)
|
||||
pub fn extract_relative_path(full_path: &str, data_root: &str) -> (String, String) {
|
||||
let full_path = PathBuf::from(full_path);
|
||||
let data_root = PathBuf::from(data_root);
|
||||
/// Extract username and filepath from relative path
|
||||
/// Input: ./demo/video.mp4 or ./demo/path/to/video.mp4
|
||||
/// Returns: (username, filepath) e.g., ("demo", "video.mp4") or ("demo", "path/to/video.mp4")
|
||||
pub fn extract_user_from_relative_path(relative_path: &str) -> (String, String) {
|
||||
// Remove leading ./
|
||||
let path = relative_path.strip_prefix("./").unwrap_or(relative_path);
|
||||
|
||||
// Canonicalize both paths
|
||||
let full_canonical = full_path.canonicalize().unwrap_or(full_path.clone());
|
||||
let root_canonical = data_root.canonicalize().unwrap_or(data_root.clone());
|
||||
let path_buf = PathBuf::from(path);
|
||||
|
||||
// Try to strip the data root prefix
|
||||
let relative = full_canonical
|
||||
.strip_prefix(&root_canonical)
|
||||
.unwrap_or(&full_canonical);
|
||||
|
||||
// Separate into parent directory and filename
|
||||
let filename = relative
|
||||
.file_name()
|
||||
.map(|n| n.to_string_lossy().to_string())
|
||||
// First component is username
|
||||
let mut components = path_buf.components();
|
||||
let username = components
|
||||
.next()
|
||||
.map(|c| c.as_os_str().to_string_lossy().to_string())
|
||||
.unwrap_or_default();
|
||||
|
||||
let parent = relative
|
||||
.parent()
|
||||
.map(|p| p.to_string_lossy().to_string())
|
||||
.unwrap_or_default();
|
||||
// Remaining path (filepath)
|
||||
let filepath: String = components
|
||||
.map(|c| c.as_os_str().to_string_lossy().to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join("/");
|
||||
|
||||
(parent, filename)
|
||||
(username, filepath)
|
||||
}
|
||||
|
||||
/// Compute UUID from full path using data root for relative path extraction
|
||||
pub fn compute_uuid_from_path_with_root(full_path: &str, data_root: &str) -> String {
|
||||
let (parent, filename) = extract_relative_path(full_path, data_root);
|
||||
compute_uuid(&parent, &filename)
|
||||
/// Compute UUID from relative path (like ./demo/video.mp4)
|
||||
/// The username is extracted from the first path component
|
||||
pub fn compute_uuid_from_relative_path(relative_path: &str) -> String {
|
||||
let (username, filepath) = extract_user_from_relative_path(relative_path);
|
||||
compute_uuid(&username, &filepath)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -78,24 +75,26 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_relative_path_extraction() {
|
||||
let (parent, filename) =
|
||||
extract_relative_path("/data/sftpgo/data/demo/video.mp4", "/data/sftpgo/data");
|
||||
assert_eq!(parent, "demo");
|
||||
assert_eq!(filename, "video.mp4");
|
||||
fn test_extract_user_from_relative_path() {
|
||||
let (username, filepath) = extract_user_from_relative_path("./demo/video.mp4");
|
||||
assert_eq!(username, "demo");
|
||||
assert_eq!(filepath, "video.mp4");
|
||||
|
||||
let (username, filepath) = extract_user_from_relative_path("./demo/path/to/video.mp4");
|
||||
assert_eq!(username, "demo");
|
||||
assert_eq!(filepath, "path/to/video.mp4");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_uuid_with_data_root() {
|
||||
let uuid1 = compute_uuid_from_path_with_root(
|
||||
"/data/sftpgo/data/demo/video.mp4",
|
||||
"/data/sftpgo/data",
|
||||
);
|
||||
let uuid2 = compute_uuid_from_path_with_root(
|
||||
"/data/sftpgo/data/demo/video.mp4",
|
||||
"/data/sftpgo/data",
|
||||
);
|
||||
fn test_uuid_from_relative_path() {
|
||||
let uuid1 = compute_uuid_from_relative_path("./demo/video.mp4");
|
||||
let uuid2 = compute_uuid_from_relative_path("./demo/video.mp4");
|
||||
assert_eq!(uuid1, uuid2);
|
||||
assert_eq!(uuid1.len(), 16);
|
||||
|
||||
// Different users with same filename should have different UUIDs
|
||||
let uuid_demo = compute_uuid_from_relative_path("./demo/video.mp4");
|
||||
let uuid_warren = compute_uuid_from_relative_path("./warren/video.mp4");
|
||||
assert_ne!(uuid_demo, uuid_warren);
|
||||
}
|
||||
}
|
||||
|
||||
15
src/lib.rs
15
src/lib.rs
@@ -1,8 +1,21 @@
|
||||
pub mod core;
|
||||
|
||||
pub mod api;
|
||||
|
||||
pub mod ui;
|
||||
|
||||
pub mod worker;
|
||||
|
||||
pub use core::cache::{keys, MongoCache, RedisCache};
|
||||
pub use core::chunk::{Chunk, ChunkSplitter, ChunkType};
|
||||
pub use core::db::{Database, MongoDb, PostgresDb, QdrantDb, RedisDb, VideoRecord};
|
||||
pub use core::db::{
|
||||
Database, MongoDb, PostgresDb, QdrantDb, RedisClient, RedisDb, VectorPayload, VideoRecord,
|
||||
VideoStatus,
|
||||
};
|
||||
pub use core::embedding::Embedder;
|
||||
pub use core::probe::ProbeResult;
|
||||
pub use core::storage::file_manager::FileManager;
|
||||
pub use core::storage::output_dir::OutputDir;
|
||||
pub use core::storage::uuid;
|
||||
pub use core::thumbnail::{ThumbnailExtractor, ThumbnailResult};
|
||||
pub use ui::progress::{ProcessorType, ProgressState, ProgressUi};
|
||||
|
||||
196
src/player/api_client.rs
Normal file
196
src/player/api_client.rs
Normal file
@@ -0,0 +1,196 @@
|
||||
use anyhow::Result;
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::PathBuf;
|
||||
|
||||
const DEFAULT_API_URL: &str = "http://localhost:3002";
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ApiClient {
|
||||
client: Client,
|
||||
base_url: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
pub struct RegisterRequest {
|
||||
pub path: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
pub struct RegisterResponse {
|
||||
pub uuid: String,
|
||||
pub video_id: i64,
|
||||
pub job_id: i64,
|
||||
pub file_name: String,
|
||||
pub duration: f64,
|
||||
pub width: u32,
|
||||
pub height: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
pub struct SearchRequest {
|
||||
pub query: String,
|
||||
pub limit: Option<usize>,
|
||||
pub uuid: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
pub struct SearchResult {
|
||||
pub uuid: String,
|
||||
pub chunk_id: String,
|
||||
pub chunk_type: String,
|
||||
pub start_time: f64,
|
||||
pub end_time: f64,
|
||||
pub text: String,
|
||||
pub score: f32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
pub struct SearchResponse {
|
||||
pub results: Vec<SearchResult>,
|
||||
pub query: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
#[allow(dead_code)]
|
||||
pub struct LookupQuery {
|
||||
pub path: Option<String>,
|
||||
pub uuid: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
pub struct LookupResponse {
|
||||
pub uuid: String,
|
||||
pub file_path: Option<String>,
|
||||
pub file_name: Option<String>,
|
||||
pub duration: Option<f64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
pub struct VideoInfo {
|
||||
pub uuid: String,
|
||||
pub file_path: String,
|
||||
pub file_name: String,
|
||||
pub duration: f64,
|
||||
pub width: u32,
|
||||
pub height: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
pub struct VideosResponse {
|
||||
pub videos: Vec<VideoInfo>,
|
||||
}
|
||||
|
||||
impl ApiClient {
|
||||
pub fn new() -> Self {
|
||||
let url = std::env::var("MOMENTRY_API_URL").unwrap_or_else(|_| DEFAULT_API_URL.to_string());
|
||||
Self {
|
||||
client: Client::new(),
|
||||
base_url: url,
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn with_url(url: &str) -> Self {
|
||||
Self {
|
||||
client: Client::new(),
|
||||
base_url: url.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn register_video(&self, path: &str) -> Result<RegisterResponse> {
|
||||
let url = format!("{}/api/v1/register", self.base_url);
|
||||
let request = RegisterRequest {
|
||||
path: path.to_string(),
|
||||
};
|
||||
let response = self.client.post(&url).json(&request).send().await?;
|
||||
let status = response.status();
|
||||
let result = response.json::<RegisterResponse>().await?;
|
||||
if !status.is_success() {
|
||||
anyhow::bail!("API request failed with status: {}", status);
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
pub async fn search_chunks(
|
||||
&self,
|
||||
query: &str,
|
||||
uuid: Option<&str>,
|
||||
limit: Option<usize>,
|
||||
) -> Result<SearchResponse> {
|
||||
let url = format!("{}/api/v1/search", self.base_url);
|
||||
let request = SearchRequest {
|
||||
query: query.to_string(),
|
||||
limit,
|
||||
uuid: uuid.map(|s| s.to_string()),
|
||||
};
|
||||
let response = self.client.post(&url).json(&request).send().await?;
|
||||
let status = response.status();
|
||||
let result = response.json::<SearchResponse>().await?;
|
||||
if !status.is_success() {
|
||||
anyhow::bail!("API request failed with status: {}", status);
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
pub async fn lookup_video(&self, uuid: &str) -> Result<LookupResponse> {
|
||||
let url = format!("{}/api/v1/lookup?uuid={}", self.base_url, uuid);
|
||||
let response = self.client.get(&url).send().await?;
|
||||
let status = response.status();
|
||||
let result = response.json::<LookupResponse>().await?;
|
||||
if !status.is_success() {
|
||||
anyhow::bail!("API request failed with status: {}", status);
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
pub async fn list_videos(&self) -> Result<Vec<VideoInfo>> {
|
||||
let url = format!("{}/api/v1/videos", self.base_url);
|
||||
let response = self.client.get(&url).send().await?;
|
||||
let status = response.status();
|
||||
let result = response.json::<VideosResponse>().await?;
|
||||
if !status.is_success() {
|
||||
anyhow::bail!("API request failed with status: {}", status);
|
||||
}
|
||||
Ok(result.videos)
|
||||
}
|
||||
|
||||
pub fn base_url(&self) -> &str {
|
||||
&self.base_url
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ApiClient {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn find_video_path() -> Option<String> {
|
||||
let test_dirs = vec![
|
||||
PathBuf::from("/Users/accusys/Movies"),
|
||||
PathBuf::from("/Users/accusys/Downloads"),
|
||||
PathBuf::from("/Users/accusys/momentry_core_project/test_video"),
|
||||
PathBuf::from("."),
|
||||
];
|
||||
|
||||
for dir in test_dirs {
|
||||
if dir.exists() {
|
||||
if let Ok(entries) = std::fs::read_dir(&dir) {
|
||||
for entry in entries.flatten() {
|
||||
let path = entry.path();
|
||||
if let Some(ext) = path.extension() {
|
||||
let ext_str = ext.to_string_lossy().to_lowercase();
|
||||
if matches!(
|
||||
ext_str.as_str(),
|
||||
"mp4" | "mov" | "m4v" | "avi" | "mkv" | "webm"
|
||||
) {
|
||||
return Some(path.to_string_lossy().to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
181
src/player/asr_overlay.rs
Normal file
181
src/player/asr_overlay.rs
Normal file
@@ -0,0 +1,181 @@
|
||||
use std::fs;
|
||||
use std::path::PathBuf;
|
||||
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
#[allow(dead_code)]
|
||||
pub struct AsrSegment {
|
||||
pub start: f64,
|
||||
pub end: f64,
|
||||
pub text: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
#[allow(dead_code)]
|
||||
pub struct AsrData {
|
||||
#[serde(default)]
|
||||
pub segments: Vec<AsrSegment>,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub struct AsrOverlay {
|
||||
segments: Vec<AsrSegment>,
|
||||
current_text: String,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
impl AsrOverlay {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
segments: Vec::new(),
|
||||
current_text: String::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn load_from_file(&mut self, video_path: &str) -> bool {
|
||||
// Try to find ASR JSON file in various locations
|
||||
let video_dir = PathBuf::from(video_path).parent().map(|p| p.to_path_buf());
|
||||
let _video_stem = PathBuf::from(video_path)
|
||||
.file_stem()
|
||||
.and_then(|s| s.to_str())
|
||||
.unwrap_or("");
|
||||
|
||||
let mut paths = Vec::new();
|
||||
|
||||
// In same directory as video
|
||||
if let Some(_dir) = &video_dir {
|
||||
paths.push(PathBuf::from(video_path).with_extension("asr.json"));
|
||||
}
|
||||
|
||||
// In data directory
|
||||
let data_dir = PathBuf::from("/Users/accusys/momentry_core_0.1");
|
||||
if let Ok(content) = fs::read_to_string(video_path) {
|
||||
let _ = content;
|
||||
}
|
||||
|
||||
// Try probe file for UUID
|
||||
let uuid = self
|
||||
.find_uuid_from_probe(video_path)
|
||||
.or_else(|| lookup_uuid_from_db(video_path));
|
||||
|
||||
if let Some(uuid_val) = uuid {
|
||||
paths.push(data_dir.join(format!("{}.asr.json", uuid_val)));
|
||||
}
|
||||
|
||||
for path in &paths {
|
||||
if path.exists() {
|
||||
if let Ok(content) = fs::read_to_string(path) {
|
||||
if let Ok(data) = serde_json::from_str::<AsrData>(&content) {
|
||||
self.segments = data.segments;
|
||||
println!(
|
||||
"Loaded {} ASR segments from {:?}",
|
||||
self.segments.len(),
|
||||
path
|
||||
);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Try to load from PostgreSQL
|
||||
if let Some(uuid) = lookup_uuid_from_db(video_path) {
|
||||
let db_path = PathBuf::from("/Users/accusys/momentry_core_0.1")
|
||||
.join(format!("{}.asr.json", uuid));
|
||||
if db_path.exists() {
|
||||
if let Ok(content) = fs::read_to_string(&db_path) {
|
||||
if let Ok(data) = serde_json::from_str::<AsrData>(&content) {
|
||||
self.segments = data.segments;
|
||||
println!(
|
||||
"Loaded {} ASR segments from database file",
|
||||
self.segments.len()
|
||||
);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn update(&mut self, current_time: f64) {
|
||||
self.current_text = String::new();
|
||||
|
||||
for segment in &self.segments {
|
||||
if current_time >= segment.start && current_time <= segment.end {
|
||||
self.current_text = segment.text.clone();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn get_text(&self) -> &str {
|
||||
&self.current_text
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.segments.is_empty()
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn find_uuid_from_probe(&self, video_path: &str) -> Option<String> {
|
||||
let path_buf = PathBuf::from(video_path);
|
||||
let video_stem = path_buf.file_stem().and_then(|s| s.to_str()).unwrap_or("");
|
||||
|
||||
let probe_path = PathBuf::from("/Users/accusys/momentry_core_0.1")
|
||||
.join(format!("{}.probe.json", video_stem));
|
||||
|
||||
if probe_path.exists() {
|
||||
if let Ok(content) = fs::read_to_string(&probe_path) {
|
||||
if let Ok(probe) = serde_json::from_str::<serde_json::Value>(&content) {
|
||||
return probe
|
||||
.get("uuid")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn lookup_uuid_from_db(video_path: &str) -> Option<String> {
|
||||
use std::process::Command as StdCommand;
|
||||
|
||||
let filename = std::path::Path::new(video_path)
|
||||
.file_stem()
|
||||
.and_then(|s| s.to_str())
|
||||
.unwrap_or("");
|
||||
|
||||
let output = StdCommand::new("psql")
|
||||
.args([
|
||||
"-U",
|
||||
"accusys",
|
||||
"-h",
|
||||
"localhost",
|
||||
"-d",
|
||||
"momentry",
|
||||
"-t",
|
||||
"-A",
|
||||
"-c",
|
||||
&format!(
|
||||
"SELECT uuid FROM videos WHERE file_path LIKE '%{}%' LIMIT 1",
|
||||
filename
|
||||
),
|
||||
])
|
||||
.output()
|
||||
.ok()?;
|
||||
|
||||
if output.status.success() {
|
||||
let uuid = String::from_utf8_lossy(&output.stdout).trim().to_string();
|
||||
if !uuid.is_empty() {
|
||||
return Some(uuid);
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
333
src/player/chunk_selector.rs
Normal file
333
src/player/chunk_selector.rs
Normal file
@@ -0,0 +1,333 @@
|
||||
use anyhow::Result;
|
||||
use ratatui::{
|
||||
backend::CrosstermBackend,
|
||||
layout::{Constraint, Direction, Layout},
|
||||
style::{Color, Style},
|
||||
text::{Line, Span},
|
||||
widgets::{Block, Borders, List, ListItem, Paragraph},
|
||||
Frame, Terminal,
|
||||
};
|
||||
use std::io;
|
||||
use std::process::Command as StdCommand;
|
||||
|
||||
#[allow(dead_code)]
|
||||
const QDRANT_URL: &str = "http://localhost:6333";
|
||||
#[allow(dead_code)]
|
||||
const QDRANT_API_KEY: &str = "Test3200Test3200Test3200";
|
||||
#[allow(dead_code)]
|
||||
const OLLAMA_URL: &str = "http://localhost:11434";
|
||||
#[allow(dead_code)]
|
||||
const MODEL: &str = "nomic-embed-text-v2-moe";
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[allow(dead_code)]
|
||||
pub struct ChunkEntry {
|
||||
pub chunk_id: String,
|
||||
pub start_time: f64,
|
||||
pub end_time: f64,
|
||||
pub text: String,
|
||||
pub score: f64,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
impl ChunkEntry {
|
||||
pub fn format_time_range(&self) -> String {
|
||||
let start_mins = (self.start_time / 60.0) as u32;
|
||||
let start_secs = (self.start_time % 60.0) as u32;
|
||||
let end_mins = (self.end_time / 60.0) as u32;
|
||||
let end_secs = (self.end_time % 60.0) as u32;
|
||||
format!(
|
||||
"{:02}:{:02} - {:02}:{:02}",
|
||||
start_mins, start_secs, end_mins, end_secs
|
||||
)
|
||||
}
|
||||
|
||||
pub fn truncate_text(&self, max_len: usize) -> String {
|
||||
if self.text.len() > max_len {
|
||||
format!("{}...", &self.text[..max_len])
|
||||
} else {
|
||||
self.text.clone()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub struct ChunkSelector {
|
||||
chunks: Vec<ChunkEntry>,
|
||||
selected_index: usize,
|
||||
query: String,
|
||||
video_uuid: String,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
impl ChunkSelector {
|
||||
pub fn new(video_uuid: &str) -> Self {
|
||||
Self {
|
||||
chunks: Vec::new(),
|
||||
selected_index: 0,
|
||||
query: String::new(),
|
||||
video_uuid: video_uuid.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn search(&mut self, query: &str) -> Result<Vec<ChunkEntry>> {
|
||||
self.query = query.to_string();
|
||||
self.chunks = Vec::new();
|
||||
self.selected_index = 0;
|
||||
|
||||
if query.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
// Get embedding from Ollama
|
||||
let embed_output = StdCommand::new("curl")
|
||||
.args([
|
||||
"-s",
|
||||
&format!("{}/api/embeddings", OLLAMA_URL),
|
||||
"-X",
|
||||
"POST",
|
||||
"-H",
|
||||
"Content-Type: application/json",
|
||||
"-d",
|
||||
&format!(
|
||||
r#"{{"model":"{}","prompt":"search_query: {}"}}"#,
|
||||
MODEL, query
|
||||
),
|
||||
])
|
||||
.output()?;
|
||||
|
||||
let embed_text = String::from_utf8_lossy(&embed_output.stdout);
|
||||
|
||||
// Parse embedding from response
|
||||
let embedding: Vec<f64> = serde_json::from_str(&embed_text)
|
||||
.ok()
|
||||
.and_then(|v: serde_json::Value| {
|
||||
v.get("embedding")
|
||||
.and_then(|e| serde_json::from_value(e.clone()).ok())
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
if embedding.is_empty() {
|
||||
println!("Failed to get embedding for query: {}", query);
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
// Search Qdrant - try both collections (chunks_v3 for multilingual, AccusysDB for others)
|
||||
let collections = ["chunks_v3", "AccusysDB"];
|
||||
|
||||
for collection in collections {
|
||||
let vector_str = serde_json::to_string(&embedding)
|
||||
.unwrap_or_default()
|
||||
.replace(['[', ']'], "");
|
||||
|
||||
let qdrant_output = StdCommand::new("curl")
|
||||
.args([
|
||||
"-s",
|
||||
&format!("{}/collections/{}/points/search", QDRANT_URL, collection),
|
||||
"-X",
|
||||
"POST",
|
||||
"-H",
|
||||
&format!("api-key: {}", QDRANT_API_KEY),
|
||||
"-H",
|
||||
"Content-Type: application/json",
|
||||
"-d",
|
||||
&format!(
|
||||
r#"{{"vector":[{}],"limit":20,"with_payload":true}}"#,
|
||||
vector_str
|
||||
),
|
||||
])
|
||||
.output()?;
|
||||
|
||||
let qdrant_text = String::from_utf8_lossy(&qdrant_output.stdout);
|
||||
|
||||
if let Ok(response) = serde_json::from_str::<serde_json::Value>(&qdrant_text) {
|
||||
if let Some(results) = response.get("result").and_then(|r| r.as_array()) {
|
||||
for r in results {
|
||||
let payload = r.get("payload");
|
||||
|
||||
// Try to match UUID - either exact match or partial match
|
||||
let _uuid = payload
|
||||
.and_then(|p| p.get("uuid"))
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("");
|
||||
|
||||
// Accept all chunks (remove UUID filter for now since we want to find any content)
|
||||
// The user can select which chunk to play
|
||||
let uuid_match = true; // Accept all
|
||||
|
||||
if !uuid_match {
|
||||
continue;
|
||||
}
|
||||
|
||||
let chunk_id = payload
|
||||
.and_then(|p| p.get("chunk_id"))
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
let start_time = payload
|
||||
.and_then(|p| p.get("start_time"))
|
||||
.and_then(|v| v.as_f64())
|
||||
.unwrap_or(0.0);
|
||||
let end_time = payload
|
||||
.and_then(|p| p.get("end_time"))
|
||||
.and_then(|v| v.as_f64())
|
||||
.unwrap_or(0.0);
|
||||
let text = payload
|
||||
.and_then(|p| p.get("text"))
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
let score = r.get("score").and_then(|v| v.as_f64()).unwrap_or(0.0);
|
||||
|
||||
if !text.is_empty() {
|
||||
self.chunks.push(ChunkEntry {
|
||||
chunk_id,
|
||||
start_time,
|
||||
end_time,
|
||||
text,
|
||||
score,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if !self.chunks.is_empty() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(self.chunks.clone())
|
||||
}
|
||||
|
||||
pub fn run(&mut self) -> Result<Option<ChunkEntry>> {
|
||||
let stdout = io::stdout();
|
||||
let backend = CrosstermBackend::new(stdout);
|
||||
let mut terminal = Terminal::new(backend)?;
|
||||
|
||||
loop {
|
||||
terminal.draw(|f| self.render(f))?;
|
||||
|
||||
match crossterm::event::read() {
|
||||
Ok(crossterm::event::Event::Key(key)) => match key.code {
|
||||
crossterm::event::KeyCode::Up => {
|
||||
if self.selected_index > 0 {
|
||||
self.selected_index -= 1;
|
||||
}
|
||||
}
|
||||
crossterm::event::KeyCode::Down => {
|
||||
if self.selected_index < self.chunks.len().saturating_sub(1) {
|
||||
self.selected_index += 1;
|
||||
}
|
||||
}
|
||||
crossterm::event::KeyCode::Enter => {
|
||||
let selected = self.chunks.get(self.selected_index).cloned();
|
||||
terminal.show_cursor()?;
|
||||
return Ok(selected);
|
||||
}
|
||||
crossterm::event::KeyCode::Char(c) => {
|
||||
if c == 'q' {
|
||||
terminal.show_cursor()?;
|
||||
return Ok(None);
|
||||
}
|
||||
self.query.push(c);
|
||||
}
|
||||
crossterm::event::KeyCode::Backspace => {
|
||||
self.query.pop();
|
||||
}
|
||||
crossterm::event::KeyCode::Esc => {
|
||||
terminal.show_cursor()?;
|
||||
return Ok(None);
|
||||
}
|
||||
_ => {}
|
||||
},
|
||||
Ok(crossterm::event::Event::Resize(_, _)) => {}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn render(&self, f: &mut Frame) {
|
||||
let chunks = Layout::default()
|
||||
.direction(Direction::Vertical)
|
||||
.constraints([
|
||||
Constraint::Length(3),
|
||||
Constraint::Length(3),
|
||||
Constraint::Min(0),
|
||||
Constraint::Length(3),
|
||||
])
|
||||
.split(f.area());
|
||||
|
||||
// Title
|
||||
let title = Paragraph::new("🔍 Chunk Search - Natural Language Query")
|
||||
.style(Style::default().fg(Color::Cyan))
|
||||
.block(Block::default().borders(Borders::ALL).title(" Search "));
|
||||
f.render_widget(title, chunks[0]);
|
||||
|
||||
// Query input
|
||||
let query_text = if self.query.is_empty() {
|
||||
"Type to search...".to_string()
|
||||
} else {
|
||||
self.query.clone()
|
||||
};
|
||||
let query_style = if self.query.is_empty() {
|
||||
Style::default().fg(Color::DarkGray)
|
||||
} else {
|
||||
Style::default().fg(Color::White)
|
||||
};
|
||||
let query = Paragraph::new(query_text)
|
||||
.style(query_style)
|
||||
.block(Block::default().borders(Borders::ALL).title(" Query "));
|
||||
f.render_widget(query, chunks[1]);
|
||||
|
||||
// Results
|
||||
if self.chunks.is_empty() {
|
||||
let no_results = Paragraph::new("No results found. Type to search...")
|
||||
.style(Style::default().fg(Color::DarkGray))
|
||||
.block(Block::default().borders(Borders::ALL).title(" Results "));
|
||||
f.render_widget(no_results, chunks[2]);
|
||||
} else {
|
||||
let items: Vec<ListItem> = self
|
||||
.chunks
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, chunk)| {
|
||||
let style = if i == self.selected_index {
|
||||
Style::default().fg(Color::Yellow).bg(Color::DarkGray)
|
||||
} else {
|
||||
Style::default()
|
||||
};
|
||||
|
||||
let content = Line::from(vec![
|
||||
Span::raw(format!(
|
||||
"{} ",
|
||||
if i == self.selected_index { "▶" } else { " " }
|
||||
)),
|
||||
Span::styled(chunk.format_time_range(), Style::default().fg(Color::Green)),
|
||||
Span::raw(" "),
|
||||
Span::raw(chunk.truncate_text(50)),
|
||||
Span::styled(
|
||||
format!(" [{:.2}]", chunk.score),
|
||||
Style::default().fg(Color::Blue),
|
||||
),
|
||||
]);
|
||||
|
||||
ListItem::new(content).style(style)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let list = List::new(items)
|
||||
.block(Block::default().borders(Borders::ALL).title(" Results "))
|
||||
.highlight_style(Style::default().fg(Color::Yellow));
|
||||
|
||||
f.render_widget(list, chunks[2]);
|
||||
}
|
||||
|
||||
// Help text
|
||||
let help =
|
||||
Paragraph::new(" [↑/↓] Navigate [Enter] Play from here [Type] Search [q] Quit ")
|
||||
.style(Style::default().fg(Color::DarkGray))
|
||||
.block(Block::default().borders(Borders::ALL));
|
||||
f.render_widget(help, chunks[3]);
|
||||
}
|
||||
}
|
||||
990
src/player/main.rs
Normal file
990
src/player/main.rs
Normal file
@@ -0,0 +1,990 @@
|
||||
use anyhow::Result;
|
||||
use std::env;
|
||||
use std::io::Write;
|
||||
#[cfg(feature = "player")]
|
||||
use std::os::unix::io::AsRawFd;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::process::{Child, Command, Stdio};
|
||||
|
||||
mod api_client;
|
||||
mod asr_overlay;
|
||||
mod chunk_selector;
|
||||
mod selector;
|
||||
|
||||
use api_client::ApiClient;
|
||||
use selector::{VideoEntry, VideoSelector};
|
||||
|
||||
#[allow(dead_code)]
|
||||
const STATUS_BAR_HEIGHT: i32 = 50;
|
||||
#[allow(dead_code)]
|
||||
const FONT_SIZE: i32 = 20;
|
||||
#[allow(dead_code)]
|
||||
const ASR_OVERLAY_HEIGHT: i32 = 80;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
#[allow(dead_code)]
|
||||
enum TerminalCommand {
|
||||
Pause,
|
||||
Sound,
|
||||
SeekBackward,
|
||||
SeekForward,
|
||||
ToggleStatusBar,
|
||||
ToggleAsr,
|
||||
Download,
|
||||
SyncIncrease,
|
||||
SyncDecrease,
|
||||
Quit,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
#[allow(dead_code)]
|
||||
struct VideoInfo {
|
||||
width: u32,
|
||||
height: u32,
|
||||
fps: f64,
|
||||
total_frames: u64,
|
||||
duration: f64,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
struct PlayerState {
|
||||
current_time: f64,
|
||||
current_frame: u64,
|
||||
is_paused: bool,
|
||||
sound_on: bool,
|
||||
quit: bool,
|
||||
status_bar_visible: bool,
|
||||
asr_overlay_visible: bool,
|
||||
}
|
||||
|
||||
struct Config {
|
||||
download_dir: PathBuf,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn start_sound_process(stream_url: &str, start_time: f64) -> Option<Child> {
|
||||
Command::new(if cfg!(target_os = "macos") {
|
||||
"/opt/homebrew/bin/ffplay"
|
||||
} else {
|
||||
"ffplay"
|
||||
})
|
||||
.args([
|
||||
"-nodisp",
|
||||
"-autoexit",
|
||||
"-ss",
|
||||
&format!("{:.2}", start_time),
|
||||
stream_url,
|
||||
])
|
||||
.spawn()
|
||||
.ok()
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
struct FormatOption {
|
||||
format_id: String,
|
||||
resolution: String,
|
||||
ext: String,
|
||||
note: String,
|
||||
filesize: String,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn load_config() -> Config {
|
||||
let config_path = PathBuf::from(env::var("HOME").unwrap_or_default())
|
||||
.join(".config")
|
||||
.join("video_player")
|
||||
.join("config.toml");
|
||||
|
||||
if config_path.exists() {
|
||||
if let Ok(content) = std::fs::read_to_string(&config_path) {
|
||||
for line in content.lines() {
|
||||
let line = line.trim();
|
||||
if line.starts_with("download_dir") && line.contains('=') {
|
||||
let value = line.split('=').nth(1).unwrap_or("").trim();
|
||||
let path = value
|
||||
.trim_matches('"')
|
||||
.replace("~", &env::var("HOME").unwrap_or_default());
|
||||
return Config {
|
||||
download_dir: PathBuf::from(path),
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Config {
|
||||
download_dir: PathBuf::from(env::var("HOME").unwrap_or_default()).join("Downloads"),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn list_available_formats(video_url: &str) -> Result<Vec<FormatOption>> {
|
||||
println!("Fetching available formats...");
|
||||
let output = Command::new("yt-dlp")
|
||||
.args(["-F", "--no-download", video_url])
|
||||
.output()?;
|
||||
|
||||
if !output.status.success() {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
anyhow::bail!("Failed to list formats: {}", stderr);
|
||||
}
|
||||
|
||||
let stdout = String::from_utf8_lossy(&output.stdout);
|
||||
let mut formats = Vec::new();
|
||||
|
||||
for line in stdout.lines() {
|
||||
let line = line.trim();
|
||||
if line.is_empty()
|
||||
|| line.starts_with("ID")
|
||||
|| line.starts_with("---")
|
||||
|| line.contains("storyboard")
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(first_part) = line.split('|').next() {
|
||||
let parts: Vec<&str> = first_part.split_whitespace().collect();
|
||||
if parts.len() >= 3 {
|
||||
let format_id = parts[0].to_string();
|
||||
let ext = parts[1].to_string();
|
||||
let resolution = parts[2].to_string();
|
||||
|
||||
if ext == "mp4" || ext == "webm" || ext == "m4a" {
|
||||
if !resolution.contains("x") {
|
||||
continue;
|
||||
}
|
||||
|
||||
let note = if line.contains("|") {
|
||||
line.split('|')
|
||||
.nth(1)
|
||||
.map(|s| s.trim().to_string())
|
||||
.unwrap_or_default()
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
formats.push(FormatOption {
|
||||
format_id,
|
||||
resolution,
|
||||
ext,
|
||||
note,
|
||||
filesize: String::new(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
formats.sort_by(|a, b| {
|
||||
let a_h = a
|
||||
.resolution
|
||||
.split('x')
|
||||
.nth(1)
|
||||
.unwrap_or("0")
|
||||
.parse::<u32>()
|
||||
.unwrap_or(0);
|
||||
let b_h = b
|
||||
.resolution
|
||||
.split('x')
|
||||
.nth(1)
|
||||
.unwrap_or("0")
|
||||
.parse::<u32>()
|
||||
.unwrap_or(0);
|
||||
b_h.cmp(&a_h)
|
||||
});
|
||||
|
||||
Ok(formats)
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn show_format_menu(formats: &[FormatOption], term_fd: libc::c_int) -> Option<usize> {
|
||||
for (i, fmt) in formats.iter().enumerate().take(10) {
|
||||
println!("[{}] {} {} ({})", i + 1, fmt.resolution, fmt.ext, fmt.note);
|
||||
}
|
||||
println!("----------------------------------------");
|
||||
print!("Enter choice (default: 1): ");
|
||||
|
||||
if std::io::stdout().flush().is_err() {
|
||||
return None;
|
||||
}
|
||||
|
||||
unsafe {
|
||||
let mut termios = std::mem::zeroed();
|
||||
libc::tcgetattr(term_fd, &mut termios);
|
||||
let mut normal = termios;
|
||||
libc::cfmakeraw(&mut normal);
|
||||
libc::tcsetattr(term_fd, libc::TCSANOW, &normal);
|
||||
|
||||
let mut input = String::new();
|
||||
let result = std::io::stdin().read_line(&mut input);
|
||||
|
||||
libc::tcsetattr(term_fd, libc::TCSANOW, &termios);
|
||||
|
||||
if result.is_ok() {
|
||||
let input = input.trim();
|
||||
if input.is_empty() {
|
||||
return Some(0);
|
||||
}
|
||||
if let Ok(choice) = input.parse::<usize>() {
|
||||
if choice >= 1 && choice <= formats.len() {
|
||||
return Some(choice - 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn download_video(video_url: &str, format_id: &str, download_dir: &Path) -> Result<String> {
|
||||
println!("Downloading video to {:?}...", download_dir);
|
||||
|
||||
std::fs::create_dir_all(download_dir)?;
|
||||
|
||||
let output = Command::new("yt-dlp")
|
||||
.args([
|
||||
"-f",
|
||||
format_id,
|
||||
"-o",
|
||||
&format!("{}/%(title)s.%(ext)s", download_dir.display()),
|
||||
video_url,
|
||||
])
|
||||
.output()?;
|
||||
|
||||
if !output.status.success() {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
anyhow::bail!("Download failed: {}", stderr);
|
||||
}
|
||||
|
||||
let stdout = String::from_utf8_lossy(&output.stdout);
|
||||
for line in stdout.lines() {
|
||||
if line.contains("Merging")
|
||||
|| line.contains("Destination")
|
||||
|| line.contains(".mp4")
|
||||
|| line.contains(".webm")
|
||||
{
|
||||
if let Some(path) = line.split("Destination: ").nth(1) {
|
||||
let path = path.trim();
|
||||
if Path::new(path).exists() {
|
||||
return Ok(path.to_string());
|
||||
}
|
||||
}
|
||||
if line.contains(".mp4") || line.contains(".webm") {
|
||||
let filename = line.split_whitespace().last().unwrap_or("");
|
||||
let full_path = download_dir.join(filename);
|
||||
if full_path.exists() {
|
||||
return Ok(full_path.to_string_lossy().to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for entry in (std::fs::read_dir(download_dir)?).flatten() {
|
||||
let path = entry.path();
|
||||
let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
|
||||
if ext == "mp4" || ext == "webm" || ext == "mkv" {
|
||||
return Ok(path.to_string_lossy().to_string());
|
||||
}
|
||||
}
|
||||
|
||||
anyhow::bail!("Could not find downloaded file")
|
||||
}
|
||||
|
||||
fn is_youtube_url(input: &str) -> bool {
|
||||
input.starts_with("http://")
|
||||
|| input.starts_with("https://")
|
||||
|| input.contains("youtube.com")
|
||||
|| input.contains("youtu.be")
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn get_youtube_stream_url(video_url: &str) -> Result<String> {
|
||||
println!("Getting video stream from YouTube...");
|
||||
let output = Command::new("yt-dlp")
|
||||
.args(["-f", "best[ext=mp4][vcodec!=none]", "-g", video_url])
|
||||
.output()?;
|
||||
|
||||
if !output.status.success() {
|
||||
let output = Command::new("yt-dlp")
|
||||
.args(["-f", "best", "-g", video_url])
|
||||
.output()?;
|
||||
if !output.status.success() {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
anyhow::bail!("yt-dlp failed: {}", stderr);
|
||||
}
|
||||
let url = String::from_utf8_lossy(&output.stdout).trim().to_string();
|
||||
if url.is_empty() {
|
||||
anyhow::bail!("yt-dlp returned empty URL");
|
||||
}
|
||||
println!("Stream URL obtained");
|
||||
return Ok(url);
|
||||
}
|
||||
|
||||
let url = String::from_utf8_lossy(&output.stdout).trim().to_string();
|
||||
if url.is_empty() {
|
||||
anyhow::bail!("yt-dlp returned empty URL");
|
||||
}
|
||||
println!("Stream URL obtained");
|
||||
Ok(url)
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn get_video_info(video_path: &str) -> Result<VideoInfo> {
|
||||
let output = Command::new("ffprobe")
|
||||
.args([
|
||||
"-v",
|
||||
"error",
|
||||
"-select_streams",
|
||||
"v:0",
|
||||
"-show_entries",
|
||||
"stream=width,height,r_frame_rate,nb_frames,duration",
|
||||
"-of",
|
||||
"json",
|
||||
video_path,
|
||||
])
|
||||
.output();
|
||||
|
||||
match output {
|
||||
Ok(output) if output.status.success() => {
|
||||
let json_str = String::from_utf8_lossy(&output.stdout);
|
||||
if let Ok(json) = serde_json::from_str::<serde_json::Value>(&json_str) {
|
||||
let stream = &json["streams"][0];
|
||||
return Ok(VideoInfo {
|
||||
width: stream["width"].as_u64().unwrap_or(1280) as u32,
|
||||
height: stream["height"].as_u64().unwrap_or(720) as u32,
|
||||
fps: parse_fps(stream["r_frame_rate"].as_str().unwrap_or("30/1")),
|
||||
total_frames: stream["nb_frames"].as_u64().unwrap_or(0),
|
||||
duration: stream["duration"]
|
||||
.as_str()
|
||||
.and_then(|s| s.parse::<f64>().ok())
|
||||
.unwrap_or(0.0),
|
||||
});
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
Ok(VideoInfo {
|
||||
width: 1280,
|
||||
height: 720,
|
||||
fps: 30.0,
|
||||
total_frames: 0,
|
||||
duration: 0.0,
|
||||
})
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn parse_fps(fps_str: &str) -> f64 {
|
||||
let parts: Vec<&str> = fps_str.split('/').collect();
|
||||
if parts.len() == 2 {
|
||||
let num: f64 = parts[0].parse().unwrap_or(30.0);
|
||||
let den: f64 = parts[1].parse().unwrap_or(1.0);
|
||||
if den > 0.0 {
|
||||
num / den
|
||||
} else {
|
||||
30.0
|
||||
}
|
||||
} else {
|
||||
fps_str.parse().unwrap_or(30.0)
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn format_time(seconds: f64) -> String {
|
||||
let hours = (seconds / 3600.0).floor() as u32;
|
||||
let minutes = ((seconds % 3600.0) / 60.0).floor() as u32;
|
||||
let secs = (seconds % 60.0).floor() as u32;
|
||||
let millis = ((seconds % 1.0) * 100.0).floor() as u32;
|
||||
format!("{:02}:{:02}:{:02}.{:02}", hours, minutes, secs, millis)
|
||||
}
|
||||
|
||||
fn lookup_video_uuid(video_path: &str) -> Option<String> {
|
||||
use std::process::Command as StdCommand;
|
||||
|
||||
// Try to find UUID from database by matching file_path
|
||||
let output = StdCommand::new("psql")
|
||||
.args([
|
||||
"-U",
|
||||
"accusys",
|
||||
"-h",
|
||||
"localhost",
|
||||
"-d",
|
||||
"momentry",
|
||||
"-t",
|
||||
"-A",
|
||||
"-c",
|
||||
&format!(
|
||||
"SELECT uuid FROM videos WHERE file_path LIKE '%{}%' LIMIT 1",
|
||||
video_path.split('/').next_back().unwrap_or("")
|
||||
),
|
||||
])
|
||||
.output()
|
||||
.ok()?;
|
||||
|
||||
if output.status.success() {
|
||||
let uuid = String::from_utf8_lossy(&output.stdout).trim().to_string();
|
||||
if !uuid.is_empty() {
|
||||
return Some(uuid);
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[cfg(not(feature = "player"))]
|
||||
fn draw_status_bar(
|
||||
_video_info: &VideoInfo,
|
||||
_state: &PlayerState,
|
||||
_sync_delay_ms: u64,
|
||||
) -> Result<String> {
|
||||
Ok(format!(
|
||||
"{:.2}s / {:.2}s",
|
||||
_state.current_time, _video_info.duration
|
||||
))
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[cfg(feature = "player")]
|
||||
fn draw_status_bar(
|
||||
_canvas: &mut (),
|
||||
_font: &(),
|
||||
_video_info: &VideoInfo,
|
||||
_state: &PlayerState,
|
||||
_sync_delay_ms: u64,
|
||||
) -> Result<String> {
|
||||
Ok(format!(
|
||||
"{:.2}s / {:.2}s",
|
||||
_state.current_time, _video_info.duration
|
||||
))
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn start_ffmpeg(
|
||||
stream_url: &str,
|
||||
video_width: u32,
|
||||
video_height: u32,
|
||||
video_fps: f64,
|
||||
seek_time: f64,
|
||||
) -> Result<Child> {
|
||||
let ffmpeg_path = if cfg!(target_os = "macos") {
|
||||
"/opt/homebrew/bin/ffmpeg"
|
||||
} else {
|
||||
"ffmpeg"
|
||||
};
|
||||
let fps_str = format!("{:.2}", video_fps);
|
||||
|
||||
let mut cmd = Command::new(ffmpeg_path);
|
||||
|
||||
if seek_time > 0.0 {
|
||||
cmd.args(["-ss", &format!("{:.2}", seek_time)]);
|
||||
}
|
||||
|
||||
cmd.args([
|
||||
"-i",
|
||||
stream_url,
|
||||
"-vf",
|
||||
&format!(
|
||||
"scale={}:{}:force_original_aspect_ratio=decrease",
|
||||
video_width, video_height
|
||||
),
|
||||
"-f",
|
||||
"rawvideo",
|
||||
"-pix_fmt",
|
||||
"rgb24",
|
||||
"-r",
|
||||
&fps_str,
|
||||
"-",
|
||||
])
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::null())
|
||||
.spawn()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to start ffmpeg: {}", e))
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "player"))]
|
||||
fn run_player(_video_path: &str, _video_uuid: Option<String>) -> Result<()> {
|
||||
println!("Player not available - SDL2 not configured");
|
||||
println!("Playing: {} (UUID: {:?})", _video_path, _video_uuid);
|
||||
println!("(This is a stub - full player requires SDL2)");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(feature = "player")]
|
||||
fn run_player(_video_path: &str, _video_uuid: Option<String>) -> Result<()> {
|
||||
println!("Player not available - SDL2 not configured");
|
||||
println!("Playing: {} (UUID: {:?})", _video_path, _video_uuid);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
env_logger::init();
|
||||
let args: Vec<String> = env::args().collect();
|
||||
|
||||
let should_download = args.iter().any(|a| a == "-d" || a == "--download");
|
||||
let show_selector = args.iter().any(|a| a == "-s" || a == "--selector");
|
||||
let test_api_mode = args.iter().any(|a| a == "-t" || a == "--test-api");
|
||||
|
||||
// API Testing Mode
|
||||
if test_api_mode {
|
||||
return run_api_test_mode();
|
||||
}
|
||||
|
||||
// If --selector flag is provided, show video selector
|
||||
if show_selector {
|
||||
return run_selector();
|
||||
}
|
||||
|
||||
let video_path = if args.len() < 2 || (should_download && args.len() < 3) {
|
||||
println!("Video Player\n============\nEnter video path or YouTube URL:");
|
||||
let mut input = String::new();
|
||||
std::io::stdin().read_line(&mut input)?;
|
||||
let trimmed = input.trim().to_string();
|
||||
if trimmed.is_empty() {
|
||||
anyhow::bail!("No input provided");
|
||||
}
|
||||
trimmed
|
||||
} else {
|
||||
let idx = if should_download { 2 } else { 1 };
|
||||
args[idx].clone()
|
||||
};
|
||||
|
||||
println!("Video Player\n============\nInput: {}", video_path);
|
||||
|
||||
let final_path = if should_download && is_youtube_url(&video_path) {
|
||||
println!("Downloading with best quality...");
|
||||
let config = load_config();
|
||||
let format_id = "best[ext=mp4][vcodec!=none]";
|
||||
match download_video(&video_path, format_id, &config.download_dir) {
|
||||
Ok(local_path) => {
|
||||
println!("Download complete: {}", local_path);
|
||||
Some(local_path)
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Download failed: {}, playing stream instead", e);
|
||||
None
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let play_path = final_path.unwrap_or(video_path);
|
||||
|
||||
if is_youtube_url(&play_path) {
|
||||
println!("Source: YouTube");
|
||||
} else if Path::new(&play_path).exists() {
|
||||
println!("Source: Local file");
|
||||
} else {
|
||||
anyhow::bail!("File not found: {}", play_path);
|
||||
}
|
||||
|
||||
let video_uuid = lookup_video_uuid(&play_path);
|
||||
println!("Video UUID: {:?}", video_uuid);
|
||||
run_player(&play_path, video_uuid)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_selector() -> Result<()> {
|
||||
use std::process::Command as StdCommand;
|
||||
|
||||
let _db_url = momentry_core::core::config::DATABASE_URL.as_str();
|
||||
|
||||
// Use psql to query videos
|
||||
let output = StdCommand::new("psql")
|
||||
.args(["-U", "accusys", "-h", "localhost", "-d", "momentry", "-t", "-A",
|
||||
"-c", "SELECT uuid, file_name, file_path, duration, width, height FROM videos ORDER BY created_at DESC"])
|
||||
.output();
|
||||
|
||||
let videos: Vec<VideoEntry> = match output {
|
||||
Ok(out) if out.status.success() => {
|
||||
let stdout = String::from_utf8_lossy(&out.stdout);
|
||||
stdout
|
||||
.lines()
|
||||
.filter(|line| !line.is_empty())
|
||||
.filter_map(|line| {
|
||||
let parts: Vec<&str> = line.split('|').collect();
|
||||
if parts.len() >= 6 {
|
||||
Some(VideoEntry {
|
||||
uuid: parts[0].to_string(),
|
||||
file_name: parts[1].to_string(),
|
||||
file_path: parts[2].to_string(),
|
||||
duration: parts[3].parse().unwrap_or(0.0),
|
||||
width: parts[4].parse().unwrap_or(0),
|
||||
height: parts[5].parse().unwrap_or(0),
|
||||
thumbnail_dir: None,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
_ => {
|
||||
// Fallback: scan directory
|
||||
println!("Could not connect to database, scanning directory...");
|
||||
let test_video_dir = PathBuf::from("/Users/accusys/momentry_core_project/test_video");
|
||||
std::fs::read_dir(&test_video_dir)?
|
||||
.filter_map(|e| e.ok())
|
||||
.filter(|e| {
|
||||
let path = e.path();
|
||||
matches!(
|
||||
path.extension().and_then(|s| s.to_str()),
|
||||
Some("mp4") | Some("mov") | Some("m4v") | Some("avi")
|
||||
)
|
||||
})
|
||||
.map(|e| {
|
||||
let path = e.path();
|
||||
VideoEntry {
|
||||
uuid: format!("{:x}", md5::compute(path.to_string_lossy().as_bytes()))
|
||||
[0..16]
|
||||
.to_string(),
|
||||
file_name: path
|
||||
.file_name()
|
||||
.unwrap_or_default()
|
||||
.to_string_lossy()
|
||||
.to_string(),
|
||||
file_path: path.to_string_lossy().to_string(),
|
||||
duration: 0.0,
|
||||
width: 0,
|
||||
height: 0,
|
||||
thumbnail_dir: None,
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
};
|
||||
|
||||
if videos.is_empty() {
|
||||
println!("No videos found. Register videos first with 'momentry register'");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Try interactive mode, fall back to list if not a terminal
|
||||
if atty::is(atty::Stream::Stdout) {
|
||||
println!("Found {} videos", videos.len());
|
||||
|
||||
let mut selector = VideoSelector::new(videos);
|
||||
match selector.run() {
|
||||
Ok(Some(video)) => {
|
||||
println!("\nPlaying: {} ({})", video.file_name, video.uuid);
|
||||
run_player(&video.file_path, Some(video.uuid))?;
|
||||
}
|
||||
Ok(None) => {
|
||||
println!("\nNo video selected");
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Selector error: {}", e);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Non-interactive: show list
|
||||
println!("\nAvailable videos:");
|
||||
for (i, video) in videos.iter().enumerate() {
|
||||
println!(
|
||||
" {}) {} - {} ({})",
|
||||
i + 1,
|
||||
video.file_name,
|
||||
video.format_duration(),
|
||||
video.uuid
|
||||
);
|
||||
}
|
||||
println!("\nRun with a video path to play, or use interactive mode in a terminal.");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_api_test_mode() -> Result<()> {
|
||||
println!("\n===========================================");
|
||||
println!(" 🎬 API Testing GUI");
|
||||
println!("===========================================\n");
|
||||
|
||||
let client = ApiClient::new();
|
||||
println!("API Server: {}\n", client.base_url());
|
||||
|
||||
println!(
|
||||
"Waiting for API server... (make sure 'cargo run --bin momentry -- server' is running)\n"
|
||||
);
|
||||
|
||||
let rt = tokio::runtime::Runtime::new()?;
|
||||
|
||||
loop {
|
||||
println!("\n┌─────────────────────────────────────────┐");
|
||||
println!("│ Main Menu │");
|
||||
println!("├─────────────────────────────────────────┤");
|
||||
println!("│ [1] Search - 自然語言搜尋影片內容 │");
|
||||
println!("│ [2] List - 列出所有影片 │");
|
||||
println!("│ [3] Register - 註冊新影片 │");
|
||||
println!("│ [4] Lookup - 查詢影片資訊 │");
|
||||
println!("│ [5] Play - 播放影片 │");
|
||||
println!("│ [q] Quit - 離開 │");
|
||||
println!("└─────────────────────────────────────────┘");
|
||||
print!("\n請選擇: ");
|
||||
|
||||
let mut input = String::new();
|
||||
std::io::stdin().read_line(&mut input)?;
|
||||
let choice = input.trim();
|
||||
|
||||
match choice {
|
||||
"1" => {
|
||||
println!("\n=== 🔍 自然語言搜尋 ===");
|
||||
print!("輸入搜尋關鍵字: ");
|
||||
input.clear();
|
||||
std::io::stdin().read_line(&mut input)?;
|
||||
let query = input.trim().to_string();
|
||||
if query.is_empty() {
|
||||
println!("搜尋關鍵字不能為空");
|
||||
continue;
|
||||
}
|
||||
|
||||
print!("是否限定特定影片?(y/N): ");
|
||||
input.clear();
|
||||
std::io::stdin().read_line(&mut input)?;
|
||||
let limit_uuid = if input.trim().to_lowercase() == "y" {
|
||||
print!("輸入影片 UUID: ");
|
||||
input.clear();
|
||||
std::io::stdin().read_line(&mut input)?;
|
||||
Some(input.trim().to_string())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
println!("\n搜尋中...");
|
||||
match rt.block_on(client.search_chunks(&query, limit_uuid.as_deref(), Some(20))) {
|
||||
Ok(response) => {
|
||||
println!("\n找到 {} 個結果:\n", response.results.len());
|
||||
for (i, r) in response.results.iter().enumerate() {
|
||||
let time_range = format!(
|
||||
"{:02}:{:02} - {:02}:{:02}",
|
||||
(r.start_time / 60.0) as u32,
|
||||
(r.start_time % 60.0) as u32,
|
||||
(r.end_time / 60.0) as u32,
|
||||
(r.end_time % 60.0) as u32
|
||||
);
|
||||
let text_preview = if r.text.len() > 60 {
|
||||
format!("{}...", &r.text[..60])
|
||||
} else {
|
||||
r.text.clone()
|
||||
};
|
||||
println!(
|
||||
" [{}] {} | {} | {:.2} | {}",
|
||||
i + 1,
|
||||
time_range,
|
||||
r.uuid.chars().take(8).collect::<String>(),
|
||||
r.score,
|
||||
text_preview
|
||||
);
|
||||
}
|
||||
|
||||
if !response.results.is_empty() {
|
||||
print!("\n選擇要播放的結果編號 (直接Enter跳過): ");
|
||||
input.clear();
|
||||
std::io::stdin().read_line(&mut input)?;
|
||||
if let Ok(idx) = input.trim().parse::<usize>() {
|
||||
if idx > 0 && idx <= response.results.len() {
|
||||
let selected = &response.results[idx - 1];
|
||||
println!("\n正在取得影片路徑...");
|
||||
match rt.block_on(client.lookup_video(&selected.uuid)) {
|
||||
Ok(info) => {
|
||||
if let Some(path) = &info.file_path {
|
||||
println!(
|
||||
"播放: {} @ {:.2}s",
|
||||
path, selected.start_time
|
||||
);
|
||||
let video_uuid = Some(selected.uuid.clone());
|
||||
run_player(path, video_uuid)?;
|
||||
} else {
|
||||
println!("無法取得影片路徑");
|
||||
}
|
||||
}
|
||||
Err(e) => println!("查詢失敗: {}", e),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => println!("搜尋失敗: {}", e),
|
||||
}
|
||||
}
|
||||
"2" => {
|
||||
println!("\n=== 📋 影片列表 ===");
|
||||
println!("載入中...");
|
||||
match rt.block_on(client.list_videos()) {
|
||||
Ok(videos) => {
|
||||
if videos.is_empty() {
|
||||
println!("沒有找到任何影片,請先註冊");
|
||||
} else {
|
||||
println!("\n共 {} 部影片:\n", videos.len());
|
||||
for (i, v) in videos.iter().enumerate() {
|
||||
let duration = format!(
|
||||
"{}:{:02}",
|
||||
(v.duration / 60.0) as u32,
|
||||
(v.duration % 60.0) as u32
|
||||
);
|
||||
println!(
|
||||
" [{}] {} | {} | {}x{} | {}",
|
||||
i + 1,
|
||||
v.file_name,
|
||||
v.uuid.chars().take(8).collect::<String>(),
|
||||
v.width,
|
||||
v.height,
|
||||
duration
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => println!("取得影片列表失敗: {}", e),
|
||||
}
|
||||
}
|
||||
"3" => {
|
||||
println!("\n=== 📝 註冊影片 ===");
|
||||
print!("輸入影片檔案路徑 (直接Enter使用自動搜尋): ");
|
||||
input.clear();
|
||||
std::io::stdin().read_line(&mut input)?;
|
||||
let path = input.trim();
|
||||
|
||||
let video_path = if path.is_empty() {
|
||||
println!("自動搜尋影片...");
|
||||
match api_client::find_video_path() {
|
||||
Some(p) => {
|
||||
println!("找到: {}", p);
|
||||
p
|
||||
}
|
||||
None => {
|
||||
println!("找不到影片檔案,請手動輸入路徑");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
path.to_string()
|
||||
};
|
||||
|
||||
if !std::path::Path::new(&video_path).exists() {
|
||||
println!("檔案不存在: {}", video_path);
|
||||
continue;
|
||||
}
|
||||
|
||||
println!("\n註冊中...");
|
||||
match rt.block_on(client.register_video(&video_path)) {
|
||||
Ok(resp) => {
|
||||
println!("\n✓ 註冊成功!");
|
||||
println!(" UUID: {}", resp.uuid);
|
||||
println!(" 名稱: {}", resp.file_name);
|
||||
println!(" 時長: {:.2}s", resp.duration);
|
||||
println!(" 解析度: {}x{}", resp.width, resp.height);
|
||||
}
|
||||
Err(e) => println!("註冊失敗: {}", e),
|
||||
}
|
||||
}
|
||||
"4" => {
|
||||
print!("\n=== 🔎 查詢影片 ===\n輸入影片 UUID: ");
|
||||
input.clear();
|
||||
std::io::stdin().read_line(&mut input)?;
|
||||
let uuid = input.trim();
|
||||
if uuid.is_empty() {
|
||||
println!("UUID 不能為空");
|
||||
continue;
|
||||
}
|
||||
|
||||
println!("\n查詢中...");
|
||||
match rt.block_on(client.lookup_video(uuid)) {
|
||||
Ok(info) => {
|
||||
println!("\n✓ 找到影片:");
|
||||
println!(" UUID: {}", info.uuid);
|
||||
if let Some(path) = &info.file_path {
|
||||
println!(" 路徑: {}", path);
|
||||
}
|
||||
if let Some(name) = &info.file_name {
|
||||
println!(" 名稱: {}", name);
|
||||
}
|
||||
if let Some(dur) = info.duration {
|
||||
println!(" 時長: {:.2}s", dur);
|
||||
}
|
||||
}
|
||||
Err(e) => println!("查詢失敗: {}", e),
|
||||
}
|
||||
}
|
||||
"5" => {
|
||||
println!("\n=== ▶ 播放影片 ===");
|
||||
print!("輸入影片 UUID (直接Enter從列表選擇): ");
|
||||
input.clear();
|
||||
std::io::stdin().read_line(&mut input)?;
|
||||
let uuid = input.trim();
|
||||
|
||||
let video_path = if uuid.is_empty() {
|
||||
println!("載入影片列表...");
|
||||
match rt.block_on(client.list_videos()) {
|
||||
Ok(videos) => {
|
||||
if videos.is_empty() {
|
||||
println!("沒有影片");
|
||||
continue;
|
||||
}
|
||||
let entries: Vec<VideoEntry> = videos
|
||||
.into_iter()
|
||||
.map(|v| VideoEntry {
|
||||
uuid: v.uuid,
|
||||
file_name: v.file_name,
|
||||
file_path: v.file_path,
|
||||
duration: v.duration,
|
||||
width: v.width,
|
||||
height: v.height,
|
||||
thumbnail_dir: None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
let mut selector = VideoSelector::new(entries);
|
||||
match selector.run() {
|
||||
Ok(Some(video)) => video.file_path,
|
||||
Ok(None) => {
|
||||
println!("取消選擇");
|
||||
continue;
|
||||
}
|
||||
Err(e) => {
|
||||
println!("選擇錯誤: {}", e);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
println!("取得列表失敗: {}", e);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
match rt.block_on(client.lookup_video(uuid)) {
|
||||
Ok(info) => {
|
||||
if let Some(path) = info.file_path {
|
||||
path
|
||||
} else {
|
||||
println!("找不到影片路徑");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
println!("查詢失敗: {}", e);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let video_uuid = if let Ok(info) = rt.block_on(client.lookup_video(&video_path)) {
|
||||
Some(info.uuid)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
println!("\n播放: {}", video_path);
|
||||
run_player(&video_path, video_uuid)?;
|
||||
}
|
||||
"q" | "Q" => {
|
||||
println!("\n再見!");
|
||||
break;
|
||||
}
|
||||
_ => {
|
||||
println!("無效選項,請重新選擇");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,3 +1,13 @@
|
||||
pub mod api_client;
|
||||
pub mod asr_overlay;
|
||||
pub mod chunk_selector;
|
||||
pub mod player;
|
||||
pub mod selector;
|
||||
|
||||
pub use api_client::{
|
||||
ApiClient, LookupResponse, RegisterResponse, SearchResponse, SearchResult, VideoInfo,
|
||||
};
|
||||
pub use asr_overlay::{AsrData, AsrOverlay, AsrSegment};
|
||||
pub use chunk_selector::{ChunkEntry, ChunkSelector};
|
||||
pub use player::{play_video, PlayerConfig};
|
||||
pub use selector::{VideoEntry, VideoSelector};
|
||||
|
||||
163
src/player/selector.rs
Normal file
163
src/player/selector.rs
Normal file
@@ -0,0 +1,163 @@
|
||||
use anyhow::Result;
|
||||
use ratatui::{
|
||||
backend::CrosstermBackend,
|
||||
layout::{Constraint, Direction, Layout},
|
||||
style::{Color, Style},
|
||||
text::{Line, Span},
|
||||
widgets::{Block, Borders, List, ListItem, Paragraph},
|
||||
Frame, Terminal,
|
||||
};
|
||||
use std::io;
|
||||
use std::path::PathBuf;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct VideoEntry {
|
||||
pub uuid: String,
|
||||
pub file_name: String,
|
||||
pub file_path: String,
|
||||
pub duration: f64,
|
||||
pub width: u32,
|
||||
pub height: u32,
|
||||
pub thumbnail_dir: Option<PathBuf>,
|
||||
}
|
||||
|
||||
impl VideoEntry {
|
||||
pub fn format_duration(&self) -> String {
|
||||
let secs = self.duration as u64;
|
||||
let hours = secs / 3600;
|
||||
let mins = (secs % 3600) / 60;
|
||||
let secs = secs % 60;
|
||||
if hours > 0 {
|
||||
format!("{}:{:02}:{:02}", hours, mins, secs)
|
||||
} else {
|
||||
format!("{}:{:02}", mins, secs)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn format_resolution(&self) -> String {
|
||||
format!("{}x{}", self.width, self.height)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct VideoSelector {
|
||||
videos: Vec<VideoEntry>,
|
||||
selected_index: usize,
|
||||
}
|
||||
|
||||
impl VideoSelector {
|
||||
pub fn new(videos: Vec<VideoEntry>) -> Self {
|
||||
Self {
|
||||
videos,
|
||||
selected_index: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn run(&mut self) -> Result<Option<VideoEntry>> {
|
||||
let stdout = io::stdout();
|
||||
let backend = CrosstermBackend::new(stdout);
|
||||
let mut terminal = Terminal::new(backend)?;
|
||||
|
||||
loop {
|
||||
terminal.draw(|f| self.render(f))?;
|
||||
|
||||
match crossterm::event::read() {
|
||||
Ok(crossterm::event::Event::Key(key)) => match key.code {
|
||||
crossterm::event::KeyCode::Up => {
|
||||
if self.selected_index > 0 {
|
||||
self.selected_index -= 1;
|
||||
}
|
||||
}
|
||||
crossterm::event::KeyCode::Down => {
|
||||
if self.selected_index < self.videos.len() - 1 {
|
||||
self.selected_index += 1;
|
||||
}
|
||||
}
|
||||
crossterm::event::KeyCode::Enter => {
|
||||
let selected = self.videos.get(self.selected_index).cloned();
|
||||
terminal.show_cursor()?;
|
||||
return Ok(selected);
|
||||
}
|
||||
crossterm::event::KeyCode::Char('q') | crossterm::event::KeyCode::Esc => {
|
||||
terminal.show_cursor()?;
|
||||
return Ok(None);
|
||||
}
|
||||
_ => {}
|
||||
},
|
||||
Ok(crossterm::event::Event::Resize(_, _)) => {}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn render(&self, f: &mut Frame) {
|
||||
let chunks = Layout::default()
|
||||
.direction(Direction::Vertical)
|
||||
.constraints([
|
||||
Constraint::Length(3),
|
||||
Constraint::Min(0),
|
||||
Constraint::Length(3),
|
||||
])
|
||||
.split(f.area());
|
||||
|
||||
// Title
|
||||
let title = Paragraph::new("🎬 Video Selector")
|
||||
.style(Style::default().fg(Color::Cyan))
|
||||
.block(
|
||||
Block::default()
|
||||
.borders(Borders::ALL)
|
||||
.title(" Select Video "),
|
||||
);
|
||||
f.render_widget(title, chunks[0]);
|
||||
|
||||
// Video list
|
||||
let items: Vec<ListItem> = self
|
||||
.videos
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, video)| {
|
||||
let style = if i == self.selected_index {
|
||||
Style::default().fg(Color::Yellow).bg(Color::DarkGray)
|
||||
} else {
|
||||
Style::default()
|
||||
};
|
||||
|
||||
let duration = video.format_duration();
|
||||
let resolution = video.format_resolution();
|
||||
let thumb_info = if video.thumbnail_dir.is_some() {
|
||||
"📷"
|
||||
} else {
|
||||
""
|
||||
};
|
||||
|
||||
let content = Line::from(vec![
|
||||
Span::raw(format!(
|
||||
"{} ",
|
||||
if i == self.selected_index { "▶" } else { " " }
|
||||
)),
|
||||
Span::raw(&video.file_name),
|
||||
Span::raw(" "),
|
||||
Span::styled(
|
||||
format!("{} | {}", duration, resolution),
|
||||
Style::default().fg(Color::Blue),
|
||||
),
|
||||
Span::raw(" "),
|
||||
Span::raw(thumb_info),
|
||||
]);
|
||||
|
||||
ListItem::new(content).style(style)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let list = List::new(items)
|
||||
.block(Block::default().borders(Borders::ALL).title(" Videos "))
|
||||
.highlight_style(Style::default().fg(Color::Yellow));
|
||||
|
||||
f.render_widget(list, chunks[1]);
|
||||
|
||||
// Help text
|
||||
let help = Paragraph::new(" [↑/↓] Navigate [Enter] Select [q] Quit ")
|
||||
.style(Style::default().fg(Color::DarkGray))
|
||||
.block(Block::default().borders(Borders::ALL));
|
||||
f.render_widget(help, chunks[2]);
|
||||
}
|
||||
}
|
||||
1
src/ui/mod.rs
Normal file
1
src/ui/mod.rs
Normal file
@@ -0,0 +1 @@
|
||||
pub mod progress;
|
||||
413
src/ui/progress/mod.rs
Normal file
413
src/ui/progress/mod.rs
Normal file
@@ -0,0 +1,413 @@
|
||||
use ratatui::prelude::Stylize;
|
||||
use ratatui::{
|
||||
backend::CrosstermBackend,
|
||||
layout::{Constraint, Direction, Layout, Rect},
|
||||
style::{Color, Style},
|
||||
text::Span,
|
||||
widgets::{Block, Borders, Paragraph, Row, Table},
|
||||
Frame, Terminal,
|
||||
};
|
||||
use std::fmt;
|
||||
use std::io;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub enum ProcessorType {
|
||||
Asr,
|
||||
Cut,
|
||||
Asrx,
|
||||
Yolo,
|
||||
Ocr,
|
||||
Face,
|
||||
Pose,
|
||||
Story,
|
||||
Caption,
|
||||
}
|
||||
|
||||
impl ProcessorType {
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
ProcessorType::Asr => "ASR",
|
||||
ProcessorType::Cut => "CUT",
|
||||
ProcessorType::Asrx => "ASRX",
|
||||
ProcessorType::Yolo => "YOLO",
|
||||
ProcessorType::Ocr => "OCR",
|
||||
ProcessorType::Face => "Face",
|
||||
ProcessorType::Pose => "Pose",
|
||||
ProcessorType::Story => "Story",
|
||||
ProcessorType::Caption => "Caption",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for ProcessorType {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{}", self.as_str())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub enum ProcessorStatus {
|
||||
Pending,
|
||||
Running,
|
||||
Completed,
|
||||
Failed,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ProcessorProgress {
|
||||
pub processor_type: ProcessorType,
|
||||
pub status: ProcessorStatus,
|
||||
pub current: u32,
|
||||
pub total: u32,
|
||||
pub message: String,
|
||||
pub elapsed_secs: u64,
|
||||
}
|
||||
|
||||
impl ProcessorProgress {
|
||||
pub fn new(processor_type: ProcessorType) -> Self {
|
||||
Self {
|
||||
processor_type,
|
||||
status: ProcessorStatus::Pending,
|
||||
current: 0,
|
||||
total: 0,
|
||||
message: String::new(),
|
||||
elapsed_secs: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn start(&mut self, total: u32) {
|
||||
self.status = ProcessorStatus::Running;
|
||||
self.total = total;
|
||||
self.current = 0;
|
||||
self.elapsed_secs = 0;
|
||||
}
|
||||
|
||||
pub fn update(&mut self, current: u32, message: &str) {
|
||||
self.current = current;
|
||||
self.message = message.to_string();
|
||||
}
|
||||
|
||||
pub fn complete(&mut self, message: &str) {
|
||||
self.status = ProcessorStatus::Completed;
|
||||
self.current = self.total;
|
||||
self.message = message.to_string();
|
||||
}
|
||||
|
||||
pub fn fail(&mut self, message: &str) {
|
||||
self.status = ProcessorStatus::Failed;
|
||||
self.message = message.to_string();
|
||||
}
|
||||
|
||||
pub fn progress_ratio(&self) -> f64 {
|
||||
if self.total == 0 {
|
||||
0.0
|
||||
} else {
|
||||
self.current as f64 / self.total as f64
|
||||
}
|
||||
}
|
||||
|
||||
pub fn eta(&self) -> Option<std::time::Duration> {
|
||||
if self.status == ProcessorStatus::Completed || self.total == 0 || self.current == 0 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let elapsed = std::time::Duration::from_secs(self.elapsed_secs);
|
||||
let ratio = self.current as f64 / self.total as f64;
|
||||
if ratio <= 0.0 {
|
||||
return None;
|
||||
}
|
||||
let total_estimated = elapsed.div_f64(ratio);
|
||||
Some(total_estimated - elapsed)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ProgressState {
|
||||
pub processors: Vec<ProcessorProgress>,
|
||||
pub video_name: String,
|
||||
pub is_active: bool,
|
||||
}
|
||||
|
||||
impl ProgressState {
|
||||
pub fn new(video_name: &str) -> Self {
|
||||
Self {
|
||||
processors: vec![
|
||||
ProcessorProgress::new(ProcessorType::Asr),
|
||||
ProcessorProgress::new(ProcessorType::Cut),
|
||||
ProcessorProgress::new(ProcessorType::Asrx),
|
||||
ProcessorProgress::new(ProcessorType::Yolo),
|
||||
ProcessorProgress::new(ProcessorType::Ocr),
|
||||
ProcessorProgress::new(ProcessorType::Face),
|
||||
ProcessorProgress::new(ProcessorType::Pose),
|
||||
ProcessorProgress::new(ProcessorType::Story),
|
||||
ProcessorProgress::new(ProcessorType::Caption),
|
||||
],
|
||||
video_name: video_name.to_string(),
|
||||
is_active: false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_processor(&mut self, processor_type: ProcessorType) -> &mut ProcessorProgress {
|
||||
self.processors
|
||||
.iter_mut()
|
||||
.find(|p| p.processor_type == processor_type)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
pub fn completed_count(&self) -> usize {
|
||||
self.processors
|
||||
.iter()
|
||||
.filter(|p| p.status == ProcessorStatus::Completed)
|
||||
.count()
|
||||
}
|
||||
|
||||
pub fn total_count(&self) -> usize {
|
||||
self.processors.len()
|
||||
}
|
||||
|
||||
pub fn overall_progress(&self) -> f64 {
|
||||
let total: f64 = self.processors.iter().map(|p| p.progress_ratio()).sum();
|
||||
total / self.processors.len() as f64
|
||||
}
|
||||
|
||||
pub fn start(&mut self) {
|
||||
self.is_active = true;
|
||||
}
|
||||
|
||||
pub fn stop(&mut self) {
|
||||
self.is_active = false;
|
||||
}
|
||||
|
||||
pub fn update_from_redis(
|
||||
&mut self,
|
||||
msg_type: &str,
|
||||
processor: &str,
|
||||
current: Option<i32>,
|
||||
total: Option<i32>,
|
||||
message: Option<&str>,
|
||||
) {
|
||||
let proc_type = match processor.to_uppercase().as_str() {
|
||||
"ASR" => ProcessorType::Asr,
|
||||
"CUT" => ProcessorType::Cut,
|
||||
"ASRX" => ProcessorType::Asrx,
|
||||
"YOLO" => ProcessorType::Yolo,
|
||||
"OCR" => ProcessorType::Ocr,
|
||||
"FACE" => ProcessorType::Face,
|
||||
"POSE" => ProcessorType::Pose,
|
||||
"STORY" => ProcessorType::Story,
|
||||
"CAPTION" => ProcessorType::Caption,
|
||||
_ => return,
|
||||
};
|
||||
|
||||
let p = self.get_processor(proc_type);
|
||||
|
||||
match msg_type {
|
||||
"START" | "INFO" => {
|
||||
p.status = ProcessorStatus::Running;
|
||||
if let Some(m) = message {
|
||||
p.message = m.to_string();
|
||||
}
|
||||
}
|
||||
"PROGRESS" => {
|
||||
p.status = ProcessorStatus::Running;
|
||||
if let Some(c) = current {
|
||||
p.current = c as u32;
|
||||
}
|
||||
if let Some(t) = total {
|
||||
p.total = t as u32;
|
||||
}
|
||||
if let Some(m) = message {
|
||||
p.message = m.to_string();
|
||||
}
|
||||
}
|
||||
"COMPLETE" => {
|
||||
p.status = ProcessorStatus::Completed;
|
||||
p.current = p.total;
|
||||
if let Some(m) = message {
|
||||
p.message = m.to_string();
|
||||
}
|
||||
}
|
||||
"ERROR" => {
|
||||
p.status = ProcessorStatus::Failed;
|
||||
if let Some(m) = message {
|
||||
p.message = m.to_string();
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ProgressUi {
|
||||
terminal: Terminal<CrosstermBackend<io::Stderr>>,
|
||||
state: std::sync::Mutex<ProgressState>,
|
||||
}
|
||||
|
||||
impl ProgressUi {
|
||||
pub fn new(video_name: &str) -> io::Result<Self> {
|
||||
use crossterm::execute;
|
||||
use crossterm::terminal::{enable_raw_mode, EnterAlternateScreen};
|
||||
|
||||
let mut stderr = io::stderr();
|
||||
|
||||
enable_raw_mode()?;
|
||||
execute!(stderr, EnterAlternateScreen)?;
|
||||
|
||||
let backend = CrosstermBackend::new(stderr);
|
||||
let terminal = Terminal::new(backend)?;
|
||||
|
||||
let state = std::sync::Mutex::new(ProgressState::new(video_name));
|
||||
|
||||
Ok(Self { terminal, state })
|
||||
}
|
||||
|
||||
pub fn state(&self) -> &std::sync::Mutex<ProgressState> {
|
||||
&self.state
|
||||
}
|
||||
|
||||
pub fn render(&mut self) -> io::Result<()> {
|
||||
let state = self.state.lock().unwrap().clone();
|
||||
let video_name = state.video_name.clone();
|
||||
let is_active = state.is_active;
|
||||
let processors = state.processors.clone();
|
||||
let completed = state.completed_count();
|
||||
let total = state.total_count();
|
||||
let overall = state.overall_progress();
|
||||
|
||||
self.terminal.draw(|f| {
|
||||
let chunks = Layout::default()
|
||||
.direction(Direction::Vertical)
|
||||
.constraints([
|
||||
Constraint::Length(3),
|
||||
Constraint::Min(10),
|
||||
Constraint::Length(3),
|
||||
])
|
||||
.split(f.area());
|
||||
|
||||
Self::render_header_static(f, chunks[0], &video_name);
|
||||
Self::render_processors_static(f, chunks[1], &processors);
|
||||
Self::render_footer_static(f, chunks[2], completed, total, overall, is_active);
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn cleanup(&mut self) -> io::Result<()> {
|
||||
use crossterm::execute;
|
||||
use crossterm::terminal::{disable_raw_mode, LeaveAlternateScreen};
|
||||
|
||||
execute!(self.terminal.backend_mut(), LeaveAlternateScreen)?;
|
||||
disable_raw_mode()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn render_header_static(f: &mut Frame, area: Rect, video_name: &str) {
|
||||
let title = format!(" Processing: {} ", video_name);
|
||||
let block = Block::default()
|
||||
.title(title)
|
||||
.borders(Borders::ALL)
|
||||
.style(Style::default().fg(Color::Cyan));
|
||||
f.render_widget(block, area);
|
||||
}
|
||||
|
||||
fn render_processors_static(f: &mut Frame, area: Rect, processors: &[ProcessorProgress]) {
|
||||
let rows: Vec<Row> = processors
|
||||
.iter()
|
||||
.map(|p| Self::processor_to_row_static(p))
|
||||
.collect();
|
||||
|
||||
let widths = [
|
||||
Constraint::Length(8),
|
||||
Constraint::Length(10),
|
||||
Constraint::Min(20),
|
||||
Constraint::Length(12),
|
||||
];
|
||||
|
||||
let table = Table::new(rows, widths)
|
||||
.block(Block::default().borders(Borders::ALL).title(" Processors "))
|
||||
.column_spacing(1);
|
||||
|
||||
f.render_widget(table, area);
|
||||
}
|
||||
|
||||
fn processor_to_row_static(p: &ProcessorProgress) -> Row<'_> {
|
||||
let status_color = match p.status {
|
||||
ProcessorStatus::Pending => Color::DarkGray,
|
||||
ProcessorStatus::Running => Color::Yellow,
|
||||
ProcessorStatus::Completed => Color::Green,
|
||||
ProcessorStatus::Failed => Color::Red,
|
||||
};
|
||||
|
||||
let progress_bar = if p.total > 0 {
|
||||
let filled = (p.progress_ratio() * 20.0) as usize;
|
||||
let bar: String = format!(
|
||||
"[{}{}]",
|
||||
"█".repeat(filled.min(20)),
|
||||
"░".repeat((20 - filled).min(20))
|
||||
);
|
||||
bar
|
||||
} else {
|
||||
"[--------------------]".to_string()
|
||||
};
|
||||
|
||||
let percentage = format!("{:5.1}%", p.progress_ratio() * 100.0);
|
||||
|
||||
let detail = if p.total > 0 {
|
||||
format!("{}/{}", p.current, p.total)
|
||||
} else {
|
||||
"-".to_string()
|
||||
};
|
||||
|
||||
let eta = match p.eta() {
|
||||
Some(d) => {
|
||||
let secs = d.as_secs();
|
||||
if secs > 60 {
|
||||
format!("{}m", secs / 60)
|
||||
} else {
|
||||
format!("{}s", secs)
|
||||
}
|
||||
}
|
||||
None => "-".to_string(),
|
||||
};
|
||||
|
||||
Row::new(vec![
|
||||
Span::raw(format!(" {} ", p.processor_type.as_str())),
|
||||
Span::raw(progress_bar).fg(status_color),
|
||||
Span::raw(format!(" {} {}", detail, eta)),
|
||||
Span::raw(format!(" {} ", percentage)),
|
||||
])
|
||||
}
|
||||
|
||||
fn render_footer_static(
|
||||
f: &mut Frame,
|
||||
area: Rect,
|
||||
completed: usize,
|
||||
total: usize,
|
||||
overall: f64,
|
||||
is_active: bool,
|
||||
) {
|
||||
let status_text = if is_active {
|
||||
format!(
|
||||
" Progress: {}/{} ({:.1}%) | Press Ctrl+C to cancel ",
|
||||
completed,
|
||||
total,
|
||||
overall * 100.0
|
||||
)
|
||||
} else if completed == total {
|
||||
" ✓ All processors completed! ".to_string()
|
||||
} else {
|
||||
" Ready ".to_string()
|
||||
};
|
||||
|
||||
let block = Block::default()
|
||||
.borders(Borders::ALL)
|
||||
.style(Style::default().fg(Color::Green));
|
||||
let paragraph = Paragraph::new(status_text).block(block);
|
||||
f.render_widget(paragraph, area);
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for ProgressUi {
|
||||
fn drop(&mut self) {
|
||||
let _ = self.cleanup();
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user