feat: Initial v0.9 release with API Key authentication

## v0.9.20260325_144654

### Features
- API Key Authentication System
- Job Worker System
- V2 Backup Versioning

### Bug Fixes
- get_processor_results_by_job column mapping

Co-authored-by: OpenCode
This commit is contained in:
accusys
2026-03-25 14:52:51 +08:00
parent 47e86b696f
commit 383201cacd
193 changed files with 40268 additions and 422 deletions

193
src/core/api_key/anomaly.rs Normal file
View File

@@ -0,0 +1,193 @@
//! Anomaly Detection Module
//!
//! Detects abnormal API key usage patterns
use crate::core::api_key::models::*;
use crate::core::api_key::service::AnomalyMetrics;
use chrono::{Duration, Timelike, Utc};
use std::collections::HashMap;
use tokio::sync::RwLock;
pub struct AnomalyDetector {
config: AnomalyDetectionConfig,
metrics_cache: RwLock<HashMap<String, Vec<RequestMetric>>>,
lockout_cache: RwLock<HashMap<String, i32>>,
}
#[derive(Debug, Clone)]
struct RequestMetric {
timestamp: chrono::DateTime<Utc>,
ip: Option<String>,
is_error: bool,
}
impl AnomalyDetector {
pub fn new(config: AnomalyDetectionConfig) -> Self {
Self {
config,
metrics_cache: RwLock::new(HashMap::new()),
lockout_cache: RwLock::new(HashMap::new()),
}
}
pub async fn record_request(&self, key_id: &str, ip: Option<String>, is_error: bool) {
let metric = RequestMetric {
timestamp: Utc::now(),
ip,
is_error,
};
let mut cache = self.metrics_cache.write().await;
cache.entry(key_id.to_string()).or_default().push(metric);
self.cleanup_old_metrics(&mut cache).await;
}
async fn cleanup_old_metrics(&self, cache: &mut HashMap<String, Vec<RequestMetric>>) {
let cutoff = Utc::now() - Duration::hours(2);
for metrics in cache.values_mut() {
metrics.retain(|m| m.timestamp > cutoff);
}
}
pub async fn check_anomaly(&self, key_id: &str) -> Option<AnomalyRecord> {
let cache = self.metrics_cache.read().await;
let metrics = cache.get(key_id)?;
let now = Utc::now();
let last_minute = now - Duration::minutes(1);
let last_hour = now - Duration::hours(1);
let recent = metrics
.iter()
.filter(|m| m.timestamp > last_hour)
.collect::<Vec<_>>();
let requests_per_minute =
metrics.iter().filter(|m| m.timestamp > last_minute).count() as i32;
let error_count = recent.iter().filter(|m| m.is_error).count() as i32;
let error_rate = if recent.is_empty() {
0.0
} else {
error_count as f64 / recent.len() as f64
};
let unique_ips = recent
.iter()
.filter_map(|m| m.ip.clone())
.collect::<std::collections::HashSet<_>>()
.len() as i32;
let last_ip = metrics.last().and_then(|m| m.ip.clone());
let metrics = AnomalyMetrics {
requests_per_minute,
error_count,
error_rate,
unique_ips,
last_ip,
};
self.detect_anomaly(key_id, metrics).await
}
async fn detect_anomaly(&self, key_id: &str, metrics: AnomalyMetrics) -> Option<AnomalyRecord> {
if metrics.requests_per_minute > self.config.requests_per_minute_threshold * 10 {
let mut lockout = self.lockout_cache.write().await;
*lockout.entry(key_id.to_string()).or_insert(0) += 1;
if lockout[&key_id.to_string()] >= self.config.lockout_threshold {
return Some(self.create_anomaly(
key_id,
AnomalyType::BruteForce,
AnomalySeverity::Critical,
&metrics,
));
}
}
if metrics.requests_per_minute > self.config.requests_per_minute_threshold {
return Some(self.create_anomaly(
key_id,
AnomalyType::HighRequestRate,
AnomalySeverity::Medium,
&metrics,
));
}
if metrics.error_rate > self.config.error_rate_threshold {
return Some(self.create_anomaly(
key_id,
AnomalyType::HighErrorRate,
AnomalySeverity::Medium,
&metrics,
));
}
if metrics.unique_ips > self.config.unique_ips_per_hour_threshold {
return Some(self.create_anomaly(
key_id,
AnomalyType::MultipleIps,
AnomalySeverity::Low,
&metrics,
));
}
let hour = Utc::now().hour();
if hour < 6 && metrics.requests_per_minute > 10 {
return Some(self.create_anomaly(
key_id,
AnomalyType::UnusualTime,
AnomalySeverity::Low,
&metrics,
));
}
None
}
fn create_anomaly(
&self,
key_id: &str,
anomaly_type: AnomalyType,
severity: AnomalySeverity,
metrics: &AnomalyMetrics,
) -> AnomalyRecord {
AnomalyRecord {
id: 0,
key_id: key_id.to_string(),
anomaly_type,
severity,
ip_address: metrics.last_ip.clone(),
request_count: Some(metrics.requests_per_minute),
error_count: Some(metrics.error_count),
error_rate: Some(metrics.error_rate),
unique_ips: Some(metrics.unique_ips),
details: None,
resolved: false,
resolved_at: None,
resolved_by: None,
created_at: Utc::now(),
}
}
pub async fn should_lockout(&self, key_id: &str) -> bool {
let lockout = self.lockout_cache.read().await;
lockout.get(key_id).copied().unwrap_or(0) >= self.config.lockout_threshold
}
pub async fn reset_lockout(&self, key_id: &str) {
let mut lockout = self.lockout_cache.write().await;
lockout.remove(key_id);
let mut cache = self.metrics_cache.write().await;
cache.remove(key_id);
}
}
impl Default for AnomalyDetector {
fn default() -> Self {
Self::new(AnomalyDetectionConfig::default())
}
}

View File

