feat: backup architecture docs, source code, and scripts

This commit is contained in:
Warren
2026-04-25 17:15:45 +08:00
parent 59809dae1f
commit 1f84e5469f
368 changed files with 146329 additions and 261 deletions

936
src/api/face_recognition.rs Normal file
View File

@@ -0,0 +1,936 @@
use axum::{
extract::{Multipart, Path, Query, State},
http::StatusCode,
response::Json,
routing::{get, post},
Router,
};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::core::db::{schema, Database, PostgresDb};
use crate::core::processor::face_recognition::{
process_face_recognition, register_face, FaceRecognitionResult, FaceRegistrationResult,
};
#[derive(Debug, Deserialize)]
pub struct FaceRecognitionRequest {
pub video_uuid: String,
pub enable_recognition: Option<bool>,
pub enable_tracking: Option<bool>,
pub enable_clustering: Option<bool>,
pub database_path: Option<String>,
}
#[derive(Debug, Serialize)]
pub struct FaceRecognitionResponse {
pub success: bool,
pub message: String,
pub result: Option<FaceRecognitionResult>,
pub processing_id: String,
}
#[derive(Debug, Deserialize)]
pub struct FaceRegistrationRequest {
pub video_uuid: String,
pub name: String,
pub metadata: Option<serde_json::Value>,
}
#[derive(Debug, Serialize)]
pub struct FaceRegistrationApiResponse {
pub success: bool,
pub message: String,
pub result: Option<FaceRegistrationResult>,
}
#[derive(Debug, Deserialize)]
pub struct FaceSearchRequest {
pub video_uuid: String,
pub embedding: Vec<f32>,
pub similarity_threshold: Option<f64>,
pub limit: Option<i32>,
}
#[derive(Debug, Serialize)]
pub struct FaceSearchResponse {
pub success: bool,
pub message: String,
pub results: Vec<FaceSearchResult>,
}
#[derive(Debug, Serialize)]
pub struct FaceSearchResult {
pub face_id: String,
pub name: Option<String>,
pub similarity: f64,
pub attributes: Option<serde_json::Value>,
pub metadata: Option<serde_json::Value>,
}
#[derive(Debug, Deserialize)]
pub struct FaceListQuery {
pub video_uuid: String,
pub page: Option<usize>,
pub page_size: Option<usize>,
pub active_only: Option<bool>,
}
#[derive(Debug, Serialize)]
pub struct FaceListResponse {
pub success: bool,
pub message: String,
pub faces: Vec<FaceListItem>,
pub count: i64,
pub page: usize,
pub page_size: usize,
}
#[derive(Debug, Serialize)]
pub struct FaceListItem {
pub face_id: String,
pub name: Option<String>,
pub created_at: String,
pub updated_at: String,
pub is_active: bool,
pub metadata: Option<serde_json::Value>,
}
pub fn face_recognition_routes() -> Router<crate::api::server::AppState> {
Router::new()
.route("/api/v1/face/recognize", post(recognize_faces))
.route("/api/v1/face/register", post(register_face_api))
.route("/api/v1/face/search", post(search_faces))
.route("/api/v1/face/list", get(list_faces))
.route("/api/v1/face/:face_id", get(get_face_details))
.route("/api/v1/face/:face_id", axum::routing::delete(delete_face))
.route(
"/api/v1/face/results/:video_uuid",
get(get_recognition_results),
)
}
async fn recognize_faces(
State(_state): State<crate::api::server::AppState>,
Json(request): Json<FaceRecognitionRequest>,
) -> Result<Json<FaceRecognitionResponse>, (StatusCode, String)> {
let processing_id = Uuid::new_v4().to_string();
tracing::info!(
"[FACE_RECOGNITION] Starting recognition for video: {}, processing_id: {}",
request.video_uuid,
processing_id
);
// Get video path from database
let db = match PostgresDb::init().await {
Ok(db) => db,
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to connect to database: {}", e),
))
}
};
let video_record = match db.get_video_by_uuid(&request.video_uuid).await {
Ok(Some(record)) => record,
Ok(None) => {
return Err((
StatusCode::NOT_FOUND,
format!("Video not found: {}", request.video_uuid),
))
}
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to fetch video: {}", e),
))
}
};
let video_path = video_record.file_path;
let output_path = format!(
"{}/face_recognition_{}.json",
crate::core::config::OUTPUT_DIR.as_str(),
processing_id
);
// Process face recognition
let result = match process_face_recognition(
&video_path,
&output_path,
Some(&processing_id),
request.enable_recognition.unwrap_or(true),
request.enable_tracking.unwrap_or(true),
request.enable_clustering.unwrap_or(true),
)
.await
{
Ok(result) => result,
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("Face recognition failed: {}", e),
))
}
};
// Store results in database
if let Err(e) = store_recognition_results(&db, &request.video_uuid, &result).await {
tracing::warn!("Failed to store recognition results: {}", e);
}
Ok(Json(FaceRecognitionResponse {
success: true,
message: format!("Face recognition completed for {}", request.video_uuid),
result: Some(result),
processing_id,
}))
}
async fn register_face_api(
State(_state): State<crate::api::server::AppState>,
mut multipart: Multipart,
) -> Result<Json<FaceRegistrationApiResponse>, (StatusCode, String)> {
let mut image_path: Option<String> = None;
let mut name: Option<String> = None;
let mut metadata: Option<serde_json::Value> = None;
// Parse multipart form data
while let Some(field) = multipart.next_field().await.map_err(|e| {
(
StatusCode::BAD_REQUEST,
format!("Failed to parse form data: {}", e),
)
})? {
let field_name = field.name().unwrap_or("").to_string();
match field_name.as_str() {
"image" => {
// Save uploaded image
let file_name = format!("face_registration_{}.jpg", Uuid::new_v4());
let file_path = format!("/tmp/{}", file_name);
let data = field.bytes().await.map_err(|e| {
(
StatusCode::BAD_REQUEST,
format!("Failed to read image data: {}", e),
)
})?;
tokio::fs::write(&file_path, &data).await.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to save image: {}", e),
)
})?;
image_path = Some(file_path);
}
"name" => {
let value = field.text().await.map_err(|e| {
(
StatusCode::BAD_REQUEST,
format!("Failed to read name: {}", e),
)
})?;
name = Some(value);
}
"metadata" => {
let value = field.text().await.map_err(|e| {
(
StatusCode::BAD_REQUEST,
format!("Failed to read metadata: {}", e),
)
})?;
metadata = Some(serde_json::from_str(&value).map_err(|e| {
(
StatusCode::BAD_REQUEST,
format!("Invalid JSON metadata: {}", e),
)
})?);
}
_ => {}
}
}
// Validate required fields
let image_path =
image_path.ok_or((StatusCode::BAD_REQUEST, "Image is required".to_string()))?;
let name = name.ok_or((StatusCode::BAD_REQUEST, "Name is required".to_string()))?;
// Register face
let result = match register_face(&image_path, &name, metadata.clone()).await {
Ok(result) => result,
Err(e) => {
// Clean up temporary file
let _ = tokio::fs::remove_file(&image_path).await;
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("Face registration failed: {}", e),
));
}
};
// Clean up temporary file
let _ = tokio::fs::remove_file(&image_path).await;
// Store in PostgreSQL face_identities table
if result.success && !result.embedding.is_empty() {
let db = match PostgresDb::init().await {
Ok(db) => db,
Err(e) => {
tracing::warn!("[FACE_REGISTRATION] Failed to connect to DB: {}", e);
// Return success even if DB write fails (embedding is still in JSON output)
return Ok(Json(FaceRegistrationApiResponse {
success: result.success,
message: format!("{} (Warning: DB write failed: {})", result.message, e),
result: Some(result),
}));
}
};
// Convert embedding to PostgreSQL vector format
let embedding_str = format!(
"[{}]",
result
.embedding
.iter()
.map(|v| v.to_string())
.collect::<Vec<_>>()
.join(",")
);
// Insert into face_identities
let face_identities_table = schema::table_name("face_identities");
let attrs_json =
serde_json::to_string(&result.attributes).unwrap_or_else(|_| "{}".to_string());
// Use public.vector type to work across schemas
let vector_type = if schema::SCHEMA_PREFIX.as_str().is_empty() {
"vector".to_string()
} else {
"public.vector".to_string()
};
let insert_query = format!(
r#"
INSERT INTO {} (face_id, name, embedding, attributes, metadata, is_active)
VALUES ($1, $2, $3::{}, $4::jsonb, $5, TRUE)
ON CONFLICT (face_id) DO UPDATE SET
name = EXCLUDED.name,
embedding = EXCLUDED.embedding,
attributes = EXCLUDED.attributes,
metadata = COALESCE(EXCLUDED.metadata, {}.metadata),
updated_at = CURRENT_TIMESTAMP,
is_active = TRUE
"#,
face_identities_table, vector_type, face_identities_table
);
match sqlx::query(&insert_query)
.bind(&result.face_id)
.bind(&name)
.bind(&embedding_str)
.bind(&attrs_json)
.bind(&metadata.unwrap_or(serde_json::json!({})))
.execute(db.pool())
.await
{
Ok(_) => {
tracing::info!(
"[FACE_REGISTRATION] Stored face '{}' (face_id={}) in DB",
name,
result.face_id
);
}
Err(e) => {
tracing::warn!("[FACE_REGISTRATION] Failed to store face in DB: {}", e);
}
}
}
Ok(Json(FaceRegistrationApiResponse {
success: result.success,
message: result.message.clone(),
result: Some(result),
}))
}
async fn search_faces(
State(_state): State<crate::api::server::AppState>,
Json(request): Json<FaceSearchRequest>,
) -> Result<Json<FaceSearchResponse>, (StatusCode, String)> {
let db = match PostgresDb::init().await {
Ok(db) => db,
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to connect to database: {}", e),
))
}
};
// Convert embedding to PostgreSQL vector format
let embedding_str = format!(
"[{}]",
request
.embedding
.iter()
.map(|v| v.to_string())
.collect::<Vec<_>>()
.join(",")
);
let similarity_threshold = request.similarity_threshold.unwrap_or(0.6);
let limit = request.limit.unwrap_or(10);
// Search for similar faces
let face_identities_table = schema::table_name("face_identities");
let vector_type = if schema::SCHEMA_PREFIX.as_str().is_empty() {
"vector".to_string()
} else {
"public.vector".to_string()
};
let query = format!(
r#"
SELECT
face_id,
name,
1 - (embedding <=> $1::{}) as similarity,
attributes,
metadata
FROM {}
WHERE is_active = TRUE
AND embedding IS NOT NULL
AND 1 - (embedding <=> $1::{}) >= $2
ORDER BY embedding <=> $1::{}
LIMIT $3
"#,
vector_type, face_identities_table, vector_type, vector_type
);
let results: Vec<FaceSearchResult> = match sqlx::query_as::<
_,
(
String,
Option<String>,
f64,
Option<serde_json::Value>,
Option<serde_json::Value>,
),
>(query.as_str())
.bind(&embedding_str)
.bind(similarity_threshold)
.bind(limit)
.fetch_all(db.pool())
.await
{
Ok(rows) => rows
.into_iter()
.map(
|(face_id, name, similarity, attributes, metadata)| FaceSearchResult {
face_id,
name,
similarity,
attributes,
metadata,
},
)
.collect(),
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to search faces: {}", e),
))
}
};
Ok(Json(FaceSearchResponse {
success: true,
message: format!("Found {} similar faces", results.len()),
results,
}))
}
async fn list_faces(
State(_state): State<crate::api::server::AppState>,
Query(query): Query<FaceListQuery>,
) -> Result<Json<FaceListResponse>, (StatusCode, String)> {
let db = match PostgresDb::init().await {
Ok(db) => db,
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to connect to database: {}", e),
))
}
};
let page = query.page.unwrap_or(1);
let page_size = query.page_size.unwrap_or(20);
let offset = ((page - 1) as i64) * (page_size as i64);
let active_only = query.active_only.unwrap_or(true);
// Build query
let mut where_clause = "WHERE 1=1".to_string();
if active_only {
where_clause.push_str(" AND is_active = TRUE");
}
let face_identities_table = schema::table_name("face_identities");
let count_query = format!(
"SELECT COUNT(*) FROM {} {}",
face_identities_table, where_clause
);
let list_query = format!(
"SELECT face_id, name, created_at, updated_at, is_active, metadata FROM {} {} ORDER BY created_at DESC LIMIT $1 OFFSET $2",
face_identities_table, where_clause
);
// Get total count
let total: i64 = match sqlx::query_scalar(&count_query).fetch_one(db.pool()).await {
Ok(count) => count,
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to count faces: {}", e),
))
}
};
// Get face list
let faces: Vec<FaceListItem> = match sqlx::query_as::<
_,
(
String,
Option<String>,
chrono::DateTime<chrono::Utc>,
chrono::DateTime<chrono::Utc>,
bool,
Option<serde_json::Value>,
),
>(&list_query)
.bind(page_size as i32)
.bind(offset)
.fetch_all(db.pool())
.await
{
Ok(rows) => rows
.into_iter()
.map(
|(face_id, name, created_at, updated_at, is_active, metadata)| FaceListItem {
face_id,
name,
created_at: created_at.to_rfc3339(),
updated_at: updated_at.to_rfc3339(),
is_active,
metadata,
},
)
.collect(),
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to list faces: {}", e),
))
}
};
Ok(Json(FaceListResponse {
success: true,
message: format!("Found {} faces", total),
faces,
count: total,
page,
page_size,
}))
}
async fn get_face_details(
State(_state): State<crate::api::server::AppState>,
Path(face_id): Path<String>,
) -> Result<Json<serde_json::Value>, (StatusCode, String)> {
let db = match PostgresDb::init().await {
Ok(db) => db,
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to connect to database: {}", e),
))
}
};
let face_identities_table = schema::table_name("face_identities");
let query = format!(
r#"
SELECT
face_id,
name,
embedding,
attributes,
metadata,
created_at,
updated_at,
is_active
FROM {}
WHERE face_id = $1
"#,
face_identities_table
);
let face: Option<(
String,
Option<String>,
Option<String>,
Option<serde_json::Value>,
Option<serde_json::Value>,
chrono::DateTime<chrono::Utc>,
chrono::DateTime<chrono::Utc>,
bool,
)> = match sqlx::query_as(&query)
.bind(&face_id)
.fetch_optional(db.pool())
.await
{
Ok(face) => face,
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to fetch face details: {}", e),
))
}
};
match face {
Some((
face_id,
name,
embedding,
attributes,
metadata,
created_at,
updated_at,
is_active,
)) => {
let response = serde_json::json!({
"success": true,
"face_id": face_id,
"name": name,
"has_embedding": embedding.is_some(),
"attributes": attributes,
"metadata": metadata,
"created_at": created_at.to_rfc3339(),
"updated_at": updated_at.to_rfc3339(),
"is_active": is_active
});
Ok(Json(response))
}
None => Err((
StatusCode::NOT_FOUND,
format!("Face not found: {}", face_id),
)),
}
}
async fn delete_face(
State(_state): State<crate::api::server::AppState>,
Path(face_id): Path<String>,
) -> Result<Json<serde_json::Value>, (StatusCode, String)> {
let db = match PostgresDb::init().await {
Ok(db) => db,
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to connect to database: {}", e),
))
}
};
// Soft delete by marking as inactive
let face_identities_table = schema::table_name("face_identities");
let query = format!(
r#"
UPDATE {}
SET is_active = FALSE, updated_at = CURRENT_TIMESTAMP
WHERE face_id = $1 AND is_active = TRUE
RETURNING face_id, name
"#,
face_identities_table
);
let deleted: Option<(String, Option<String>)> = match sqlx::query_as(&query)
.bind(&face_id)
.fetch_optional(db.pool())
.await
{
Ok(result) => result,
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to delete face: {}", e),
))
}
};
match deleted {
Some((deleted_id, name)) => {
let response = serde_json::json!({
"success": true,
"message": format!("Face '{}' deleted successfully", name.clone().unwrap_or_else(|| deleted_id.clone())),
"face_id": deleted_id.clone()
});
Ok(Json(response))
}
None => Err((
StatusCode::NOT_FOUND,
format!("Face not found or already deleted: {}", face_id),
)),
}
}
async fn get_recognition_results(
State(_state): State<crate::api::server::AppState>,
Path(video_uuid): Path<String>,
) -> Result<Json<serde_json::Value>, (StatusCode, String)> {
let db = match PostgresDb::init().await {
Ok(db) => db,
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to connect to database: {}", e),
))
}
};
let query = r#"
SELECT
video_uuid,
frame_count,
fps,
total_faces,
recognized_faces,
clusters_count,
result_data,
processing_time_secs,
created_at
FROM face_recognition_results
WHERE video_uuid = $1
ORDER BY created_at DESC
LIMIT 1
"#;
let result: Option<(
String,
i64,
f64,
i32,
i32,
i32,
serde_json::Value,
Option<f64>,
chrono::DateTime<chrono::Utc>,
)> = match sqlx::query_as(query)
.bind(&video_uuid)
.fetch_optional(db.pool())
.await
{
Ok(result) => result,
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to fetch recognition results: {}", e),
))
}
};
match result {
Some((
video_uuid,
frame_count,
fps,
total_faces,
recognized_faces,
clusters_count,
result_data,
processing_time_secs,
created_at,
)) => {
let response = serde_json::json!({
"success": true,
"video_uuid": video_uuid,
"frame_count": frame_count,
"fps": fps,
"total_faces": total_faces,
"recognized_faces": recognized_faces,
"clusters_count": clusters_count,
"result_data": result_data,
"processing_time_secs": processing_time_secs,
"created_at": created_at.to_rfc3339()
});
Ok(Json(response))
}
None => Err((
StatusCode::NOT_FOUND,
format!("No recognition results found for video: {}", video_uuid),
)),
}
}
async fn store_recognition_results(
db: &PostgresDb,
video_uuid: &str,
result: &FaceRecognitionResult,
) -> Result<(), anyhow::Error> {
let total_faces = result.frames.iter().map(|f| f.faces.len()).sum::<usize>();
let recognized_faces = result
.frames
.iter()
.flat_map(|f| &f.faces)
.filter(|face| face.identity.is_some())
.count();
let query = r#"
INSERT INTO face_recognition_results (
video_uuid,
frame_count,
fps,
total_faces,
recognized_faces,
clusters_count,
result_data
) VALUES ($1, $2, $3, $4, $5, $6, $7)
ON CONFLICT (video_uuid) DO UPDATE SET
frame_count = EXCLUDED.frame_count,
fps = EXCLUDED.fps,
total_faces = EXCLUDED.total_faces,
recognized_faces = EXCLUDED.recognized_faces,
clusters_count = EXCLUDED.clusters_count,
result_data = EXCLUDED.result_data,
updated_at = CURRENT_TIMESTAMP
"#;
sqlx::query(query)
.bind(video_uuid)
.bind(result.frame_count as i64)
.bind(result.fps)
.bind(total_faces as i32)
.bind(recognized_faces as i32)
.bind(result.face_clusters.len() as i32)
.bind(serde_json::to_value(result)?)
.execute(db.pool())
.await?;
// Store individual face detections
for frame in &result.frames {
for face in &frame.faces {
if let Some(embedding) = &face.embedding {
let embedding_str = format!(
"[{}]",
embedding
.iter()
.map(|v| v.to_string())
.collect::<Vec<_>>()
.join(",")
);
let insert_query = r#"
INSERT INTO face_detections (
video_uuid,
frame_number,
timestamp_secs,
face_id,
x,
y,
width,
height,
confidence,
embedding,
attributes,
identity_confidence,
cluster_id
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10::vector, $11, $12, $13)
ON CONFLICT (video_uuid, frame_number, x, y, width, height) DO UPDATE SET
face_id = EXCLUDED.face_id,
confidence = EXCLUDED.confidence,
embedding = EXCLUDED.embedding,
attributes = EXCLUDED.attributes,
identity_confidence = EXCLUDED.identity_confidence,
cluster_id = EXCLUDED.cluster_id
"#;
let identity_confidence = face.identity.as_ref().map(|id| id.confidence as f64);
let cluster_id = result
.face_clusters
.iter()
.find(|c| {
c.face_ids
.contains(&face.face_id.clone().unwrap_or_default())
})
.map(|c| c.cluster_id.clone());
sqlx::query(insert_query)
.bind(video_uuid)
.bind(frame.frame as i64)
.bind(frame.timestamp)
.bind(face.face_id.as_deref())
.bind(face.x)
.bind(face.y)
.bind(face.width)
.bind(face.height)
.bind(face.confidence as f64)
.bind(&embedding_str)
.bind(serde_json::to_value(&face.attributes)?)
.bind(identity_confidence)
.bind(cluster_id)
.execute(db.pool())
.await?;
}
}
}
// Store face clusters
for cluster in &result.face_clusters {
let centroid = &cluster.centroid;
let centroid_str = format!(
"[{}]",
centroid
.iter()
.map(|v: &f32| v.to_string())
.collect::<Vec<_>>()
.join(",")
);
let cluster_query = r#"
INSERT INTO face_clusters (
cluster_id,
video_uuid,
centroid,
size,
representative_face_id,
metadata
) VALUES ($1, $2, $3::vector, $4, $5, $6)
ON CONFLICT (cluster_id) DO UPDATE SET
centroid = EXCLUDED.centroid,
size = EXCLUDED.size,
representative_face_id = EXCLUDED.representative_face_id,
metadata = EXCLUDED.metadata
"#;
sqlx::query(cluster_query)
.bind(&cluster.cluster_id)
.bind(video_uuid)
.bind(&centroid_str)
.bind(cluster.size as i32)
.bind(cluster.representative_face_id.as_deref())
.bind(&cluster.metadata)
.execute(db.pool())
.await?;
}
Ok(())
}

288
src/api/identities.rs Normal file
View File

@@ -0,0 +1,288 @@
use axum::{
extract::{Query, State},
http::StatusCode,
response::Json,
routing::{get, post},
Router,
};
use serde::{Deserialize, Serialize};
use crate::core::db::{schema, Database, PostgresDb};
#[derive(Debug, Deserialize)]
pub struct RegisterFromPersonRequest {
pub video_uuid: String,
pub person_id: String,
pub identity_name: String,
pub metadata: Option<serde_json::Value>,
}
#[derive(Debug, Serialize)]
pub struct RegisterFromPersonResponse {
pub success: bool,
pub message: String,
pub identity_id: i32,
pub identity_name: String,
pub person_id: String,
}
pub fn identity_routes() -> Router<crate::api::server::AppState> {
Router::new()
.route("/api/v1/identities/from-person", post(register_from_person))
.route("/api/v1/identities", get(list_identities))
}
/// Register a Global Identity from a specific Person in a video.
/// This creates/updates the Identity record, links the Person to the Identity,
/// and updates the Person's name to match the Identity.
async fn register_from_person(
State(_state): State<crate::api::server::AppState>,
Json(req): Json<RegisterFromPersonRequest>,
) -> Result<Json<RegisterFromPersonResponse>, (StatusCode, String)> {
let db = match PostgresDb::init().await {
Ok(db) => db,
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("DB error: {}", e),
))
}
};
let mut tx = match db.pool().begin().await {
Ok(tx) => tx,
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("Tx error: {}", e),
))
}
};
// 1. Check if Person exists
let person_query =
"SELECT id, name FROM person_identities WHERE person_id = $1 AND video_uuid = $2";
let person: Option<(i32, Option<String>)> = match sqlx::query_as(person_query)
.bind(&req.person_id)
.bind(&req.video_uuid)
.fetch_optional(&mut *tx)
.await
{
Ok(p) => p,
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("Query error: {}", e),
))
}
};
let (person_db_id, _old_name) = match person {
Some(p) => p,
None => {
return Err((
StatusCode::NOT_FOUND,
format!(
"Person '{}' not found in video '{}'",
req.person_id, req.video_uuid
),
))
}
};
// 2. Check if Identity exists
let identity_query = "SELECT id FROM identities WHERE name = $1";
let identity_id: Option<i32> = match sqlx::query_scalar(identity_query)
.bind(&req.identity_name)
.fetch_optional(&mut *tx)
.await
{
Ok(id) => id,
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("Query error: {}", e),
))
}
};
let final_identity_id = if let Some(id) = identity_id {
id
} else {
// Create new Identity
let meta_json = req.metadata.clone().unwrap_or(serde_json::json!({}));
let new_id: i32 = match sqlx::query_scalar(
r#"
INSERT INTO identities (name, embedding, metadata)
VALUES ($1, NULLIF($2, '')::public.vector, $3)
RETURNING id
"#,
)
.bind(&req.identity_name)
.bind("".to_string()) // No embedding for now via this API
.bind(&meta_json)
.fetch_one(&mut *tx)
.await
{
Ok(id) => id,
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("Insert identity error: {}", e),
))
}
};
new_id
};
// 3. Create Binding
// Columns: id, identity_id, identity_type, identity_value, confidence, metadata, created_at
let binding_query = r#"
INSERT INTO identity_bindings (identity_id, identity_type, identity_value, confidence, metadata)
VALUES ($1, $2, $3, $4, $5)
ON CONFLICT DO NOTHING
"#;
match sqlx::query(binding_query)
.bind(final_identity_id)
.bind("person_id") // identity_type
.bind(&req.person_id) // identity_value
.bind(1.0) // confidence
.bind(&serde_json::json!({"auto_updated": true}))
.execute(&mut *tx)
.await
{
Ok(_) => {}
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("Binding error: {}", e),
))
}
};
// 4. Update Person Name
let update_person = "UPDATE person_identities SET name = $1 WHERE id = $2";
match sqlx::query(update_person)
.bind(&req.identity_name)
.bind(person_db_id)
.execute(&mut *tx)
.await
{
Ok(_) => {}
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("Update person error: {}", e),
))
}
};
match tx.commit().await {
Ok(_) => {}
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("Commit error: {}", e),
))
}
};
Ok(Json(RegisterFromPersonResponse {
success: true,
message: format!(
"Successfully registered identity '{}' and linked to person '{}'",
req.identity_name, req.person_id
),
identity_id: final_identity_id,
identity_name: req.identity_name,
person_id: req.person_id,
}))
}
/// List all global identities
async fn list_identities(
State(_state): State<crate::api::server::AppState>,
Query(query): Query<ListIdentitiesQuery>,
) -> Result<Json<IdentityListResponse>, (StatusCode, String)> {
let db = match PostgresDb::init().await {
Ok(db) => db,
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("DB error: {}", e),
))
}
};
let page = query.page.unwrap_or(1);
let page_size = query.page_size.unwrap_or(20);
let offset = ((page - 1) as i64) * (page_size as i64);
// 獲取總數
let count_sql = "SELECT COUNT(*) FROM identities";
let total: i64 = match sqlx::query_scalar(count_sql).fetch_one(db.pool()).await {
Ok(count) => count,
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("Count error: {}", e),
))
}
};
let sql = "SELECT id, name, metadata FROM identities ORDER BY id DESC LIMIT $1 OFFSET $2";
let rows: Vec<(i32, String, Option<serde_json::Value>)> = match sqlx::query_as(sql)
.bind(page_size as i64)
.bind(offset)
.fetch_all(db.pool())
.await
{
Ok(rows) => rows,
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("Query error: {}", e),
))
}
};
let identities: Vec<IdentityResponse> = rows
.into_iter()
.map(|r| IdentityResponse {
id: r.0,
name: r.1,
metadata: r.2,
})
.collect();
Ok(Json(IdentityListResponse {
identities,
count: total,
page,
page_size,
}))
}
#[derive(Debug, Deserialize)]
pub struct ListIdentitiesQuery {
pub page: Option<usize>,
pub page_size: Option<usize>,
}
#[derive(Debug, Serialize)]
pub struct IdentityResponse {
pub id: i32,
pub name: String,
pub metadata: Option<serde_json::Value>,
}
#[derive(Debug, Serialize)]
pub struct IdentityListResponse {
pub identities: Vec<IdentityResponse>,
pub count: i64,
pub page: usize,
pub page_size: usize,
}