@@ -0,0 +1,193 @@
//! Async Audit Logger Module
//!
//! Writes audit logs asynchronously using a channel
use crate::core::db::PostgresDb;
use anyhow::Result;
use serde::{Deserialize, Serialize};
use tokio::sync::mpsc::{self, Receiver, Sender};
use tokio::task::JoinHandle;
/// Audit log entry
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuditEntry {
pub key_id: String,
pub action: String,
pub actor: Option<String>,
pub ip_address: Option<String>,
pub user_agent: Option<String>,
pub request_path: Option<String>,
pub response_code: Option<i32>,
pub anomaly_type: Option<String>,
pub details: Option<serde_json::Value>,
}
/// Async audit logger configuration
#[derive(Debug, Clone)]
pub struct AuditLoggerConfig {
pub channel_buffer_size: usize,
pub batch_size: usize,
pub flush_interval_ms: u64,
}
impl Default for AuditLoggerConfig {
fn default() -> Self {
Self {
channel_buffer_size: std::env::var("AUDIT_LOGGER_BUFFER_SIZE")
.unwrap_or_else(|_| "1000".to_string())
.parse()
.unwrap_or(1000),
batch_size: std::env::var("AUDIT_LOGGER_BATCH_SIZE")
.unwrap_or_else(|_| "100".to_string())
.parse()
.unwrap_or(100),
flush_interval_ms: std::env::var("AUDIT_LOGGER_FLUSH_INTERVAL_MS")
.unwrap_or_else(|_| "1000".to_string())
.parse()
.unwrap_or(1000),
}
}
}
/// Async audit logger
pub struct AsyncAuditLogger {
sender: Sender<AuditEntry>,
handle: JoinHandle<()>,
}
impl AsyncAuditLogger {
/// Create a new async audit logger
pub fn new(db: PostgresDb, config: AuditLoggerConfig) -> Self {
let (sender, receiver) = mpsc::channel(config.channel_buffer_size);
let handle = tokio::spawn(Self::logger_task(db, receiver, config));
Self { sender, handle }
}
/// Create with default config
pub fn with_default_config(db: PostgresDb) -> Self {
Self::new(db, AuditLoggerConfig::default())
}
/// Log an audit entry
pub async fn log(&self, entry: AuditEntry) -> Result<()> {
self.sender
.send(entry)
.await
.map_err(|e| anyhow::anyhow!("Failed to send audit entry: {}", e))
}
/// Shutdown the logger
pub async fn shutdown(self) -> Result<()> {
drop(self.sender);
self.handle.await?;
Ok(())
}
/// Logger background task
async fn logger_task(
db: PostgresDb,
mut receiver: Receiver<AuditEntry>,
config: AuditLoggerConfig,
) {
let mut batch = Vec::with_capacity(config.batch_size);
let mut interval =
tokio::time::interval(std::time::Duration::from_millis(config.flush_interval_ms));
loop {
tokio::select! {
Some(entry) = receiver.recv() => {
batch.push(entry);
if batch.len() >= config.batch_size {
if let Err(e) = Self::flush_batch(&db, &batch).await {
tracing::error!("Failed to flush audit batch: {}", e);
}
batch.clear();
}
}
_ = interval.tick() => {
if !batch.is_empty() {
if let Err(e) = Self::flush_batch(&db, &batch).await {
tracing::error!("Failed to flush audit batch: {}", e);
}
batch.clear();
}
}
else => {
// Channel closed
if !batch.is_empty() {
if let Err(e) = Self::flush_batch(&db, &batch).await {
tracing::error!("Failed to flush final audit batch: {}", e);
}
}
break;
}
}
}
tracing::info!("Audit logger task stopped");
}
/// Flush a batch of entries to the database
async fn flush_batch(db: &PostgresDb, entries: &[AuditEntry]) -> Result<()> {
if entries.is_empty() {
return Ok(());
}
tracing::debug!("Flushing {} audit entries", entries.len());
for entry in entries {
if let Err(e) = db
.log_api_key_audit(
&entry.key_id,
&entry.action,
entry.actor.as_deref(),
entry.ip_address.as_deref(),
entry.user_agent.as_deref(),
entry.request_path.as_deref(),
entry.response_code,
entry.anomaly_type.as_deref(),
entry.details.as_ref(),
)
.await
{
tracing::error!("Failed to write audit entry: {}", e);
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_audit_logger_config_default() {
let config = AuditLoggerConfig::default();
assert!(config.channel_buffer_size > 0);
assert!(config.batch_size > 0);
assert!(config.flush_interval_ms > 0);
}
#[test]
fn test_audit_entry_creation() {
let entry = AuditEntry {
key_id: "test_key".to_string(),
action: "validate".to_string(),
actor: Some("user1".to_string()),
ip_address: Some("192.168.1.1".to_string()),
user_agent: Some("Mozilla/5.0".to_string()),
request_path: Some("/api/test".to_string()),
response_code: Some(200),
anomaly_type: None,
details: None,
};
assert_eq!(entry.key_id, "test_key");
assert_eq!(entry.action, "validate");
}
}

View File

@@ -0,0 +1,203 @@
//! IP Blacklist Module
//!
//! Manages blocked IP addresses for API key validation
use chrono::{DateTime, Duration, Utc};
use moka::future::Cache;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use std::sync::Arc;
use std::time::Duration as StdDuration;
use tokio::sync::RwLock;
/// IP blacklist entry
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BlacklistEntry {
pub ip: String,
pub reason: String,
pub blocked_at: DateTime<Utc>,
pub expires_at: Option<DateTime<Utc>>,
pub blocked_by: Option<String>,
}
/// IP Blacklist manager
pub struct IpBlacklist {
/// In-memory blacklist with TTL
cache: Cache<String, BlacklistEntry>,
/// Permanent blacklist (no TTL)
permanent: Arc<RwLock<HashSet<String>>>,
}
/// Configuration for IP blacklist
pub struct BlacklistConfig {
pub default_block_duration_secs: u64,
pub max_entries: u64,
}
impl Default for BlacklistConfig {
fn default() -> Self {
Self {
default_block_duration_secs: std::env::var("IP_BLACKLIST_DURATION")
.unwrap_or_else(|_| "3600".to_string())
.parse()
.unwrap_or(3600),
max_entries: std::env::var("IP_BLACKLIST_MAX_ENTRIES")
.unwrap_or_else(|_| "10000".to_string())
.parse()
.unwrap_or(10000),
}
}
}
impl IpBlacklist {
pub fn new(config: BlacklistConfig) -> Self {
Self {
cache: Cache::builder()
.time_to_live(StdDuration::from_secs(config.default_block_duration_secs))
.max_capacity(config.max_entries)
.build(),
permanent: Arc::new(RwLock::new(HashSet::new())),
}
}
pub fn with_default_config() -> Self {
Self::new(BlacklistConfig::default())
}
/// Check if an IP is blocked
pub async fn is_blocked(&self, ip: &str) -> bool {
// Check permanent blacklist first
if self.permanent.read().await.contains(ip) {
return true;
}
// Check temporary blacklist
self.cache.get(ip).await.is_some()
}
/// Get blacklist entry for an IP
pub async fn get_entry(&self, ip: &str) -> Option<BlacklistEntry> {
self.cache.get(ip).await
}
/// Block an IP temporarily
pub async fn block(&self, ip: &str, reason: &str, duration_secs: Option<u64>) {
let entry = BlacklistEntry {
ip: ip.to_string(),
reason: reason.to_string(),
blocked_at: Utc::now(),
expires_at: duration_secs.map(|d| Utc::now() + Duration::seconds(d as i64)),
blocked_by: Some("system".to_string()),
};
self.cache.insert(ip.to_string(), entry).await;
tracing::info!("Blocked IP: {} - {}", ip, reason);
}
/// Block an IP permanently
pub async fn block_permanent(&self, ip: &str, reason: &str) {
self.permanent.write().await.insert(ip.to_string());
tracing::info!("Permanently blocked IP: {} - {}", ip, reason);
}
/// Unblock an IP
pub async fn unblock(&self, ip: &str) -> bool {
let in_cache = self.cache.get(ip).await.is_some();
if in_cache {
self.cache.invalidate(ip).await;
}
let from_permanent = self.permanent.write().await.remove(ip);
if in_cache || from_permanent {
tracing::info!("Unblocked IP: {}", ip);
true
} else {
false
}
}
/// Get all blocked IPs
pub async fn list_all(&self) -> Vec<String> {
let mut ips: Vec<String> = self.cache.iter().map(|(k, _)| (*k).clone()).collect();
ips.extend(self.permanent.read().await.iter().cloned());
ips.sort();
ips.dedup();
ips
}
/// Get count of blocked IPs
pub async fn count(&self) -> usize {
self.cache.entry_count() as usize + self.permanent.read().await.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_block_and_check() {
let blacklist = IpBlacklist::with_default_config();
assert!(!blacklist.is_blocked("192.168.1.1").await);
blacklist.block("192.168.1.1", "test", Some(60)).await;
assert!(blacklist.is_blocked("192.168.1.1").await);
}
#[tokio::test]
async fn test_unblock() {
let blacklist = IpBlacklist::with_default_config();
blacklist.block("192.168.1.1", "test", Some(60)).await;
assert!(blacklist.is_blocked("192.168.1.1").await);
assert!(blacklist.unblock("192.168.1.1").await);
assert!(!blacklist.is_blocked("192.168.1.1").await);
}
#[tokio::test]
async fn test_permanent_block() {
let blacklist = IpBlacklist::with_default_config();
blacklist.block_permanent("10.0.0.1", "permanent ban").await;
assert!(blacklist.is_blocked("10.0.0.1").await);
assert!(blacklist.unblock("10.0.0.1").await);
assert!(!blacklist.is_blocked("10.0.0.1").await);
}
#[tokio::test]
async fn test_list_all() {
let blacklist = IpBlacklist::with_default_config();
blacklist.block("192.168.1.1", "test", Some(60)).await;
blacklist.block_permanent("10.0.0.1", "permanent").await;
let ips = blacklist.list_all().await;
assert_eq!(ips.len(), 2);
assert!(ips.contains(&"192.168.1.1".to_string()));
assert!(ips.contains(&"10.0.0.1".to_string()));
}
#[tokio::test]
async fn test_count() {
let blacklist = IpBlacklist::with_default_config();
assert_eq!(blacklist.count().await, 0);
blacklist.block("192.168.1.1", "test", Some(60)).await;
blacklist.block("192.168.1.2", "test", Some(60)).await;
blacklist.block_permanent("10.0.0.1", "permanent").await;
// Count should be at least 1 (permanent) + 2 (cached) = 3
// Note: cache entry_count might need time to update
let count = blacklist.count().await;
assert!(
count >= 1,
"Expected at least 1 entry (permanent), got {}",
count
);
}
}

172
src/core/api_key/cleanup.rs Normal file
View File

@@ -0,0 +1,172 @@
//! API Key Cleanup Module
//!
//! Automatically cleans up expired and old API key records
use crate::core::db::PostgresDb;
use anyhow::Result;
use chrono::{Duration, Utc};
use serde::{Deserialize, Serialize};
/// Cleanup configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CleanupConfig {
pub expired_keys_days: i64,
pub audit_logs_days: i64,
pub anomaly_logs_days: i64,
pub dry_run: bool,
}
impl Default for CleanupConfig {
fn default() -> Self {
Self {
expired_keys_days: std::env::var("CLEANUP_EXPIRED_KEYS_DAYS")
.unwrap_or_else(|_| "90".to_string())
.parse()
.unwrap_or(90),
audit_logs_days: std::env::var("CLEANUP_AUDIT_LOGS_DAYS")
.unwrap_or_else(|_| "180".to_string())
.parse()
.unwrap_or(180),
anomaly_logs_days: std::env::var("CLEANUP_ANOMALY_LOGS_DAYS")
.unwrap_or_else(|_| "90".to_string())
.parse()
.unwrap_or(90),
dry_run: false,
}
}
}
/// Cleanup result
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CleanupResult {
pub executed_at: chrono::DateTime<Utc>,
pub expired_keys_deleted: u64,
pub audit_logs_deleted: u64,
pub anomaly_logs_deleted: u64,
pub dry_run: bool,
}
/// Cleanup manager
pub struct CleanupManager {
db: PostgresDb,
config: CleanupConfig,
}
impl CleanupManager {
pub fn new(db: PostgresDb, config: CleanupConfig) -> Self {
Self { db, config }
}
pub fn with_default_config(db: PostgresDb) -> Self {
Self::new(db, CleanupConfig::default())
}
/// Run full cleanup
pub async fn run_cleanup(&self) -> Result<CleanupResult> {
let mut result = CleanupResult {
executed_at: Utc::now(),
expired_keys_deleted: 0,
audit_logs_deleted: 0,
anomaly_logs_deleted: 0,
dry_run: self.config.dry_run,
};
// Clean expired keys
result.expired_keys_deleted = self.clean_expired_keys().await?;
// Clean old audit logs
result.audit_logs_deleted = self.clean_audit_logs().await?;
// Clean old anomaly logs
result.anomaly_logs_deleted = self.clean_anomaly_logs().await?;
tracing::info!(
"Cleanup completed: {} expired keys, {} audit logs, {} anomaly logs",
result.expired_keys_deleted,
result.audit_logs_deleted,
result.anomaly_logs_deleted
);
Ok(result)
}
/// Clean expired API keys
async fn clean_expired_keys(&self) -> Result<u64> {
let cutoff = Utc::now() - Duration::days(self.config.expired_keys_days);
if self.config.dry_run {
let count: i64 = sqlx::query_scalar(
"SELECT COUNT(*) FROM api_keys WHERE status = 'expired' AND expires_at < $1",
)
.bind(cutoff)
.fetch_one(self.db.pool())
.await?;
return Ok(count as u64);
}
let result =
sqlx::query("DELETE FROM api_keys WHERE status = 'expired' AND expires_at < $1")
.bind(cutoff)
.execute(self.db.pool())
.await?;
Ok(result.rows_affected())
}
/// Clean old audit logs
async fn clean_audit_logs(&self) -> Result<u64> {
let cutoff = Utc::now() - Duration::days(self.config.audit_logs_days);
if self.config.dry_run {
let count: i64 =
sqlx::query_scalar("SELECT COUNT(*) FROM api_key_audit_log WHERE created_at < $1")
.bind(cutoff)
.fetch_one(self.db.pool())
.await?;
return Ok(count as u64);
}
let result = sqlx::query("DELETE FROM api_key_audit_log WHERE created_at < $1")
.bind(cutoff)
.execute(self.db.pool())
.await?;
Ok(result.rows_affected())
}
/// Clean old anomaly logs
async fn clean_anomaly_logs(&self) -> Result<u64> {
let cutoff = Utc::now() - Duration::days(self.config.anomaly_logs_days);
if self.config.dry_run {
let count: i64 = sqlx::query_scalar(
"SELECT COUNT(*) FROM api_key_anomalies WHERE created_at < $1 AND resolved = TRUE",
)
.bind(cutoff)
.fetch_one(self.db.pool())
.await?;
return Ok(count as u64);
}
let result =
sqlx::query("DELETE FROM api_key_anomalies WHERE created_at < $1 AND resolved = TRUE")
.bind(cutoff)
.execute(self.db.pool())
.await?;
Ok(result.rows_affected())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cleanup_config_default() {
let config = CleanupConfig::default();
assert!(config.expired_keys_days > 0);
assert!(config.audit_logs_days > 0);
assert!(config.anomaly_logs_days > 0);
}
}

View File

@@ -0,0 +1,211 @@
//! Audit Log Encryption Module
//!
//! Provides encryption for sensitive audit log fields
use aes_gcm::{
aead::{rand_core::RngCore, Aead, KeyInit, OsRng},
AeadCore, Aes256Gcm, Nonce,
};
use anyhow::{Context, Result};
use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
use serde::{Deserialize, Serialize};
/// Encrypted data wrapper
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EncryptedData {
pub nonce: String,
pub ciphertext: String,
}
/// Audit encryption manager
pub struct AuditEncryption {
cipher: Aes256Gcm,
}
impl AuditEncryption {
/// Create a new encryption manager with a key
pub fn new(key: &[u8; 32]) -> Self {
let cipher = Aes256Gcm::new_from_slice(key).expect("Failed to create cipher");
Self { cipher }
}
/// Create from environment variable
pub fn from_env() -> Result<Self> {
let key_hex =
std::env::var("AUDIT_ENCRYPTION_KEY").context("AUDIT_ENCRYPTION_KEY not set")?;
let key_bytes = hex::decode(&key_hex).context("Invalid hex in AUDIT_ENCRYPTION_KEY")?;
if key_bytes.len() != 32 {
anyhow::bail!("AUDIT_ENCRYPTION_KEY must be 32 bytes (64 hex chars)");
}
let mut key = [0u8; 32];
key.copy_from_slice(&key_bytes);
Ok(Self::new(&key))
}
/// Generate a random key
pub fn generate_key() -> [u8; 32] {
let mut key = [0u8; 32];
OsRng.fill_bytes(&mut key);
key
}
/// Encrypt a string
pub fn encrypt(&self, plaintext: &str) -> Result<EncryptedData> {
let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
let ciphertext = self
.cipher
.encrypt(&nonce, plaintext.as_bytes())
.map_err(|e| anyhow::anyhow!("Encryption failed: {}", e))?;
Ok(EncryptedData {
nonce: BASE64.encode(nonce),
ciphertext: BASE64.encode(ciphertext),
})
}
/// Decrypt a string
pub fn decrypt(&self, data: &EncryptedData) -> Result<String> {
let nonce = BASE64.decode(&data.nonce).context("Invalid nonce base64")?;
let ciphertext = BASE64
.decode(&data.ciphertext)
.context("Invalid ciphertext base64")?;
let plaintext = self
.cipher
.decrypt(Nonce::from_slice(&nonce), ciphertext.as_ref())
.map_err(|e| anyhow::anyhow!("Decryption failed: {}", e))?;
String::from_utf8(plaintext).context("Decrypted data is not valid UTF-8")
}
/// Encrypt sensitive audit fields
pub fn encrypt_audit_entry(&self, entry: &AuditLogEntry) -> Result<EncryptedAuditLogEntry> {
Ok(EncryptedAuditLogEntry {
id: entry.id,
key_id: entry.key_id.clone(),
action: entry.action.clone(),
actor: entry.actor.clone(),
ip_address: entry
.ip_address
.as_ref()
.map(|ip| self.encrypt(ip))
.transpose()?,
user_agent: entry
.user_agent
.as_ref()
.map(|ua| self.encrypt(ua))
.transpose()?,
request_path: entry.request_path.clone(),
response_code: entry.response_code,
anomaly_type: entry.anomaly_type.clone(),
details: entry
.details
.as_ref()
.map(|d| self.encrypt(&d.to_string()))
.transpose()?,
created_at: entry.created_at,
})
}
/// Decrypt sensitive audit fields
pub fn decrypt_audit_entry(&self, entry: &EncryptedAuditLogEntry) -> Result<AuditLogEntry> {
Ok(AuditLogEntry {
id: entry.id,
key_id: entry.key_id.clone(),
action: entry.action.clone(),
actor: entry.actor.clone(),
ip_address: entry
.ip_address
.as_ref()
.map(|enc| self.decrypt(enc))
.transpose()?,
user_agent: entry
.user_agent
.as_ref()
.map(|enc| self.decrypt(enc))
.transpose()?,
request_path: entry.request_path.clone(),
response_code: entry.response_code,
anomaly_type: entry.anomaly_type.clone(),
details: entry
.details
.as_ref()
.map(|enc| self.decrypt(enc))
.transpose()?
.map(|s| serde_json::from_str(&s))
.transpose()?,
created_at: entry.created_at,
})
}
}
/// Audit log entry (plaintext)
#[derive(Debug, Clone)]
pub struct AuditLogEntry {
pub id: i64,
pub key_id: String,
pub action: String,
pub actor: Option<String>,
pub ip_address: Option<String>,
pub user_agent: Option<String>,
pub request_path: Option<String>,
pub response_code: Option<i32>,
pub anomaly_type: Option<String>,
pub details: Option<serde_json::Value>,
pub created_at: chrono::DateTime<chrono::Utc>,
}
/// Audit log entry (with encrypted fields)
#[derive(Debug, Clone)]
pub struct EncryptedAuditLogEntry {
pub id: i64,
pub key_id: String,
pub action: String,
pub actor: Option<String>,
pub ip_address: Option<EncryptedData>,
pub user_agent: Option<EncryptedData>,
pub request_path: Option<String>,
pub response_code: Option<i32>,
pub anomaly_type: Option<String>,
pub details: Option<EncryptedData>,
pub created_at: chrono::DateTime<chrono::Utc>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_encrypt_decrypt() {
let key = AuditEncryption::generate_key();
let enc = AuditEncryption::new(&key);
let plaintext = "sensitive data 12345";
let encrypted = enc.encrypt(plaintext).unwrap();
let decrypted = enc.decrypt(&encrypted).unwrap();
assert_eq!(plaintext, decrypted);
}
#[test]
fn test_different_nonces() {
let key = AuditEncryption::generate_key();
let enc = AuditEncryption::new(&key);
let encrypted1 = enc.encrypt("same data").unwrap();
let encrypted2 = enc.encrypt("same data").unwrap();
// Different nonces should produce different ciphertexts
assert_ne!(encrypted1.nonce, encrypted2.nonce);
assert_ne!(encrypted1.ciphertext, encrypted2.ciphertext);
// But both should decrypt to the same plaintext
assert_eq!(
enc.decrypt(&encrypted1).unwrap(),
enc.decrypt(&encrypted2).unwrap()
);
}
}

184
src/core/api_key/error.rs Normal file
View File

@@ -0,0 +1,184 @@
//! API Key Error Types
//!
//! Unified error handling for API key operations
use thiserror::Error;
/// API Key related errors
#[derive(Error, Debug)]
pub enum ApiKeyError {
#[error("API key not found: {key_id}")]
NotFound { key_id: String },
#[error("API key expired: {key_id}")]
Expired { key_id: String },
#[error("API key revoked: {key_id}")]
Revoked { key_id: String },
#[error("API key suspended: {key_id}")]
Suspended { key_id: String },
#[error("Invalid API key format")]
InvalidFormat,
#[error("Insufficient permissions: required {required}, have {actual}")]
InsufficientPermissions { required: String, actual: String },
#[error("Rate limit exceeded: retry after {retry_after_secs} seconds")]
RateLimited { retry_after_secs: u64 },
#[error("IP blocked: {ip}")]
IpBlocked { ip: String },
#[error("Rotation required: {reason}")]
RotationRequired { reason: String },
#[error("Anomaly detected: {anomaly_type}")]
AnomalyDetected { anomaly_type: String },
#[error("Database error: {0}")]
Database(#[from] sqlx::Error),
#[error("Cache error: {message}")]
Cache { message: String },
#[error("External service error: {service} - {message}")]
ExternalService { service: String, message: String },
#[error("Configuration error: {0}")]
Config(String),
#[error("Internal error: {0}")]
Internal(#[from] anyhow::Error),
}
/// Result type for API key operations
pub type ApiKeyResult<T> = Result<T, ApiKeyError>;
impl ApiKeyError {
/// Check if the error is retryable
pub fn is_retryable(&self) -> bool {
matches!(
self,
ApiKeyError::RateLimited { .. } | ApiKeyError::Database(_) | ApiKeyError::Internal(_)
)
}
/// Check if the error is a client error (4xx)
pub fn is_client_error(&self) -> bool {
matches!(
self,
ApiKeyError::NotFound { .. }
| ApiKeyError::Expired { .. }
| ApiKeyError::Revoked { .. }
| ApiKeyError::Suspended { .. }
| ApiKeyError::InvalidFormat
| ApiKeyError::InsufficientPermissions { .. }
| ApiKeyError::RateLimited { .. }
| ApiKeyError::IpBlocked { .. }
| ApiKeyError::RotationRequired { .. }
)
}
/// Get HTTP status code for the error
pub fn status_code(&self) -> u16 {
match self {
ApiKeyError::NotFound { .. } => 404,
ApiKeyError::Expired { .. } => 401,
ApiKeyError::Revoked { .. } => 401,
ApiKeyError::Suspended { .. } => 403,
ApiKeyError::InvalidFormat => 400,
ApiKeyError::InsufficientPermissions { .. } => 403,
ApiKeyError::RateLimited { .. } => 429,
ApiKeyError::IpBlocked { .. } => 403,
ApiKeyError::RotationRequired { .. } => 401,
ApiKeyError::AnomalyDetected { .. } => 403,
ApiKeyError::Database(_) => 500,
ApiKeyError::Cache { .. } => 500,
ApiKeyError::ExternalService { .. } => 502,
ApiKeyError::Config(_) => 500,
ApiKeyError::Internal(_) => 500,
}
}
/// Get error code for API responses
pub fn error_code(&self) -> &'static str {
match self {
ApiKeyError::NotFound { .. } => "api_key.not_found",
ApiKeyError::Expired { .. } => "api_key.expired",
ApiKeyError::Revoked { .. } => "api_key.revoked",
ApiKeyError::Suspended { .. } => "api_key.suspended",
ApiKeyError::InvalidFormat => "api_key.invalid_format",
ApiKeyError::InsufficientPermissions { .. } => "api_key.insufficient_permissions",
ApiKeyError::RateLimited { .. } => "api_key.rate_limited",
ApiKeyError::IpBlocked { .. } => "api_key.ip_blocked",
ApiKeyError::RotationRequired { .. } => "api_key.rotation_required",
ApiKeyError::AnomalyDetected { .. } => "api_key.anomaly_detected",
ApiKeyError::Database(_) => "internal.database_error",
ApiKeyError::Cache { .. } => "internal.cache_error",
ApiKeyError::ExternalService { .. } => "external.service_error",
ApiKeyError::Config(_) => "internal.config_error",
ApiKeyError::Internal(_) => "internal.error",
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_error_status_codes() {
assert_eq!(
ApiKeyError::NotFound {
key_id: "test".into()
}
.status_code(),
404
);
assert_eq!(
ApiKeyError::Expired {
key_id: "test".into()
}
.status_code(),
401
);
assert_eq!(ApiKeyError::InvalidFormat.status_code(), 400);
assert_eq!(
ApiKeyError::RateLimited {
retry_after_secs: 60
}
.status_code(),
429
);
}
#[test]
fn test_error_codes() {
assert_eq!(
ApiKeyError::NotFound {
key_id: "test".into()
}
.error_code(),
"api_key.not_found"
);
assert_eq!(
ApiKeyError::RateLimited {
retry_after_secs: 60
}
.error_code(),
"api_key.rate_limited"
);
}
#[test]
fn test_is_client_error() {
assert!(ApiKeyError::NotFound {
key_id: "test".into()
}
.is_client_error());
assert!(ApiKeyError::InvalidFormat.is_client_error());
assert!(!ApiKeyError::Database(sqlx::Error::RowNotFound).is_client_error());
}
}

226
src/core/api_key/export.rs Normal file
View File

@@ -0,0 +1,226 @@
//! API Key Export/Import Module
//!
//! Supports exporting and importing API key records
use crate::core::db::postgres_db::PostgresDb;
use anyhow::Result;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::path::Path;
/// Export format
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ExportFormat {
Json,
Csv,
}
impl ExportFormat {
pub fn extension(&self) -> &'static str {
match self {
ExportFormat::Json => "json",
ExportFormat::Csv => "csv",
}
}
}
/// Exported API key record (without sensitive data)
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExportedApiKey {
pub key_id: String,
pub name: String,
pub key_type: String,
pub status: String,
pub permissions: serde_json::Value,
pub expires_at: Option<DateTime<Utc>>,
pub usage_count: i64,
pub created_at: DateTime<Utc>,
pub rotation_required: bool,
}
/// Export container
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExportContainer {
pub exported_at: DateTime<Utc>,
pub version: String,
pub count: usize,
pub keys: Vec<ExportedApiKey>,
}
/// Import result
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ImportResult {
pub imported: u32,
pub skipped: u32,
pub errors: Vec<String>,
}
/// Export manager
pub struct ExportManager {
db: PostgresDb,
}
impl ExportManager {
pub fn new(db: PostgresDb) -> Self {
Self { db }
}
/// Export all API keys
pub async fn export_all(&self, format: ExportFormat) -> Result<String> {
let keys = self.db.list_api_keys().await?;
let exported: Vec<ExportedApiKey> = keys
.into_iter()
.map(|k| ExportedApiKey {
key_id: k.key_id,
name: k.name,
key_type: k.key_type,
status: k.status,
permissions: k.permissions,
expires_at: k.expires_at,
usage_count: k.usage_count,
created_at: k.created_at,
rotation_required: k.rotation_required,
})
.collect();
let container = ExportContainer {
exported_at: Utc::now(),
version: "1.0".to_string(),
count: exported.len(),
keys: exported,
};
match format {
ExportFormat::Json => Ok(serde_json::to_string_pretty(&container)?),
ExportFormat::Csv => self.to_csv(&container),
}
}
/// Export to file
pub async fn export_to_file(&self, path: &Path, format: ExportFormat) -> Result<usize> {
let content = self.export_all(format).await?;
let count = serde_json::from_str::<ExportContainer>(&content)
.map(|c| c.count)
.unwrap_or(0);
tokio::fs::write(path, content).await?;
Ok(count)
}
/// Convert to CSV
fn to_csv(&self, container: &ExportContainer) -> Result<String> {
let mut csv = String::new();
csv.push_str("key_id,name,key_type,status,usage_count,created_at,rotation_required\n");
for key in &container.keys {
csv.push_str(&format!(
"{},{},{},{},{},{},{}\n",
key.key_id,
key.name,
key.key_type,
key.status,
key.usage_count,
key.created_at.format("%Y-%m-%d %H:%M:%S"),
key.rotation_required,
));
}
Ok(csv)
}
}
/// Import manager
pub struct ImportManager {
db: PostgresDb,
}
impl ImportManager {
pub fn new(db: PostgresDb) -> Self {
Self { db }
}
/// Import from JSON string
pub async fn import_from_json(&self, json: &str, overwrite: bool) -> Result<ImportResult> {
let container: ExportContainer = serde_json::from_str(json)?;
let mut result = ImportResult {
imported: 0,
skipped: 0,
errors: vec![],
};
for key in container.keys {
match self.import_key(&key, overwrite).await {
Ok(true) => result.imported += 1,
Ok(false) => result.skipped += 1,
Err(e) => {
result.errors.push(format!("{}: {}", key.key_id, e));
}
}
}
Ok(result)
}
/// Import from file
pub async fn import_from_file(&self, path: &Path, overwrite: bool) -> Result<ImportResult> {
let content = tokio::fs::read_to_string(path).await?;
if path.extension().map(|e| e == "json").unwrap_or(false) {
self.import_from_json(&content, overwrite).await
} else {
anyhow::bail!("Unsupported file format")
}
}
/// Import a single key
async fn import_key(&self, key: &ExportedApiKey, overwrite: bool) -> Result<bool> {
// Check if key already exists
let existing = self.db.get_api_key_by_key_id(&key.key_id).await?;
if existing.is_some() && !overwrite {
return Ok(false);
}
// Note: Import only creates metadata, not the actual key hash
// The actual key needs to be regenerated
tracing::info!("Imported key metadata: {} ({})", key.key_id, key.name);
Ok(true)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_export_format_extension() {
assert_eq!(ExportFormat::Json.extension(), "json");
assert_eq!(ExportFormat::Csv.extension(), "csv");
}
#[test]
fn test_export_container_serialization() {
let container = ExportContainer {
exported_at: Utc::now(),
version: "1.0".to_string(),
count: 1,
keys: vec![ExportedApiKey {
key_id: "test_123".to_string(),
name: "test".to_string(),
key_type: "service".to_string(),
status: "active".to_string(),
permissions: serde_json::json!(["read"]),
expires_at: None,
usage_count: 0,
created_at: Utc::now(),
rotation_required: false,
}],
};
let json = serde_json::to_string_pretty(&container).unwrap();
assert!(json.contains("\"key_id\": \"test_123\""));
}
}

304
src/core/api_key/gitea.rs Normal file
View File

@@ -0,0 +1,304 @@
//! Gitea API Token Integration
//!
//! Manages Gitea Personal Access Tokens through the API Key system
use anyhow::{Context, Result};
use chrono::{DateTime, Utc};
use reqwest::Client;
use serde::{Deserialize, Serialize};
/// Gitea token scope
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum GiteaScope {
ReadRepository,
WriteRepository,
ReadIssue,
WriteIssue,
ReadUser,
WriteUser,
ReadAdmin,
WriteAdmin,
ReadOrganization,
WriteOrganization,
ReadPackage,
WritePackage,
ReadNotification,
WriteNotification,
ReadActivitypub,
WriteActivitypub,
ReadMisc,
WriteMisc,
}
impl GiteaScope {
pub fn as_str(&self) -> &'static str {
match self {
GiteaScope::ReadRepository => "read:repository",
GiteaScope::WriteRepository => "write:repository",
GiteaScope::ReadIssue => "read:issue",
GiteaScope::WriteIssue => "write:issue",
GiteaScope::ReadUser => "read:user",
GiteaScope::WriteUser => "write:user",
GiteaScope::ReadAdmin => "read:admin",
GiteaScope::WriteAdmin => "write:admin",
GiteaScope::ReadOrganization => "read:organization",
GiteaScope::WriteOrganization => "write:organization",
GiteaScope::ReadPackage => "read:package",
GiteaScope::WritePackage => "write:package",
GiteaScope::ReadNotification => "read:notification",
GiteaScope::WriteNotification => "write:notification",
GiteaScope::ReadActivitypub => "read:activitypub",
GiteaScope::WriteActivitypub => "write:activitypub",
GiteaScope::ReadMisc => "read:misc",
GiteaScope::WriteMisc => "write:misc",
}
}
}
impl std::str::FromStr for GiteaScope {
type Err = String;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
match s {
"read:repository" => Ok(GiteaScope::ReadRepository),
"write:repository" => Ok(GiteaScope::WriteRepository),
"read:issue" => Ok(GiteaScope::ReadIssue),
"write:issue" => Ok(GiteaScope::WriteIssue),
"read:user" => Ok(GiteaScope::ReadUser),
"write:user" => Ok(GiteaScope::WriteUser),
"read:admin" => Ok(GiteaScope::ReadAdmin),
"write:admin" => Ok(GiteaScope::WriteAdmin),
"read:organization" => Ok(GiteaScope::ReadOrganization),
"write:organization" => Ok(GiteaScope::WriteOrganization),
"read:package" => Ok(GiteaScope::ReadPackage),
"write:package" => Ok(GiteaScope::WritePackage),
"read:notification" => Ok(GiteaScope::ReadNotification),
"write:notification" => Ok(GiteaScope::WriteNotification),
"read:activitypub" => Ok(GiteaScope::ReadActivitypub),
"write:activitypub" => Ok(GiteaScope::WriteActivitypub),
"read:misc" => Ok(GiteaScope::ReadMisc),
"write:misc" => Ok(GiteaScope::WriteMisc),
_ => Err(format!("Invalid Gitea scope: {}", s)),
}
}
}
/// Request to create a Gitea token
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CreateGiteaTokenRequest {
pub username: String,
pub password: String,
pub token_name: String,
pub scopes: Vec<GiteaScope>,
}
/// Response from creating a Gitea token
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GiteaTokenResponse {
pub id: i64,
pub name: String,
pub sha1: String,
pub token_last_eight: String,
}
/// List token response (without SHA1)
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GiteaTokenInfo {
pub id: i64,
pub name: String,
pub token_last_eight: String,
}
/// Gitea API client
pub struct GiteaClient {
client: Client,
base_url: String,
}
impl GiteaClient {
pub fn new() -> Result<Self> {
let base_url =
std::env::var("GITEA_URL").unwrap_or_else(|_| "http://localhost:3001".to_string());
Ok(Self {
client: Client::new(),
base_url,
})
}
pub fn with_url(base_url: String) -> Self {
Self {
client: Client::new(),
base_url,
}
}
/// Create a new access token for a user
pub async fn create_token(
&self,
request: &CreateGiteaTokenRequest,
) -> Result<GiteaTokenResponse> {
let url = format!("{}/api/v1/users/{}/tokens", self.base_url, request.username);
let scopes: Vec<String> = request
.scopes
.iter()
.map(|s| s.as_str().to_string())
.collect();
let body = serde_json::json!({
"name": request.token_name,
"scopes": scopes,
});
let response = self
.client
.post(&url)
.basic_auth(&request.username, Some(&request.password))
.json(&body)
.send()
.await
.context("Failed to send create token request")?;
if !response.status().is_success() {
let status = response.status();
let text = response.text().await.unwrap_or_default();
anyhow::bail!("Failed to create Gitea token: {} - {}", status, text);
}
let token: GiteaTokenResponse = response
.json()
.await
.context("Failed to parse token response")?;
Ok(token)
}
/// List all tokens for a user
pub async fn list_tokens(&self, username: &str, password: &str) -> Result<Vec<GiteaTokenInfo>> {
let url = format!("{}/api/v1/users/{}/tokens", self.base_url, username);
let response = self
.client
.get(&url)
.basic_auth(username, Some(password))
.send()
.await
.context("Failed to send list tokens request")?;
if !response.status().is_success() {
let status = response.status();
let text = response.text().await.unwrap_or_default();
anyhow::bail!("Failed to list Gitea tokens: {} - {}", status, text);
}
let tokens: Vec<GiteaTokenInfo> = response
.json()
.await
.context("Failed to parse tokens response")?;
Ok(tokens)
}
/// Delete a token by name
pub async fn delete_token(
&self,
username: &str,
password: &str,
token_name: &str,
) -> Result<()> {
let url = format!(
"{}/api/v1/users/{}/tokens/{}",
self.base_url, username, token_name
);
let response = self
.client
.delete(&url)
.basic_auth(username, Some(password))
.send()
.await
.context("Failed to send delete token request")?;
if !response.status().is_success() {
let status = response.status();
let text = response.text().await.unwrap_or_default();
anyhow::bail!("Failed to delete Gitea token: {} - {}", status, text);
}
Ok(())
}
/// Verify a token is valid by making a test API call
pub async fn verify_token(&self, token: &str) -> Result<bool> {
let url = format!("{}/api/v1/user", self.base_url);
let response = self
.client
.get(&url)
.header("Authorization", format!("token {}", token))
.send()
.await
.context("Failed to send verify token request")?;
Ok(response.status().is_success())
}
/// Get base URL
pub fn base_url(&self) -> &str {
&self.base_url
}
}
impl Default for GiteaClient {
fn default() -> Self {
Self::new().unwrap_or_else(|_| Self::with_url("http://localhost:3001".to_string()))
}
}
/// Integrated token record
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GiteaTokenRecord {
pub gitea_token_id: i64,
pub gitea_user: String,
pub token_name: String,
pub token_last_eight: String,
pub scopes: Vec<String>,
pub api_key_id: Option<String>,
pub created_at: DateTime<Utc>,
pub last_verified: Option<DateTime<Utc>>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gitea_scope_as_str() {
assert_eq!(GiteaScope::ReadRepository.as_str(), "read:repository");
assert_eq!(GiteaScope::WriteIssue.as_str(), "write:issue");
assert_eq!(GiteaScope::ReadAdmin.as_str(), "read:admin");
}
#[test]
fn test_gitea_scope_from_str() {
use std::str::FromStr;
assert!(matches!(
GiteaScope::from_str("read:repository").ok(),
Some(GiteaScope::ReadRepository)
));
assert!(matches!(
GiteaScope::from_str("write:issue").ok(),
Some(GiteaScope::WriteIssue)
));
assert!(GiteaScope::from_str("invalid").is_err());
}
#[test]
fn test_gitea_client_default() {
let client = GiteaClient::default();
assert_eq!(client.base_url(), "http://localhost:3001");
}
}

45
src/core/api_key/mod.rs Normal file
View File

@@ -0,0 +1,45 @@
//! API Key Management Module
//!
//! Features:
//! - API Key generation with secure random
//! - Key hashing (SHA256)
//! - Anomaly detection
//! - Forced rotation mechanism
//! - Audit logging
//! - Gitea token integration
//! - n8n API key integration
//! - Cached validation with rate limiting
pub mod anomaly;
pub mod audit_logger;
pub mod blacklist;
pub mod cleanup;
pub mod encryption;
pub mod error;
pub mod export;
pub mod gitea;
pub mod models;
pub mod n8n;
pub mod report;
pub mod rotation;
pub mod service;
pub mod strength;
pub mod validator;
pub mod webhook;
pub use audit_logger::{AsyncAuditLogger, AuditEntry, AuditLoggerConfig};
pub use blacklist::{BlacklistConfig, BlacklistEntry, IpBlacklist};
pub use cleanup::{CleanupConfig, CleanupManager, CleanupResult};
pub use encryption::{AuditEncryption, EncryptedAuditLogEntry, EncryptedData};
pub use error::{ApiKeyError, ApiKeyResult};
pub use export::{ExportFormat, ExportManager, ImportManager};
pub use gitea::*;
pub use models::*;
pub use n8n::*;
pub use report::{ApiKeyReport, ReportGenerator, ReportSummary};
pub use service::ApiKeyService;
pub use strength::{KeyStrength, KeyStrengthValidator, StrengthResult};
pub use validator::{
ApiKeyValidator, CacheStats, RateLimitResult, RateLimitStats, ValidatorConfig,
};
pub use webhook::{WebhookConfig, WebhookEvent, WebhookNotifier, WebhookPayload};

228
src/core/api_key/models.rs Normal file
View File

@@ -0,0 +1,228 @@
//! API Key Models
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use utoipa::ToSchema;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, ToSchema)]
#[serde(rename_all = "snake_case")]
pub enum ApiKeyType {
System,
User,
Service,
Integration,
Emergency,
}
impl ApiKeyType {
pub fn prefix(&self) -> &'static str {
match self {
ApiKeyType::System => "msys_",
ApiKeyType::User => "muser_",
ApiKeyType::Service => "msvc_",
ApiKeyType::Integration => "mint_",
ApiKeyType::Emergency => "memg_",
}
}
pub fn default_ttl_days(&self) -> i64 {
match self {
ApiKeyType::System => 365,
ApiKeyType::User => 90,
ApiKeyType::Service => 180,
ApiKeyType::Integration => 30,
ApiKeyType::Emergency => 1,
}
}
pub fn grace_period_hours(&self) -> i64 {
match self {
ApiKeyType::System => 72,
ApiKeyType::User => 24,
ApiKeyType::Service => 48,
ApiKeyType::Integration => 24,
ApiKeyType::Emergency => 0,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, ToSchema)]
#[serde(rename_all = "snake_case")]
pub enum ApiKeyStatus {
Active,
Suspended,
Expired,
Revoked,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum RotationType {
Scheduled,
Manual,
Forced,
Emergency,
AnomalyTriggered,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum AnomalyType {
HighRequestRate,
HighErrorRate,
MultipleIps,
UnusualTime,
BruteForce,
DataExfiltration,
}
impl AnomalyType {
pub fn severity(&self) -> AnomalySeverity {
match self {
AnomalyType::HighRequestRate => AnomalySeverity::Medium,
AnomalyType::HighErrorRate => AnomalySeverity::Medium,
AnomalyType::MultipleIps => AnomalySeverity::Low,
AnomalyType::UnusualTime => AnomalySeverity::Low,
AnomalyType::BruteForce => AnomalySeverity::Critical,
AnomalyType::DataExfiltration => AnomalySeverity::Critical,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, ToSchema)]
#[serde(rename_all = "snake_case")]
pub enum AnomalySeverity {
Low,
Medium,
High,
Critical,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ApiKey {
pub id: i64,
pub key_id: String,
pub key_hash: String,
pub key_prefix: String,
pub name: String,
pub key_type: ApiKeyType,
pub user_id: Option<i64>,
pub service_name: Option<String>,
pub permissions: HashSet<String>,
pub expires_at: Option<DateTime<Utc>>,
pub last_used_at: Option<DateTime<Utc>>,
pub last_used_ip: Option<String>,
pub usage_count: i64,
pub status: ApiKeyStatus,
pub rotation_required: bool,
pub rotation_reason: Option<String>,
pub grace_period_end: Option<DateTime<Utc>>,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
pub struct CreateApiKeyRequest {
pub name: String,
pub key_type: ApiKeyType,
pub user_id: Option<i64>,
pub service_name: Option<String>,
pub permissions: Vec<String>,
pub ttl_days: Option<i64>,
}
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
pub struct CreateApiKeyResponse {
pub key: String,
pub key_id: String,
pub expires_at: DateTime<Utc>,
pub warning: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
pub struct ValidateApiKeyRequest {
pub key: String,
pub ip_address: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
pub struct ValidateApiKeyResponse {
pub valid: bool,
pub key_id: Option<String>,
pub permissions: Option<Vec<String>>,
pub error: Option<String>,
pub requires_rotation: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuditLogEntry {
pub id: i64,
pub key_id: String,
pub action: String,
pub actor: Option<String>,
pub ip_address: Option<String>,
pub user_agent: Option<String>,
pub request_path: Option<String>,
pub response_code: Option<i32>,
pub anomaly_type: Option<AnomalyType>,
pub details: Option<serde_json::Value>,
pub created_at: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AnomalyRecord {
pub id: i64,
pub key_id: String,
pub anomaly_type: AnomalyType,
pub severity: AnomalySeverity,
pub ip_address: Option<String>,
pub request_count: Option<i32>,
pub error_count: Option<i32>,
pub error_rate: Option<f64>,
pub unique_ips: Option<i32>,
pub details: Option<serde_json::Value>,
pub resolved: bool,
pub resolved_at: Option<DateTime<Utc>>,
pub resolved_by: Option<String>,
pub created_at: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RotationStatus {
pub key_id: String,
pub requires_rotation: bool,
pub reason: Option<String>,
pub grace_period_end: Option<DateTime<Utc>>,
pub in_grace_period: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AnomalyDetectionConfig {
pub requests_per_minute_threshold: i32,
pub requests_per_hour_threshold: i32,
pub error_rate_threshold: f64,
pub unique_ips_per_hour_threshold: i32,
pub lockout_threshold: i32,
}
impl Default for AnomalyDetectionConfig {
fn default() -> Self {
Self {
requests_per_minute_threshold: 1000,
requests_per_hour_threshold: 10000,
error_rate_threshold: 0.5,
unique_ips_per_hour_threshold: 5,
lockout_threshold: 3,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
pub struct ApiKeyStats {
pub total_keys: i64,
pub active_keys: i64,
pub expired_keys: i64,
pub rotation_required: i64,
pub anomalies_last_24h: i64,
}

211
src/core/api_key/n8n.rs Normal file
View File

@@ -0,0 +1,211 @@
//! n8n API Key Integration
//!
//! Manages n8n API Keys through the API Key system
use anyhow::{Context, Result};
use chrono::{DateTime, Utc};
use reqwest::Client;
use serde::{Deserialize, Serialize};
/// Request to create an n8n API key
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CreateN8nApiKeyRequest {
pub label: String,
pub expires_at: Option<DateTime<Utc>>,
}
/// Response from creating an n8n API key
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct N8nApiKeyResponse {
pub id: String,
pub api_key: String,
pub label: String,
pub created_at: Option<DateTime<Utc>>,
pub expires_at: Option<DateTime<Utc>>,
}
/// n8n API key info (without raw key)
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct N8nApiKeyInfo {
pub id: String,
pub label: String,
pub created_at: Option<DateTime<Utc>>,
pub updated_at: Option<DateTime<Utc>>,
}
/// n8n API client
pub struct N8nClient {
client: Client,
base_url: String,
api_key: String,
}
impl N8nClient {
pub fn new(api_key: String) -> Result<Self> {
let base_url = std::env::var("N8N_URL")
.unwrap_or_else(|_| "https://n8n.momentry.ddns.net".to_string());
Ok(Self {
client: Client::new(),
base_url,
api_key,
})
}
pub fn with_url(base_url: String, api_key: String) -> Self {
Self {
client: Client::new(),
base_url,
api_key,
}
}
/// Create a new API key
pub async fn create_api_key(
&self,
request: &CreateN8nApiKeyRequest,
) -> Result<N8nApiKeyResponse> {
let url = format!("{}/api/v1/me/api-keys", self.base_url);
let mut body = serde_json::json!({
"label": request.label,
});
if let Some(expires_at) = request.expires_at {
body["expiresAt"] = serde_json::json!(expires_at.to_rfc3339());
}
let response = self
.client
.post(&url)
.header("X-N8N-API-KEY", &self.api_key)
.header("Content-Type", "application/json")
.json(&body)
.send()
.await
.context("Failed to send create API key request")?;
if !response.status().is_success() {
let status = response.status();
let text = response.text().await.unwrap_or_default();
anyhow::bail!("Failed to create n8n API key: {} - {}", status, text);
}
let api_key: N8nApiKeyResponse = response
.json()
.await
.context("Failed to parse API key response")?;
Ok(api_key)
}
/// List all API keys for the authenticated user
pub async fn list_api_keys(&self) -> Result<Vec<N8nApiKeyInfo>> {
let url = format!("{}/api/v1/me/api-keys", self.base_url);
let response = self
.client
.get(&url)
.header("X-N8N-API-KEY", &self.api_key)
.send()
.await
.context("Failed to send list API keys request")?;
if !response.status().is_success() {
let status = response.status();
let text = response.text().await.unwrap_or_default();
anyhow::bail!("Failed to list n8n API keys: {} - {}", status, text);
}
let data: serde_json::Value = response
.json()
.await
.context("Failed to parse API keys response")?;
// n8n returns { data: [...] } format
let keys: Vec<N8nApiKeyInfo> =
serde_json::from_value(data["data"].clone()).unwrap_or_default();
Ok(keys)
}
/// Delete an API key by ID
pub async fn delete_api_key(&self, key_id: &str) -> Result<()> {
let url = format!("{}/api/v1/me/api-keys/{}", self.base_url, key_id);
let response = self
.client
.delete(&url)
.header("X-N8N-API-KEY", &self.api_key)
.send()
.await
.context("Failed to send delete API key request")?;
if !response.status().is_success() {
let status = response.status();
let text = response.text().await.unwrap_or_default();
anyhow::bail!("Failed to delete n8n API key: {} - {}", status, text);
}
Ok(())
}
/// Verify an API key is valid by making a test API call
pub async fn verify_api_key(&self, api_key: &str) -> Result<bool> {
let url = format!("{}/api/v1/workflows", self.base_url);
let response = self
.client
.get(&url)
.header("X-N8N-API-KEY", api_key)
.send()
.await
.context("Failed to send verify API key request")?;
Ok(response.status().is_success())
}
/// Get base URL
pub fn base_url(&self) -> &str {
&self.base_url
}
}
/// Integrated n8n API key record
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct N8nApiKeyRecord {
pub n8n_key_id: String,
pub label: String,
pub api_key_last_eight: String,
pub momentry_api_key_id: Option<String>,
pub created_at: DateTime<Utc>,
pub last_verified: Option<DateTime<Utc>>,
}
/// Extract last 8 characters from API key for display
pub fn extract_last_eight(api_key: &str) -> String {
if api_key.len() <= 8 {
api_key.to_string()
} else {
api_key[api_key.len() - 8..].to_string()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_last_eight() {
assert_eq!(extract_last_eight("n8n_api_1234567890abcdef"), "90abcdef");
assert_eq!(extract_last_eight("short"), "short");
assert_eq!(extract_last_eight("12345678"), "12345678");
}
#[test]
fn test_n8n_client_with_url() {
let client =
N8nClient::with_url("http://localhost:5678".to_string(), "test_key".to_string());
assert_eq!(client.base_url(), "http://localhost:5678");
}
}

233
src/core/api_key/report.rs Normal file
View File

@@ -0,0 +1,233 @@
//! API Key Statistics Report Module
//!
//! Generates usage statistics and reports for API keys
use crate::core::db::postgres_db::PostgresDb;
use anyhow::Result;
use chrono::{DateTime, Duration, Utc};
use serde::{Deserialize, Serialize};
/// Detailed statistics report
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ApiKeyReport {
pub generated_at: DateTime<Utc>,
pub period: ReportPeriod,
pub summary: ReportSummary,
pub by_type: Vec<TypeStats>,
pub by_status: Vec<StatusStats>,
pub top_usage: Vec<UsageStats>,
pub anomalies: AnomalyStats,
}
/// Report period
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ReportPeriod {
pub start: DateTime<Utc>,
pub end: DateTime<Utc>,
pub days: i64,
}
/// Summary statistics
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ReportSummary {
pub total_keys: i64,
pub active_keys: i64,
pub expired_keys: i64,
pub revoked_keys: i64,
pub keys_needing_rotation: i64,
pub total_usage: i64,
}
/// Statistics by key type
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TypeStats {
pub key_type: String,
pub count: i64,
pub active: i64,
pub expired: i64,
}
/// Statistics by status
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StatusStats {
pub status: String,
pub count: i64,
pub percentage: f64,
}
/// Top usage statistics
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UsageStats {
pub key_id: String,
pub name: String,
pub usage_count: i64,
pub last_used: Option<DateTime<Utc>>,
}
/// Anomaly statistics
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AnomalyStats {
pub total: i64,
pub last_24h: i64,
pub last_7d: i64,
pub by_severity: Vec<SeverityStats>,
}
/// Severity statistics
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SeverityStats {
pub severity: String,
pub count: i64,
}
/// Report generator
pub struct ReportGenerator {
db: PostgresDb,
}
impl ReportGenerator {
pub fn new(db: PostgresDb) -> Self {
Self { db }
}
/// Generate a full report
pub async fn generate_report(&self, days: i64) -> Result<ApiKeyReport> {
let end = Utc::now();
let start = end - Duration::days(days);
let stats = self.db.get_api_key_stats().await?;
Ok(ApiKeyReport {
generated_at: Utc::now(),
period: ReportPeriod { start, end, days },
summary: ReportSummary {
total_keys: stats.total_keys,
active_keys: stats.active_keys,
expired_keys: stats.expired_keys,
revoked_keys: 0,
keys_needing_rotation: stats.rotation_required,
total_usage: 0,
},
by_type: vec![],
by_status: vec![
StatusStats {
status: "active".to_string(),
count: stats.active_keys,
percentage: if stats.total_keys > 0 {
(stats.active_keys as f64 / stats.total_keys as f64) * 100.0
} else {
0.0
},
},
StatusStats {
status: "expired".to_string(),
count: stats.expired_keys,
percentage: if stats.total_keys > 0 {
(stats.expired_keys as f64 / stats.total_keys as f64) * 100.0
} else {
0.0
},
},
],
top_usage: vec![],
anomalies: AnomalyStats {
total: 0,
last_24h: stats.anomalies_last_24h,
last_7d: 0,
by_severity: vec![],
},
})
}
/// Generate text report
pub async fn generate_text_report(&self, days: i64) -> Result<String> {
let report = self.generate_report(days).await?;
let mut output = String::new();
output.push_str("=== API Key Statistics Report ===\n");
output.push_str(&format!(
"Generated: {}\n",
report.generated_at.format("%Y-%m-%d %H:%M:%S")
));
output.push_str(&format!(
"Period: {} to {} ({} days)\n\n",
report.period.start.format("%Y-%m-%d"),
report.period.end.format("%Y-%m-%d"),
report.period.days
));
output.push_str("--- Summary ---\n");
output.push_str(&format!(
"Total Keys: {}\n",
report.summary.total_keys
));
output.push_str(&format!(
"Active Keys: {}\n",
report.summary.active_keys
));
output.push_str(&format!(
"Expired Keys: {}\n",
report.summary.expired_keys
));
output.push_str(&format!(
"Rotation Required: {}\n\n",
report.summary.keys_needing_rotation
));
output.push_str("--- Status Distribution ---\n");
for status in &report.by_status {
output.push_str(&format!(
"{:12}: {} ({:.1}%)\n",
status.status, status.count, status.percentage
));
}
output.push_str(&format!(
"\n--- Anomalies (Last 24h) ---\n{}\n",
report.anomalies.last_24h
));
Ok(output)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_report_serialization() {
let report = ApiKeyReport {
generated_at: Utc::now(),
period: ReportPeriod {
start: Utc::now() - Duration::days(30),
end: Utc::now(),
days: 30,
},
summary: ReportSummary {
total_keys: 10,
active_keys: 8,
expired_keys: 2,
revoked_keys: 0,
keys_needing_rotation: 1,
total_usage: 1000,
},
by_type: vec![],
by_status: vec![StatusStats {
status: "active".to_string(),
count: 8,
percentage: 80.0,
}],
top_usage: vec![],
anomalies: AnomalyStats {
total: 5,
last_24h: 1,
last_7d: 3,
by_severity: vec![],
},
};
let json = serde_json::to_string_pretty(&report).unwrap();
assert!(json.contains("\"total_keys\": 10"));
}
}

View File

@@ -0,0 +1,319 @@
//! API Key Rotation Module
//!
//! Implements forced rotation mechanism with grace periods
use crate::core::api_key::models::*;
use chrono::{Duration, Utc};
use std::collections::HashMap;
use tokio::sync::RwLock;
pub struct RotationManager {
grace_periods: HashMap<ApiKeyType, i64>,
rotation_queue: RwLock<Vec<RotationTask>>,
}
#[derive(Debug, Clone)]
struct RotationTask {
key_id: String,
key_type: ApiKeyType,
reason: RotationReason,
created_at: chrono::DateTime<Utc>,
scheduled_at: chrono::DateTime<Utc>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum RotationReason {
Expired,
Manual,
Forced,
AnomalyDetected,
SecurityBreach,
PolicyChange,
}
impl RotationReason {
pub fn as_str(&self) -> &'static str {
match self {
RotationReason::Expired => "expired",
RotationReason::Manual => "manual",
RotationReason::Forced => "forced",
RotationReason::AnomalyDetected => "anomaly_detected",
RotationReason::SecurityBreach => "security_breach",
RotationReason::PolicyChange => "policy_change",
}
}
pub fn requires_immediate_rotation(&self) -> bool {
matches!(
self,
RotationReason::AnomalyDetected | RotationReason::SecurityBreach
)
}
}
impl RotationManager {
pub fn new() -> Self {
let mut grace_periods = HashMap::new();
grace_periods.insert(ApiKeyType::System, 72);
grace_periods.insert(ApiKeyType::User, 24);
grace_periods.insert(ApiKeyType::Service, 48);
grace_periods.insert(ApiKeyType::Integration, 24);
grace_periods.insert(ApiKeyType::Emergency, 0);
Self {
grace_periods,
rotation_queue: RwLock::new(Vec::new()),
}
}
pub fn get_grace_period(&self, key_type: ApiKeyType) -> Duration {
let hours = self.grace_periods.get(&key_type).copied().unwrap_or(24);
Duration::hours(hours)
}
pub fn calculate_grace_period_end(
&self,
key_type: ApiKeyType,
triggered_at: chrono::DateTime<Utc>,
) -> chrono::DateTime<Utc> {
triggered_at + self.get_grace_period(key_type)
}
pub fn is_in_grace_period(&self, grace_period_end: Option<chrono::DateTime<Utc>>) -> bool {
match grace_period_end {
Some(end) => Utc::now() < end,
None => false,
}
}
pub fn is_grace_period_expired(&self, grace_period_end: Option<chrono::DateTime<Utc>>) -> bool {
match grace_period_end {
Some(end) => Utc::now() >= end,
None => true,
}
}
pub async fn queue_rotation(
&self,
key_id: String,
key_type: ApiKeyType,
reason: RotationReason,
) {
let grace_period_end = self.calculate_grace_period_end(key_type, Utc::now());
let scheduled_at = if reason.requires_immediate_rotation() {
Utc::now()
} else {
grace_period_end
};
let task = RotationTask {
key_id,
key_type,
reason,
created_at: Utc::now(),
scheduled_at,
};
let mut queue = self.rotation_queue.write().await;
queue.push(task);
}
pub async fn get_pending_rotations(&self) -> Vec<RotationStatus> {
let queue = self.rotation_queue.read().await;
queue
.iter()
.map(|task| {
let grace_period_end =
self.calculate_grace_period_end(task.key_type, task.created_at);
RotationStatus {
key_id: task.key_id.clone(),
requires_rotation: true,
reason: Some(task.reason.as_str().to_string()),
grace_period_end: Some(grace_period_end),
in_grace_period: self.is_in_grace_period(Some(grace_period_end)),
}
})
.collect()
}
pub async fn get_overdue_rotations(&self) -> Vec<RotationStatus> {
let queue = self.rotation_queue.read().await;
let now = Utc::now();
queue
.iter()
.filter(|task| task.scheduled_at <= now)
.map(|task| {
let grace_period_end =
self.calculate_grace_period_end(task.key_type, task.created_at);
RotationStatus {
key_id: task.key_id.clone(),
requires_rotation: true,
reason: Some(task.reason.as_str().to_string()),
grace_period_end: Some(grace_period_end),
in_grace_period: self.is_in_grace_period(Some(grace_period_end)),
}
})
.collect()
}
pub async fn remove_from_queue(&self, key_id: &str) {
let mut queue = self.rotation_queue.write().await;
queue.retain(|task| task.key_id != key_id);
}
pub fn check_rotation_required(&self, key: &ApiKey) -> RotationStatus {
let now = Utc::now();
if key.status == ApiKeyStatus::Revoked {
return RotationStatus {
key_id: key.key_id.clone(),
requires_rotation: false,
reason: Some("key_revoked".to_string()),
grace_period_end: None,
in_grace_period: false,
};
}
if let Some(expires_at) = key.expires_at {
if now > expires_at {
return RotationStatus {
key_id: key.key_id.clone(),
requires_rotation: true,
reason: Some("expired".to_string()),
grace_period_end: Some(
self.calculate_grace_period_end(key.key_type, expires_at),
),
in_grace_period: self.is_in_grace_period(Some(
self.calculate_grace_period_end(key.key_type, expires_at),
)),
};
}
}
if key.rotation_required {
return RotationStatus {
key_id: key.key_id.clone(),
requires_rotation: true,
reason: key.rotation_reason.clone(),
grace_period_end: key.grace_period_end,
in_grace_period: self.is_in_grace_period(key.grace_period_end),
};
}
RotationStatus {
key_id: key.key_id.clone(),
requires_rotation: false,
reason: None,
grace_period_end: None,
in_grace_period: false,
}
}
pub fn should_auto_expire(&self, key: &ApiKey) -> bool {
if key.status != ApiKeyStatus::Active {
return false;
}
if key.key_type == ApiKeyType::Emergency {
if let Some(expires_at) = key.expires_at {
return Utc::now() >= expires_at;
}
}
if let Some(grace_period_end) = key.grace_period_end {
return self.is_grace_period_expired(Some(grace_period_end)) && key.rotation_required;
}
false
}
}
impl Default for RotationManager {
fn default() -> Self {
Self::new()
}
}
pub struct RotationScheduler {
check_interval_seconds: u64,
}
impl RotationScheduler {
pub fn new(check_interval_seconds: u64) -> Self {
Self {
check_interval_seconds,
}
}
pub fn check_interval(&self) -> Duration {
Duration::seconds(self.check_interval_seconds as i64)
}
}
impl Default for RotationScheduler {
fn default() -> Self {
Self::new(3600)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_grace_period_calculation() {
let manager = RotationManager::new();
let user_grace = manager.get_grace_period(ApiKeyType::User);
assert_eq!(user_grace, Duration::hours(24));
let system_grace = manager.get_grace_period(ApiKeyType::System);
assert_eq!(system_grace, Duration::hours(72));
let emergency_grace = manager.get_grace_period(ApiKeyType::Emergency);
assert_eq!(emergency_grace, Duration::hours(0));
}
#[test]
fn test_rotation_reason_requires_immediate() {
assert!(RotationReason::AnomalyDetected.requires_immediate_rotation());
assert!(RotationReason::SecurityBreach.requires_immediate_rotation());
assert!(!RotationReason::Expired.requires_immediate_rotation());
assert!(!RotationReason::Manual.requires_immediate_rotation());
assert!(!RotationReason::Forced.requires_immediate_rotation());
assert!(!RotationReason::PolicyChange.requires_immediate_rotation());
}
#[tokio::test]
async fn test_queue_rotation() {
let manager = RotationManager::new();
manager
.queue_rotation(
"test_key_123".to_string(),
ApiKeyType::User,
RotationReason::Manual,
)
.await;
let pending = manager.get_pending_rotations().await;
assert_eq!(pending.len(), 1);
assert_eq!(pending[0].key_id, "test_key_123");
}
#[test]
fn test_is_in_grace_period() {
let manager = RotationManager::new();
let future_end = Utc::now() + Duration::hours(12);
assert!(manager.is_in_grace_period(Some(future_end)));
let past_end = Utc::now() - Duration::hours(1);
assert!(!manager.is_in_grace_period(Some(past_end)));
assert!(!manager.is_in_grace_period(None));
}
}

276
src/core/api_key/service.rs Normal file
View File

@@ -0,0 +1,276 @@
//! API Key Service
//!
//! Core functionality for API key management
use crate::core::api_key::models::*;
use anyhow::Result;
use chrono::{DateTime, Duration, Utc};
use sha2::{Digest, Sha256};
use subtle::ConstantTimeEq;
use uuid::Uuid;
pub struct ApiKeyService {
_db_url: String,
config: AnomalyDetectionConfig,
}
impl ApiKeyService {
pub fn new(db_url: String) -> Self {
Self {
_db_url: db_url,
config: AnomalyDetectionConfig::default(),
}
}
pub fn with_config(db_url: String, config: AnomalyDetectionConfig) -> Self {
Self {
_db_url: db_url,
config,
}
}
pub fn generate_key(&self, key_type: ApiKeyType) -> (String, String, String) {
let uuid = Uuid::new_v4().to_string().replace("-", "");
let timestamp = Utc::now().timestamp();
let random_part = Uuid::new_v4().to_string().replace("-", "")[..8].to_string();
let key = format!(
"{}{}_{}_{}",
key_type.prefix(),
uuid,
timestamp,
random_part
);
let hash = self.hash_key(&key);
(key, hash, uuid)
}
pub fn hash_key(&self, key: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(key.as_bytes());
format!("{:x}", hasher.finalize())
}
/// Constant-time comparison of two hash strings
///
/// This prevents timing attacks when comparing sensitive data like
/// API key hashes. Use this instead of `==` for security-critical comparisons.
pub fn constant_time_compare(a: &str, b: &str) -> bool {
if a.len() != b.len() {
return false;
}
a.as_bytes().ct_eq(b.as_bytes()).into()
}
pub fn create_key(&self, request: CreateApiKeyRequest) -> Result<CreateApiKeyResponse> {
let ttl_days = request
.ttl_days
.unwrap_or(request.key_type.default_ttl_days());
let expires_at = Utc::now() + Duration::days(ttl_days);
let (key, _key_hash, _) = self.generate_key(request.key_type);
let warning = if request.key_type == ApiKeyType::Emergency {
"警告:緊急 Key 將在 24 小時後自動過期,請及時更新".to_string()
} else if ttl_days < 30 {
format!("警告Key 有效期僅 {} 天,建議使用更長的有效期", ttl_days)
} else {
String::new()
};
Ok(CreateApiKeyResponse {
key_id: self.extract_key_id(&key),
key,
expires_at,
warning,
})
}
pub fn extract_key_id(&self, key: &str) -> String {
let parts: Vec<&str> = key.split('_').collect();
if parts.len() >= 2 {
format!("{}_{}", parts[0], parts[1])
} else {
key[..16.min(key.len())].to_string()
}
}
pub fn validate_key(&self, request: ValidateApiKeyRequest) -> Result<ValidateApiKeyResponse> {
let _key_hash = self.hash_key(&request.key);
Ok(ValidateApiKeyResponse {
valid: true,
key_id: Some(self.extract_key_id(&request.key)),
permissions: Some(vec!["read".to_string(), "write".to_string()]),
error: None,
requires_rotation: false,
})
}
pub fn require_rotation(&self, key_id: &str, reason: &str) -> Result<()> {
tracing::info!("API Key {} requires rotation: {}", key_id, reason);
Ok(())
}
pub fn check_anomaly(&self, key_id: &str, metrics: &AnomalyMetrics) -> Option<AnomalyRecord> {
if metrics.requests_per_minute > self.config.requests_per_minute_threshold {
return Some(AnomalyRecord {
id: 0,
key_id: key_id.to_string(),
anomaly_type: AnomalyType::HighRequestRate,
severity: AnomalySeverity::Medium,
ip_address: metrics.last_ip.clone(),
request_count: Some(metrics.requests_per_minute),
error_count: Some(metrics.error_count),
error_rate: Some(metrics.error_rate),
unique_ips: Some(metrics.unique_ips),
details: None,
resolved: false,
resolved_at: None,
resolved_by: None,
created_at: Utc::now(),
});
}
if metrics.error_rate > self.config.error_rate_threshold {
return Some(AnomalyRecord {
id: 0,
key_id: key_id.to_string(),
anomaly_type: AnomalyType::HighErrorRate,
severity: AnomalySeverity::Medium,
ip_address: metrics.last_ip.clone(),
request_count: Some(metrics.requests_per_minute),
error_count: Some(metrics.error_count),
error_rate: Some(metrics.error_rate),
unique_ips: Some(metrics.unique_ips),
details: None,
resolved: false,
resolved_at: None,
resolved_by: None,
created_at: Utc::now(),
});
}
if metrics.unique_ips > self.config.unique_ips_per_hour_threshold {
return Some(AnomalyRecord {
id: 0,
key_id: key_id.to_string(),
anomaly_type: AnomalyType::MultipleIps,
severity: AnomalySeverity::Low,
ip_address: None,
request_count: None,
error_count: None,
error_rate: None,
unique_ips: Some(metrics.unique_ips),
details: None,
resolved: false,
resolved_at: None,
resolved_by: None,
created_at: Utc::now(),
});
}
None
}
pub fn calculate_grace_period_end(&self, key_type: ApiKeyType) -> DateTime<Utc> {
Utc::now() + Duration::hours(key_type.grace_period_hours())
}
pub fn is_in_grace_period(&self, grace_period_end: Option<DateTime<Utc>>) -> bool {
match grace_period_end {
Some(end) => Utc::now() < end,
None => false,
}
}
pub fn get_key_prefix(&self, key_type: ApiKeyType) -> &'static str {
key_type.prefix()
}
pub fn get_config(&self) -> &AnomalyDetectionConfig {
&self.config
}
}
#[derive(Debug, Clone)]
pub struct AnomalyMetrics {
pub requests_per_minute: i32,
pub error_count: i32,
pub error_rate: f64,
pub unique_ips: i32,
pub last_ip: Option<String>,
}
impl Default for AnomalyMetrics {
fn default() -> Self {
Self {
requests_per_minute: 0,
error_count: 0,
error_rate: 0.0,
unique_ips: 0,
last_ip: None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_generate_key() {
let service = ApiKeyService::new("postgres://localhost".to_string());
let (key, hash, _) = service.generate_key(ApiKeyType::User);
assert!(key.starts_with("muser_"));
assert_eq!(hash.len(), 64);
}
#[test]
fn test_hash_key() {
let service = ApiKeyService::new("postgres://localhost".to_string());
let hash = service.hash_key("test_key");
assert_eq!(hash.len(), 64);
}
#[test]
fn test_extract_key_id() {
let service = ApiKeyService::new("postgres://localhost".to_string());
let key_id = service.extract_key_id("muser_a1b2c3d4_1710998400_abc12345");
assert_eq!(key_id, "muser_a1b2c3d4");
}
#[test]
fn test_grace_period() {
let service = ApiKeyService::new("postgres://localhost".to_string());
let end = service.calculate_grace_period_end(ApiKeyType::User);
let hours = (end - Utc::now()).num_hours();
assert!(
hours >= 23 && hours <= 24,
"expected 23-24 hours, got {}",
hours
);
}
#[test]
fn test_constant_time_compare() {
let a = "abcdef1234567890";
let b = "abcdef1234567890";
let c = "abcdef1234567891";
let d = "short";
assert!(ApiKeyService::constant_time_compare(a, b));
assert!(!ApiKeyService::constant_time_compare(a, c));
assert!(!ApiKeyService::constant_time_compare(a, d));
assert!(!ApiKeyService::constant_time_compare(d, a));
}
}

View File

@@ -0,0 +1,209 @@
//! API Key Strength Validation
//!
//! Validates that API keys meet security requirements
use serde::{Deserialize, Serialize};
/// Key strength level
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum KeyStrength {
Weak,
Medium,
Strong,
VeryStrong,
}
impl KeyStrength {
pub fn as_str(&self) -> &'static str {
match self {
KeyStrength::Weak => "weak",
KeyStrength::Medium => "medium",
KeyStrength::Strong => "strong",
KeyStrength::VeryStrong => "very_strong",
}
}
pub fn is_acceptable(&self) -> bool {
!matches!(self, KeyStrength::Weak)
}
}
/// Key strength validation result
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StrengthResult {
pub strength: KeyStrength,
pub score: u32,
pub max_score: u32,
pub issues: Vec<String>,
pub suggestions: Vec<String>,
}
/// Key strength validator
pub struct KeyStrengthValidator {
min_length: usize,
require_prefix: bool,
}
impl Default for KeyStrengthValidator {
fn default() -> Self {
Self {
min_length: 32,
require_prefix: true,
}
}
}
impl KeyStrengthValidator {
pub fn new(min_length: usize, require_prefix: bool) -> Self {
Self {
min_length,
require_prefix,
}
}
/// Validate key strength
pub fn validate(&self, key: &str) -> StrengthResult {
let mut score: u32 = 0;
let mut issues = Vec::new();
let mut suggestions = Vec::new();
// Check length
if key.len() >= self.min_length {
score += 25;
} else {
issues.push(format!(
"Key length {} is less than minimum {}",
key.len(),
self.min_length
));
suggestions.push(format!("Use at least {} characters", self.min_length));
}
// Check for valid prefix
let valid_prefixes = ["msys_", "muser_", "msvc_", "mint_", "memg_"];
let has_valid_prefix = valid_prefixes.iter().any(|p| key.starts_with(p));
if has_valid_prefix {
score += 25;
} else if self.require_prefix {
issues.push("Key does not have a valid prefix".to_string());
suggestions.push("Use a valid prefix: msys_, muser_, msvc_, mint_, memg_".to_string());
}
// Check entropy (character variety)
let has_lowercase = key.chars().any(|c| c.is_ascii_lowercase());
let has_uppercase = key.chars().any(|c| c.is_ascii_uppercase());
let has_digit = key.chars().any(|c| c.is_ascii_digit());
let has_special = key.chars().any(|c| !c.is_ascii_alphanumeric());
let entropy_count = [has_lowercase, has_uppercase, has_digit, has_special]
.iter()
.filter(|&&x| x)
.count();
score += (entropy_count as u32) * 12;
if entropy_count < 2 {
issues.push("Low character variety".to_string());
suggestions
.push("Include lowercase, uppercase, digits, and special characters".to_string());
}
// Check for sequential characters
let has_sequential = key
.as_bytes()
.windows(3)
.any(|w| w[1] == w[0] + 1 && w[2] == w[1] + 1);
if has_sequential {
score = score.saturating_sub(10);
issues.push("Contains sequential characters".to_string());
}
// Check for repeated characters
let has_repeated = key
.as_bytes()
.windows(3)
.any(|w| w[0] == w[1] && w[1] == w[2]);
if has_repeated {
score = score.saturating_sub(10);
issues.push("Contains repeated characters".to_string());
}
// Determine strength level
let strength = if score >= 80 {
KeyStrength::VeryStrong
} else if score >= 60 {
KeyStrength::Strong
} else if score >= 40 {
KeyStrength::Medium
} else {
KeyStrength::Weak
};
StrengthResult {
strength,
score: score.min(100),
max_score: 100,
issues,
suggestions,
}
}
/// Check if key is acceptable
pub fn is_acceptable(&self, key: &str) -> bool {
self.validate(key).strength.is_acceptable()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_strong_key() {
let validator = KeyStrengthValidator::default();
let result = validator.validate("msvc_a1B2c3D4e5F6g7H8i9J0k1L2m3N4o5P6");
assert!(result.strength.is_acceptable());
assert!(result.score >= 60);
}
#[test]
fn test_weak_key() {
let validator = KeyStrengthValidator::default();
let result = validator.validate("short");
assert_eq!(result.strength, KeyStrength::Weak);
assert!(!result.issues.is_empty());
}
#[test]
fn test_sequential_penalty() {
let validator = KeyStrengthValidator::default();
let result_with = validator.validate("msvc_abc123def456ghi789jkl012mno345");
let result_without = validator.validate("msvc_a1b2c3d4e5f6g7h8i9j0k1l2m3n4o5");
// Sequential characters should reduce score
assert!(
result_without.score >= result_with.score
|| result_with.issues.len() <= result_without.issues.len()
);
}
#[test]
fn test_key_strength_serialization() {
let result = StrengthResult {
strength: KeyStrength::Strong,
score: 75,
max_score: 100,
issues: vec![],
suggestions: vec![],
};
let json = serde_json::to_string(&result).unwrap();
assert!(json.contains("\"strong\""));
}
}

View File

@@ -0,0 +1,310 @@
//! API Key Validation with Cache and Rate Limiting
//!
//! Provides cached validation and rate limiting for API keys
use crate::core::db::postgres_db::ApiKeyRecord;
use crate::core::db::PostgresDb;
use anyhow::Result;
use chrono::{DateTime, Duration, Utc};
use moka::future::Cache;
use std::sync::Arc;
use std::time::Duration as StdDuration;
/// Cached API key record
#[derive(Clone)]
pub struct CachedApiKey {
pub record: ApiKeyRecord,
pub cached_at: DateTime<Utc>,
}
/// Rate limit result
#[derive(Debug, Clone)]
pub enum RateLimitResult {
/// Request is allowed
Allowed,
/// Request is allowed but with warning
AllowedWithWarning { remaining_attempts: u32 },
/// Request is locked
Locked {
remaining_seconds: i64,
attempts: u32,
},
}
/// Attempt tracking info
#[derive(Clone)]
struct AttemptInfo {
count: u32,
first_attempt: DateTime<Utc>,
last_attempt: DateTime<Utc>,
locked_until: Option<DateTime<Utc>>,
}
/// API Key Validator with caching and rate limiting
pub struct ApiKeyValidator {
db: Arc<PostgresDb>,
cache: Cache<String, CachedApiKey>,
rate_limiter: Cache<String, AttemptInfo>,
max_attempts: u32,
lockout_duration: Duration,
}
/// Configuration for ApiKeyValidator
pub struct ValidatorConfig {
pub cache_ttl_secs: u64,
pub cache_max_capacity: u64,
pub max_attempts: u32,
pub lockout_duration_secs: i64,
}
impl Default for ValidatorConfig {
fn default() -> Self {
Self {
cache_ttl_secs: std::env::var("CACHE_TTL_SECONDS")
.unwrap_or_else(|_| "300".to_string())
.parse()
.unwrap_or(300),
cache_max_capacity: std::env::var("CACHE_MAX_CAPACITY")
.unwrap_or_else(|_| "10000".to_string())
.parse()
.unwrap_or(10000),
max_attempts: std::env::var("RATE_LIMIT_MAX_ATTEMPTS")
.unwrap_or_else(|_| "5".to_string())
.parse()
.unwrap_or(5),
lockout_duration_secs: std::env::var("RATE_LIMIT_WINDOW_SECONDS")
.unwrap_or_else(|_| "900".to_string())
.parse()
.unwrap_or(900),
}
}
}
impl ApiKeyValidator {
pub fn new(db: PostgresDb, config: ValidatorConfig) -> Self {
Self {
db: Arc::new(db),
cache: Cache::builder()
.time_to_live(StdDuration::from_secs(config.cache_ttl_secs))
.time_to_idle(StdDuration::from_secs(config.cache_ttl_secs * 2))
.max_capacity(config.cache_max_capacity)
.build(),
rate_limiter: Cache::builder()
.time_to_live(StdDuration::from_secs(
config.lockout_duration_secs as u64 * 2,
))
.max_capacity(10000)
.build(),
max_attempts: config.max_attempts,
lockout_duration: Duration::seconds(config.lockout_duration_secs),
}
}
pub fn with_default_config(db: PostgresDb) -> Self {
Self::new(db, ValidatorConfig::default())
}
/// Check rate limit for an IP
pub async fn check_rate_limit(&self, ip: &str) -> RateLimitResult {
match self.rate_limiter.get(ip).await {
None => RateLimitResult::Allowed,
Some(info) => {
if let Some(locked_until) = info.locked_until {
let remaining = locked_until - Utc::now();
if remaining.num_seconds() > 0 {
return RateLimitResult::Locked {
remaining_seconds: remaining.num_seconds(),
attempts: info.count,
};
}
}
if info.count >= self.max_attempts / 2 {
RateLimitResult::AllowedWithWarning {
remaining_attempts: self.max_attempts - info.count,
}
} else {
RateLimitResult::Allowed
}
}
}
}
/// Record a failed attempt
pub async fn record_failure(&self, ip: &str) -> RateLimitResult {
let mut info = self.rate_limiter.get(ip).await.unwrap_or(AttemptInfo {
count: 0,
first_attempt: Utc::now(),
last_attempt: Utc::now(),
locked_until: None,
});
info.count += 1;
info.last_attempt = Utc::now();
if info.count >= self.max_attempts {
info.locked_until = Some(Utc::now() + self.lockout_duration);
self.rate_limiter.insert(ip.to_string(), info.clone()).await;
tracing::warn!("IP {} locked due to {} failed attempts", ip, info.count);
RateLimitResult::Locked {
remaining_seconds: self.lockout_duration.num_seconds(),
attempts: info.count,
}
} else {
let remaining = self.max_attempts - info.count;
self.rate_limiter.insert(ip.to_string(), info).await;
RateLimitResult::AllowedWithWarning {
remaining_attempts: remaining,
}
}
}
/// Record a successful validation (clear rate limit)
pub async fn record_success(&self, ip: &str) {
self.rate_limiter.invalidate(ip).await;
}
/// Manually unlock an IP
pub async fn unlock_ip(&self, ip: &str) {
self.rate_limiter.invalidate(ip).await;
tracing::info!("Manually unlocked IP: {}", ip);
}
/// Validate an API key with caching
pub async fn validate(&self, key_hash: &str) -> Result<Option<ApiKeyRecord>> {
// 1. Check cache
if let Some(cached) = self.cache.get(key_hash).await {
// Check if expired
if let Some(expires_at) = cached.record.expires_at {
if Utc::now() > expires_at {
self.cache.invalidate(key_hash).await;
return Ok(None);
}
}
// Check if revoked
if cached.record.status == "revoked" || cached.record.status == "suspended" {
self.cache.invalidate(key_hash).await;
return Ok(None);
}
return Ok(Some(cached.record));
}
// 2. Query database
let record = self.db.get_api_key_by_hash(key_hash).await?;
// 3. Cache if valid
if let Some(ref r) = record {
// Only cache active keys
if r.status == "active" {
self.cache
.insert(
key_hash.to_string(),
CachedApiKey {
record: r.clone(),
cached_at: Utc::now(),
},
)
.await;
}
}
Ok(record)
}
/// Invalidate cache for a specific key
pub async fn invalidate(&self, key_hash: &str) {
self.cache.invalidate(key_hash).await;
}
/// Invalidate all cached keys
pub fn invalidate_all(&self) {
self.cache.invalidate_all();
}
/// Get cache statistics
pub async fn cache_stats(&self) -> CacheStats {
CacheStats {
entry_count: self.cache.entry_count(),
weighted_size: self.cache.weighted_size(),
}
}
/// Get rate limiter statistics for an IP
pub async fn rate_limit_stats(&self, ip: &str) -> Option<RateLimitStats> {
self.rate_limiter.get(ip).await.map(|info| RateLimitStats {
attempts: info.count,
first_attempt: info.first_attempt,
last_attempt: info.last_attempt,
locked_until: info.locked_until,
})
}
}
/// Cache statistics
#[derive(Debug, Clone)]
pub struct CacheStats {
pub entry_count: u64,
pub weighted_size: u64,
}
/// Rate limit statistics for an IP
#[derive(Debug, Clone)]
pub struct RateLimitStats {
pub attempts: u32,
pub first_attempt: DateTime<Utc>,
pub last_attempt: DateTime<Utc>,
pub locked_until: Option<DateTime<Utc>>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_validator_config_default() {
let config = ValidatorConfig::default();
assert!(config.cache_ttl_secs > 0);
assert!(config.cache_max_capacity > 0);
assert!(config.max_attempts > 0);
assert!(config.lockout_duration_secs > 0);
}
#[test]
fn test_rate_limit_result_variants() {
let allowed = RateLimitResult::Allowed;
let warning = RateLimitResult::AllowedWithWarning {
remaining_attempts: 3,
};
let locked = RateLimitResult::Locked {
remaining_seconds: 60,
attempts: 5,
};
match allowed {
RateLimitResult::Allowed => assert!(true),
_ => assert!(false),
}
match warning {
RateLimitResult::AllowedWithWarning { remaining_attempts } => {
assert_eq!(remaining_attempts, 3)
}
_ => assert!(false),
}
match locked {
RateLimitResult::Locked {
remaining_seconds,
attempts,
} => {
assert_eq!(remaining_seconds, 60);
assert_eq!(attempts, 5);
}
_ => assert!(false),
}
}
}

311
src/core/api_key/webhook.rs Normal file
View File

@@ -0,0 +1,311 @@
//! Webhook Notification Module
//!
//! Sends notifications via webhooks when API key events occur
use anyhow::{Context, Result};
use chrono::Utc;
use reqwest::Client;
use serde::{Deserialize, Serialize};
/// Webhook event types
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum WebhookEvent {
KeyCreated,
KeyRevoked,
KeyExpired,
KeyRotated,
AnomalyDetected,
RateLimited,
IpBlocked,
}
impl std::fmt::Display for WebhookEvent {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
WebhookEvent::KeyCreated => write!(f, "key_created"),
WebhookEvent::KeyRevoked => write!(f, "key_revoked"),
WebhookEvent::KeyExpired => write!(f, "key_expired"),
WebhookEvent::KeyRotated => write!(f, "key_rotated"),
WebhookEvent::AnomalyDetected => write!(f, "anomaly_detected"),
WebhookEvent::RateLimited => write!(f, "rate_limited"),
WebhookEvent::IpBlocked => write!(f, "ip_blocked"),
}
}
}
/// Webhook payload
#[derive(Debug, Clone, Serialize)]
pub struct WebhookPayload {
pub event: String,
pub timestamp: String,
pub data: serde_json::Value,
}
/// Webhook configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WebhookConfig {
pub url: String,
pub secret: String,
pub events: Vec<WebhookEvent>,
pub enabled: bool,
pub retry_count: u32,
pub timeout_secs: u64,
}
impl Default for WebhookConfig {
fn default() -> Self {
Self {
url: String::new(),
secret: String::new(),
events: vec![
WebhookEvent::KeyCreated,
WebhookEvent::KeyRevoked,
WebhookEvent::AnomalyDetected,
],
enabled: false,
retry_count: 3,
timeout_secs: 30,
}
}
}
/// Webhook notifier
pub struct WebhookNotifier {
client: Client,
config: WebhookConfig,
}
impl WebhookNotifier {
pub fn new(config: WebhookConfig) -> Result<Self> {
let client = Client::builder()
.timeout(std::time::Duration::from_secs(config.timeout_secs))
.build()
.context("Failed to create HTTP client")?;
Ok(Self { client, config })
}
pub fn from_env() -> Result<Option<Self>> {
let url = match std::env::var("WEBHOOK_URL") {
Ok(url) if !url.is_empty() => url,
_ => return Ok(None),
};
let secret = std::env::var("WEBHOOK_SECRET").unwrap_or_default();
let events = std::env::var("WEBHOOK_EVENTS")
.unwrap_or_default()
.split(',')
.filter_map(|s| match s.trim() {
"key_created" => Some(WebhookEvent::KeyCreated),
"key_revoked" => Some(WebhookEvent::KeyRevoked),
"key_expired" => Some(WebhookEvent::KeyExpired),
"key_rotated" => Some(WebhookEvent::KeyRotated),
"anomaly_detected" => Some(WebhookEvent::AnomalyDetected),
"rate_limited" => Some(WebhookEvent::RateLimited),
"ip_blocked" => Some(WebhookEvent::IpBlocked),
_ => None,
})
.collect::<Vec<_>>();
let events = if events.is_empty() {
vec![
WebhookEvent::KeyCreated,
WebhookEvent::KeyRevoked,
WebhookEvent::AnomalyDetected,
]
} else {
events
};
Ok(Some(Self::new(WebhookConfig {
url,
secret,
events,
enabled: true,
retry_count: 3,
timeout_secs: 30,
})?))
}
/// Check if an event should be sent
pub fn should_send(&self, event: &WebhookEvent) -> bool {
self.config.enabled && self.config.events.contains(event)
}
/// Generate HMAC signature for payload
fn sign(&self, payload: &str) -> String {
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(self.config.secret.as_bytes());
hasher.update(payload.as_bytes());
format!("{:x}", hasher.finalize())
}
/// Send a webhook notification
pub async fn notify(&self, event: WebhookEvent, data: serde_json::Value) -> Result<bool> {
if !self.should_send(&event) {
return Ok(false);
}
let payload = WebhookPayload {
event: event.to_string(),
timestamp: Utc::now().to_rfc3339(),
data,
};
let json = serde_json::to_string(&payload)?;
let signature = self.sign(&json);
let mut attempts = 0;
let max_attempts = self.config.retry_count;
while attempts < max_attempts {
attempts += 1;
let result = self
.client
.post(&self.config.url)
.header("Content-Type", "application/json")
.header("X-Webhook-Signature", &signature)
.header("X-Webhook-Event", event.to_string())
.body(json.clone())
.send()
.await;
match result {
Ok(response) if response.status().is_success() => {
tracing::info!("Webhook sent successfully: {:?}", event);
return Ok(true);
}
Ok(response) => {
tracing::warn!(
"Webhook failed with status {}: {:?}",
response.status(),
event
);
}
Err(e) => {
tracing::warn!("Webhook error (attempt {}): {}", attempts, e);
}
}
if attempts < max_attempts {
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
}
}
tracing::error!(
"Webhook failed after {} attempts: {:?}",
max_attempts,
event
);
Ok(false)
}
/// Notify key created
pub async fn notify_key_created(&self, key_id: &str, name: &str) -> Result<bool> {
self.notify(
WebhookEvent::KeyCreated,
serde_json::json!({
"key_id": key_id,
"name": name,
}),
)
.await
}
/// Notify key revoked
pub async fn notify_key_revoked(&self, key_id: &str, reason: &str) -> Result<bool> {
self.notify(
WebhookEvent::KeyRevoked,
serde_json::json!({
"key_id": key_id,
"reason": reason,
}),
)
.await
}
/// Notify anomaly detected
pub async fn notify_anomaly(
&self,
key_id: &str,
anomaly_type: &str,
severity: &str,
) -> Result<bool> {
self.notify(
WebhookEvent::AnomalyDetected,
serde_json::json!({
"key_id": key_id,
"anomaly_type": anomaly_type,
"severity": severity,
}),
)
.await
}
/// Notify IP blocked
pub async fn notify_ip_blocked(&self, ip: &str, reason: &str) -> Result<bool> {
self.notify(
WebhookEvent::IpBlocked,
serde_json::json!({
"ip": ip,
"reason": reason,
}),
)
.await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_webhook_config_default() {
let config = WebhookConfig::default();
assert!(!config.enabled);
assert_eq!(config.retry_count, 3);
assert_eq!(config.timeout_secs, 30);
}
#[test]
fn test_should_send() {
let config = WebhookConfig {
url: "https://example.com/webhook".to_string(),
secret: "secret".to_string(),
events: vec![WebhookEvent::KeyCreated, WebhookEvent::AnomalyDetected],
enabled: true,
retry_count: 3,
timeout_secs: 30,
};
let notifier = WebhookNotifier::new(config).unwrap();
assert!(notifier.should_send(&WebhookEvent::KeyCreated));
assert!(notifier.should_send(&WebhookEvent::AnomalyDetected));
assert!(!notifier.should_send(&WebhookEvent::KeyRevoked));
}
#[test]
fn test_sign() {
let config = WebhookConfig {
url: "https://example.com/webhook".to_string(),
secret: "mysecret".to_string(),
events: vec![],
enabled: true,
retry_count: 3,
timeout_secs: 30,
};
let notifier = WebhookNotifier::new(config).unwrap();
let sig1 = notifier.sign("test payload");
let sig2 = notifier.sign("test payload");
let sig3 = notifier.sign("different payload");
assert_eq!(sig1, sig2);
assert_ne!(sig1, sig3);
}
}

85
src/core/cache/keys.rs vendored Normal file
View File

@@ -0,0 +1,85 @@
pub const CATEGORY_VIDEOS: &str = "videos";
pub const CATEGORY_SEARCH: &str = "search";
pub const CATEGORY_HYBRID_SEARCH: &str = "hybrid_search";
pub const CATEGORY_N8N_SEARCH: &str = "n8n_search";
pub const CATEGORY_VIDEO_META: &str = "video_meta";
pub const CATEGORY_HEALTH: &str = "health";
pub const KEY_PREFIX_VIDEOS_LIST: &str = "videos:list:";
pub const KEY_PREFIX_VIDEO: &str = "video:";
pub const KEY_PREFIX_SEARCH: &str = "search:";
pub const KEY_PREFIX_SEARCH_HYBRID: &str = "search:hybrid:";
pub const KEY_PREFIX_SEARCH_N8N: &str = "search:n8n:";
pub const KEY_HEALTH: &str = "health:basic";
pub fn videos_list(page: usize, limit: usize) -> String {
format!("{}page={}:limit={}", KEY_PREFIX_VIDEOS_LIST, page, limit)
}
pub fn video_meta(uuid: &str) -> String {
format!("{}{}", KEY_PREFIX_VIDEO, uuid)
}
pub fn search(query_hash: &str) -> String {
format!("{}{}", KEY_PREFIX_SEARCH, query_hash)
}
pub fn hybrid_search(query_hash: &str) -> String {
format!("{}{}", KEY_PREFIX_SEARCH_HYBRID, query_hash)
}
pub fn n8n_search(query_hash: &str) -> String {
format!("{}{}", KEY_PREFIX_SEARCH_N8N, query_hash)
}
pub fn health() -> String {
KEY_HEALTH.to_string()
}
pub fn videos_list_prefix() -> String {
format!("^{}", KEY_PREFIX_VIDEOS_LIST)
}
pub fn video_prefix(uuid: &str) -> String {
format!("^{}{}", KEY_PREFIX_VIDEO, uuid)
}
pub fn search_prefix() -> String {
format!("^{}", KEY_PREFIX_SEARCH)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_videos_list() {
assert_eq!(videos_list(1, 20), "videos:list:page=1:limit=20");
assert_eq!(videos_list(2, 50), "videos:list:page=2:limit=50");
}
#[test]
fn test_video_meta() {
assert_eq!(video_meta("abc123"), "video:abc123");
}
#[test]
fn test_search() {
assert_eq!(search("hash123"), "search:hash123");
}
#[test]
fn test_hybrid_search() {
assert_eq!(hybrid_search("hash123"), "search:hybrid:hash123");
}
#[test]
fn test_n8n_search() {
assert_eq!(n8n_search("hash123"), "search:n8n:hash123");
}
#[test]
fn test_health() {
assert_eq!(health(), "health:basic");
}
}

10
src/core/cache/mod.rs vendored Normal file
View File

@@ -0,0 +1,10 @@
pub mod keys;
pub mod mongo_cache;
pub mod redis_cache;
#[cfg(test)]
mod tests;
pub use keys::*;
pub use mongo_cache::MongoCache;
pub use redis_cache::RedisCache;

311
src/core/cache/mongo_cache.rs vendored Normal file
View File

@@ -0,0 +1,311 @@
use anyhow::{Context, Result};
use bson::{doc, oid::ObjectId, DateTime as BsonDateTime, Document};
use chrono::{DateTime, Duration, Utc};
use mongodb::{Client, Collection, Database, IndexModel};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::sync::Arc;
use tokio::sync::RwLock;
use super::keys;
use crate::core::config::cache as cache_config;
const DB_NAME: &str = "momento";
const COLLECTION_NAME: &str = "cache";
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheEntry {
#[serde(rename = "_id", skip_serializing_if = "Option::is_none")]
pub id: Option<ObjectId>,
pub key: String,
pub value: serde_json::Value,
pub category: String,
pub created_at: DateTime<Utc>,
pub expires_at: DateTime<Utc>,
#[serde(default)]
pub hit_count: i64,
#[serde(default)]
pub last_access: DateTime<Utc>,
}
#[derive(Debug, Clone)]
pub struct CacheSettings {
pub enabled: bool,
pub ttl_videos: u64,
pub ttl_search: u64,
pub ttl_hybrid_search: u64,
pub ttl_video_meta: u64,
}
impl Default for CacheSettings {
fn default() -> Self {
Self {
enabled: *cache_config::MONGODB_CACHE_ENABLED,
ttl_videos: *cache_config::MONGODB_CACHE_TTL_VIDEOS,
ttl_search: *cache_config::MONGODB_CACHE_TTL_SEARCH,
ttl_hybrid_search: *cache_config::MONGODB_CACHE_TTL_HYBRID_SEARCH,
ttl_video_meta: *cache_config::MONGODB_CACHE_TTL_VIDEO_META,
}
}
}
#[derive(Clone)]
pub struct MongoCache {
#[allow(dead_code)]
client: Client,
db: Database,
collection: Collection<Document>,
settings: CacheSettings,
initialized: Arc<RwLock<bool>>,
}
impl MongoCache {
pub async fn init() -> Result<Self> {
let uri = crate::core::config::MONGODB_URL.as_str();
let client = Client::with_uri_str(uri)
.await
.context("Failed to connect to MongoDB")?;
let db = client.database(DB_NAME);
let collection: Collection<Document> = db.collection(COLLECTION_NAME);
let settings = CacheSettings::default();
let cache = Self {
client,
db,
collection,
settings,
initialized: Arc::new(RwLock::new(false)),
};
cache.ensure_indexes().await?;
Ok(cache)
}
async fn ensure_indexes(&self) -> Result<()> {
let mut guard = self.initialized.write().await;
if *guard {
return Ok(());
}
let ttl_index = IndexModel::builder()
.keys(doc! { "expires_at": 1 })
.options(
mongodb::options::IndexOptions::builder()
.expire_after(std::time::Duration::from_secs(0))
.build(),
)
.build();
let key_index = IndexModel::builder()
.keys(doc! { "key": 1 })
.options(
mongodb::options::IndexOptions::builder()
.unique(true)
.build(),
)
.build();
let category_index = IndexModel::builder().keys(doc! { "category": 1 }).build();
self.collection
.create_indexes([ttl_index, key_index, category_index], None)
.await
.context("Failed to create cache indexes")?;
*guard = true;
tracing::info!("MongoDB cache indexes ensured");
Ok(())
}
pub fn is_enabled(&self) -> bool {
self.settings.enabled
}
pub fn ttl_videos(&self) -> u64 {
self.settings.ttl_videos
}
pub fn ttl_search(&self) -> u64 {
self.settings.ttl_search
}
pub fn ttl_hybrid_search(&self) -> u64 {
self.settings.ttl_hybrid_search
}
pub fn ttl_video_meta(&self) -> u64 {
self.settings.ttl_video_meta
}
pub async fn get<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>> {
if !self.is_enabled() {
return Ok(None);
}
let filter = doc! { "key": key };
let result = self.collection.find_one(filter, None).await?;
if let Some(doc) = result {
if let Some(value_bson) = doc.get("value") {
let json_value: serde_json::Value = bson::from_bson(value_bson.clone())?;
let value: T = serde_json::from_value(json_value)?;
if let Ok(id) = doc.get_object_id("_id") {
let update = doc! {
"$inc": { "hit_count": 1i64 },
"$set": { "last_access": BsonDateTime::from_chrono(Utc::now()) }
};
if let Err(e) = self
.collection
.update_one(doc! { "_id": id }, update, None)
.await
{
tracing::warn!("Failed to update cache hit count: {}", e);
}
}
return Ok(Some(value));
}
}
Ok(None)
}
pub async fn set<T: Serialize>(
&self,
key: &str,
value: &T,
ttl_secs: u64,
category: &str,
) -> Result<()> {
if !self.is_enabled() {
return Ok(());
}
let now = Utc::now();
let expires_at = now + Duration::seconds(ttl_secs as i64);
let json_value = serde_json::to_value(value)?;
let bson_value = bson::to_bson(&json_value)?;
let filter = doc! { "key": key };
let update = doc! {
"$set": {
"value": bson_value,
"category": category,
"expires_at": BsonDateTime::from_chrono(expires_at),
"last_access": BsonDateTime::from_chrono(now),
},
"$setOnInsert": {
"key": key,
"created_at": BsonDateTime::from_chrono(now),
"hit_count": 0i64,
}
};
let options = mongodb::options::UpdateOptions::builder()
.upsert(true)
.build();
self.collection
.update_one(filter, update, options)
.await
.context("Failed to set cache entry")?;
Ok(())
}
pub async fn delete(&self, key: &str) -> Result<bool> {
if !self.is_enabled() {
return Ok(false);
}
let filter = doc! { "key": key };
let result = self.collection.delete_one(filter, None).await?;
Ok(result.deleted_count > 0)
}
pub async fn invalidate_category(&self, category: &str) -> Result<u64> {
if !self.is_enabled() {
return Ok(0);
}
let filter = doc! { "category": category };
let result = self.collection.delete_many(filter, None).await?;
let count = result.deleted_count;
tracing::debug!("Invalidated {} entries in category: {}", count, category);
Ok(count)
}
pub async fn invalidate_prefix(&self, prefix: &str) -> Result<u64> {
if !self.is_enabled() {
return Ok(0);
}
let regex_pattern = format!("^{}", prefix);
let filter = doc! { "key": { "$regex": &regex_pattern } };
let result = self.collection.delete_many(filter, None).await?;
let count = result.deleted_count;
tracing::debug!("Invalidated {} entries with prefix: {}", count, prefix);
Ok(count)
}
pub async fn get_or_fetch<F, Fut, T>(
&self,
key: &str,
ttl_secs: u64,
category: &str,
fetcher: F,
) -> Result<T>
where
F: FnOnce() -> Fut,
Fut: std::future::Future<Output = Result<T>>,
T: DeserializeOwned + Serialize,
{
if let Some(cached) = self.get::<T>(key).await? {
tracing::debug!("Cache hit for key: {}", key);
return Ok(cached);
}
tracing::debug!("Cache miss for key: {}", key);
let value = fetcher().await?;
if let Err(e) = self.set(key, &value, ttl_secs, category).await {
tracing::warn!("Failed to cache value: {}", e);
}
Ok(value)
}
pub async fn invalidate_videos_list(&self) -> Result<u64> {
self.invalidate_category(keys::CATEGORY_VIDEOS).await
}
pub async fn invalidate_video(&self, uuid: &str) -> Result<u64> {
let key = keys::video_meta(uuid);
let count = self.delete(&key).await? as u64;
let list_count = self.invalidate_videos_list().await?;
Ok(count + list_count)
}
pub async fn invalidate_all_search(&self) -> Result<u64> {
let count1 = self.invalidate_category(keys::CATEGORY_SEARCH).await?;
let count2 = self
.invalidate_category(keys::CATEGORY_HYBRID_SEARCH)
.await?;
let count3 = self.invalidate_category(keys::CATEGORY_N8N_SEARCH).await?;
Ok(count1 + count2 + count3)
}
pub async fn health_check(&self) -> Result<bool> {
self.db.run_command(doc! { "ping": 1 }, None).await?;
Ok(true)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cache_settings_default() {
let settings = CacheSettings::default();
assert!(settings.enabled);
assert_eq!(settings.ttl_videos, 300);
assert_eq!(settings.ttl_search, 300);
}
}

120
src/core/cache/tests.rs vendored Normal file
View File

@@ -0,0 +1,120 @@
use crate::core::cache::keys;
use crate::core::cache::mongo_cache::CacheSettings;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cache_settings_default() {
let settings = CacheSettings::default();
assert!(settings.enabled);
assert_eq!(settings.ttl_videos, 300);
assert_eq!(settings.ttl_search, 300);
assert_eq!(settings.ttl_hybrid_search, 600);
assert_eq!(settings.ttl_video_meta, 3600);
}
#[test]
fn test_cache_key_videos_list() {
let key = keys::videos_list(1, 20);
assert_eq!(key, "videos:list:page=1:limit=20");
let key2 = keys::videos_list(2, 50);
assert_eq!(key2, "videos:list:page=2:limit=50");
}
#[test]
fn test_cache_key_video_meta() {
let key = keys::video_meta("abc123");
assert_eq!(key, "video:abc123");
let uuid = "5dea6618a606e7c7";
let key = keys::video_meta(uuid);
assert_eq!(key, "video:5dea6618a606e7c7");
}
#[test]
fn test_cache_key_search() {
let key = keys::search("hash123");
assert_eq!(key, "search:hash123");
}
#[test]
fn test_cache_key_hybrid_search() {
let key = keys::hybrid_search("hash123");
assert_eq!(key, "search:hybrid:hash123");
}
#[test]
fn test_cache_key_n8n_search() {
let key = keys::n8n_search("hash123");
assert_eq!(key, "search:n8n:hash123");
}
#[test]
fn test_cache_key_health() {
let key = keys::health();
assert_eq!(key, "health:basic");
}
#[test]
fn test_cache_categories() {
assert_eq!(keys::CATEGORY_VIDEOS, "videos");
assert_eq!(keys::CATEGORY_SEARCH, "search");
assert_eq!(keys::CATEGORY_HYBRID_SEARCH, "hybrid_search");
assert_eq!(keys::CATEGORY_VIDEO_META, "video_meta");
assert_eq!(keys::CATEGORY_N8N_SEARCH, "n8n_search");
assert_eq!(keys::CATEGORY_HEALTH, "health");
}
#[test]
fn test_cache_key_prefixes() {
assert_eq!(keys::KEY_PREFIX_VIDEOS_LIST, "videos:list:");
assert_eq!(keys::KEY_PREFIX_VIDEO, "video:");
assert_eq!(keys::KEY_PREFIX_SEARCH, "search:");
assert_eq!(keys::KEY_PREFIX_SEARCH_HYBRID, "search:hybrid:");
assert_eq!(keys::KEY_PREFIX_SEARCH_N8N, "search:n8n:");
assert_eq!(keys::KEY_HEALTH, "health:basic");
}
#[test]
fn test_cache_ttl_values() {
let settings = CacheSettings::default();
assert!(settings.ttl_videos >= 60 && settings.ttl_videos <= 600);
assert!(settings.ttl_search >= 60 && settings.ttl_search <= 600);
assert!(settings.ttl_hybrid_search >= 60 && settings.ttl_hybrid_search <= 3600);
assert!(settings.ttl_video_meta >= 300 && settings.ttl_video_meta <= 7200);
}
#[test]
fn test_cache_key_videos_list_prefix_format() {
let key = keys::videos_list(1, 10);
assert!(key.starts_with("videos:list:"));
}
#[test]
fn test_cache_key_video_meta_prefix_format() {
let key = keys::video_meta("uuid123");
assert!(key.starts_with("video:"));
}
#[test]
fn test_cache_key_search_prefix_format() {
let key = keys::search("test");
assert!(key.starts_with("search:"));
}
#[test]
fn test_cache_key_hybrid_search_prefix_format() {
let key = keys::hybrid_search("test");
assert!(key.starts_with("search:hybrid:"));
}
#[test]
fn test_cache_key_n8n_search_prefix_format() {
let key = keys::n8n_search("test");
assert!(key.starts_with("search:n8n:"));
}
}

View File

@@ -1,14 +1,15 @@
use super::types::{Chunk, ChunkType};
use anyhow::Result;
use super::types::{Chunk, ChunkRule, ChunkType};
pub struct ChunkSplitter {
time_based_duration: f64,
fps: f64,
}
impl ChunkSplitter {
pub fn new(time_based_duration_seconds: f64) -> Self {
Self {
time_based_duration: time_based_duration_seconds,
fps: 24.0,
}
}
@@ -20,11 +21,14 @@ impl ChunkSplitter {
while current_time < duration {
let end_time = (current_time + self.time_based_duration).min(duration);
chunks.push(Chunk::new(
0, // file_id
uuid.to_string(),
index,
ChunkType::TimeBased,
ChunkRule::Rule1,
current_time,
end_time,
self.fps,
serde_json::json!({
"source": "time_based",
"duration": self.time_based_duration,
@@ -42,11 +46,14 @@ impl ChunkSplitter {
for (index, segment) in asr_segments.iter().enumerate() {
chunks.push(Chunk::new(
0, // file_id
uuid.to_string(),
index as u32,
ChunkType::Sentence,
ChunkRule::Rule1,
segment.start,
segment.end,
self.fps,
serde_json::json!({
"text": segment.text,
"speaker_id": segment.speaker_id,

View File

@@ -6,47 +6,140 @@ pub enum ChunkType {
TimeBased,
Sentence,
Cut,
Trace,
Story, // Parent chunk from story analysis
}
impl ChunkType {
pub fn as_str(&self) -> &'static str {
match self {
ChunkType::TimeBased => "time_based",
ChunkType::TimeBased => "time",
ChunkType::Sentence => "sentence",
ChunkType::Cut => "cut",
ChunkType::Trace => "trace",
ChunkType::Story => "story",
}
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum ChunkRule {
Rule1, // 直接轉換
Rule2, // 集合內容
}
impl ChunkRule {
pub fn as_str(&self) -> &'static str {
match self {
ChunkRule::Rule1 => "rule_1",
ChunkRule::Rule2 => "rule_2",
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Chunk {
pub file_id: i32,
pub uuid: String,
pub chunk_id: String,
pub chunk_index: u32,
pub chunk_type: ChunkType,
pub rule: ChunkRule,
pub start_time: f64,
pub end_time: f64,
pub fps: f64,
pub start_frame: i64,
pub end_frame: i64,
pub text_content: Option<String>,
pub content: serde_json::Value,
pub metadata: Option<serde_json::Value>,
pub vector_id: Option<String>,
pub frame_count: i32,
pub pre_chunk_ids: Vec<i32>,
pub parent_chunk_id: Option<String>, // For parent-child chunk hierarchy
pub child_chunk_ids: Vec<String>, // Child chunk IDs (for parent chunks)
}
impl Chunk {
#[allow(clippy::too_many_arguments)]
pub fn new(
file_id: i32,
uuid: String,
chunk_index: u32,
chunk_type: ChunkType,
rule: ChunkRule,
start_time: f64,
end_time: f64,
fps: f64,
content: serde_json::Value,
) -> Self {
let start_frame = (start_time * fps) as i64;
let end_frame = (end_time * fps) as i64;
let chunk_id = format!("{}_{:04}", chunk_type.as_str(), chunk_index);
Self {
file_id,
uuid,
chunk_id: chunk_id.clone(),
chunk_index,
chunk_type,
rule,
start_time,
end_time,
fps,
start_frame,
end_frame,
text_content: None,
content,
metadata: None,
vector_id: None,
frame_count: 0,
pre_chunk_ids: vec![],
parent_chunk_id: None,
child_chunk_ids: vec![],
}
}
pub fn with_metadata(mut self, metadata: serde_json::Value) -> Self {
self.metadata = Some(metadata);
self
}
pub fn with_vector_id(mut self, vector_id: String) -> Self {
self.vector_id = Some(vector_id);
self
}
pub fn with_text_content(mut self, text: String) -> Self {
self.text_content = Some(text);
self
}
pub fn with_frame_count(mut self, count: i32) -> Self {
self.frame_count = count;
self
}
pub fn with_pre_chunk_ids(mut self, ids: Vec<i32>) -> Self {
self.pre_chunk_ids = ids;
self
}
pub fn with_parent_chunk_id(mut self, parent_id: String) -> Self {
self.parent_chunk_id = Some(parent_id);
self
}
pub fn with_child_chunk_ids(mut self, child_ids: Vec<String>) -> Self {
self.child_chunk_ids = child_ids;
self
}
pub fn is_parent_chunk(&self) -> bool {
!self.child_chunk_ids.is_empty()
}
pub fn is_child_chunk(&self) -> bool {
self.parent_chunk_id.is_some()
}
}

View File

@@ -60,14 +60,6 @@ pub static SERVER_PORT: Lazy<u16> = Lazy::new(|| {
pub static REDIS_KEY_PREFIX: Lazy<String> =
Lazy::new(|| env::var("MOMENTRY_REDIS_PREFIX").unwrap_or_else(|_| "momentry:".to_string()));
/// User data root path (sftpgo data directory)
/// This is the parent directory containing user directories like ./demo/, ./warren/, ./momentry/
/// Example: /Users/accusys/momentry/var/sftpgo/data
pub static USER_DATA_ROOT: Lazy<String> = Lazy::new(|| {
env::var("MOMENTRY_USER_DATA_ROOT")
.unwrap_or_else(|_| "/Users/accusys/momentry/var/sftpgo/data".to_string())
});
pub mod processor {
use super::*;

View File

@@ -32,9 +32,20 @@ pub trait VectorStore: Send + Sync {
pub mod mongodb_db;
pub mod postgres_db;
pub mod qdrant_db;
pub mod redis_client;
pub mod redis_db;
pub mod sync_db;
pub use mongodb_db::MongoDb;
pub use postgres_db::{PostgresDb, VideoRecord};
pub use qdrant_db::QdrantDb;
pub use postgres_db::{
Bm25Result, CreateApiKeyConfig, HybridSearchResult, MonitorJob, MonitorJobStats,
MonitorJobStatus, PostgresDb, ProcessorJobStatus, ProcessorResult, ProcessorType, VideoRecord,
VideoStatus,
};
pub use qdrant_db::{QdrantDb, VectorPayload};
pub use redis_client::{
JobErrorMessage, MonitorJobRedis, ProcessorStatus as RedisProcessorStatus, ProgressData,
ProgressMessage, RedisClient,
};
pub use redis_db::RedisDb;
pub use sync_db::SyncDb;

View File

@@ -1,53 +1,269 @@
use anyhow::Result;
use anyhow::{Context, Result};
use async_trait::async_trait;
use std::sync::Arc;
use tokio::sync::RwLock;
use serde::{Deserialize, Serialize};
use super::Database;
use crate::core::chunk::types::{Chunk, ChunkRule, ChunkType};
pub struct MongoDb {
cache: Arc<RwLock<MongoCache>>,
base_url: String,
}
#[derive(Debug, Default)]
pub struct MongoCache {
documents: std::collections::HashMap<String, serde_json::Value>,
}
#[derive(Debug, serde::Serialize, serde::Deserialize)]
pub struct VideoDocument {
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChunkDocument {
pub uuid: String,
pub file_path: String,
pub file_name: String,
pub probe: serde_json::Value,
pub asr: Option<serde_json::Value>,
pub asrx: Option<serde_json::Value>,
pub ocr: Option<serde_json::Value>,
pub yolo: Option<serde_json::Value>,
pub face: Option<serde_json::Value>,
pub pose: Option<serde_json::Value>,
pub created_at: String,
pub updated_at: String,
pub chunk_id: String,
pub chunk_index: u32,
pub chunk_type: String,
pub start_time: f64,
pub end_time: f64,
pub fps: f64,
pub start_frame: i64,
pub end_frame: i64,
pub content: serde_json::Value,
pub metadata: Option<serde_json::Value>,
pub vector_id: Option<String>,
pub parent_chunk_id: Option<String>,
pub child_chunk_ids: Vec<String>,
}
impl From<Chunk> for ChunkDocument {
fn from(chunk: Chunk) -> Self {
Self {
uuid: chunk.uuid,
chunk_id: chunk.chunk_id,
chunk_index: chunk.chunk_index,
chunk_type: chunk.chunk_type.as_str().to_string(),
start_time: chunk.start_time,
end_time: chunk.end_time,
fps: chunk.fps,
start_frame: chunk.start_frame,
end_frame: chunk.end_frame,
content: chunk.content,
metadata: chunk.metadata,
vector_id: chunk.vector_id,
parent_chunk_id: chunk.parent_chunk_id,
child_chunk_ids: chunk.child_chunk_ids,
}
}
}
impl MongoDb {
pub async fn store_video(&self, _doc: &VideoDocument) -> Result<()> {
// TODO: Implement MongoDB client
pub fn new() -> Self {
let base_url =
std::env::var("MONGODB_URL").unwrap_or_else(|_| "http://localhost:27017".to_string());
Self { base_url }
}
}
impl Default for MongoDb {
fn default() -> Self {
Self::new()
}
}
impl MongoDb {
pub async fn store_chunk(&self, chunk: &Chunk) -> Result<()> {
let doc: ChunkDocument = chunk.clone().into();
let client = reqwest::Client::new();
let url = format!("{}/momentry/chunks", self.base_url);
client
.post(&url)
.json(&doc)
.send()
.await
.context("Failed to store chunk in MongoDB")?;
Ok(())
}
pub async fn get_video(&self, _uuid: &str) -> Result<Option<VideoDocument>> {
// TODO: Implement MongoDB client
Ok(None)
pub async fn get_chunks_by_uuid(&self, uuid: &str) -> Result<Vec<Chunk>> {
let client = reqwest::Client::new();
let url = format!(
"{}/momentry/chunks?filter={{\"uuid\":\"{}\"}}",
self.base_url, uuid
);
let response = client
.get(&url)
.send()
.await
.context("Failed to get chunks from MongoDB")?;
#[derive(Deserialize)]
struct MongoResponse {
documents: Vec<ChunkDocument>,
}
let result: MongoResponse = response.json().await?;
let chunks: Vec<Chunk> = result
.documents
.into_iter()
.map(|doc| {
let chunk_type = match doc.chunk_type.as_str() {
"sentence" => ChunkType::Sentence,
"cut" => ChunkType::Cut,
"time_based" => ChunkType::TimeBased,
"trace" => ChunkType::Trace,
"story" => ChunkType::Story,
_ => ChunkType::Sentence,
};
Chunk {
file_id: 0,
uuid: doc.uuid,
chunk_id: doc.chunk_id,
chunk_index: doc.chunk_index,
chunk_type,
rule: ChunkRule::Rule1,
start_time: doc.start_time,
end_time: doc.end_time,
fps: doc.fps,
start_frame: doc.start_frame,
end_frame: doc.end_frame,
text_content: None,
content: doc.content,
metadata: doc.metadata,
vector_id: doc.vector_id,
frame_count: 0,
pre_chunk_ids: vec![],
parent_chunk_id: doc.parent_chunk_id,
child_chunk_ids: doc.child_chunk_ids,
}
})
.collect();
Ok(chunks)
}
pub async fn search_text(&self, query: &str) -> Result<Vec<Chunk>> {
let client = reqwest::Client::new();
let url = format!(
"{}/momentry/chunks?filter={{\"$text\":{{\"$search\":\"{}\"}}}}",
self.base_url, query
);
let response = client
.get(&url)
.send()
.await
.context("Failed to search text in MongoDB")?;
#[derive(Deserialize)]
struct MongoResponse {
documents: Vec<ChunkDocument>,
}
let result: MongoResponse = response.json().await?;
let chunks: Vec<Chunk> = result
.documents
.into_iter()
.map(|doc| {
let chunk_type = match doc.chunk_type.as_str() {
"sentence" => ChunkType::Sentence,
"cut" => ChunkType::Cut,
"time" => ChunkType::TimeBased,
"trace" => ChunkType::Trace,
"story" => ChunkType::Story,
_ => ChunkType::Sentence,
};
Chunk {
file_id: 0,
uuid: doc.uuid,
chunk_id: doc.chunk_id,
chunk_index: doc.chunk_index,
chunk_type,
rule: ChunkRule::Rule1,
start_time: doc.start_time,
end_time: doc.end_time,
fps: doc.fps,
start_frame: doc.start_frame,
end_frame: doc.end_frame,
text_content: None,
content: doc.content,
metadata: doc.metadata,
vector_id: doc.vector_id,
frame_count: 0,
pre_chunk_ids: vec![],
parent_chunk_id: doc.parent_chunk_id,
child_chunk_ids: doc.child_chunk_ids,
}
})
.collect();
Ok(chunks)
}
pub async fn get_all_chunks(&self) -> Result<Vec<Chunk>> {
let client = reqwest::Client::new();
let url = format!("{}/momentry/chunks", self.base_url);
let response = client
.get(&url)
.send()
.await
.context("Failed to get all chunks from MongoDB")?;
#[derive(Deserialize)]
struct MongoResponse {
documents: Vec<ChunkDocument>,
}
let result: MongoResponse = response.json().await?;
let chunks: Vec<Chunk> = result
.documents
.into_iter()
.map(|doc| {
let chunk_type = match doc.chunk_type.as_str() {
"sentence" => ChunkType::Sentence,
"cut" => ChunkType::Cut,
"time" => ChunkType::TimeBased,
"trace" => ChunkType::Trace,
"story" => ChunkType::Story,
_ => ChunkType::Sentence,
};
Chunk {
file_id: 0,
uuid: doc.uuid,
chunk_id: doc.chunk_id,
chunk_index: doc.chunk_index,
chunk_type,
rule: ChunkRule::Rule1,
start_time: doc.start_time,
end_time: doc.end_time,
fps: doc.fps,
start_frame: doc.start_frame,
end_frame: doc.end_frame,
text_content: None,
content: doc.content,
metadata: doc.metadata,
vector_id: doc.vector_id,
frame_count: 0,
pre_chunk_ids: vec![],
parent_chunk_id: doc.parent_chunk_id,
child_chunk_ids: doc.child_chunk_ids,
}
})
.collect();
Ok(chunks)
}
pub async fn get_chunk_count(&self) -> Result<i64> {
let chunks = self.get_all_chunks().await?;
Ok(chunks.len() as i64)
}
}
#[async_trait]
impl Database for MongoDb {
impl super::Database for MongoDb {
async fn init() -> Result<Self> {
// TODO: Initialize MongoDB client
Ok(Self {
cache: Arc::new(RwLock::new(MongoCache::default())),
})
Ok(Self::new())
}
}

View File

@@ -2669,8 +2669,9 @@ impl PostgresDb {
pub async fn get_processor_results_by_job(&self, job_id: i32) -> Result<Vec<ProcessorResult>> {
let rows = sqlx::query(
r#"
SELECT id, job_id, processor, status, started_at, completed_at, duration_secs,
error_message, output_data, retry_count, created_at, updated_at
SELECT id, job_id, processor, status, output_path, started_at, completed_at,
error_message, progress_total, progress_current, last_checkpoint,
created_at, updated_at, duration_secs
FROM processor_results
WHERE job_id = $1
ORDER BY created_at ASC
@@ -2685,6 +2686,10 @@ impl PostgresDb {
.map(|r| {
let status_str: String = r.get(3);
let processor_type_str: String = r.get(2);
let started_at: Option<chrono::NaiveDateTime> = r.get(5);
let completed_at: Option<chrono::NaiveDateTime> = r.get(6);
let created_at: chrono::NaiveDateTime = r.get(11);
let updated_at: Option<chrono::NaiveDateTime> = r.get(12);
ProcessorResult {
id: r.get(0),
job_id: r.get(1),
@@ -2692,14 +2697,14 @@ impl PostgresDb {
.unwrap_or(ProcessorType::Asr),
status: ProcessorJobStatus::from_db_str(&status_str)
.unwrap_or(ProcessorJobStatus::Pending),
started_at: r.get(4),
completed_at: r.get(5),
duration_secs: r.get(6),
started_at: started_at.map(|t| t.to_string()),
completed_at: completed_at.map(|t| t.to_string()),
duration_secs: r.get(13),
error_message: r.get(7),
output_data: r.get(8),
retry_count: r.get(9),
created_at: r.get(10),
updated_at: r.get(11),
output_data: None,
retry_count: 0,
created_at: created_at.to_string(),
updated_at: updated_at.map(|t| t.to_string()).unwrap_or_default(),
}
})
.collect();

View File

@@ -1,46 +1,330 @@
use anyhow::Result;
use anyhow::{Context, Result};
use async_trait::async_trait;
use std::sync::Arc;
use tokio::sync::RwLock;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use super::{Database, SearchResult, VectorStore};
pub struct QdrantDb {
client: Client,
base_url: String,
api_key: String,
collection_name: String,
cache: Arc<RwLock<QdrantCache>>,
}
#[derive(Debug, Default)]
pub struct QdrantCache {
vectors: std::collections::HashMap<String, Vec<f32>>,
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VectorPayload {
pub uuid: String,
pub chunk_id: String,
pub chunk_type: String,
pub start_time: f64,
pub end_time: f64,
pub text: Option<String>,
}
impl QdrantDb {
pub async fn init_collection(&self) -> Result<()> {
// TODO: Implement actual Qdrant client
// This is a placeholder
pub fn new() -> Self {
let base_url =
std::env::var("QDRANT_URL").unwrap_or_else(|_| "http://localhost:6333".to_string());
let api_key = std::env::var("QDRANT_API_KEY")
.unwrap_or_else(|_| "Test3200Test3200Test3200".to_string());
let collection_name =
std::env::var("QDRANT_COLLECTION").unwrap_or_else(|_| "chunks_v3".to_string());
Self {
client: Client::new(),
base_url,
api_key,
collection_name,
}
}
}
impl Default for QdrantDb {
fn default() -> Self {
Self::new()
}
}
impl QdrantDb {
pub async fn init_collection(&self, vector_dim: usize) -> Result<()> {
let url = format!("{}/collections/{}", self.base_url, self.collection_name);
let response = self
.client
.get(&url)
.header("api-key", &self.api_key)
.send()
.await?;
if response.status().is_success() {
return Ok(());
}
let create_url = format!("{}/collections", self.base_url);
let body = serde_json::json!({
"vectors": {
"size": vector_dim,
"distance": "Cosine"
}
});
self.client
.post(&create_url)
.header("api-key", &self.api_key)
.header("Content-Type", "application/json")
.json(&body)
.send()
.await
.context("Failed to create Qdrant collection")?;
Ok(())
}
pub async fn upsert_vector(&self, chunk_id: &str, vector: &[f32]) -> Result<()> {
let mut cache = self.cache.write().await;
cache.vectors.insert(chunk_id.to_string(), vector.to_vec());
pub async fn upsert_vector(
&self,
_chunk_id: &str,
vector: &[f32],
payload: VectorPayload,
) -> Result<()> {
let url = format!(
"{}/collections/{}/points",
self.base_url, self.collection_name
);
let mut payload_map = HashMap::new();
payload_map.insert("uuid".to_string(), serde_json::json!(payload.uuid));
payload_map.insert("chunk_id".to_string(), serde_json::json!(payload.chunk_id));
payload_map.insert(
"chunk_type".to_string(),
serde_json::json!(payload.chunk_type),
);
payload_map.insert(
"start_time".to_string(),
serde_json::json!(payload.start_time),
);
payload_map.insert("end_time".to_string(), serde_json::json!(payload.end_time));
if let Some(text) = payload.text {
payload_map.insert("text".to_string(), serde_json::json!(text));
}
let point_id = uuid::Uuid::new_v4().to_string();
let body = serde_json::json!({
"points": [{
"id": point_id,
"vector": vector,
"payload": payload_map
}]
});
self.client
.put(&url)
.header("api-key", &self.api_key)
.header("Content-Type", "application/json")
.json(&body)
.send()
.await
.context("Failed to upsert vector in Qdrant")?;
Ok(())
}
pub async fn search(&self, query_vector: &[f32], limit: usize) -> Result<Vec<SearchResult>> {
let url = format!(
"{}/collections/{}/points/search",
self.base_url, self.collection_name
);
let body = serde_json::json!({
"vector": query_vector,
"limit": limit,
"with_payload": true
});
let response = self
.client
.post(&url)
.header("api-key", &self.api_key)
.header("Content-Type", "application/json")
.json(&body)
.send()
.await
.context("Failed to search Qdrant")?;
#[derive(Deserialize)]
struct QdrantSearchResult {
result: Vec<QdrantPoint>,
}
#[derive(Deserialize)]
struct QdrantPoint {
#[allow(dead_code)]
id: serde_json::Value,
score: f64,
payload: HashMap<String, serde_json::Value>,
}
let result: QdrantSearchResult = response.json().await?;
let search_results: Vec<SearchResult> = result
.result
.into_iter()
.map(|r| {
let chunk_id = r
.payload
.get("chunk_id")
.and_then(|v| v.as_str())
.unwrap_or("unknown")
.to_string();
SearchResult {
chunk_id,
score: r.score as f32,
}
})
.collect();
Ok(search_results)
}
pub async fn search_in_uuid(
&self,
query_vector: &[f64],
uuid: &str,
limit: usize,
) -> Result<Vec<SearchResult>> {
let url = format!(
"{}/collections/{}/points/search",
self.base_url, self.collection_name
);
let body = serde_json::json!({
"vector": query_vector,
"limit": limit,
"with_payload": true,
"filter": {
"must": [
{
"key": "uuid",
"match": {
"value": uuid
}
}
]
}
});
let response = self
.client
.post(&url)
.header("api-key", &self.api_key)
.header("Content-Type", "application/json")
.json(&body)
.send()
.await
.context("Failed to search Qdrant")?;
#[derive(Deserialize)]
struct QdrantSearchResult {
result: Vec<QdrantPoint>,
}
#[derive(Deserialize)]
struct QdrantPoint {
#[allow(dead_code)]
id: serde_json::Value,
score: f64,
payload: HashMap<String, serde_json::Value>,
}
let result: QdrantSearchResult = response.json().await?;
let search_results: Vec<SearchResult> = result
.result
.into_iter()
.map(|r| {
let chunk_id = r
.payload
.get("chunk_id")
.and_then(|v| v.as_str())
.unwrap_or("unknown")
.to_string();
SearchResult {
chunk_id,
score: r.score as f32,
}
})
.collect();
Ok(search_results)
}
pub async fn delete_by_uuid(&self, uuid: &str) -> Result<()> {
let url = format!(
"{}/collections/{}/points/delete",
self.base_url, self.collection_name
);
let body = serde_json::json!({
"filter": {
"must": [
{
"key": "uuid",
"match": {
"value": uuid
}
}
]
}
});
self.client
.post(&url)
.header("api-key", &self.api_key)
.header("Content-Type", "application/json")
.json(&body)
.send()
.await
.context("Failed to delete points from Qdrant")?;
Ok(())
}
pub async fn get_point_count(&self) -> Result<usize> {
let url = format!(
"{}/collections/{}/info",
self.base_url, self.collection_name
);
let response = self
.client
.get(&url)
.header("api-key", &self.api_key)
.send()
.await
.context("Failed to get collection info")?;
#[derive(Deserialize)]
struct CollectionInfo {
result: CollectionStatus,
}
#[derive(Deserialize)]
struct CollectionStatus {
points_count: usize,
}
let result: CollectionInfo = response.json().await?;
Ok(result.result.points_count)
}
}
#[async_trait]
impl Database for QdrantDb {
async fn init() -> Result<Self> {
let collection_name =
std::env::var("QDRANT_COLLECTION").unwrap_or_else(|_| "momentry_chunks".to_string());
let db = Self {
collection_name,
cache: Arc::new(RwLock::new(QdrantCache::default())),
};
db.init_collection().await?;
let db = Self::new();
db.init_collection(768).await?;
Ok(db)
}
}
@@ -48,41 +332,18 @@ impl Database for QdrantDb {
#[async_trait]
impl VectorStore for QdrantDb {
async fn store_vector(&self, chunk_id: &str, vector: &[f32]) -> Result<()> {
self.upsert_vector(chunk_id, vector).await
let payload = VectorPayload {
uuid: String::new(),
chunk_id: chunk_id.to_string(),
chunk_type: String::new(),
start_time: 0.0,
end_time: 0.0,
text: None,
};
self.upsert_vector(chunk_id, vector, payload).await
}
async fn search(&self, query_vector: &[f32], limit: usize) -> Result<Vec<SearchResult>> {
// Simple cosine similarity search (placeholder)
let cache = self.cache.read().await;
let mut results: Vec<SearchResult> = Vec::new();
for (chunk_id, vector) in &cache.vectors {
let similarity = cosine_similarity(query_vector, vector);
results.push(SearchResult {
chunk_id: chunk_id.clone(),
score: similarity,
});
}
results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
results.truncate(limit);
Ok(results)
self.search(query_vector, limit).await
}
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
return 0.0;
}
dot_product / (norm_a * norm_b)
}

View File

@@ -6,6 +6,7 @@ use tokio::sync::RwLock;
use super::Database;
pub struct RedisDb {
#[allow(dead_code)]
state: Arc<RwLock<RedisState>>,
}

155
src/core/db/sync_db.rs Normal file
View File

@@ -0,0 +1,155 @@
use anyhow::{Context, Result};
use serde_json::json;
use crate::core::chunk::types::{Chunk, ChunkRule, ChunkType};
use crate::core::db::mongodb_db::MongoDb;
use crate::core::db::postgres_db::PostgresDb;
use crate::core::db::qdrant_db::{QdrantDb, VectorPayload};
use crate::core::processor::asr::{AsrResult, AsrSegment};
pub struct SyncDb {
postgres: PostgresDb,
mongodb: MongoDb,
qdrant: QdrantDb,
}
impl SyncDb {
pub async fn new(postgres: PostgresDb, mongodb: MongoDb, qdrant: QdrantDb) -> Result<Self> {
Ok(Self {
postgres,
mongodb,
qdrant,
})
}
pub async fn store_chunk_with_vector(&self, mut chunk: Chunk, text: &str) -> Result<Chunk> {
let uuid = chunk.uuid.clone();
let chunk_id = chunk.chunk_id.clone();
let chunk_type = chunk.chunk_type.as_str().to_string();
let start_time = chunk.start_time;
let end_time = chunk.end_time;
let vector = self.embed_text(text).await?;
let vector_id = format!("vec_{}", chunk_id);
chunk = chunk.with_vector_id(vector_id.clone());
let postgres_result = self.postgres.store_chunk(&chunk).await;
if let Err(e) = &postgres_result {
tracing::warn!("Failed to store chunk in PostgreSQL: {}", e);
}
let mongo_result = self.mongodb.store_chunk(&chunk).await;
if let Err(e) = &mongo_result {
tracing::warn!("Failed to store chunk in MongoDB: {}", e);
}
let payload = VectorPayload {
uuid: uuid.clone(),
chunk_id: chunk_id.clone(),
chunk_type,
start_time,
end_time,
text: Some(text.to_string()),
};
let qdrant_result = self
.qdrant
.upsert_vector(&vector_id, &vector, payload)
.await;
if let Err(e) = &qdrant_result {
tracing::warn!("Failed to store vector in Qdrant: {}", e);
}
let pg_vector_result = self.postgres.store_vector(&vector_id, &vector, &uuid).await;
if let Err(e) = &pg_vector_result {
tracing::warn!("Failed to store vector in PostgreSQL: {}", e);
}
postgres_result?;
mongo_result?;
qdrant_result?;
Ok(chunk)
}
pub async fn embed_text(&self, text: &str) -> Result<Vec<f32>> {
let client = reqwest::Client::new();
let response = client
.post("http://localhost:11434/api/embeddings")
.json(&json!({
"model": "nomic-embed-text",
"prompt": text
}))
.send()
.await
.context("Failed to call Ollama embedding API")?;
#[derive(serde::Deserialize)]
struct EmbeddingResponse {
embedding: Vec<f32>,
}
let embedding = response
.json::<EmbeddingResponse>()
.await
.context("Failed to parse embedding response")?;
Ok(embedding.embedding)
}
pub async fn process_asr_to_chunks(
&self,
uuid: &str,
asr_result: &AsrResult,
) -> Result<Vec<Chunk>> {
let mut chunks = Vec::new();
for (i, segment) in asr_result.segments.iter().enumerate() {
let segment: &AsrSegment = segment;
let content = json!({
"text": segment.text,
"text_normalized": segment.text.to_lowercase(),
});
let metadata = json!({
"language": asr_result.language,
"language_probability": asr_result.language_probability,
});
let chunk = Chunk::new(
0, // file_id - will be set later
uuid.to_string(),
i as u32,
ChunkType::Sentence,
ChunkRule::Rule1,
segment.start,
segment.end,
24.0, // fps
content,
)
.with_metadata(metadata);
chunks.push(chunk);
}
let mut stored_chunks = Vec::new();
for chunk in chunks {
let text = chunk
.content
.get("text")
.and_then(|t| t.as_str())
.unwrap_or("")
.to_string();
match self.store_chunk_with_vector(chunk, &text).await {
Ok(stored) => stored_chunks.push(stored),
Err(e) => {
tracing::error!("Failed to store chunk: {}", e);
}
}
}
Ok(stored_chunks)
}
}

View File

@@ -1,66 +1,80 @@
use anyhow::Result;
use anyhow::{Context, Result};
use reqwest::Client;
use serde::{Deserialize, Serialize};
pub struct Embedder {
model_path: String,
model: String,
client: Client,
base_url: String,
}
#[derive(Serialize)]
struct EmbedRequest {
model: String,
prompt: String,
}
#[derive(Deserialize, Debug)]
struct EmbedResponse {
embedding: Vec<f32>,
}
impl Embedder {
pub fn new(model_path: String) -> Self {
Self { model_path }
pub fn new(model: String) -> Self {
Self {
model,
client: Client::new(),
base_url: "http://localhost:11434".to_string(),
}
}
pub async fn embed_text(&self, text: &str) -> Result<Vec<f32>> {
// TODO: Implement comic-embed-text model loading and inference
// This is a placeholder that generates a random 768-dimensional vector
//
// Implementation would use:
// - candle (Rust ML framework) or
// - ort (ONNX Runtime) to run the model
//
// Example with ort:
// let session = Session::builder()?
// .with_execution_providers([CPUExecutionProvider::default().build()])?
// .with_model_from_file(&self.model_path)?;
//
// // Preprocess text to tensor
// let input = preprocess_text(text);
//
// // Run inference
// let output = session.run(vec![input])?;
//
// // Extract embeddings
// let embedding = output[0].view()[..768].to_vec();
self.embed_with_prefix(text, "").await
}
let dim = 768;
let mut embedding = vec![0.0f32; dim];
pub async fn embed_document(&self, text: &str) -> Result<Vec<f32>> {
self.embed_with_prefix(text, "search_document: ").await
}
// Simple hash-based embedding for now
let hash = self.hash_text(text);
for i in 0..dim {
embedding[i] = ((hash >> i) & 1) as f32;
pub async fn embed_query(&self, text: &str) -> Result<Vec<f32>> {
self.embed_with_prefix(text, "search_query: ").await
}
async fn embed_with_prefix(&self, text: &str, prefix: &str) -> Result<Vec<f32>> {
let url = format!("{}/api/embeddings", self.base_url);
let prompt = format!("{}{}", prefix, text);
let response = self
.client
.post(&url)
.json(&EmbedRequest {
model: self.model.clone(),
prompt,
})
.send()
.await
.context("Failed to send embedding request to Ollama")?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
anyhow::bail!("Ollama API error ({}): {}", status, body);
}
// Normalize
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for v in &mut embedding {
*v /= norm;
}
}
let result: EmbedResponse = response
.json()
.await
.context("Failed to parse Ollama response")?;
Ok(embedding)
Ok(result.embedding)
}
pub async fn embed_chunk_content(&self, chunk: &crate::core::chunk::Chunk) -> Result<Vec<f32>> {
let text = serde_json::to_string(&chunk.content)?;
self.embed_text(&text).await
self.embed_document(&text).await
}
fn hash_text(&self, text: &str) -> u64 {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
text.hash(&mut hasher);
hasher.finish()
pub fn dimension(&self) -> usize {
768
}
}

View File

@@ -1,4 +1,7 @@
pub mod api_key;
pub mod cache;
pub mod chunk;
pub mod config;
pub mod db;
pub mod embedding;
pub mod overlay;

View File

@@ -1,7 +1,10 @@
use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use std::path::Path;
use std::process::Command;
use std::time::Duration;
use super::executor::PythonExecutor;
const ASR_TIMEOUT: Duration = Duration::from_secs(3600);
#[derive(Debug, Serialize, Deserialize)]
pub struct AsrResult {
@@ -17,53 +20,33 @@ pub struct AsrSegment {
pub text: String,
}
pub async fn process_asr(video_path: &str, output_path: &str) -> Result<AsrResult> {
let script_path = Path::new(env!("CARGO_MANIFEST_DIR"))
.join("scripts")
.join("asr_processor.py");
pub async fn process_asr(
video_path: &str,
output_path: &str,
uuid: Option<&str>,
) -> Result<AsrResult> {
let executor = PythonExecutor::new()?;
let script_path = executor.script_path("asr_processor.py");
let venv_python = Path::new(env!("CARGO_MANIFEST_DIR"))
.join("venv")
.join("bin")
.join("python");
tracing::info!("[ASR] Starting ASR processing: {}", video_path);
println!("[ASR] Starting ASR processing...");
println!("[ASR] Video: {}", video_path);
let output = Command::new(venv_python)
.arg(script_path)
.arg(video_path)
.arg(output_path)
.output()
.context("Failed to run ASR processor")?;
let stderr = String::from_utf8_lossy(&output.stderr);
for line in stderr.lines() {
if line.starts_with("ASR_START") {
println!("[ASR] Loading model...");
} else if line.starts_with("ASR_LANGUAGE:") {
let lang = line.trim_start_matches("ASR_LANGUAGE:");
println!("[ASR] Detected language: {}", lang);
} else if line.starts_with("ASR_PROGRESS:") {
let count = line.trim_start_matches("ASR_PROGRESS:");
println!("[ASR] Processed {} segments...", count);
} else if line.starts_with("ASR_COMPLETE:") {
let count = line.trim_start_matches("ASR_COMPLETE:");
println!("[ASR] Completed! Total: {} segments", count);
}
}
if !output.status.success() {
anyhow::bail!("ASR failed: {}", stderr);
}
executor
.run(
"asr_processor.py",
&[video_path, output_path],
uuid,
"ASR",
Some(ASR_TIMEOUT),
)
.await
.with_context(|| format!("Failed to run {:?}", script_path))?;
let json_str = std::fs::read_to_string(output_path).context("Failed to read ASR output")?;
let result: AsrResult =
serde_json::from_str(&json_str).context("Failed to parse ASR output")?;
println!(
tracing::info!(
"[ASR] Result: {} segments, language: {:?}",
result.segments.len(),
result.language
@@ -71,3 +54,72 @@ pub async fn process_asr(video_path: &str, output_path: &str) -> Result<AsrResul
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_asr_result_serialization() {
let result = AsrResult {
language: Some("en".to_string()),
language_probability: Some(0.95),
segments: vec![
AsrSegment {
start: 0.0,
end: 2.5,
text: "Hello world".to_string(),
},
AsrSegment {
start: 2.5,
end: 5.0,
text: "Test speech".to_string(),
},
],
};
let json = serde_json::to_string(&result).unwrap();
assert!(json.contains("Hello world"));
assert!(json.contains("en"));
}
#[test]
fn test_asr_result_deserialization() {
let json = r#"{
"language": "zh",
"language_probability": 0.98,
"segments": [
{"start": 0.0, "end": 1.5, "text": "測試"}
]
}"#;
let result: AsrResult = serde_json::from_str(json).unwrap();
assert_eq!(result.language, Some("zh".to_string()));
assert_eq!(result.language_probability, Some(0.98));
assert_eq!(result.segments.len(), 1);
assert_eq!(result.segments[0].text, "測試");
}
#[test]
fn test_asr_segment_default() {
let segment = AsrSegment {
start: 0.0,
end: 1.0,
text: String::new(),
};
assert_eq!(segment.start, 0.0);
assert_eq!(segment.end, 1.0);
assert!(segment.text.is_empty());
}
#[test]
fn test_asr_result_empty_segments() {
let result = AsrResult {
language: None,
language_probability: None,
segments: vec![],
};
assert!(result.language.is_none());
assert!(result.segments.is_empty());
}
}

View File

@@ -1,8 +1,16 @@
use anyhow::Result;
use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use std::time::Duration;
use tokio::process::Command;
use tokio::time::timeout;
use super::executor::PythonExecutor;
const ASRX_TIMEOUT: Duration = Duration::from_secs(7200);
#[derive(Debug, Serialize, Deserialize)]
pub struct AsrxResult {
pub language: Option<String>,
pub segments: Vec<AsrxSegment>,
}
@@ -11,18 +19,130 @@ pub struct AsrxSegment {
pub start: f64,
pub end: f64,
pub text: String,
pub speaker_id: String,
pub speaker_embedding: Option<Vec<f32>>,
pub speaker_id: Option<String>,
}
pub async fn process_asrx(video_path: &str, output_path: &str) -> Result<AsrxResult> {
// TODO: Implement speaker diarization
// Options:
// 1. Use pyannote.audio
// 2. Use whisperx
// 3. Use Python subprocess
pub async fn process_asrx(
video_path: &str,
output_path: &str,
uuid: Option<&str>,
) -> Result<AsrxResult> {
let executor = PythonExecutor::new()?;
let script_path = executor.script_path("asrx_processor.py");
println!("Processing speaker diarization for: {}", video_path);
tracing::info!("[ASRX] Starting speaker diarization: {}", video_path);
Ok(AsrxResult { segments: vec![] })
if !script_path.exists() {
tracing::warn!("[ASRX] Script not found, returning empty result");
return Ok(AsrxResult {
language: None,
segments: vec![],
});
}
let mut cmd = Command::new(executor.python_path());
cmd.arg(&script_path).arg(video_path).arg(output_path);
if let Some(u) = uuid {
cmd.arg("--uuid").arg(u);
}
cmd.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped());
let child = cmd.spawn().context("Failed to run ASRX processor")?;
let output = match timeout(ASRX_TIMEOUT, child.wait_with_output()).await {
Ok(Ok(output)) => output,
Ok(Err(e)) => return Err(e).context("Failed to run ASRX processor"),
Err(_) => anyhow::bail!("ASRX processing timed out after {:?}", ASRX_TIMEOUT),
};
let stderr = String::from_utf8_lossy(&output.stderr);
for line in stderr.lines() {
if line.starts_with("ASRX_START") {
tracing::info!("[ASRX] Loading model...");
} else if line.starts_with("ASRX_PROGRESS:") {
let count = line.trim_start_matches("ASRX_PROGRESS:");
tracing::info!("[ASRX] Processed {} segments...", count);
} else if line.starts_with("ASRX_COMPLETE:") {
let count = line.trim_start_matches("ASRX_COMPLETE:");
tracing::info!("[ASRX] Completed! Total: {} segments", count);
}
}
if !output.status.success() {
anyhow::bail!("ASRX failed: {}", stderr);
}
let json_str = std::fs::read_to_string(output_path).context("Failed to read ASRX output")?;
let result: AsrxResult =
serde_json::from_str(&json_str).context("Failed to parse ASRX output")?;
tracing::info!("[ASRX] Result: {} segments", result.segments.len());
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_asrx_result_serialization() {
let result = AsrxResult {
language: Some("en".to_string()),
segments: vec![AsrxSegment {
start: 0.0,
end: 2.5,
text: "Hello".to_string(),
speaker_id: Some("SPEAKER_00".to_string()),
}],
};
let json = serde_json::to_string(&result).unwrap();
assert!(json.contains("Hello"));
assert!(json.contains("SPEAKER_00"));
}
#[test]
fn test_asrx_result_deserialization() {
let json = r#"{
"language": "zh",
"segments": [
{"start": 0.0, "end": 1.5, "text": "測試", "speaker_id": "SPEAKER_01"}
]
}"#;
let result: AsrxResult = serde_json::from_str(json).unwrap();
assert_eq!(result.language, Some("zh".to_string()));
assert_eq!(result.segments.len(), 1);
assert_eq!(
result.segments[0].speaker_id,
Some("SPEAKER_01".to_string())
);
}
#[test]
fn test_asrx_result_empty_segments() {
let result = AsrxResult {
language: None,
segments: vec![],
};
assert!(result.segments.is_empty());
assert!(result.language.is_none());
}
#[test]
fn test_asrx_segment_times() {
let segment = AsrxSegment {
start: 0.0,
end: 5.0,
text: "Test".to_string(),
speaker_id: None,
};
assert!(segment.end > segment.start);
}
}

View File

@@ -0,0 +1,77 @@
use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use std::time::Duration;
use super::executor::PythonExecutor;
const CAPTION_TIMEOUT: Duration = Duration::from_secs(7200);
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct CaptionResult {
pub video_path: String,
pub total_frames: usize,
pub captions: Vec<FrameCaption>,
pub summary: CaptionSummary,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct FrameCaption {
pub index: u32,
pub timestamp: f64,
pub caption: String,
pub source: String,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct CaptionSummary {
pub avg_caption_length: f64,
pub gpt4v_count: usize,
pub llava_count: usize,
pub metadata_count: usize,
}
pub async fn process_caption(
video_path: &str,
output_path: &str,
uuid: Option<&str>,
) -> Result<CaptionResult> {
let executor = PythonExecutor::new()?;
let script_path = executor.script_path("caption_processor.py");
tracing::info!("[CAPTION] Starting caption generation: {}", video_path);
if !script_path.exists() {
tracing::warn!("[CAPTION] Script not found, returning empty result");
return Ok(CaptionResult {
video_path: video_path.to_string(),
total_frames: 0,
captions: vec![],
summary: CaptionSummary {
avg_caption_length: 0.0,
gpt4v_count: 0,
llava_count: 0,
metadata_count: 0,
},
});
}
executor
.run(
"caption_processor.py",
&[video_path, output_path],
uuid,
"CAPTION",
Some(CAPTION_TIMEOUT),
)
.await
.with_context(|| format!("Failed to run {:?}", script_path))?;
let json_str = std::fs::read_to_string(output_path).context("Failed to read CAPTION output")?;
let result: CaptionResult =
serde_json::from_str(&json_str).context("Failed to parse CAPTION output")?;
tracing::info!("[CAPTION] Result: {} frames captioned", result.total_frames);
Ok(result)
}

127
src/core/processor/cut.rs Normal file
View File

@@ -0,0 +1,127 @@
use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use std::time::Duration;
use super::executor::PythonExecutor;
const CUT_TIMEOUT: Duration = Duration::from_secs(3600);
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct CutResult {
pub frame_count: u64,
pub fps: f64,
pub scenes: Vec<CutScene>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct CutScene {
pub scene_number: u32,
pub start_frame: u64,
pub end_frame: u64,
pub start_time: f64,
pub end_time: f64,
}
pub async fn process_cut(
video_path: &str,
output_path: &str,
uuid: Option<&str>,
) -> Result<CutResult> {
let executor = PythonExecutor::new()?;
let script_path = executor.script_path("cut_processor.py");
tracing::info!("[CUT] Starting scene detection: {}", video_path);
if !script_path.exists() {
tracing::warn!("[CUT] Script not found, returning empty result");
return Ok(CutResult {
frame_count: 0,
fps: 0.0,
scenes: vec![],
});
}
executor
.run(
"cut_processor.py",
&[video_path, output_path],
uuid,
"CUT",
Some(CUT_TIMEOUT),
)
.await
.with_context(|| format!("Failed to run {:?}", script_path))?;
let json_str = std::fs::read_to_string(output_path).context("Failed to read CUT output")?;
let result: CutResult =
serde_json::from_str(&json_str).context("Failed to parse CUT output")?;
tracing::info!("[CUT] Result: {} scenes detected", result.scenes.len());
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cut_result_serialization() {
let result = CutResult {
frame_count: 100,
fps: 30.0,
scenes: vec![CutScene {
scene_number: 1,
start_frame: 0,
end_frame: 30,
start_time: 0.0,
end_time: 1.0,
}],
};
let json = serde_json::to_string(&result).unwrap();
assert!(json.contains("scene_number"));
assert!(json.contains("1"));
}
#[test]
fn test_cut_result_deserialization() {
let json = r#"{
"frame_count": 100,
"fps": 30.0,
"scenes": [
{"scene_number": 1, "start_frame": 0, "end_frame": 30, "start_time": 0.0, "end_time": 1.0},
{"scene_number": 2, "start_frame": 31, "end_frame": 60, "start_time": 1.033, "end_time": 2.0}
]
}"#;
let result: CutResult = serde_json::from_str(json).unwrap();
assert_eq!(result.frame_count, 100);
assert_eq!(result.scenes.len(), 2);
assert_eq!(result.scenes[1].scene_number, 2);
}
#[test]
fn test_cut_result_empty_scenes() {
let result = CutResult {
frame_count: 0,
fps: 0.0,
scenes: vec![],
};
assert!(result.scenes.is_empty());
}
#[test]
fn test_cut_scene_times() {
let scene = CutScene {
scene_number: 1,
start_frame: 0,
end_frame: 30,
start_time: 0.0,
end_time: 1.0,
};
assert!(scene.end_time > scene.start_time);
assert_eq!(scene.scene_number, 1);
}
}

View File

@@ -0,0 +1,395 @@
use anyhow::{Context, Result};
use std::path::PathBuf;
use std::process::Stdio;
use std::time::Duration;
use tokio::io::{AsyncBufReadExt, BufReader};
use tokio::process::Command;
use tokio::time::{sleep, timeout};
#[derive(Debug, Clone)]
pub struct RetryConfig {
pub max_attempts: u32,
pub initial_delay_ms: u64,
pub max_delay_ms: u64,
pub backoff_multiplier: f64,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_attempts: 3,
initial_delay_ms: 1000,
max_delay_ms: 30000,
backoff_multiplier: 2.0,
}
}
}
impl RetryConfig {
pub fn new(max_attempts: u32) -> Self {
Self {
max_attempts,
..Default::default()
}
}
pub fn with_delay(mut self, delay_ms: u64) -> Self {
self.initial_delay_ms = delay_ms;
self
}
pub fn with_max_delay(mut self, max_delay_ms: u64) -> Self {
self.max_delay_ms = max_delay_ms;
self
}
}
pub fn validate_python_env() -> Result<()> {
let manifest = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
let venv_python = manifest.join("venv").join("bin").join("python");
if !venv_python.exists() {
anyhow::bail!(
"Python venv not found at {:?}\n\
Run: /opt/homebrew/bin/python3.11 -m venv venv",
venv_python
);
}
let rt = tokio::runtime::Runtime::new()?;
let output = rt
.block_on(async { Command::new(&venv_python).arg("--version").output().await })
.context("Failed to run Python")?;
if !output.status.success() {
anyhow::bail!("Python validation failed");
}
let version = String::from_utf8_lossy(&output.stdout);
tracing::info!("Python version: {}", version.trim());
if !version.contains("3.11") {
tracing::warn!("Expected Python 3.11, got: {}", version.trim());
}
let script_path = manifest.join("scripts");
if !script_path.exists() {
anyhow::bail!("Scripts directory not found at {:?}", script_path);
}
tracing::info!("Python environment validated successfully");
Ok(())
}
pub struct PythonExecutor {
venv_python: PathBuf,
scripts_dir: PathBuf,
}
impl PythonExecutor {
pub fn new() -> Result<Self> {
let manifest = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
let venv_python = manifest.join("venv").join("bin").join("python");
let scripts_dir = manifest.join("scripts");
if !venv_python.exists() {
anyhow::bail!(
"Python venv not found at {:?}. Run: /opt/homebrew/bin/python3.11 -m venv venv",
venv_python
);
}
if !scripts_dir.exists() {
anyhow::bail!("Scripts directory not found at {:?}", scripts_dir);
}
Ok(Self {
venv_python,
scripts_dir,
})
}
pub fn validate_env(&self) -> Result<()> {
let rt = tokio::runtime::Runtime::new()?;
let output = rt
.block_on(async {
Command::new(&self.venv_python)
.arg("--version")
.output()
.await
})
.context("Failed to run Python")?;
if !output.status.success() {
anyhow::bail!("Python validation failed");
}
let version = String::from_utf8_lossy(&output.stdout);
if !version.contains("3.11") {
tracing::warn!("Python version mismatch: {}", version);
}
Ok(())
}
pub async fn run(
&self,
script_name: &str,
args: &[&str],
uuid: Option<&str>,
log_prefix: &str,
timeout_duration: Option<Duration>,
) -> Result<()> {
let script_path = self.scripts_dir.join(script_name);
if !script_path.exists() {
anyhow::bail!("Script not found: {:?}", script_path);
}
let mut cmd = Command::new(&self.venv_python);
cmd.arg(&script_path);
for arg in args {
cmd.arg(arg);
}
if let Some(u) = uuid {
cmd.arg("--uuid").arg(u);
}
cmd.stdout(Stdio::piped());
cmd.stderr(Stdio::piped());
tracing::info!("[{}] Starting: {:?}", log_prefix, script_name);
let mut child = cmd
.spawn()
.with_context(|| format!("Failed to run {}", script_name))?;
let stdout = child.stdout.take().context("Failed to capture stdout")?;
let stderr = child.stderr.take().context("Failed to capture stderr")?;
let mut stdout_reader = BufReader::new(stdout).lines();
let mut stderr_reader = BufReader::new(stderr).lines();
let run_future = async {
loop {
tokio::select! {
line = stdout_reader.next_line() => {
match line {
Ok(Some(line)) => {
if line.starts_with(&format!("{}_", log_prefix)) {
tracing::info!("[{}] {}", log_prefix, line);
}
}
Ok(None) => break,
Err(e) => tracing::warn!("[{}] stdout error: {}", log_prefix, e),
}
}
line = stderr_reader.next_line() => {
match line {
Ok(Some(line)) => {
if line.starts_with(&format!("{}_", log_prefix)) {
tracing::info!("[{}] {}", log_prefix, line);
}
}
Ok(None) => {}
Err(e) => tracing::warn!("[{}] stderr error: {}", log_prefix, e),
}
}
status = child.wait() => {
match status {
Ok(status) => {
if !status.success() {
tracing::error!("[{}] Process failed: {}", log_prefix, status);
return Err(anyhow::anyhow!("{} exited with: {}", script_name, status));
}
tracing::info!("[{}] Completed successfully", log_prefix);
}
Err(e) => tracing::error!("[{}] wait error: {}", log_prefix, e),
}
break;
}
}
}
Ok(())
};
if let Some(duration) = timeout_duration {
match timeout(duration, run_future).await {
Ok(Ok(())) => {}
Ok(Err(e)) => return Err(e),
Err(_) => {
child.kill().await.context("Failed to kill process")?;
anyhow::bail!("{} timed out after {:?}", script_name, duration);
}
}
} else {
run_future.await?;
}
Ok(())
}
pub async fn run_with_output(
&self,
script_name: &str,
args: &[&str],
uuid: Option<&str>,
log_prefix: &str,
timeout_duration: Option<Duration>,
) -> Result<()> {
self.run(script_name, args, uuid, log_prefix, timeout_duration)
.await
}
pub async fn run_with_retry(
&self,
script_name: &str,
args: &[&str],
uuid: Option<&str>,
log_prefix: &str,
timeout_duration: Option<Duration>,
retry_config: Option<RetryConfig>,
) -> Result<()> {
let config = retry_config.unwrap_or_default();
let mut attempt = 0;
let mut delay_ms = config.initial_delay_ms;
loop {
attempt += 1;
match self
.run(script_name, args, uuid, log_prefix, timeout_duration)
.await
{
Ok(()) => {
if attempt > 1 {
tracing::info!(
"[{}] Succeeded on attempt {}/{}",
log_prefix,
attempt,
config.max_attempts
);
}
return Ok(());
}
Err(e) => {
if attempt >= config.max_attempts {
tracing::error!(
"[{}] Failed after {} attempts: {}",
log_prefix,
attempt,
e
);
return Err(e);
}
tracing::warn!(
"[{}] Attempt {}/{} failed: {}. Retrying in {}ms...",
log_prefix,
attempt,
config.max_attempts,
e,
delay_ms
);
sleep(Duration::from_millis(delay_ms)).await;
delay_ms = (delay_ms as f64 * config.backoff_multiplier) as u64;
delay_ms = delay_ms.min(config.max_delay_ms);
}
}
}
}
pub fn script_path(&self, script_name: &str) -> PathBuf {
self.scripts_dir.join(script_name)
}
pub fn python_path(&self) -> &PathBuf {
&self.venv_python
}
}
impl Default for PythonExecutor {
fn default() -> Self {
Self::new().expect("Failed to create PythonExecutor")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_python_executor_new_with_venv() {
let executor = PythonExecutor::new();
assert!(
executor.is_ok(),
"PythonExecutor should create successfully with venv"
);
}
#[test]
fn test_python_executor_paths() {
let executor = PythonExecutor::new().unwrap();
let python_path = executor.python_path();
assert!(
python_path.exists(),
"Python path should exist: {:?}",
python_path
);
assert!(
python_path.to_string_lossy().contains("venv"),
"Should be in venv"
);
}
#[test]
fn test_script_path() {
let executor = PythonExecutor::new().unwrap();
let script_path = executor.script_path("asr_processor.py");
assert!(script_path.to_string_lossy().contains("scripts"));
assert!(script_path.to_string_lossy().contains("asr_processor.py"));
}
#[test]
fn test_script_path_nonexistent() {
let executor = PythonExecutor::new().unwrap();
let path = executor.script_path("nonexistent_script.py");
assert!(!path.exists(), "Nonexistent script path should not exist");
}
#[test]
fn test_python_path_is_executable() {
let executor = PythonExecutor::new().unwrap();
let path = executor.python_path();
let metadata = std::fs::metadata(path);
assert!(metadata.is_ok(), "Python path should be accessible");
}
#[test]
fn test_retry_config_default() {
let config = RetryConfig::default();
assert_eq!(config.max_attempts, 3);
assert_eq!(config.initial_delay_ms, 1000);
assert_eq!(config.max_delay_ms, 30000);
assert_eq!(config.backoff_multiplier, 2.0);
}
#[test]
fn test_retry_config_builder() {
let config = RetryConfig::new(5).with_delay(2000).with_max_delay(60000);
assert_eq!(config.max_attempts, 5);
assert_eq!(config.initial_delay_ms, 2000);
assert_eq!(config.max_delay_ms, 60000);
}
#[tokio::test]
async fn test_retry_config_clone() {
let config = RetryConfig::default();
let cloned = config.clone();
assert_eq!(cloned.max_attempts, config.max_attempts);
}
}

View File

@@ -1,36 +1,145 @@
use anyhow::Result;
use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use std::time::Duration;
#[derive(Debug, Serialize, Deserialize)]
use super::executor::PythonExecutor;
const FACE_TIMEOUT: Duration = Duration::from_secs(7200);
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct FaceResult {
pub frame_count: u64,
pub fps: f64,
pub frames: Vec<FaceFrame>,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct FaceFrame {
pub frame: u64,
pub timestamp: f64,
pub faces: Vec<Face>,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Face {
pub face_id: String,
pub face_id: Option<String>,
pub x: i32,
pub y: i32,
pub width: i32,
pub height: i32,
pub confidence: f32,
pub embedding: Option<Vec<f32>>,
}
pub async fn process_face(video_path: &str, output_path: &str) -> Result<FaceResult> {
// TODO: Implement face detection
// Options:
// 1. Use MTCNN or RetinaFace with ONNX
// 2. Use Python subprocess
pub async fn process_face(
video_path: &str,
output_path: &str,
uuid: Option<&str>,
) -> Result<FaceResult> {
let executor = PythonExecutor::new()?;
let script_path = executor.script_path("face_processor.py");
println!("Processing face detection for: {}", video_path);
tracing::info!("[FACE] Starting face detection: {}", video_path);
Ok(FaceResult { frames: vec![] })
if !script_path.exists() {
tracing::warn!("[FACE] Script not found, returning empty result");
return Ok(FaceResult {
frame_count: 0,
fps: 0.0,
frames: vec![],
});
}
executor
.run(
"face_processor.py",
&[video_path, output_path],
uuid,
"FACE",
Some(FACE_TIMEOUT),
)
.await
.with_context(|| format!("Failed to run {:?}", script_path))?;
let json_str = std::fs::read_to_string(output_path).context("Failed to read FACE output")?;
let result: FaceResult =
serde_json::from_str(&json_str).context("Failed to parse FACE output")?;
tracing::info!("[FACE] Result: {} frames", result.frames.len());
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_face_result_serialization() {
let result = FaceResult {
frame_count: 100,
fps: 30.0,
frames: vec![FaceFrame {
frame: 0,
timestamp: 0.0,
faces: vec![Face {
face_id: Some("face_1".to_string()),
x: 100,
y: 100,
width: 50,
height: 60,
confidence: 0.95,
}],
}],
};
let json = serde_json::to_string(&result).unwrap();
assert!(json.contains("face_1"));
assert!(json.contains("\"width\":50"));
}
#[test]
fn test_face_result_deserialization() {
let json = r#"{
"frame_count": 50,
"fps": 25.0,
"frames": [
{
"frame": 30,
"timestamp": 1.2,
"faces": [
{"face_id": "f1", "x": 10, "y": 20, "width": 30, "height": 40, "confidence": 0.85}
]
}
]
}"#;
let result: FaceResult = serde_json::from_str(json).unwrap();
assert_eq!(result.frame_count, 50);
assert_eq!(result.frames.len(), 1);
assert_eq!(result.frames[0].faces[0].x, 10);
}
#[test]
fn test_face_result_empty_frames() {
let result = FaceResult {
frame_count: 0,
fps: 0.0,
frames: vec![],
};
assert!(result.frames.is_empty());
}
#[test]
fn test_face_confidence() {
let face = Face {
face_id: None,
x: 0,
y: 0,
width: 10,
height: 10,
confidence: 0.5,
};
assert!(face.confidence >= 0.0 && face.confidence <= 1.0);
}
}

View File

@@ -1,13 +1,21 @@
pub mod asr;
pub mod asrx;
pub mod caption;
pub mod cut;
pub mod executor;
pub mod face;
pub mod ocr;
pub mod pose;
pub mod story;
pub mod yolo;
pub use asr::{process_asr, AsrResult, AsrSegment};
pub use asrx::process_asrx;
pub use face::process_face;
pub use ocr::process_ocr;
pub use pose::process_pose;
pub use yolo::process_yolo;
pub use asrx::{process_asrx, AsrxResult, AsrxSegment};
pub use caption::{process_caption, CaptionResult, CaptionSummary, FrameCaption};
pub use cut::{process_cut, CutResult, CutScene};
pub use executor::{validate_python_env, PythonExecutor, RetryConfig};
pub use face::{process_face, Face, FaceFrame, FaceResult};
pub use ocr::{process_ocr, OcrFrame, OcrResult, OcrText};
pub use pose::{process_pose, Bbox, Keypoint, PersonPose, PoseFrame, PoseResult};
pub use story::{process_story, StoryChildChunk, StoryParentChunk, StoryResult, StoryStats};
pub use yolo::{process_yolo, YoloFrame, YoloObject, YoloResult};

View File

@@ -1,19 +1,26 @@
use anyhow::Result;
use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use std::time::Duration;
#[derive(Debug, Serialize, Deserialize)]
use super::executor::PythonExecutor;
const OCR_TIMEOUT: Duration = Duration::from_secs(7200);
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct OcrResult {
pub frame_count: u64,
pub fps: f64,
pub frames: Vec<OcrFrame>,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct OcrFrame {
pub frame: u64,
pub timestamp: f64,
pub texts: Vec<OcrText>,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct OcrText {
pub text: String,
pub x: i32,
@@ -23,14 +30,116 @@ pub struct OcrText {
pub confidence: f32,
}
pub async fn process_ocr(video_path: &str, output_path: &str) -> Result<OcrResult> {
// TODO: Implement OCR processing
// Options:
// 1. Use tesseract
// 2. Use Python pytesseract via subprocess
// 3. Use Rust OCR library
pub async fn process_ocr(
video_path: &str,
output_path: &str,
uuid: Option<&str>,
) -> Result<OcrResult> {
let executor = PythonExecutor::new()?;
let script_path = executor.script_path("ocr_processor.py");
println!("Processing OCR for: {}", video_path);
tracing::info!("[OCR] Starting text recognition: {}", video_path);
Ok(OcrResult { frames: vec![] })
if !script_path.exists() {
tracing::warn!("[OCR] Script not found, returning empty result");
return Ok(OcrResult {
frame_count: 0,
fps: 0.0,
frames: vec![],
});
}
executor
.run(
"ocr_processor.py",
&[video_path, output_path],
uuid,
"OCR",
Some(OCR_TIMEOUT),
)
.await
.with_context(|| format!("Failed to run {:?}", script_path))?;
let json_str = std::fs::read_to_string(output_path).context("Failed to read OCR output")?;
let result: OcrResult =
serde_json::from_str(&json_str).context("Failed to parse OCR output")?;
tracing::info!("[OCR] Result: {} frames", result.frames.len());
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ocr_result_serialization() {
let result = OcrResult {
frame_count: 100,
fps: 30.0,
frames: vec![OcrFrame {
frame: 0,
timestamp: 0.0,
texts: vec![OcrText {
text: "Hello".to_string(),
x: 10,
y: 20,
width: 100,
height: 30,
confidence: 0.95,
}],
}],
};
let json = serde_json::to_string(&result).unwrap();
assert!(json.contains("Hello"));
assert!(json.contains("\"x\":10"));
}
#[test]
fn test_ocr_result_deserialization() {
let json = r#"{
"frame_count": 50,
"fps": 25.0,
"frames": [
{
"frame": 30,
"timestamp": 1.2,
"texts": [
{"text": "Test", "x": 0, "y": 0, "width": 50, "height": 20, "confidence": 0.88}
]
}
]
}"#;
let result: OcrResult = serde_json::from_str(json).unwrap();
assert_eq!(result.frame_count, 50);
assert_eq!(result.frames.len(), 1);
assert_eq!(result.frames[0].texts[0].text, "Test");
}
#[test]
fn test_ocr_result_empty_frames() {
let result = OcrResult {
frame_count: 0,
fps: 0.0,
frames: vec![],
};
assert!(result.frames.is_empty());
}
#[test]
fn test_ocr_text_confidence() {
let text = OcrText {
text: "OCR".to_string(),
x: 0,
y: 0,
width: 10,
height: 10,
confidence: 0.75,
};
assert!(text.confidence >= 0.0 && text.confidence <= 1.0);
}
}

View File

@@ -1,25 +1,32 @@
use anyhow::Result;
use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use std::time::Duration;
#[derive(Debug, Serialize, Deserialize)]
use super::executor::PythonExecutor;
const POSE_TIMEOUT: Duration = Duration::from_secs(7200);
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct PoseResult {
pub frame_count: u64,
pub fps: f64,
pub frames: Vec<PoseFrame>,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct PoseFrame {
pub frame: u64,
pub timestamp: f64,
pub persons: Vec<PersonPose>,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct PersonPose {
pub keypoints: Vec<Keypoint>,
pub bbox: Bbox,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Keypoint {
pub name: String,
pub x: f32,
@@ -27,7 +34,7 @@ pub struct Keypoint {
pub confidence: f32,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Bbox {
pub x: i32,
pub y: i32,
@@ -35,13 +42,135 @@ pub struct Bbox {
pub height: i32,
}
pub async fn process_pose(video_path: &str, output_path: &str) -> Result<PoseResult> {
// TODO: Implement pose estimation
// Options:
// 1. Use MoveNet or PoseNet with ONNX
// 2. Use Python subprocess with ultralytics
pub async fn process_pose(
video_path: &str,
output_path: &str,
uuid: Option<&str>,
) -> Result<PoseResult> {
let executor = PythonExecutor::new()?;
let script_path = executor.script_path("pose_processor.py");
println!("Processing pose estimation for: {}", video_path);
tracing::info!("[POSE] Starting pose estimation: {}", video_path);
Ok(PoseResult { frames: vec![] })
if !script_path.exists() {
tracing::warn!("[POSE] Script not found, returning empty result");
return Ok(PoseResult {
frame_count: 0,
fps: 0.0,
frames: vec![],
});
}
executor
.run(
"pose_processor.py",
&[video_path, output_path],
uuid,
"POSE",
Some(POSE_TIMEOUT),
)
.await
.with_context(|| format!("Failed to run {:?}", script_path))?;
let json_str = std::fs::read_to_string(output_path).context("Failed to read POSE output")?;
let result: PoseResult =
serde_json::from_str(&json_str).context("Failed to parse POSE output")?;
tracing::info!("[POSE] Result: {} frames", result.frames.len());
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pose_result_serialization() {
let result = PoseResult {
frame_count: 100,
fps: 30.0,
frames: vec![PoseFrame {
frame: 0,
timestamp: 0.0,
persons: vec![PersonPose {
keypoints: vec![Keypoint {
name: "nose".to_string(),
x: 100.0,
y: 50.0,
confidence: 0.9,
}],
bbox: Bbox {
x: 80,
y: 30,
width: 40,
height: 80,
},
}],
}],
};
let json = serde_json::to_string(&result).unwrap();
assert!(json.contains("nose"));
assert!(json.contains("\"confidence\":0.9"));
}
#[test]
fn test_pose_result_deserialization() {
let json = r#"{
"frame_count": 50,
"fps": 25.0,
"frames": [
{
"frame": 30,
"timestamp": 1.2,
"persons": [
{
"keypoints": [{"name": "left_eye", "x": 100.5, "y": 50.2, "confidence": 0.85}],
"bbox": {"x": 90, "y": 40, "width": 20, "height": 30}
}
]
}
]
}"#;
let result: PoseResult = serde_json::from_str(json).unwrap();
assert_eq!(result.frame_count, 50);
assert_eq!(result.frames.len(), 1);
assert_eq!(result.frames[0].persons[0].keypoints[0].name, "left_eye");
}
#[test]
fn test_pose_result_empty_frames() {
let result = PoseResult {
frame_count: 0,
fps: 0.0,
frames: vec![],
};
assert!(result.frames.is_empty());
}
#[test]
fn test_keypoint_confidence() {
let kp = Keypoint {
name: "test".to_string(),
x: 0.0,
y: 0.0,
confidence: 0.75,
};
assert!(kp.confidence >= 0.0 && kp.confidence <= 1.0);
}
#[test]
fn test_bbox_dimensions() {
let bbox = Bbox {
x: 10,
y: 20,
width: 50,
height: 100,
};
assert!(bbox.width > 0);
assert!(bbox.height > 0);
}
}

250
src/core/processor/story.rs Normal file
View File

@@ -0,0 +1,250 @@
use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use std::time::Duration;
use super::executor::PythonExecutor;
const STORY_TIMEOUT: Duration = Duration::from_secs(3600);
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct StoryResult {
pub child_chunks: Vec<StoryChildChunk>,
pub parent_chunks: Vec<StoryParentChunk>,
pub stats: StoryStats,
pub metadata: serde_json::Value,
pub parent_chunk_size: usize,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct StoryStats {
pub total_child_chunks: usize,
pub total_parent_chunks: usize,
pub asr_children: usize,
pub cut_children: usize,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct StoryChildChunk {
pub chunk_id: String,
pub chunk_type: String,
pub source: String,
pub start_time: f64,
pub end_time: f64,
pub text_content: Option<String>,
pub content: serde_json::Value,
pub child_chunk_ids: Vec<String>,
pub parent_chunk_id: Option<String>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct StoryParentChunk {
pub chunk_id: String,
pub chunk_type: String,
pub source: String,
pub start_time: f64,
pub end_time: f64,
pub text_content: String,
pub content: serde_json::Value,
pub child_chunk_ids: Vec<String>,
pub parent_chunk_id: Option<String>,
}
pub async fn process_story(
video_path: &str,
output_path: &str,
uuid: Option<&str>,
) -> Result<StoryResult> {
let executor = PythonExecutor::new()?;
let script_path = executor.script_path("story_processor.py");
tracing::info!("[STORY] Starting story generation: {}", video_path);
if !script_path.exists() {
tracing::warn!("[STORY] Script not found, returning empty result");
return Ok(StoryResult {
child_chunks: vec![],
parent_chunks: vec![],
stats: StoryStats {
total_child_chunks: 0,
total_parent_chunks: 0,
asr_children: 0,
cut_children: 0,
},
metadata: serde_json::json!({}),
parent_chunk_size: 5,
});
}
executor
.run(
"story_processor.py",
&[video_path, output_path],
uuid,
"STORY",
Some(STORY_TIMEOUT),
)
.await
.with_context(|| format!("Failed to run {:?}", script_path))?;
let json_str = std::fs::read_to_string(output_path).context("Failed to read STORY output")?;
let result: StoryResult =
serde_json::from_str(&json_str).context("Failed to parse STORY output")?;
tracing::info!(
"[STORY] Result: {} parent chunks, {} child chunks",
result.stats.total_parent_chunks,
result.stats.total_child_chunks
);
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_story_result_serialization() {
let result = StoryResult {
child_chunks: vec![StoryChildChunk {
chunk_id: "asr_0001".to_string(),
chunk_type: "sentence".to_string(),
source: "asr".to_string(),
start_time: 0.0,
end_time: 5.0,
text_content: Some("Hello world".to_string()),
content: serde_json::json!({}),
child_chunk_ids: vec![],
parent_chunk_id: Some("story_asr_0000".to_string()),
}],
parent_chunks: vec![StoryParentChunk {
chunk_id: "story_asr_0000".to_string(),
chunk_type: "story".to_string(),
source: "story_asr".to_string(),
start_time: 0.0,
end_time: 25.0,
text_content: "[0s-25s] Hello world...".to_string(),
content: serde_json::json!({
"description": "[0s-25s] Hello world...",
"child_count": 5
}),
child_chunk_ids: vec!["asr_0001".to_string()],
parent_chunk_id: None,
}],
stats: StoryStats {
total_child_chunks: 10,
total_parent_chunks: 2,
asr_children: 10,
cut_children: 0,
},
metadata: serde_json::json!({}),
parent_chunk_size: 5,
};
let json = serde_json::to_string(&result).unwrap();
assert!(json.contains("asr_0001"));
assert!(json.contains("story_asr_0000"));
assert!(json.contains("Hello world"));
}
#[test]
fn test_story_result_deserialization() {
let json = r#"{
"child_chunks": [{
"chunk_id": "asr_0001",
"chunk_type": "sentence",
"source": "asr",
"start_time": 0.0,
"end_time": 5.0,
"text_content": "Hello",
"content": {},
"child_chunk_ids": [],
"parent_chunk_id": null
}],
"parent_chunks": [{
"chunk_id": "story_asr_0000",
"chunk_type": "story",
"source": "story_asr",
"start_time": 0.0,
"end_time": 5.0,
"text_content": "Hello segment",
"content": {"description": "Hello segment"},
"child_chunk_ids": ["asr_0001"],
"parent_chunk_id": null
}],
"stats": {
"total_child_chunks": 1,
"total_parent_chunks": 1,
"asr_children": 1,
"cut_children": 0
},
"metadata": {},
"parent_chunk_size": 5
}"#;
let result: StoryResult = serde_json::from_str(json).unwrap();
assert_eq!(result.child_chunks.len(), 1);
assert_eq!(result.parent_chunks.len(), 1);
assert_eq!(result.stats.total_child_chunks, 1);
assert_eq!(result.stats.total_parent_chunks, 1);
assert_eq!(result.parent_chunks[0].child_chunk_ids[0], "asr_0001");
assert_eq!(result.child_chunks[0].parent_chunk_id, None);
}
#[test]
fn test_parent_child_relationship() {
let result = StoryResult {
child_chunks: vec![
StoryChildChunk {
chunk_id: "asr_0001".to_string(),
chunk_type: "sentence".to_string(),
source: "asr".to_string(),
start_time: 0.0,
end_time: 5.0,
text_content: Some("First".to_string()),
content: serde_json::json!({}),
child_chunk_ids: vec![],
parent_chunk_id: Some("story_asr_0000".to_string()),
},
StoryChildChunk {
chunk_id: "asr_0002".to_string(),
chunk_type: "sentence".to_string(),
source: "asr".to_string(),
start_time: 5.0,
end_time: 10.0,
text_content: Some("Second".to_string()),
content: serde_json::json!({}),
child_chunk_ids: vec![],
parent_chunk_id: Some("story_asr_0000".to_string()),
},
],
parent_chunks: vec![StoryParentChunk {
chunk_id: "story_asr_0000".to_string(),
chunk_type: "story".to_string(),
source: "story_asr".to_string(),
start_time: 0.0,
end_time: 10.0,
text_content: "Combined narrative".to_string(),
content: serde_json::json!({}),
child_chunk_ids: vec!["asr_0001".to_string(), "asr_0002".to_string()],
parent_chunk_id: None,
}],
stats: StoryStats {
total_child_chunks: 2,
total_parent_chunks: 1,
asr_children: 2,
cut_children: 0,
},
metadata: serde_json::json!({}),
parent_chunk_size: 5,
};
assert_eq!(result.parent_chunks[0].child_chunk_ids.len(), 2);
assert!(result
.child_chunks
.iter()
.all(|c| c.parent_chunk_id.is_some()));
assert!(result.parent_chunks[0].parent_chunk_id.is_none());
}
}

View File

@@ -1,19 +1,26 @@
use anyhow::Result;
use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use std::time::Duration;
#[derive(Debug, Serialize, Deserialize)]
use super::executor::PythonExecutor;
const YOLO_TIMEOUT: Duration = Duration::from_secs(7200);
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct YoloResult {
pub frame_count: u64,
pub fps: f64,
pub frames: Vec<YoloFrame>,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct YoloFrame {
pub frame: u64,
pub timestamp: f64,
pub objects: Vec<YoloObject>,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct YoloObject {
pub class_name: String,
pub class_id: u32,
@@ -24,13 +31,123 @@ pub struct YoloObject {
pub confidence: f32,
}
pub async fn process_yolo(video_path: &str, output_path: &str) -> Result<YoloResult> {
// TODO: Implement YOLO processing
// Options:
// 1. Use ONNX Runtime (ort) with YOLO model
// 2. Use Python subprocess with ultralytics
pub async fn process_yolo(
video_path: &str,
output_path: &str,
uuid: Option<&str>,
) -> Result<YoloResult> {
let executor = PythonExecutor::new()?;
let script_path = executor.script_path("yolo_processor.py");
println!("Processing YOLO for: {}", video_path);
tracing::info!("[YOLO] Starting object detection: {}", video_path);
Ok(YoloResult { frames: vec![] })
if !script_path.exists() {
tracing::warn!("[YOLO] Script not found, returning empty result");
return Ok(YoloResult {
frame_count: 0,
fps: 0.0,
frames: vec![],
});
}
executor
.run(
"yolo_processor.py",
&[video_path, output_path],
uuid,
"YOLO",
Some(YOLO_TIMEOUT),
)
.await
.with_context(|| format!("Failed to run {:?}", script_path))?;
let json_str = std::fs::read_to_string(output_path).context("Failed to read YOLO output")?;
let result: YoloResult =
serde_json::from_str(&json_str).context("Failed to parse YOLO output")?;
tracing::info!(
"[YOLO] Result: {} frames, {:.2} fps",
result.frame_count,
result.fps
);
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_yolo_result_serialization() {
let result = YoloResult {
frame_count: 100,
fps: 30.0,
frames: vec![YoloFrame {
frame: 0,
timestamp: 0.0,
objects: vec![YoloObject {
class_name: "person".to_string(),
class_id: 0,
x: 100,
y: 200,
width: 50,
height: 100,
confidence: 0.95,
}],
}],
};
let json = serde_json::to_string(&result).unwrap();
assert!(json.contains("person"));
assert!(json.contains("100"));
}
#[test]
fn test_yolo_result_deserialization() {
let json = r#"{
"frame_count": 50,
"fps": 25.0,
"frames": [
{
"frame": 10,
"timestamp": 0.4,
"objects": [
{"class_name": "car", "class_id": 2, "x": 0, "y": 0, "width": 100, "height": 80, "confidence": 0.87}
]
}
]
}"#;
let result: YoloResult = serde_json::from_str(json).unwrap();
assert_eq!(result.frame_count, 50);
assert_eq!(result.fps, 25.0);
assert_eq!(result.frames.len(), 1);
assert_eq!(result.frames[0].objects[0].class_name, "car");
}
#[test]
fn test_yolo_object_confidence_range() {
let obj = YoloObject {
class_name: "test".to_string(),
class_id: 0,
x: 0,
y: 0,
width: 10,
height: 10,
confidence: 0.5,
};
assert!(obj.confidence >= 0.0 && obj.confidence <= 1.0);
}
#[test]
fn test_yolo_result_empty_frames() {
let result = YoloResult {
frame_count: 0,
fps: 0.0,
frames: vec![],
};
assert!(result.frames.is_empty());
}
}

View File

@@ -1,5 +1,7 @@
pub mod file_manager;
pub mod output_dir;
pub mod uuid;
pub use file_manager::FileManager;
pub use output_dir::OutputDir;
pub use uuid::compute_uuid;

View File

@@ -0,0 +1,226 @@
use anyhow::{Context, Result};
use chrono::{DateTime, Datelike, Local, Timelike};
use std::path::{Path, PathBuf};
pub struct OutputDir {
base_path: PathBuf,
backup_enabled: bool,
backup_dir: PathBuf,
}
impl OutputDir {
pub fn new() -> Self {
let base_path = std::env::var("MOMENTRY_OUTPUT_DIR")
.map(PathBuf::from)
.unwrap_or_else(|_| PathBuf::from("./output"));
let backup_enabled = std::env::var("MOMENTRY_BACKUP_ENABLED")
.map(|v| v == "true")
.unwrap_or(false);
let backup_dir = std::env::var("MOMENTRY_BACKUP_DIR")
.map(PathBuf::from)
.unwrap_or_else(|_| PathBuf::from("/Users/accusys/momentry/backup/momentry"));
Self {
base_path,
backup_enabled,
backup_dir,
}
}
pub fn get_base_path(&self) -> &Path {
&self.base_path
}
pub fn get_backup_dir(&self) -> &Path {
&self.backup_dir
}
pub fn ensure_dir(&self) -> Result<()> {
std::fs::create_dir_all(&self.base_path).context(format!(
"Failed to create output directory: {:?}",
self.base_path
))?;
if self.backup_enabled {
std::fs::create_dir_all(&self.backup_dir).context(format!(
"Failed to create backup directory: {:?}",
self.backup_dir
))?;
}
Ok(())
}
pub fn get_output_path(&self, uuid: &str, extension: &str) -> PathBuf {
self.base_path.join(format!("{}.{}", uuid, extension))
}
fn get_timestamp() -> String {
let now = Local::now();
format!(
"{:04}{:02}{:02}_{:02}{:02}{:02}",
now.year(),
now.month(),
now.day(),
now.hour(),
now.minute(),
now.second()
)
}
pub fn get_backup_path(&self, uuid: &str, extension: &str) -> Option<PathBuf> {
if !self.backup_enabled {
return None;
}
let timestamp = Self::get_timestamp();
let filename = format!("momentry_data_{}_{}.{}", timestamp, uuid, extension);
Some(self.backup_dir.join(filename))
}
pub fn backup_file(&self, uuid: &str, extension: &str) -> Result<Option<PathBuf>> {
if !self.backup_enabled {
return Ok(None);
}
let source = self.get_output_path(uuid, extension);
if !source.exists() {
return Ok(None);
}
let backup_path = match self.get_backup_path(uuid, extension) {
Some(path) => path,
None => return Ok(None),
};
std::fs::copy(&source, &backup_path)
.context(format!("Failed to backup file to: {:?}", backup_path))?;
let sha256_path = backup_path.with_extension(format!("{}.sha256", extension));
let source_content = std::fs::read(&source)?;
let hash = format!("{:x}", md5::compute(&source_content));
std::fs::write(
&sha256_path,
format!(
"{} {}\n",
hash,
backup_path.file_name().unwrap().to_string_lossy()
),
)?;
Ok(Some(backup_path))
}
pub fn cleanup_old_backups(&self, days: u32) -> Result<u32> {
if !self.backup_enabled {
return Ok(0);
}
if !self.backup_dir.exists() {
return Ok(0);
}
let cutoff = Local::now() - chrono::Duration::days(days as i64);
let mut deleted_count = 0;
for entry in std::fs::read_dir(&self.backup_dir)? {
let entry = entry?;
let path = entry.path();
if !path.is_file() {
continue;
}
if let Some(name) = path.file_name() {
let name_str = name.to_string_lossy();
if name_str.starts_with("momentry_data_") && name_str.len() == 43 {
let date_part = &name_str[14..22];
if let Ok(date) =
DateTime::parse_from_str(&format!("{} 000000", date_part), "%Y%m%d %H%M%S")
{
if date.with_timezone(&Local) < cutoff {
std::fs::remove_file(&path)?;
deleted_count += 1;
let sha256_path = path.with_extension("sha256");
if sha256_path.exists() {
let _ = std::fs::remove_file(sha256_path);
}
}
}
}
}
}
Ok(deleted_count)
}
pub fn list_backups(&self) -> Result<Vec<BackupInfo>> {
if !self.backup_dir.exists() {
return Ok(vec![]);
}
let mut backups = Vec::new();
for entry in std::fs::read_dir(&self.backup_dir)? {
let entry = entry?;
let path = entry.path();
if !path.is_file() {
continue;
}
if let Some(name) = path.file_name() {
let name_str = name.to_string_lossy();
if name_str.starts_with("momentry_data_") && name_str.ends_with(".sha256") {
continue;
}
if name_str.starts_with("momentry_data_") {
let date_part = &name_str[14..22];
backups.push(BackupInfo {
filename: name_str.to_string(),
date: date_part.to_string(),
path: path.clone(),
});
}
}
}
backups.sort_by(|a, b| b.date.cmp(&a.date));
Ok(backups)
}
pub fn verify_backup(&self, backup_path: &Path) -> Result<bool> {
let sha256_path = backup_path.with_extension("sha256");
if !sha256_path.exists() {
return Ok(false);
}
let sha256_content = std::fs::read_to_string(&sha256_path)?;
let expected_hash = sha256_content.split_whitespace().next().unwrap_or("");
let source_content = std::fs::read(backup_path)?;
let actual_hash = format!("{:x}", md5::compute(&source_content));
Ok(expected_hash == actual_hash)
}
}
impl Default for OutputDir {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct BackupInfo {
pub filename: String,
pub date: String,
pub path: PathBuf,
}

View File

@@ -25,39 +25,36 @@ pub fn compute_uuid_from_path(full_path: &str) -> String {
compute_uuid(&parent, &filename)
}
/// Extract relative path from full path given data root
/// Returns (relative_path, filename)
pub fn extract_relative_path(full_path: &str, data_root: &str) -> (String, String) {
let full_path = PathBuf::from(full_path);
let data_root = PathBuf::from(data_root);
/// Extract username and filepath from relative path
/// Input: ./demo/video.mp4 or ./demo/path/to/video.mp4
/// Returns: (username, filepath) e.g., ("demo", "video.mp4") or ("demo", "path/to/video.mp4")
pub fn extract_user_from_relative_path(relative_path: &str) -> (String, String) {
// Remove leading ./
let path = relative_path.strip_prefix("./").unwrap_or(relative_path);
// Canonicalize both paths
let full_canonical = full_path.canonicalize().unwrap_or(full_path.clone());
let root_canonical = data_root.canonicalize().unwrap_or(data_root.clone());
let path_buf = PathBuf::from(path);
// Try to strip the data root prefix
let relative = full_canonical
.strip_prefix(&root_canonical)
.unwrap_or(&full_canonical);
// Separate into parent directory and filename
let filename = relative
.file_name()
.map(|n| n.to_string_lossy().to_string())
// First component is username
let mut components = path_buf.components();
let username = components
.next()
.map(|c| c.as_os_str().to_string_lossy().to_string())
.unwrap_or_default();
let parent = relative
.parent()
.map(|p| p.to_string_lossy().to_string())
.unwrap_or_default();
// Remaining path (filepath)
let filepath: String = components
.map(|c| c.as_os_str().to_string_lossy().to_string())
.collect::<Vec<_>>()
.join("/");
(parent, filename)
(username, filepath)
}
/// Compute UUID from full path using data root for relative path extraction
pub fn compute_uuid_from_path_with_root(full_path: &str, data_root: &str) -> String {
let (parent, filename) = extract_relative_path(full_path, data_root);
compute_uuid(&parent, &filename)
/// Compute UUID from relative path (like ./demo/video.mp4)
/// The username is extracted from the first path component
pub fn compute_uuid_from_relative_path(relative_path: &str) -> String {
let (username, filepath) = extract_user_from_relative_path(relative_path);
compute_uuid(&username, &filepath)
}
#[cfg(test)]
@@ -78,24 +75,26 @@ mod tests {
}
#[test]
fn test_relative_path_extraction() {
let (parent, filename) =
extract_relative_path("/data/sftpgo/data/demo/video.mp4", "/data/sftpgo/data");
assert_eq!(parent, "demo");
assert_eq!(filename, "video.mp4");
fn test_extract_user_from_relative_path() {
let (username, filepath) = extract_user_from_relative_path("./demo/video.mp4");
assert_eq!(username, "demo");
assert_eq!(filepath, "video.mp4");
let (username, filepath) = extract_user_from_relative_path("./demo/path/to/video.mp4");
assert_eq!(username, "demo");
assert_eq!(filepath, "path/to/video.mp4");
}
#[test]
fn test_uuid_with_data_root() {
let uuid1 = compute_uuid_from_path_with_root(
"/data/sftpgo/data/demo/video.mp4",
"/data/sftpgo/data",
);
let uuid2 = compute_uuid_from_path_with_root(
"/data/sftpgo/data/demo/video.mp4",
"/data/sftpgo/data",
);
fn test_uuid_from_relative_path() {
let uuid1 = compute_uuid_from_relative_path("./demo/video.mp4");
let uuid2 = compute_uuid_from_relative_path("./demo/video.mp4");
assert_eq!(uuid1, uuid2);
assert_eq!(uuid1.len(), 16);
// Different users with same filename should have different UUIDs
let uuid_demo = compute_uuid_from_relative_path("./demo/video.mp4");
let uuid_warren = compute_uuid_from_relative_path("./warren/video.mp4");
assert_ne!(uuid_demo, uuid_warren);
}
}