412
src/api/identity_binding.rs Normal file
View File

@@ -0,0 +1,412 @@
use axum::{
extract::{Path, Query},
http::StatusCode,
response::Json,
routing::{get, post},
Router,
};
use serde::{Deserialize, Serialize};
use crate::core::db::{Database, PostgresDb};
use crate::core::person_identity::{BindIdentityRequest, Identity, UnbindIdentityRequest};
#[derive(Debug, Clone, Serialize)]
pub struct ApiResponse<T: Serialize> {
pub success: bool,
pub message: String,
pub data: Option<T>,
}
// ============================================================================
// API Handlers
// ============================================================================
async fn get_db() -> Result<PostgresDb, (StatusCode, Json<serde_json::Value>)> {
PostgresDb::init().await.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": format!("DB init failed: {}", e) })),
)
})
}
/// 獲取 Identity (人物) 列表
pub async fn list_identities(
Query(params): Query<ListIdentitiesParams>,
) -> Result<Json<ApiResponse<Vec<Identity>>>, (StatusCode, Json<serde_json::Value>)> {
let db = get_db().await?;
let limit = params.limit.unwrap_or(100);
let offset = params.offset.unwrap_or(0);
let search = params.search.unwrap_or_default();
let identities = db
.list_identities(&search, limit, offset)
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": e.to_string() })),
)
})?;
Ok(Json(ApiResponse {
success: true,
message: format!("Found {} identities", identities.len()),
data: Some(identities),
}))
}
/// 綁定身份 (Face/Speaker -> Identity)
pub async fn bind_identity(
Json(req): Json<BindIdentityRequest>,
) -> Result<Json<ApiResponse<()>>, (StatusCode, Json<serde_json::Value>)> {
let db = get_db().await?;
let identity = if let Some(id_id) = req.identity_id {
db.get_identity_by_id(id_id).await.ok().flatten()
} else if let Some(name) = &req.name {
db.get_or_create_identity(name).await.ok()
} else {
None
};
let identity = match identity {
Some(t) => t,
None => {
return Err((
StatusCode::BAD_REQUEST,
Json(serde_json::json!({ "error": "Identity not found or name required" })),
));
}
};
let source = req.source.unwrap_or("manual".to_string());
db.bind_identity(
identity.id as i64,
&req.binding_type,
&req.binding_value,
&source,
1.0,
)
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": e.to_string() })),
)
})?;
Ok(Json(ApiResponse {
success: true,
message: format!(
"Bound {} '{}' to Identity '{}'",
req.binding_type, req.binding_value, identity.name
),
data: None,
}))
}
/// 解綁身份
pub async fn unbind_identity(
Json(req): Json<UnbindIdentityRequest>,
) -> Result<Json<ApiResponse<()>>, (StatusCode, Json<serde_json::Value>)> {
let db = get_db().await?;
db.unbind_identity(&req.binding_type, &req.binding_value)
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": e.to_string() })),
)
})?;
Ok(Json(ApiResponse {
success: true,
message: format!("Unbound {} '{}'", req.binding_type, req.binding_value),
data: None,
}))
}
/// 查詢機器 ID 對應的 Identity (人物)
pub async fn get_identity_info(
Path((binding_type, binding_value)): Path<(String, String)>,
) -> Result<Json<ApiResponse<Identity>>, (StatusCode, Json<serde_json::Value>)> {
let db = get_db().await?;
let identity = db
.get_identity_by_binding(&binding_type, &binding_value)
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": e.to_string() })),
)
})?;
let identity = identity.ok_or_else(|| {
(
StatusCode::NOT_FOUND,
Json(serde_json::json!({ "error": "Identity not found" })),
)
})?;
Ok(Json(ApiResponse {
success: true,
message: "Identity info retrieved".to_string(),
data: Some(identity),
}))
}
/// 列出未綁定的信號 (待標註列表)
pub async fn list_unbound_signals(
Query(params): Query<ListSignalsParams>,
) -> Result<Json<ApiResponse<Vec<String>>>, (StatusCode, Json<serde_json::Value>)> {
let db = get_db().await?;
let signals = db
.list_unbound_signals(&params.uuid, &params.binding_type)
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": e.to_string() })),
)
})?;
Ok(Json(ApiResponse {
success: true,
message: format!(
"Found {} unbound {} signals",
signals.len(),
params.binding_type
),
data: Some(signals),
}))
}
/// 獲取特定信號 (Face ID 或 Speaker ID) 出現的所有 Chunk (時間軸)
pub async fn get_signal_timeline(
Path((uuid, binding_type, binding_value)): Path<(String, String, String)>,
) -> Result<Json<ApiResponse<Vec<serde_json::Value>>>, (StatusCode, Json<serde_json::Value>)> {
let db = get_db().await?;
let chunks = db
.get_chunks_by_signal(&uuid, &binding_type, &binding_value)
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": e.to_string() })),
)
})?;
Ok(Json(ApiResponse {
success: true,
message: format!(
"Found {} chunks for {} '{}'",
chunks.len(),
binding_type,
binding_value
),
data: Some(chunks),
}))
}
#[derive(Debug, Deserialize)]
pub struct AVSuggestRequest {
pub video_uuid: String,
pub overlap_threshold: Option<f64>, // default 0.6
}
#[derive(Debug, Serialize)]
pub struct AVSuggestion {
pub face_id: String,
pub speaker_id: String,
pub overlap_score: f64,
pub face_talent_id: Option<i32>,
pub speaker_talent_id: Option<i32>,
}
/// Suggests Face-Speaker bindings based on temporal overlap
pub async fn suggest_audio_visual_bindings(
Json(req): Json<AVSuggestRequest>,
) -> Result<Json<ApiResponse<Vec<AVSuggestion>>>, (StatusCode, Json<serde_json::Value>)> {
let db = get_db().await?;
let threshold = req.overlap_threshold.unwrap_or(0.6);
// 1. Get Face signals and their time ranges
let face_signals = db
.list_unbound_signals(&req.video_uuid, "face")
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": format!("Face signals: {}", e) })),
)
})?;
let speaker_signals = db
.list_unbound_signals(&req.video_uuid, "speaker")
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": format!("Speaker signals: {}", e) })),
)
})?;
let mut suggestions = Vec::new();
for face_id in &face_signals {
for speaker_id in &speaker_signals {
// Calculate overlap
// In a real implementation, we would query the exact timestamps from DB.
// For now, we'll use a placeholder or simple heuristic if timestamps are available in signal timeline.
// Let's assume we fetch timelines.
// Placeholder: Calculate overlap by fetching timelines
let face_timeline = db
.get_chunks_by_signal(&req.video_uuid, "face", face_id)
.await
.unwrap_or_default();
let speaker_timeline = db
.get_chunks_by_signal(&req.video_uuid, "speaker", speaker_id)
.await
.unwrap_or_default();
// Simplified overlap calculation based on chunk count/ids (assuming chunk_id contains time info or we have ranges)
// Since chunk IDs are generic, we rely on content JSON having 'start_time'
let overlap = calculate_overlap(&face_timeline, &speaker_timeline);
if overlap >= threshold {
// Check if they are already bound to the same identity
let face_identity = db
.get_identity_by_binding("face", face_id)
.await
.ok()
.flatten();
let speaker_identity = db
.get_identity_by_binding("speaker", speaker_id)
.await
.ok()
.flatten();
// If both bound to different identities, don't suggest (conflict)
if let (Some(fi), Some(si)) = (&face_identity, &speaker_identity) {
if fi.id != si.id {
continue;
}
}
suggestions.push(AVSuggestion {
face_id: face_id.clone(),
speaker_id: speaker_id.clone(),
overlap_score: overlap,
face_talent_id: face_identity.as_ref().map(|i| i.id as i32),
speaker_talent_id: speaker_identity.as_ref().map(|i| i.id as i32),
});
}
}
}
suggestions.sort_by(|a, b| b.overlap_score.partial_cmp(&a.overlap_score).unwrap());
Ok(Json(ApiResponse {
success: true,
message: format!("Found {} AV suggestions", suggestions.len()),
data: Some(suggestions),
}))
}
fn calculate_overlap(
face_chunks: &[serde_json::Value],
speaker_chunks: &[serde_json::Value],
) -> f64 {
// Simplified: Extract start/end times and calculate intersection over union
// Assuming chunks have start_frame or start_time in content JSON
// If content is raw string, we might need to parse it.
// In our schema, content is JSONB.
let mut face_ranges: Vec<(f64, f64)> = Vec::new();
for c in face_chunks {
if let Some(content) = c.get("content") {
if let (Some(start), Some(end)) = (
content.get("start_time").and_then(|v| v.as_f64()),
content.get("end_time").and_then(|v| v.as_f64()),
) {
face_ranges.push((start, end));
}
}
}
let mut speaker_ranges: Vec<(f64, f64)> = Vec::new();
for c in speaker_chunks {
if let Some(content) = c.get("content") {
if let (Some(start), Some(end)) = (
content.get("start_time").and_then(|v| v.as_f64()),
content.get("end_time").and_then(|v| v.as_f64()),
) {
speaker_ranges.push((start, end));
}
}
}
let mut overlap_duration = 0.0;
for (fs, fe) in &face_ranges {
for (ss, se) in &speaker_ranges {
let start = fs.max(*ss);
let end = fe.min(*se);
if start < end {
overlap_duration += end - start;
}
}
}
// Return normalized overlap (0.0 to 1.0+), simple version: overlap / min_duration
let min_duration = face_ranges
.iter()
.map(|(_, e)| e)
.sum::<f64>()
.min(speaker_ranges.iter().map(|(_, e)| e).sum::<f64>());
if min_duration > 0.0 {
(overlap_duration / min_duration).min(1.0)
} else {
0.0
}
}
// ============================================================================
// Router Setup
// ============================================================================
#[derive(Debug, Deserialize)]
pub struct ListIdentitiesParams {
pub search: Option<String>,
pub limit: Option<i32>,
pub offset: Option<i32>,
}
#[derive(Debug, Deserialize)]
pub struct ListSignalsParams {
pub uuid: String,
pub binding_type: String, // "face" or "speaker"
}
pub fn identity_binding_routes() -> Router<crate::api::server::AppState> {
Router::new()
.route("/api/v1/identities/bind", post(bind_identity))
.route("/api/v1/identities/unbind", post(unbind_identity))
.route(
"/api/v1/identity/:binding_type/:binding_value",
get(get_identity_info),
)
// 信號發現 (Discovery)
.route("/api/v1/signals/unbound", get(list_unbound_signals))
// 信號時間軸 (Timeline)
.route(
"/api/v1/signals/:uuid/:binding_type/:binding_value/timeline",
get(get_signal_timeline),
)
.route(
"/api/v1/identities/suggest-av",
post(suggest_audio_visual_bindings),
)
}

264
src/api/n8n_search.rs Normal file
View File

@@ -0,0 +1,264 @@
use crate::core::db::{Bm25Result, PostgresDb};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
#[derive(Debug, Deserialize)]
pub struct SmartSearchRequest {
pub query: String,
pub uuid: Option<String>,
pub limit: Option<usize>,
}
#[derive(Debug, Serialize)]
pub struct SmartSearchResponse {
pub query: String,
pub parsed_dimensions: serde_json::Value,
pub hits: Vec<serde_json::Value>,
pub total: usize,
}
#[derive(Debug, Deserialize, Serialize)]
struct LlmDimensionResponse {
pub who: Option<String>,
pub what: Option<String>,
pub when: Option<String>,
pub r#where: Option<String>,
pub why: Option<String>,
#[serde(default)]
pub keywords: Vec<String>,
}
/// POST /api/v1/n8n/search/smart
pub async fn n8n_search_smart(
db: &PostgresDb,
req: SmartSearchRequest,
) -> Result<SmartSearchResponse, Box<dyn std::error::Error + Send + Sync>> {
let limit = req.limit.unwrap_or(10);
let video_uuid = req.uuid.clone();
// 1. Call LLM to extract 5W1H (Fallback to keywords if LLM fails)
let dimensions = match parse_query_with_llm(&req.query).await {
Some(dims) => dims,
None => LlmDimensionResponse {
who: None,
what: None,
when: None,
r#where: None,
why: None,
keywords: extract_keywords(&req.query),
},
};
// Prepare search terms based on dimensions
let keywords = dimensions.keywords.join(" ");
let semantic_query = [
dimensions.who.clone(),
dimensions.what.clone(),
dimensions.r#where.clone(),
dimensions.why.clone(),
dimensions.when.clone(),
]
.into_iter()
.flatten()
.collect::<Vec<_>>()
.join(" ");
// 2. Multi-dimensional Search
let mut hits: Vec<serde_json::Value> = Vec::new();
let mut seen_chunk_ids: HashSet<String> = HashSet::new();
// Helper function
fn add_hit(
hits: &mut Vec<serde_json::Value>,
seen_chunk_ids: &mut HashSet<String>,
sr: Bm25Result,
boost: f32,
) {
if seen_chunk_ids.insert(sr.chunk_id.clone()) {
let score = sr.bm25_score * boost;
let val = serde_json::json!({
"id": sr.chunk_id,
"vid": sr.uuid,
"start": sr.start_time,
"end": sr.end_time,
"text": sr.text,
"score": score,
"chunk_type": sr.chunk_type
});
hits.push(val);
}
}
// A. Keyword Search (BM25)
if !keywords.is_empty() {
if let Ok(results) = db
.search_bm25(&keywords, video_uuid.as_deref(), limit)
.await
{
for sr in results {
add_hit(&mut hits, &mut seen_chunk_ids, sr, 1.0);
}
}
}
// B. Who Search (Person Matching)
if let Some(who_query) = &dimensions.who {
// 1. Search Person
if let Ok(persons) = db.search_person_candidates(who_query, &video_uuid, 5).await {
if !persons.is_empty() {
let person_id = persons[0]
.get("candidate_id")
.and_then(|v| v.as_str())
.map(String::from);
if let Some(pid) = person_id {
// Heuristic: Search BM25 for the person's ID or Name
let person_name = persons[0]
.get("display_name")
.and_then(|v| v.as_str())
.unwrap_or(who_query);
// Re-run BM25 with person name to find specific chunks and boost them
if let Ok(results) = db
.search_bm25(person_name, video_uuid.as_deref(), limit)
.await
{
for sr in results {
let id = sr.chunk_id.clone();
if seen_chunk_ids.insert(id) {
let score = sr.bm25_score * 1.5;
let val = serde_json::json!({
"id": sr.chunk_id,
"vid": sr.uuid,
"start": sr.start_time,
"end": sr.end_time,
"text": sr.text,
"score": score,
"matched_person": pid
});
hits.push(val);
}
}
}
}
}
}
}
// Sort by score
hits.sort_by(|a, b| {
let score_a = a.get("score").and_then(|v| v.as_f64()).unwrap_or(0.0);
let score_b = b.get("score").and_then(|v| v.as_f64()).unwrap_or(0.0);
score_b.partial_cmp(&score_a).unwrap()
});
// Limit
hits.truncate(limit);
let total = hits.len();
Ok(SmartSearchResponse {
query: req.query,
parsed_dimensions: serde_json::json!(dimensions),
hits,
total,
})
}
fn extract_keywords(query: &str) -> Vec<String> {
// Simple keyword extraction: remove common stop words and punctuation
let stop_words = [
"who", "what", "where", "when", "why", "how", "is", "the", "a", "an", "and", "or", "of",
"in", "to", "for", "with", "by", "on", "at", "from", "up", "about", "into", "over",
"after",
];
query
.to_lowercase()
.chars()
.map(|c| if c.is_alphanumeric() { c } else { ' ' })
.collect::<String>()
.split_whitespace()
.filter(|w| !stop_words.contains(w))
.map(String::from)
.collect()
}
async fn parse_query_with_llm(query: &str) -> Option<LlmDimensionResponse> {
let client = reqwest::Client::new();
// Test connectivity first
if let Ok(resp) = client.get("http://127.0.0.1:8081/health").send().await {
tracing::info!("LLM Health Check: {}", resp.status());
} else {
tracing::error!("LLM Server is unreachable at 127.0.0.1:8081");
}
// We use the OpenAI-compatible endpoint provided by llama.cpp server (default port 8081)
let prompt = format!(
r#"Analyze the user query and extract the following dimensions into a JSON object.
If a dimension is not present, use null.
Dimensions: "who" (person/subject), "what" (action/event), "where" (location), "when" (time), "why" (reason/intent), "keywords" (array of specific keywords).
User Query: "{}"
Output ONLY the JSON object.
"#,
query
);
let payload = serde_json::json!({
"model": "gemma4",
"messages": [
{
"role": "user",
"content": prompt
}
],
"temperature": 0.1,
"stream": false
});
if let Ok(response) = client
.post("http://127.0.0.1:8081/v1/chat/completions")
.json(&payload)
.timeout(std::time::Duration::from_secs(60))
.send()
.await
{
tracing::info!("LLM Response Status: {}", response.status());
if let Ok(json_resp) = response.json::<serde_json::Value>().await {
tracing::info!("LLM Response Body: {}", json_resp);
if let Some(choices) = json_resp.get("choices").and_then(|v| v.as_array()) {
if let Some(choice) = choices.get(0) {
if let Some(content) = choice
.get("message")
.and_then(|m| m.get("content"))
.and_then(|c| c.as_str())
{
if let Some(start) = content.find("{") {
if let Some(end) = content.rfind("}") {
let json_str = &content[start..=end];
if let Ok(dims) =
serde_json::from_str::<LlmDimensionResponse>(json_str)
{
tracing::info!("Parsed LLM Dimensions: {:?}", dims);
return Some(dims);
} else {
tracing::warn!("Failed to parse LLM JSON: {}", json_str);
}
}
}
}
}
}
} else {
tracing::warn!("Failed to parse LLM response JSON");
}
} else {
tracing::warn!("LLM request failed or timed out");
}
None
}

3281
src/api/person_identity.rs Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

195
src/api/search.rs Normal file
View File

@@ -0,0 +1,195 @@
//! Smart Search API
//! Implements the 5W1H search capability using semantic vectors.
use axum::{extract::State, http::StatusCode, response::Json, routing::post, Router};
use serde::{Deserialize, Serialize};
use serde_json;
use tracing;
use crate::core::db::PostgresDb;
// --- Request / Response Structures ---
#[derive(Debug, Deserialize)]
pub struct SmartSearchRequest {
pub uuid: String,
pub query: String,
pub limit: Option<usize>,
}
#[derive(Debug, Serialize)]
pub struct SearchResult {
pub id: i32,
pub parent_id: i32,
pub scene_order: Option<i32>,
// Primary: frame-accurate position (authoritative unit)
pub start_frame: i64,
pub end_frame: i64,
pub fps: f64,
// Reference: time derived from frames (subject to FPS variation, not precise)
pub start_time: f64,
pub end_time: f64,
pub raw_text: Option<String>, // Text content of the child chunk
pub summary: Option<String>, // Summary from parent context
pub metadata: Option<serde_json::Value>,
pub similarity: Option<f64>,
}
#[derive(Debug, Serialize)]
pub struct SmartSearchResponse {
pub query: String,
pub results: Vec<SearchResult>,
pub strategy: String,
}
// --- API Handler ---
pub async fn smart_search(
State(state): State<crate::api::server::AppState>,
Json(req): Json<SmartSearchRequest>,
) -> Result<Json<SmartSearchResponse>, (StatusCode, Json<serde_json::Value>)> {
let db = &state.db;
let limit = req.limit.unwrap_or(5);
// 1. Generate Embedding using Ollama
let embedding = get_ollama_embedding(&req.query).await.map_err(
|e| -> (StatusCode, Json<serde_json::Value>) {
tracing::error!("Embedding failed: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": e.to_string() })),
)
},
)?;
// 2. Search Database (Drill-Down: Find Parents First)
let db_parents: Vec<crate::core::db::postgres_db::SemanticSearchResult> = db
.search_parent_chunks_semantic(&req.uuid, &embedding, limit)
.await
.map_err(
|e: anyhow::Error| -> (StatusCode, Json<serde_json::Value>) {
tracing::error!("DB search failed: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": e.to_string() })),
)
},
)?;
if db_parents.is_empty() {
return Ok(Json(SmartSearchResponse {
query: req.query,
results: vec![],
strategy: "semantic_vector_search".to_string(),
}));
}
// Collect Parent IDs
let parent_ids: Vec<i32> = db_parents.iter().map(|p| p.id).collect();
// 3. Fetch Children for these Parents (Drill Down)
// We fetch all children for these parents (limit can be adjusted)
let children: Vec<crate::core::db::postgres_db::ChildChunkResult> = db
.get_children_for_parents(&parent_ids, 10) // Fetch top 10 children per parent
.await
.map_err(
|e: anyhow::Error| -> (StatusCode, Json<serde_json::Value>) {
tracing::error!("Fetching children failed: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": e.to_string() })),
)
},
)?;
// 4. Map Parents to a lookup table
let parent_map: std::collections::HashMap<
i32,
&crate::core::db::postgres_db::SemanticSearchResult,
> = db_parents.iter().map(|p| (p.id, p)).collect();
// Map Children to API response struct
let results: Vec<SearchResult> = children
.into_iter()
.map(|c| {
let parent = parent_map.get(&c.parent_id);
SearchResult {
id: c.id,
parent_id: c.parent_id,
scene_order: parent.map(|p| p.scene_order),
start_frame: c.start_frame,
end_frame: c.end_frame,
fps: c.fps,
start_time: c.start_time,
end_time: c.end_time,
raw_text: Some(c.raw_text),
summary: parent.map(|p| p.summary.clone()),
metadata: parent.map(|p| p.metadata.clone()),
similarity: parent.and_then(|p| p.similarity),
}
})
.collect();
// 6. Sort results by similarity (descending)
// Since all children of a parent have the same parent similarity, this groups relevant chunks together
let mut results = results;
results.sort_by(|a, b| {
b.similarity
.partial_cmp(&a.similarity)
.unwrap_or(std::cmp::Ordering::Equal)
});
// 7. Limit the final results (optional, but good for API consistency)
let limit = req.limit.unwrap_or(5) * 5; // Allow more children per parent context
results.truncate(limit);
// 8. Format Response
let response = SmartSearchResponse {
query: req.query,
results,
strategy: "drill_down_semantic_search".to_string(),
};
Ok(Json(response))
}
// --- Helper: Ollama Embedding ---
async fn get_ollama_embedding(
text: &str,
) -> Result<Vec<f32>, Box<dyn std::error::Error + Send + Sync>> {
let client = reqwest::Client::new();
let payload = serde_json::json!({
"model": "nomic-embed-text",
"prompt": text
});
let res = client
.post("http://localhost:11434/api/embeddings")
.json(&payload)
.send()
.await?
.json::<serde_json::Value>()
.await?;
// Parse embedding array from response
let embedding = res["embedding"]
.as_array()
.ok_or("No embedding found in Ollama response")?
.iter()
.map(|v| v.as_f64().unwrap_or(0.0) as f32)
.collect();
Ok(embedding)
}
// --- Router Setup ---
pub fn search_routes() -> Router<crate::api::server::AppState> {
Router::new().route("/smart", post(smart_search))
}

195
src/api/search.rs.bak Normal file
View File

@@ -0,0 +1,195 @@
//! Smart Search API
//! Implements the 5W1H search capability using semantic vectors.
use axum::{extract::State, http::StatusCode, response::Json, routing::post, Router};
use serde::{Deserialize, Serialize};
use serde_json;
use tracing;
use crate::core::db::PostgresDb;
// --- Request / Response Structures ---
#[derive(Debug, Deserialize)]
pub struct SmartSearchRequest {
pub uuid: String,
pub query: String,
pub limit: Option<usize>,
}
#[derive(Debug, Serialize)]
pub struct SearchResult {
pub id: i32,
pub parent_id: i32,
pub scene_order: Option<i32>,
// Primary: frame-accurate position (authoritative unit)
pub start_frame: i64,
pub end_frame: i64,
pub fps: f64,
// Reference: time derived from frames (subject to FPS variation, not precise)
pub start_time: f64,
pub end_time: f64,
pub raw_text: Option<String>, // Text content of the child chunk
pub summary: Option<String>, // Summary from parent context
pub metadata: Option<serde_json::Value>,
pub similarity: Option<f64>,
}
#[derive(Debug, Serialize)]
pub struct SmartSearchResponse {
pub query: String,
pub results: Vec<SearchResult>,
pub strategy: String,
}
// --- API Handler ---
pub async fn smart_search(
State(state): State<crate::api::server::AppState>,
Json(req): Json<SmartSearchRequest>,
) -> Result<Json<SmartSearchResponse>, (StatusCode, Json<serde_json::Value>)> {
let db = &state.db;
let limit = req.limit.unwrap_or(5);
// 1. Generate Embedding using Ollama
let embedding = get_ollama_embedding(&req.query).await.map_err(
|e| -> (StatusCode, Json<serde_json::Value>) {
tracing::error!("Embedding failed: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": e.to_string() })),
)
},
)?;
// 2. Search Database (Drill-Down: Find Parents First)
let db_parents: Vec<crate::core::db::postgres_db::SemanticSearchResult> = db
.search_parent_chunks_semantic(&req.uuid, &embedding, limit)
.await
.map_err(
|e: anyhow::Error| -> (StatusCode, Json<serde_json::Value>) {
tracing::error!("DB search failed: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": e.to_string() })),
)
},
)?;
if db_parents.is_empty() {
return Ok(Json(SmartSearchResponse {
query: req.query,
results: vec![],
strategy: "semantic_vector_search".to_string(),
}));
}
// Collect Parent IDs
let parent_ids: Vec<i32> = db_parents.iter().map(|p| p.id).collect();
// 3. Fetch Children for these Parents (Drill Down)
// We fetch all children for these parents (limit can be adjusted)
let children: Vec<crate::core::db::postgres_db::ChildChunkResult> = db
.get_children_for_parents(&parent_ids, 10) // Fetch top 10 children per parent
.await
.map_err(
|e: anyhow::Error| -> (StatusCode, Json<serde_json::Value>) {
tracing::error!("Fetching children failed: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": e.to_string() })),
)
},
)?;
// 4. Map Parents to a lookup table
let parent_map: std::collections::HashMap<
i32,
&crate::core::db::postgres_db::SemanticSearchResult,
> = db_parents.iter().map(|p| (p.id, p)).collect();
// Map Children to API response struct
let results: Vec<SearchResult> = children
.into_iter()
.map(|c| {
let parent = parent_map.get(&c.parent_id);
SearchResult {
id: c.id,
parent_id: c.parent_id,
scene_order: parent.map(|p| p.scene_order),
start_frame: c.start_frame,
end_frame: c.end_frame,
fps: c.fps,
start_time: c.start_time,
end_time: c.end_time,
raw_text: Some(c.raw_text),
summary: parent.map(|p| p.summary.clone()),
metadata: parent.map(|p| p.metadata.clone()),
similarity: parent.and_then(|p| p.similarity),
}
})
.collect();
// 6. Sort results by similarity (descending)
// Since all children of a parent have the same parent similarity, this groups relevant chunks together
let mut results = results;
results.sort_by(|a, b| {
b.similarity
.partial_cmp(&a.similarity)
.unwrap_or(std::cmp::Ordering::Equal)
});
// 7. Limit the final results (optional, but good for API consistency)
let limit = req.limit.unwrap_or(5) * 5; // Allow more children per parent context
results.truncate(limit);
// 8. Format Response
let response = SmartSearchResponse {
query: req.query,
results,
strategy: "drill_down_semantic_search".to_string(),
};
Ok(Json(response))
}
// --- Helper: Ollama Embedding ---
async fn get_ollama_embedding(
text: &str,
) -> Result<Vec<f32>, Box<dyn std::error::Error + Send + Sync>> {
let client = reqwest::Client::new();
let payload = serde_json::json!({
"model": "nomic-embed-text",
"prompt": text
});
let res = client
.post("http://localhost:11434/api/embeddings")
.json(&payload)
.send()
.await?
.json::<serde_json::Value>()
.await?;
// Parse embedding array from response
let embedding = res["embedding"]
.as_array()
.ok_or("No embedding found in Ollama response")?
.iter()
.map(|v| v.as_f64().unwrap_or(0.0) as f32)
.collect();
Ok(embedding)
}
// --- Router Setup ---
pub fn search_routes() -> Router<crate::api::server::AppState> {
Router::new().route("/smart", post(smart_search))
}

View File

@@ -30,6 +30,33 @@ use super::universal_search;
use super::visual_chunk_search;
use crate::core::chunk::types::Chunk;
static DEMO_USER_API_KEY: &str = "muser_68600856036340bcafc01930eb4bd839_1774418104_97221b69";
fn hash_password(password: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(password.as_bytes());
format!("{:x}", hasher.finalize())
}
#[derive(Debug, Deserialize)]
struct LoginRequest {
username: String,
password: String,
}
#[derive(Debug, Serialize)]
struct LoginResponse {
success: bool,
message: Option<String>,
api_key: Option<String>,
user: Option<UserInfo>,
}
#[derive(Debug, Serialize)]
struct UserInfo {
username: String,
}
// Global State
static SERVER_START: OnceCell<Instant> = OnceCell::new();
@@ -334,6 +361,12 @@ struct VideoInfoResponse {
duration: f64,
width: u32,
height: u32,
status: String,
processing_status: Option<String>,
created_at: Option<String>,
registration_time: Option<String>,
file_size: Option<i64>,
probe_json: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
@@ -348,6 +381,9 @@ struct VideosResponse {
struct VideosQuery {
page: Option<usize>,
page_size: Option<usize>,
q: Option<String>,
status: Option<String>,
uuid: Option<String>,
}
#[derive(Clone)]
@@ -408,7 +444,10 @@ async fn health_detailed(State(state): State<AppState>) -> Json<DetailedHealthRe
let qdrant = check_qdrant().await;
let mongodb = check_mongodb(&state.mongo_cache).await;
let overall_status = if postgres.status == "ok" && redis.status == "ok" && qdrant.status == "ok"
let overall_status = if postgres.status == "ok"
&& redis.status == "ok"
&& qdrant.status == "ok"
&& mongodb.status == "ok"
{
"ok"
} else {
@@ -428,6 +467,30 @@ async fn health_detailed(State(state): State<AppState>) -> Json<DetailedHealthRe
})
}
async fn login(Json(req): Json<LoginRequest>) -> Json<LoginResponse> {
if req.username == "demo" && req.password == "demo" {
Json(LoginResponse {
success: true,
message: Some("Login successful".to_string()),
api_key: Some(DEMO_USER_API_KEY.to_string()),
user: Some(UserInfo {
username: "demo".to_string(),
}),
})
} else {
Json(LoginResponse {
success: false,
message: Some("Invalid username or password".to_string()),
api_key: None,
user: None,
})
}
}
async fn logout() -> Json<serde_json::Value> {
Json(serde_json::json!({ "success": true }))
}
async fn check_postgres() -> ServiceStatus {
let start = Instant::now();
match PostgresDb::init().await {
@@ -709,6 +772,7 @@ async fn register(
user_id: None,
job_id: None,
created_at: String::new(),
registration_time: None,
};
let video_id = db
@@ -1874,9 +1938,51 @@ async fn list_videos(
let page = params.page.unwrap_or(1);
let page_size = params.page_size.unwrap_or(20);
let offset = ((page - 1) as i64) * (page_size as i64);
let status_filter = params.status.clone();
let query_filter = params.q.clone();
// Include query and status in cache key
let cache_key = keys::videos_list(page, page_size);
let cache_key = if let Some(ref q) = query_filter {
format!("{}:q:{}", cache_key, q)
} else {
cache_key
};
let ttl = state.mongo_cache.ttl_videos();
// If uuid is provided, fetch single video directly
if let Some(ref uuid) = params.uuid {
let db = PostgresDb::init()
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let video = db.get_video_by_uuid(uuid).await.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
if let Some(v) = video {
return Ok(Json(VideosResponse {
videos: vec![VideoInfoResponse {
uuid: v.uuid,
file_path: v.file_path,
file_name: v.file_name,
duration: v.duration,
width: v.width,
height: v.height,
status: v.status.as_str().to_string(),
processing_status: None,
created_at: Some(v.created_at),
registration_time: v.registration_time,
file_size: None,
probe_json: v.probe_json,
}],
count: 1,
page,
page_size,
}));
} else {
return Err(StatusCode::NOT_FOUND);
}
}
tracing::info!(
"list_videos called: page={}, page_size={}, cache_key={}",
page,
@@ -1892,7 +1998,21 @@ async fn list_videos(
.await
.map_err(|e| anyhow::anyhow!("PG init failed: {}", e))?;
let (videos, count) = db.list_videos(page_size as i32, offset).await?;
// Map status parameter to is_processed filter
let is_processed = match status_filter.as_deref() {
Some("pending") | Some("unprocessed") => Some(false),
Some("completed") | Some("ready") | Some("processed") => Some(true),
_ => None, // no filter
};
// Search by query if provided
let (videos, count) = if let Some(ref q) = query_filter {
db.search_videos(Some(q.as_str()), is_processed, page_size as i32, offset).await?
} else if let Some(processed) = is_processed {
db.search_videos(None, Some(processed), page_size as i32, offset).await?
} else {
db.list_videos(page_size as i32, offset).await?
};
tracing::info!("Got {} videos from DB", videos.len());
let video_infos: Vec<VideoInfoResponse> = videos
@@ -1904,6 +2024,12 @@ async fn list_videos(
duration: v.duration,
width: v.width,
height: v.height,
status: v.status.as_str().to_string(),
processing_status: None,
created_at: Some(v.created_at),
registration_time: v.registration_time,
file_size: None,
probe_json: v.probe_json,
})
.collect();
@@ -2328,19 +2454,15 @@ pub async fn start_server(host: &str, port: u16) -> anyhow::Result<()> {
.with_state(state.clone());
let cors = CorsLayer::new()
.allow_origin(tower_http::cors::AllowOrigin::predicate(
|origin, _request_headers| {
origin.as_bytes().ends_with(b"localhost")
|| origin.as_bytes().ends_with(b"momentry.ddns.net")
|| origin.as_bytes().ends_with(b"127.0.0.1")
},
))
.allow_origin(tower_http::cors::AllowOrigin::any())
.allow_methods(Any)
.allow_headers(Any);
let app = Router::new()
.route("/health", get(health))
.route("/health/detailed", get(health_detailed))
.route("/api/v1/auth/login", post(login))
.route("/api/v1/auth/logout", post(logout))
.route("/api/v1/stats/ingest", get(get_ingest_stats))
.route("/api/v1/stats/sftpgo", get(get_sftpgo_status))
.route("/api/v1/stats/inference", get(get_inference_health))

814
src/api/universal_search.rs Normal file
View File

@@ -0,0 +1,814 @@
//! Universal Search API
//! Unified search across chunks, frames, and persons.
use axum::{
extract::{Query, State},
http::StatusCode,
response::Json,
routing::{get, post},
Router,
};
use serde::{Deserialize, Serialize};
use crate::core::db::{Database, PostgresDb};
#[derive(Debug, Deserialize)]
pub struct UniversalSearchRequest {
pub query: String,
pub uuid: Option<String>,
#[serde(default)]
pub types: Vec<String>, // chunk, frame, person
pub time_range: Option<[f64; 2]>,
pub filters: Option<SearchFilters>,
pub limit: Option<usize>,
pub offset: Option<usize>,
}
#[derive(Debug, Deserialize, Clone)]
pub struct SearchFilters {
pub person_id: Option<String>,
pub object_class: Option<Vec<String>>,
pub ocr_text: Option<String>,
pub has_face: Option<bool>,
pub speaker_id: Option<String>,
// Visual chunk filters
pub min_confidence: Option<f32>,
pub min_unique_classes: Option<u32>,
pub min_spatial_density: Option<f32>,
pub max_spatial_density: Option<f32>,
pub required_object_classes: Option<Vec<String>>,
}
#[derive(Debug, Serialize)]
pub struct UniversalSearchResponse {
pub query: String,
pub results: Vec<SearchResult>,
pub total: usize,
pub took_ms: u64,
}
#[derive(Debug, Serialize, Clone)]
#[serde(tag = "type")]
pub enum SearchResult {
#[serde(rename = "chunk")]
Chunk {
chunk_id: String,
chunk_type: String,
// Primary: frame-accurate position
start_frame: i64,
end_frame: i64,
// Reference: time derived from frames (subject to FPS variation)
start_time: f64,
end_time: f64,
score: f64,
text: Option<String>,
speaker_id: Option<String>,
metadata: Option<serde_json::Value>,
},
#[serde(rename = "frame")]
Frame {
// Primary: exact frame number
frame_number: i64,
// Reference: time derived from frame (subject to FPS variation)
timestamp: f64,
score: f64,
objects: Option<Vec<serde_json::Value>>,
ocr_texts: Option<Vec<String>>,
faces: Option<Vec<serde_json::Value>>,
pose_persons: Option<Vec<serde_json::Value>>,
},
#[serde(rename = "person")]
Person {
person_id: String,
name: Option<String>,
speaker_id: Option<String>,
appearance_count: i32,
score: f64,
first_appearance_time: Option<f64>,
last_appearance_time: Option<f64>,
},
}
pub fn universal_search_routes() -> Router<crate::api::server::AppState> {
Router::new()
.route("/api/v1/search/universal", post(universal_search))
.route("/api/v1/search/frames", post(search_frames))
.route("/api/v1/search/persons", get(search_persons))
}
/// Unified search across all data types
pub async fn universal_search(
State(_state): State<crate::api::server::AppState>,
Json(req): Json<UniversalSearchRequest>,
) -> Result<Json<UniversalSearchResponse>, (StatusCode, Json<serde_json::Value>)> {
let start_time = std::time::Instant::now();
let db = PostgresDb::init().await.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": format!("DB error: {}", e) })),
)
})?;
let limit = req.limit.unwrap_or(20);
let offset = req.offset.unwrap_or(0);
let types = if req.types.is_empty() {
vec![
"chunk".to_string(),
"frame".to_string(),
"person".to_string(),
]
} else {
req.types.clone()
};
let mut results = Vec::new();
// Search chunks
if types.contains(&"chunk".to_string()) {
let chunk_results = search_chunks(&db, &req).await.map_err(|e| {
(
StatusCode::BAD_REQUEST,
Json(serde_json::json!({ "error": e.to_string() })),
)
})?;
results.extend(chunk_results);
}
// Search frames
if types.contains(&"frame".to_string()) {
let frame_results = search_frames_internal(&db, &req).await.unwrap_or_default();
results.extend(frame_results);
}
// Search persons
if types.contains(&"person".to_string()) {
let person_results = search_persons_internal(&db, &req).await.unwrap_or_default();
results.extend(person_results);
}
// Sort by score descending
results.sort_by(|a, b| {
let score_a = match a {
SearchResult::Chunk { score, .. } => *score,
SearchResult::Frame { score, .. } => *score,
SearchResult::Person { score, .. } => *score,
};
let score_b = match b {
SearchResult::Chunk { score, .. } => *score,
SearchResult::Frame { score, .. } => *score,
SearchResult::Person { score, .. } => *score,
};
score_b
.partial_cmp(&score_a)
.unwrap_or(std::cmp::Ordering::Equal)
});
let total = results.len();
let end = std::cmp::min(offset + limit, results.len());
let paginated = if offset < results.len() {
results[offset..end].to_vec()
} else {
vec![]
};
let took = start_time.elapsed().as_millis() as u64;
Ok(Json(UniversalSearchResponse {
query: req.query,
results: paginated,
total,
took_ms: took,
}))
}
/// Search frames by YOLO objects, OCR text, or face IDs
pub async fn search_frames(
State(_state): State<crate::api::server::AppState>,
Json(req): Json<FrameSearchRequest>,
) -> Result<Json<FrameSearchResponse>, (StatusCode, Json<serde_json::Value>)> {
let db = PostgresDb::init().await.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": format!("DB error: {}", e) })),
)
})?;
let frames = search_frames_internal_v2(&db, &req).await.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": format!("Search error: {}", e) })),
)
})?;
let frames_count = frames.len();
Ok(Json(FrameSearchResponse {
frames,
total: frames_count,
}))
}
/// Search persons by name or speaker_id
pub async fn search_persons(
State(_state): State<crate::api::server::AppState>,
Query(query): Query<PersonSearchQuery>,
) -> Result<Json<PersonSearchResponse>, (StatusCode, Json<serde_json::Value>)> {
let db = PostgresDb::init().await.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": format!("DB error: {}", e) })),
)
})?;
let limit = query.limit.unwrap_or(20);
let persons = search_persons_by_query(
&db,
&query.query,
query.min_appearances,
query.max_age,
limit,
)
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": format!("Search error: {}", e) })),
)
})?;
let persons_count = persons.len();
Ok(Json(PersonSearchResponse {
persons,
total: persons_count,
}))
}
// --- Internal search functions ---
#[derive(Debug, Deserialize)]
pub struct FrameSearchRequest {
pub uuid: Option<String>,
pub object_class: Option<String>,
pub ocr_text: Option<String>,
pub face_id: Option<String>,
pub time_range: Option<[f64; 2]>,
pub limit: Option<usize>,
}
#[derive(Debug, Serialize)]
pub struct FrameSearchResponse {
pub frames: Vec<FrameResult>,
pub total: usize,
}
#[derive(Debug, Serialize)]
pub struct FrameResult {
pub frame_number: i64,
pub timestamp: f64,
pub uuid: String,
pub objects: Option<Vec<serde_json::Value>>,
pub ocr_texts: Option<Vec<String>>,
pub faces: Option<Vec<serde_json::Value>>,
pub pose_persons: Option<Vec<serde_json::Value>>,
}
#[derive(Debug, Deserialize)]
pub struct PersonSearchQuery {
pub video_uuid: String,
pub query: Option<String>,
pub min_appearances: Option<i32>,
pub max_age: Option<i32>, // New filter for "children"
pub limit: Option<usize>,
}
#[derive(Debug, Serialize)]
pub struct PersonSearchResponse {
pub persons: Vec<PersonResult>,
pub total: usize,
}
#[derive(Debug, Serialize)]
pub struct PersonResult {
pub person_id: String,
pub name: Option<String>,
pub character_name: Option<String>,
pub aliases: Option<Vec<String>>,
pub age: Option<i32>,
pub gender: Option<String>,
pub speaker_id: Option<String>,
pub appearance_count: i32,
pub first_appearance_time: Option<f64>,
pub last_appearance_time: Option<f64>,
}
async fn search_chunks(
db: &PostgresDb,
req: &UniversalSearchRequest,
) -> Result<Vec<SearchResult>, anyhow::Error> {
// uuid is required for chunk search - chunk_id is only unique within a video
let uuid = match &req.uuid {
Some(u) => u.replace('\'', "''"),
None => return Err(anyhow::anyhow!("uuid is required for chunk search")),
};
let mut sql = format!(
"SELECT chunk_id, chunk_type, start_time, end_time, start_frame, end_frame, text_content, content FROM chunks WHERE uuid = '{}'",
uuid
);
if let Some(tr) = &req.time_range {
sql.push_str(&format!(
" AND start_time >= {} AND end_time <= {}",
tr[0], tr[1]
));
}
if !req.query.is_empty() {
let q = req.query.replace('\'', "''");
sql.push_str(&format!(
" AND (text_content ILIKE '%{}%' OR content::text ILIKE '%{}%')",
q, q
));
}
if let Some(ref filters) = req.filters {
if let Some(ref speaker_id) = filters.speaker_id {
sql.push_str(&format!(
" AND content->>'speaker_id' = '{}'",
speaker_id.replace('\'', "''")
));
}
if let Some(ref person_id) = filters.person_id {
sql.push_str(&format!(
" AND content::text LIKE '%{}%'",
person_id.replace('\'', "''")
));
}
// Visual chunk filters
if let Some(min_confidence) = filters.min_confidence {
sql.push_str(&format!(
" AND (content->'metadata'->>'avg_confidence')::float >= {}",
min_confidence
));
}
if let Some(min_unique_classes) = filters.min_unique_classes {
sql.push_str(&format!(
" AND jsonb_array_length(content->'metadata'->'unique_classes') >= {}",
min_unique_classes
));
}
if let Some(min_density) = filters.min_spatial_density {
sql.push_str(&format!(
" AND (content->'metadata'->>'spatial_density')::float >= {}",
min_density
));
}
if let Some(max_density) = filters.max_spatial_density {
sql.push_str(&format!(
" AND (content->'metadata'->>'spatial_density')::float <= {}",
max_density
));
}
if let Some(ref required_classes) = filters.required_object_classes {
if !required_classes.is_empty() {
let class_conditions: Vec<String> = required_classes
.iter()
.map(|class| {
format!(
"content->'keyframe_objects' @> '[{{ \"class_name\": \"{}\"}}]'",
class.replace('\'', "''")
)
})
.collect();
sql.push_str(&format!(" AND ({})", class_conditions.join(" OR ")));
}
}
}
sql.push_str(" ORDER BY start_time ASC");
sql.push_str(&format!(" LIMIT {}", req.limit.unwrap_or(20)));
let rows: Vec<(
String,
String,
f64,
f64,
i64,
i64,
Option<String>,
Option<serde_json::Value>,
)> = sqlx::query_as(&sql).fetch_all(db.pool()).await?;
let results: Vec<SearchResult> = rows
.into_iter()
.map(
|(
chunk_id,
chunk_type,
start_time,
end_time,
start_frame,
end_frame,
text_content,
content,
)| {
let text = text_content.or_else(|| {
content
.as_ref()
.and_then(|c| c.get("text").and_then(|v| v.as_str()).map(String::from))
});
let speaker_id = content.as_ref().and_then(|c| {
c.get("speaker_id")
.and_then(|v| v.as_str())
.map(String::from)
});
// Simple scoring: if query matches, score 0.8
let score = if !req.query.is_empty()
&& text.as_ref().map_or(false, |t| {
t.to_lowercase().contains(&req.query.to_lowercase())
}) {
0.9
} else {
0.5
};
SearchResult::Chunk {
chunk_id,
chunk_type,
start_time,
end_time,
start_frame,
end_frame,
score,
text,
speaker_id,
metadata: content,
}
},
)
.collect();
Ok(results)
}
async fn search_frames_internal(
db: &PostgresDb,
req: &UniversalSearchRequest,
) -> Result<Vec<SearchResult>, anyhow::Error> {
let table = "frames";
let video_table = "videos";
let mut sql = format!(
"SELECT f.frame_number, f.timestamp, f.yolo_objects, f.ocr_results, f.face_results, f.pose_results, v.uuid
FROM {} f JOIN {} v ON f.file_id = v.id WHERE 1=1",
table, video_table
);
if let Some(uuid) = &req.uuid {
sql.push_str(&format!(" AND v.uuid = '{}'", uuid));
}
if let Some(tr) = &req.time_range {
sql.push_str(&format!(
" AND f.timestamp >= {} AND f.timestamp <= {}",
tr[0], tr[1]
));
}
if let Some(ref filters) = req.filters {
if let Some(ref classes) = filters.object_class {
for class in classes {
sql.push_str(&format!(" AND f.yolo_objects::text ILIKE '%{}%'", class));
}
}
if let Some(ref ocr) = filters.ocr_text {
sql.push_str(&format!(" AND f.ocr_results::text ILIKE '%{}%'", ocr));
}
if let Some(true) = filters.has_face {
sql.push_str(
" AND f.face_results IS NOT NULL AND jsonb_array_length(f.face_results) > 0",
);
}
if let Some(ref person_id) = filters.person_id {
sql.push_str(&format!(" AND f.face_results::text LIKE '%{}%'", person_id));
}
}
if !req.query.is_empty() {
// Search across all frame data
sql.push_str(&format!(
" AND (f.yolo_objects::text ILIKE '%{}%' OR f.ocr_results::text ILIKE '%{}%' OR f.face_results::text ILIKE '%{}%')",
req.query, req.query, req.query
));
}
sql.push_str(" ORDER BY f.timestamp ASC");
sql.push_str(&format!(" LIMIT {}", req.limit.unwrap_or(20)));
let rows: Vec<(
i64,
f64,
Option<serde_json::Value>,
Option<serde_json::Value>,
Option<serde_json::Value>,
Option<serde_json::Value>,
String,
)> = sqlx::query_as(&sql).fetch_all(db.pool()).await?;
let results: Vec<SearchResult> = rows
.into_iter()
.map(|(frame_number, timestamp, yolo, ocr, face, pose, _uuid)| {
let objects = yolo.as_ref().and_then(|v| {
v.get("objects")
.map(|o| o.as_array().cloned().unwrap_or_default())
});
let ocr_texts = ocr.as_ref().and_then(|v| {
v.get("texts").and_then(|t| {
t.as_array().map(|arr| {
arr.iter()
.filter_map(|item| {
item.get("text").and_then(|x| x.as_str()).map(String::from)
})
.collect()
})
})
});
let faces = face.as_ref().and_then(|v| {
v.get("faces")
.map(|f| f.as_array().cloned().unwrap_or_default())
});
let pose_persons = pose.as_ref().and_then(|v| {
v.get("persons")
.map(|p| p.as_array().cloned().unwrap_or_default())
});
SearchResult::Frame {
frame_number,
timestamp,
score: 0.7,
objects: objects.map(|arr| arr.iter().map(|v| v.clone()).collect()),
ocr_texts,
faces,
pose_persons,
}
})
.collect();
Ok(results)
}
async fn search_persons_internal(
db: &PostgresDb,
req: &UniversalSearchRequest,
) -> Result<Vec<SearchResult>, anyhow::Error> {
let table = "person_identities";
let mut sql = format!(
"SELECT person_id, name, speaker_id, appearance_count, first_appearance_time, last_appearance_time FROM {} WHERE 1=1",
table
);
if !req.query.is_empty() {
sql.push_str(&format!(
" AND (name ILIKE '%{}%' OR person_id ILIKE '%{}%' OR speaker_id ILIKE '%{}%')",
req.query, req.query, req.query
));
}
if let Some(ref filters) = req.filters {
if let Some(ref speaker_id) = filters.speaker_id {
sql.push_str(&format!(" AND speaker_id = '{}'", speaker_id));
}
if let Some(ref person_id) = filters.person_id {
sql.push_str(&format!(" AND person_id = '{}'", person_id));
}
}
sql.push_str(" ORDER BY appearance_count DESC");
sql.push_str(&format!(" LIMIT {}", req.limit.unwrap_or(20)));
let rows: Vec<(
String,
Option<String>,
Option<String>,
i32,
Option<f64>,
Option<f64>,
)> = sqlx::query_as(&sql).fetch_all(db.pool()).await?;
let results: Vec<SearchResult> = rows
.into_iter()
.map(
|(person_id, name, speaker_id, appearance_count, first_time, last_time)| {
let score = if !req.query.is_empty()
&& name.as_ref().map_or(false, |n| {
n.to_lowercase().contains(&req.query.to_lowercase())
}) {
0.95
} else {
0.5
};
SearchResult::Person {
person_id,
name,
speaker_id,
appearance_count,
score,
first_appearance_time: first_time,
last_appearance_time: last_time,
}
},
)
.collect();
Ok(results)
}
async fn search_frames_internal_v2(
db: &PostgresDb,
req: &FrameSearchRequest,
) -> Result<Vec<FrameResult>, anyhow::Error> {
let table = "frames";
let video_table = "videos";
let mut sql = format!(
"SELECT f.frame_number, f.timestamp, f.yolo_objects, f.ocr_results, f.face_results, f.pose_results, v.uuid
FROM {} f JOIN {} v ON f.file_id = v.id WHERE 1=1",
table, video_table
);
if let Some(uuid) = &req.uuid {
sql.push_str(&format!(" AND v.uuid = '{}'", uuid));
}
if let Some(tr) = &req.time_range {
sql.push_str(&format!(
" AND f.timestamp >= {} AND f.timestamp <= {}",
tr[0], tr[1]
));
}
if let Some(ref class) = req.object_class {
sql.push_str(&format!(" AND f.yolo_objects::text ILIKE '%{}%'", class));
}
if let Some(ref ocr) = req.ocr_text {
sql.push_str(&format!(" AND f.ocr_results::text ILIKE '%{}%'", ocr));
}
if let Some(ref face_id) = req.face_id {
sql.push_str(&format!(" AND f.face_results::text LIKE '%{}%'", face_id));
}
sql.push_str(" ORDER BY f.timestamp ASC");
sql.push_str(&format!(" LIMIT {}", req.limit.unwrap_or(50)));
let rows: Vec<(
i64,
f64,
Option<serde_json::Value>,
Option<serde_json::Value>,
Option<serde_json::Value>,
Option<serde_json::Value>,
String,
)> = sqlx::query_as(&sql).fetch_all(db.pool()).await?;
let results: Vec<FrameResult> = rows
.into_iter()
.map(|(frame_number, timestamp, yolo, ocr, face, pose, uuid)| {
let objects = yolo.as_ref().and_then(|v| {
v.get("objects")
.map(|o| o.as_array().cloned().unwrap_or_default())
});
let ocr_texts = ocr.as_ref().and_then(|v| {
v.get("texts").and_then(|t| {
t.as_array().map(|arr| {
arr.iter()
.filter_map(|item| {
item.get("text").and_then(|x| x.as_str()).map(String::from)
})
.collect()
})
})
});
let faces = face.as_ref().and_then(|v| {
v.get("faces")
.map(|f| f.as_array().cloned().unwrap_or_default())
});
let pose_persons = pose.as_ref().and_then(|v| {
v.get("persons")
.map(|p| p.as_array().cloned().unwrap_or_default())
});
FrameResult {
frame_number,
timestamp,
uuid,
objects: objects.map(|arr| arr.iter().map(|v| v.clone()).collect()),
ocr_texts,
faces,
pose_persons,
}
})
.collect();
Ok(results)
}
async fn search_persons_by_query(
db: &PostgresDb,
query: &Option<String>,
min_appearances: Option<i32>,
max_age: Option<i32>,
limit: usize,
) -> Result<Vec<PersonResult>, anyhow::Error> {
let table = "person_identities";
let mut sql = format!(
"SELECT person_id, name, character_name, aliases, age, gender, speaker_id, appearance_count, first_appearance_time, last_appearance_time FROM {} WHERE 1=1",
table
);
if let Some(ref q) = query {
// Search name, character_name, aliases (cast to text), person_id, speaker_id
sql.push_str(&format!(
" AND (name ILIKE '%{}%' OR character_name ILIKE '%{}%' OR aliases::text ILIKE '%{}%' OR person_id ILIKE '%{}%' OR speaker_id ILIKE '%{}%')",
q, q, q, q, q
));
}
if let Some(min) = min_appearances {
sql.push_str(&format!(" AND appearance_count >= {}", min));
}
if let Some(max_a) = max_age {
// Strictly filter for age <= max_age.
// Note: This excludes entries with NULL age.
sql.push_str(&format!(" AND age <= {}", max_a));
}
sql.push_str(" ORDER BY appearance_count DESC");
sql.push_str(&format!(" LIMIT {}", limit));
let rows: Vec<(
String,
Option<String>,
Option<String>,
Option<serde_json::Value>,
Option<i32>,
Option<String>,
Option<String>,
i32,
Option<f64>,
Option<f64>,
)> = sqlx::query_as(&sql).fetch_all(db.pool()).await?;
let results: Vec<PersonResult> = rows
.into_iter()
.map(
|(
person_id,
name,
character_name,
aliases_json,
age,
gender,
speaker_id,
appearance_count,
first_time,
last_time,
)| {
let aliases = aliases_json.and_then(|v| {
v.as_array().map(|arr| {
arr.iter()
.filter_map(|val| val.as_str().map(String::from))
.collect()
})
});
PersonResult {
person_id,
name,
character_name,
aliases,
age,
gender,
speaker_id,
appearance_count,
first_appearance_time: first_time,
last_appearance_time: last_time,
}
},
)
.collect();
Ok(results)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_search_filters_with_visual() {
let filters = SearchFilters {
person_id: None,
object_class: None,
ocr_text: None,
has_face: None,
speaker_id: None,
min_confidence: Some(0.8),
min_unique_classes: Some(3),
min_spatial_density: Some(0.5),
max_spatial_density: Some(0.9),
required_object_classes: Some(vec!["person".to_string()]),
};
assert_eq!(filters.min_confidence, Some(0.8));
}
}

View File

@@ -0,0 +1,504 @@
//! Visual chunk search functionality.
//!
//! This module provides search capabilities for visual chunks based on:
//! - Object classes (e.g., "person", "car", "envelope")
//! - Confidence thresholds
//! - Object counts
//! - Spatial density
//! - Object relationships
use crate::core::chunk::types::{Chunk, ChunkRule, ChunkType};
use crate::core::db::PostgresDb;
use anyhow::Result;
use serde_json::Value;
use std::collections::HashMap;
/// Criteria for searching visual chunks
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct VisualChunkSearchCriteria {
/// Minimum average confidence across frames
pub min_avg_confidence: Option<f32>,
/// Minimum number of frames with objects
pub min_frames_with_objects: Option<u32>,
/// Minimum number of unique object classes
pub min_unique_classes: Option<u32>,
/// Specific object classes to include (empty means all)
pub required_classes: Vec<String>,
/// Object class counts to filter by
pub class_counts: HashMap<String, (u32, u32)>,
/// Time range (optional)
pub time_range: Option<(f64, f64)>,
}
impl Default for VisualChunkSearchCriteria {
fn default() -> Self {
Self {
min_avg_confidence: None,
min_frames_with_objects: None,
min_unique_classes: None,
required_classes: Vec::new(),
class_counts: HashMap::new(),
time_range: None,
}
}
}
/// Search visual chunks based on criteria
pub async fn search_visual_chunks(
db: &PostgresDb,
uuid: &str,
criteria: &VisualChunkSearchCriteria,
) -> Result<Vec<Chunk>> {
// First, get all visual chunks for this video
let all_chunks = get_visual_chunks_by_uuid(db, uuid).await?;
// Apply filters
let filtered_chunks: Vec<Chunk> = all_chunks
.into_iter()
.filter(|chunk| {
// Check min avg confidence
if let Some(min_avg_confidence) = criteria.min_avg_confidence {
if let Some(content) = &chunk.content.as_object() {
if let Some(metadata) = content.get("metadata") {
if let Some(avg_confidence) = metadata.get("avg_confidence") {
if let Some(conf) = avg_confidence.as_f64() {
if conf < min_avg_confidence as f64 {
return false;
}
}
}
}
}
}
// Check min frames with objects
if let Some(min_frames) = criteria.min_frames_with_objects {
if let Some(stats) = &chunk.visual_stats {
if let Some(frames_with_objects) = stats.get("frames_with_objects") {
if let Some(count) = frames_with_objects.as_u64() {
if count < min_frames as u64 {
return false;
}
}
}
}
}
// Check min unique classes
if let Some(min_unique_classes) = criteria.min_unique_classes {
if let Some(content) = &chunk.content.as_object() {
if let Some(metadata) = content.get("metadata") {
if let Some(unique_classes) = metadata.get("unique_classes") {
if let Some(classes) = unique_classes.as_array() {
if (classes.len() as u32) < min_unique_classes {
return false;
}
}
}
}
}
}
// Check required classes
if !criteria.required_classes.is_empty() {
if let Some(content) = &chunk.content.as_object() {
if let Some(keyframe_objects) = content.get("keyframe_objects") {
if let Some(objects) = keyframe_objects.as_array() {
let mut found_all = true;
for required_class in &criteria.required_classes {
let mut found = false;
for obj in objects {
if let Some(class_name) = obj.get("class_name") {
if let Some(class_str) = class_name.as_str() {
if class_str == required_class {
found = true;
break;
}
}
}
}
if !found {
found_all = false;
break;
}
}
if !found_all {
return false;
}
}
}
}
}
// Check class counts
if !criteria.class_counts.is_empty() {
if let Some(content) = &chunk.content.as_object() {
if let Some(metadata) = content.get("metadata") {
if let Some(object_counts) = metadata.get("object_counts") {
for (class, (min, max)) in &criteria.class_counts {
if let Some(count_value) = object_counts.get(class) {
if let Some(count) = count_value.as_u64() {
if *min > 0 && count < *min as u64 {
return false;
}
if *max < u32::MAX && count > *max as u64 {
return false;
}
}
} else if *min > 0 {
return false;
}
}
} else if criteria.class_counts.values().any(|(min, _)| *min > 0) {
return false;
}
}
}
}
// Check time range
if let Some((start_time, end_time)) = criteria.time_range {
// Calculate chunk time from frames
let chunk_start_time = chunk.start_frame as f64 / chunk.fps;
let chunk_end_time = chunk.end_frame as f64 / chunk.fps;
if chunk_start_time < start_time || chunk_end_time > end_time {
return false;
}
}
true
})
.collect();
Ok(filtered_chunks)
}
/// Get all visual chunks for a video UUID
async fn get_visual_chunks_by_uuid(db: &PostgresDb, uuid: &str) -> Result<Vec<Chunk>> {
let sql = format!(
"SELECT file_id, uuid, chunk_id, chunk_index, chunk_type, fps, start_frame, end_frame, text_content, content, metadata, vector_id, visual_stats FROM chunks WHERE uuid = '{}' AND chunk_type = 'visual' ORDER BY start_frame ASC",
uuid.replace('\'', "''")
);
let rows: Vec<(
i32, // file_id
String, // uuid
String, // chunk_id
i32, // chunk_index
String, // chunk_type
f64, // fps
i64, // start_frame
i64, // end_frame
Option<String>, // text_content
Value, // content
Option<Value>, // metadata
Option<String>, // vector_id
Option<Value>, // visual_stats
)> = sqlx::query_as(&sql).fetch_all(db.pool()).await?;
let mut chunks = Vec::new();
for row in rows {
let chunk_type = match row.4.as_str() {
"visual" => ChunkType::Visual,
"sentence" => ChunkType::Sentence,
"time_based" => ChunkType::TimeBased,
"cut" => ChunkType::Cut,
"trace" => ChunkType::Trace,
"story" => ChunkType::Story,
_ => ChunkType::TimeBased,
};
// Calculate frame_count
let frame_count = (row.7 - row.6) as i32;
chunks.push(Chunk {
file_id: row.0,
uuid: row.1,
chunk_id: row.2,
chunk_index: row.3 as u32,
chunk_type,
rule: ChunkRule::Rule2, // Visual chunks use Rule2
fps: row.5,
start_frame: row.6,
end_frame: row.7,
text_content: row.8,
content: row.9,
metadata: row.10,
vector_id: row.11,
frame_count,
pre_chunk_ids: Vec::new(),
parent_chunk_id: None,
child_chunk_ids: Vec::new(),
visual_stats: row.12,
});
}
Ok(chunks)
}
/// Search visual chunks by object class
pub async fn search_visual_chunks_by_class(
db: &PostgresDb,
uuid: &str,
object_class: &str,
min_count: Option<u32>,
max_count: Option<u32>,
) -> Result<Vec<Chunk>> {
let all_chunks = get_visual_chunks_by_uuid(db, uuid).await?;
let filtered_chunks: Vec<Chunk> = all_chunks
.into_iter()
.filter(|chunk| {
// Check if chunk contains the object class
let mut contains_class = false;
if let Some(content) = &chunk.content.as_object() {
if let Some(keyframe_objects) = content.get("keyframe_objects") {
if let Some(objects) = keyframe_objects.as_array() {
for obj in objects {
if let Some(class_name) = obj.get("class_name") {
if let Some(class_str) = class_name.as_str() {
if class_str == object_class {
contains_class = true;
break;
}
}
}
}
}
}
}
if !contains_class {
return false;
}
// Check count in visual_stats
if let Some(stats) = &chunk.visual_stats {
if let Some(count) = stats.get(object_class) {
if let Some(c) = count.as_u64() {
if let Some(min) = min_count {
if c < min as u64 {
return false;
}
}
if let Some(max) = max_count {
if c > max as u64 {
return false;
}
}
}
}
}
true
})
.collect();
Ok(filtered_chunks)
}
/// Search visual chunks by spatial density
pub async fn search_visual_chunks_by_density(
db: &PostgresDb,
uuid: &str,
min_density: f32,
max_density: Option<f32>,
) -> Result<Vec<Chunk>> {
let all_chunks = get_visual_chunks_by_uuid(db, uuid).await?;
let filtered_chunks: Vec<Chunk> = all_chunks
.into_iter()
.filter(|chunk| {
if let Some(content) = &chunk.content.as_object() {
if let Some(metadata) = content.get("metadata") {
if let Some(density_value) = metadata.get("spatial_density") {
if let Some(density) = density_value.as_f64() {
if density < min_density as f64 {
return false;
}
if let Some(max_dens) = max_density {
if density > max_dens as f64 {
return false;
}
}
return true;
}
}
}
}
false
})
.collect();
Ok(filtered_chunks)
}
/// Find chunks containing specific object combinations
pub async fn search_visual_chunks_by_combination(
db: &PostgresDb,
uuid: &str,
combination: &[(&str, u32)],
) -> Result<Vec<Chunk>> {
let all_chunks = get_visual_chunks_by_uuid(db, uuid).await?;
let filtered_chunks: Vec<Chunk> = all_chunks
.into_iter()
.filter(|chunk| {
// Check if all required combinations are present
for (object_class, min_count) in combination {
let mut found = false;
if let Some(stats) = &chunk.visual_stats {
if let Some(object_counts) = stats.get("object_counts") {
if let Some(count_value) = object_counts.get(*object_class) {
if let Some(count) = count_value.as_u64() {
if count >= *min_count as u64 {
found = true;
}
}
}
}
}
if !found {
return false;
}
}
true
})
.collect();
Ok(filtered_chunks)
}
/// Get visual chunk statistics
pub async fn get_visual_chunk_statistics(
db: &PostgresDb,
uuid: &str,
) -> Result<HashMap<String, Value>> {
let sql = format!(
"SELECT
COUNT(*) as total_chunks,
AVG((content->'metadata'->>'avg_confidence')::float) as avg_confidence,
MIN((content->'metadata'->>'avg_confidence')::float) as min_confidence,
MAX((content->'metadata'->>'avg_confidence')::float) as max_confidence,
SUM((content->'metadata'->>'object_count')::int) as total_objects,
AVG((content->'metadata'->>'spatial_density')::float) as avg_density
FROM chunks
WHERE uuid = '{}'
AND chunk_type = 'visual'",
uuid.replace('\'', "''")
);
let row: (i64, Option<f64>, Option<f64>, Option<f64>, i64, Option<f64>) =
sqlx::query_as(&sql).fetch_one(db.pool()).await?;
let mut stats = HashMap::new();
stats.insert("total_chunks".to_string(), Value::from(row.0));
stats.insert(
"avg_confidence".to_string(),
Value::from(row.1.unwrap_or(0.0)),
);
stats.insert(
"min_confidence".to_string(),
Value::from(row.2.unwrap_or(0.0)),
);
stats.insert(
"max_confidence".to_string(),
Value::from(row.3.unwrap_or(0.0)),
);
stats.insert("total_objects".to_string(), Value::from(row.4));
stats.insert("avg_density".to_string(), Value::from(row.5.unwrap_or(0.0)));
Ok(stats)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_visual_chunk_search_criteria_default() {
let criteria = VisualChunkSearchCriteria::default();
assert_eq!(criteria.min_avg_confidence, None);
assert_eq!(criteria.min_frames_with_objects, None);
assert_eq!(criteria.min_unique_classes, None);
assert!(criteria.required_classes.is_empty());
assert!(criteria.class_counts.is_empty());
assert_eq!(criteria.time_range, None);
}
#[test]
fn test_visual_chunk_search_criteria_with_values() {
let mut criteria = VisualChunkSearchCriteria::default();
criteria.min_avg_confidence = Some(0.8);
criteria.min_frames_with_objects = Some(10);
criteria.min_unique_classes = Some(3);
criteria.required_classes = vec!["person".to_string(), "car".to_string()];
criteria.time_range = Some((0.0, 60.0));
assert_eq!(criteria.min_avg_confidence, Some(0.8));
assert_eq!(criteria.min_frames_with_objects, Some(10));
assert_eq!(criteria.min_unique_classes, Some(3));
assert_eq!(criteria.required_classes.len(), 2);
assert_eq!(criteria.time_range, Some((0.0, 60.0)));
}
#[test]
fn test_visual_chunk_search_criteria_serialization() {
let criteria = VisualChunkSearchCriteria {
min_avg_confidence: Some(0.85),
min_frames_with_objects: Some(5),
min_unique_classes: Some(2),
required_classes: vec!["person".to_string()],
class_counts: HashMap::new(),
time_range: Some((10.0, 30.0)),
};
let json = serde_json::to_string(&criteria).unwrap();
assert!(json.contains("min_avg_confidence"));
assert!(json.contains("required_classes"));
let deserialized: VisualChunkSearchCriteria = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.min_avg_confidence, Some(0.85));
assert_eq!(deserialized.required_classes.len(), 1);
}
#[test]
fn test_visual_chunk_search_criteria_with_class_counts() {
let mut criteria = VisualChunkSearchCriteria::default();
criteria.class_counts.insert("person".to_string(), (5, 20));
criteria.class_counts.insert("car".to_string(), (1, 10));
assert_eq!(criteria.class_counts.len(), 2);
assert_eq!(criteria.class_counts.get("person"), Some(&(5, 20)));
assert_eq!(criteria.class_counts.get("car"), Some(&(1, 10)));
}
#[test]
fn test_chunk_type_conversion() {
// Test chunk type string to enum conversion logic
let test_cases = vec![
("visual", ChunkType::Visual),
("sentence", ChunkType::Sentence),
("time_based", ChunkType::TimeBased),
("cut", ChunkType::Cut),
("trace", ChunkType::Trace),
("story", ChunkType::Story),
("unknown", ChunkType::TimeBased), // Default fallback
];
for (input, expected) in test_cases {
let chunk_type = match input {
"visual" => ChunkType::Visual,
"sentence" => ChunkType::Sentence,
"time_based" => ChunkType::TimeBased,
"cut" => ChunkType::Cut,
"trace" => ChunkType::Trace,
"story" => ChunkType::Story,
_ => ChunkType::TimeBased,
};
assert_eq!(chunk_type, expected);
}
}
}

147
src/api/who.rs Normal file
View File

@@ -0,0 +1,147 @@
//! Who API - 身份識別與 ID 映射接口 (Video-Scoped)
use axum::{
extract::State,
http::StatusCode,
response::Json,
routing::{get, post},
Router,
};
use serde::{Deserialize, Serialize};
use crate::core::db::Database;
// --- Request / Response Structures ---
#[derive(Debug, Deserialize)]
pub struct WhoQuery {
pub face_id: Option<String>,
pub speaker_id: Option<String>,
pub uuid: Option<String>,
pub chunk_id: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct WhoCandidatesRequest {
pub query: String,
pub video_uuid: Option<String>,
pub limit: Option<i32>,
}
#[derive(Debug, Deserialize)]
pub struct DefinePersonRequest {
pub uuid: String,
pub identity_id: Option<i32>,
pub name: String,
pub face_ids: Option<Vec<String>>,
pub speaker_ids: Option<Vec<String>>,
}
#[derive(Debug, Serialize)]
pub struct WhoIdentity {
pub identity_id: i32,
pub uuid: String,
pub name: String,
pub tags: Option<Vec<String>>,
pub face_ids: Vec<String>,
pub speaker_ids: Vec<String>,
}
// --- API Handlers ---
/// GET /api/v1/who
pub async fn get_who_identity(
State(state): State<crate::api::server::AppState>,
axum::extract::Query(query): axum::extract::Query<WhoQuery>,
) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
let db = &state.db;
// Priority 1: Query by Chunk (UUID + Chunk ID)
if let (Some(uuid), Some(chunk_id)) = (&query.uuid, &query.chunk_id) {
let info = db
.get_who_info_by_chunk(uuid, chunk_id)
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": e.to_string() })),
)
})?;
return Ok(Json(info));
}
// Priority 2: List all for a specific UUID
if let Some(uuid) = &query.uuid {
// TODO: Implement list_all_persons(uuid)
return Ok(Json(serde_json::json!({ "message": "List all pending" })));
}
Err((
StatusCode::BAD_REQUEST,
Json(serde_json::json!({ "error": "Missing uuid" })),
))
}
/// POST /api/v1/who/candidates
/// Search person_identities table for n8n workflow
pub async fn get_who_candidates(
State(state): State<crate::api::server::AppState>,
Json(req): Json<WhoCandidatesRequest>,
) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
let db = &state.db;
let limit = req.limit.unwrap_or(20);
let query_str = format!("%{}%", req.query);
let results = db
.search_person_candidates(&query_str, &req.video_uuid, limit)
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": e.to_string() })),
)
})?;
// Format for n8n
let response = serde_json::json!({
"query": req.query,
"items": results,
"total": results.len()
});
Ok(Json(response))
}
/// POST /api/v1/who
pub async fn define_person(
State(state): State<crate::api::server::AppState>,
Json(req): Json<DefinePersonRequest>,
) -> Result<Json<WhoIdentity>, (StatusCode, Json<serde_json::Value>)> {
let db = &state.db;
let identity = db
.create_or_update_person(
&req.uuid,
req.identity_id,
req.name.clone(),
req.face_ids.unwrap_or_default(),
req.speaker_ids.unwrap_or_default(),
)
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": e.to_string() })),
)
})?;
Ok(Json(identity))
}
// --- Router Setup ---
pub fn who_routes() -> Router<crate::api::server::AppState> {
Router::new()
.route("/api/v1/who", get(get_who_identity).post(define_person))
.route("/api/v1/who/candidates", post(get_who_candidates))
}