feat: backup architecture docs, source code, and scripts
This commit is contained in:
936
src/api/face_recognition.rs
Normal file
936
src/api/face_recognition.rs
Normal file
@@ -0,0 +1,936 @@
|
||||
use axum::{
|
||||
extract::{Multipart, Path, Query, State},
|
||||
http::StatusCode,
|
||||
response::Json,
|
||||
routing::{get, post},
|
||||
Router,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::core::db::{schema, Database, PostgresDb};
|
||||
use crate::core::processor::face_recognition::{
|
||||
process_face_recognition, register_face, FaceRecognitionResult, FaceRegistrationResult,
|
||||
};
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct FaceRecognitionRequest {
|
||||
pub video_uuid: String,
|
||||
pub enable_recognition: Option<bool>,
|
||||
pub enable_tracking: Option<bool>,
|
||||
pub enable_clustering: Option<bool>,
|
||||
pub database_path: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct FaceRecognitionResponse {
|
||||
pub success: bool,
|
||||
pub message: String,
|
||||
pub result: Option<FaceRecognitionResult>,
|
||||
pub processing_id: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct FaceRegistrationRequest {
|
||||
pub video_uuid: String,
|
||||
pub name: String,
|
||||
pub metadata: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct FaceRegistrationApiResponse {
|
||||
pub success: bool,
|
||||
pub message: String,
|
||||
pub result: Option<FaceRegistrationResult>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct FaceSearchRequest {
|
||||
pub video_uuid: String,
|
||||
pub embedding: Vec<f32>,
|
||||
pub similarity_threshold: Option<f64>,
|
||||
pub limit: Option<i32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct FaceSearchResponse {
|
||||
pub success: bool,
|
||||
pub message: String,
|
||||
pub results: Vec<FaceSearchResult>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct FaceSearchResult {
|
||||
pub face_id: String,
|
||||
pub name: Option<String>,
|
||||
pub similarity: f64,
|
||||
pub attributes: Option<serde_json::Value>,
|
||||
pub metadata: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct FaceListQuery {
|
||||
pub video_uuid: String,
|
||||
pub page: Option<usize>,
|
||||
pub page_size: Option<usize>,
|
||||
pub active_only: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct FaceListResponse {
|
||||
pub success: bool,
|
||||
pub message: String,
|
||||
pub faces: Vec<FaceListItem>,
|
||||
pub count: i64,
|
||||
pub page: usize,
|
||||
pub page_size: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct FaceListItem {
|
||||
pub face_id: String,
|
||||
pub name: Option<String>,
|
||||
pub created_at: String,
|
||||
pub updated_at: String,
|
||||
pub is_active: bool,
|
||||
pub metadata: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
pub fn face_recognition_routes() -> Router<crate::api::server::AppState> {
|
||||
Router::new()
|
||||
.route("/api/v1/face/recognize", post(recognize_faces))
|
||||
.route("/api/v1/face/register", post(register_face_api))
|
||||
.route("/api/v1/face/search", post(search_faces))
|
||||
.route("/api/v1/face/list", get(list_faces))
|
||||
.route("/api/v1/face/:face_id", get(get_face_details))
|
||||
.route("/api/v1/face/:face_id", axum::routing::delete(delete_face))
|
||||
.route(
|
||||
"/api/v1/face/results/:video_uuid",
|
||||
get(get_recognition_results),
|
||||
)
|
||||
}
|
||||
|
||||
async fn recognize_faces(
|
||||
State(_state): State<crate::api::server::AppState>,
|
||||
Json(request): Json<FaceRecognitionRequest>,
|
||||
) -> Result<Json<FaceRecognitionResponse>, (StatusCode, String)> {
|
||||
let processing_id = Uuid::new_v4().to_string();
|
||||
|
||||
tracing::info!(
|
||||
"[FACE_RECOGNITION] Starting recognition for video: {}, processing_id: {}",
|
||||
request.video_uuid,
|
||||
processing_id
|
||||
);
|
||||
|
||||
// Get video path from database
|
||||
let db = match PostgresDb::init().await {
|
||||
Ok(db) => db,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to connect to database: {}", e),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let video_record = match db.get_video_by_uuid(&request.video_uuid).await {
|
||||
Ok(Some(record)) => record,
|
||||
Ok(None) => {
|
||||
return Err((
|
||||
StatusCode::NOT_FOUND,
|
||||
format!("Video not found: {}", request.video_uuid),
|
||||
))
|
||||
}
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to fetch video: {}", e),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let video_path = video_record.file_path;
|
||||
let output_path = format!(
|
||||
"{}/face_recognition_{}.json",
|
||||
crate::core::config::OUTPUT_DIR.as_str(),
|
||||
processing_id
|
||||
);
|
||||
|
||||
// Process face recognition
|
||||
let result = match process_face_recognition(
|
||||
&video_path,
|
||||
&output_path,
|
||||
Some(&processing_id),
|
||||
request.enable_recognition.unwrap_or(true),
|
||||
request.enable_tracking.unwrap_or(true),
|
||||
request.enable_clustering.unwrap_or(true),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(result) => result,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Face recognition failed: {}", e),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
// Store results in database
|
||||
if let Err(e) = store_recognition_results(&db, &request.video_uuid, &result).await {
|
||||
tracing::warn!("Failed to store recognition results: {}", e);
|
||||
}
|
||||
|
||||
Ok(Json(FaceRecognitionResponse {
|
||||
success: true,
|
||||
message: format!("Face recognition completed for {}", request.video_uuid),
|
||||
result: Some(result),
|
||||
processing_id,
|
||||
}))
|
||||
}
|
||||
|
||||
async fn register_face_api(
|
||||
State(_state): State<crate::api::server::AppState>,
|
||||
mut multipart: Multipart,
|
||||
) -> Result<Json<FaceRegistrationApiResponse>, (StatusCode, String)> {
|
||||
let mut image_path: Option<String> = None;
|
||||
let mut name: Option<String> = None;
|
||||
let mut metadata: Option<serde_json::Value> = None;
|
||||
|
||||
// Parse multipart form data
|
||||
while let Some(field) = multipart.next_field().await.map_err(|e| {
|
||||
(
|
||||
StatusCode::BAD_REQUEST,
|
||||
format!("Failed to parse form data: {}", e),
|
||||
)
|
||||
})? {
|
||||
let field_name = field.name().unwrap_or("").to_string();
|
||||
|
||||
match field_name.as_str() {
|
||||
"image" => {
|
||||
// Save uploaded image
|
||||
let file_name = format!("face_registration_{}.jpg", Uuid::new_v4());
|
||||
let file_path = format!("/tmp/{}", file_name);
|
||||
|
||||
let data = field.bytes().await.map_err(|e| {
|
||||
(
|
||||
StatusCode::BAD_REQUEST,
|
||||
format!("Failed to read image data: {}", e),
|
||||
)
|
||||
})?;
|
||||
|
||||
tokio::fs::write(&file_path, &data).await.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to save image: {}", e),
|
||||
)
|
||||
})?;
|
||||
|
||||
image_path = Some(file_path);
|
||||
}
|
||||
"name" => {
|
||||
let value = field.text().await.map_err(|e| {
|
||||
(
|
||||
StatusCode::BAD_REQUEST,
|
||||
format!("Failed to read name: {}", e),
|
||||
)
|
||||
})?;
|
||||
name = Some(value);
|
||||
}
|
||||
"metadata" => {
|
||||
let value = field.text().await.map_err(|e| {
|
||||
(
|
||||
StatusCode::BAD_REQUEST,
|
||||
format!("Failed to read metadata: {}", e),
|
||||
)
|
||||
})?;
|
||||
metadata = Some(serde_json::from_str(&value).map_err(|e| {
|
||||
(
|
||||
StatusCode::BAD_REQUEST,
|
||||
format!("Invalid JSON metadata: {}", e),
|
||||
)
|
||||
})?);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
// Validate required fields
|
||||
let image_path =
|
||||
image_path.ok_or((StatusCode::BAD_REQUEST, "Image is required".to_string()))?;
|
||||
|
||||
let name = name.ok_or((StatusCode::BAD_REQUEST, "Name is required".to_string()))?;
|
||||
|
||||
// Register face
|
||||
let result = match register_face(&image_path, &name, metadata.clone()).await {
|
||||
Ok(result) => result,
|
||||
Err(e) => {
|
||||
// Clean up temporary file
|
||||
let _ = tokio::fs::remove_file(&image_path).await;
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Face registration failed: {}", e),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
// Clean up temporary file
|
||||
let _ = tokio::fs::remove_file(&image_path).await;
|
||||
|
||||
// Store in PostgreSQL face_identities table
|
||||
if result.success && !result.embedding.is_empty() {
|
||||
let db = match PostgresDb::init().await {
|
||||
Ok(db) => db,
|
||||
Err(e) => {
|
||||
tracing::warn!("[FACE_REGISTRATION] Failed to connect to DB: {}", e);
|
||||
// Return success even if DB write fails (embedding is still in JSON output)
|
||||
return Ok(Json(FaceRegistrationApiResponse {
|
||||
success: result.success,
|
||||
message: format!("{} (Warning: DB write failed: {})", result.message, e),
|
||||
result: Some(result),
|
||||
}));
|
||||
}
|
||||
};
|
||||
|
||||
// Convert embedding to PostgreSQL vector format
|
||||
let embedding_str = format!(
|
||||
"[{}]",
|
||||
result
|
||||
.embedding
|
||||
.iter()
|
||||
.map(|v| v.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(",")
|
||||
);
|
||||
|
||||
// Insert into face_identities
|
||||
let face_identities_table = schema::table_name("face_identities");
|
||||
let attrs_json =
|
||||
serde_json::to_string(&result.attributes).unwrap_or_else(|_| "{}".to_string());
|
||||
// Use public.vector type to work across schemas
|
||||
let vector_type = if schema::SCHEMA_PREFIX.as_str().is_empty() {
|
||||
"vector".to_string()
|
||||
} else {
|
||||
"public.vector".to_string()
|
||||
};
|
||||
let insert_query = format!(
|
||||
r#"
|
||||
INSERT INTO {} (face_id, name, embedding, attributes, metadata, is_active)
|
||||
VALUES ($1, $2, $3::{}, $4::jsonb, $5, TRUE)
|
||||
ON CONFLICT (face_id) DO UPDATE SET
|
||||
name = EXCLUDED.name,
|
||||
embedding = EXCLUDED.embedding,
|
||||
attributes = EXCLUDED.attributes,
|
||||
metadata = COALESCE(EXCLUDED.metadata, {}.metadata),
|
||||
updated_at = CURRENT_TIMESTAMP,
|
||||
is_active = TRUE
|
||||
"#,
|
||||
face_identities_table, vector_type, face_identities_table
|
||||
);
|
||||
|
||||
match sqlx::query(&insert_query)
|
||||
.bind(&result.face_id)
|
||||
.bind(&name)
|
||||
.bind(&embedding_str)
|
||||
.bind(&attrs_json)
|
||||
.bind(&metadata.unwrap_or(serde_json::json!({})))
|
||||
.execute(db.pool())
|
||||
.await
|
||||
{
|
||||
Ok(_) => {
|
||||
tracing::info!(
|
||||
"[FACE_REGISTRATION] Stored face '{}' (face_id={}) in DB",
|
||||
name,
|
||||
result.face_id
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("[FACE_REGISTRATION] Failed to store face in DB: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Json(FaceRegistrationApiResponse {
|
||||
success: result.success,
|
||||
message: result.message.clone(),
|
||||
result: Some(result),
|
||||
}))
|
||||
}
|
||||
|
||||
async fn search_faces(
|
||||
State(_state): State<crate::api::server::AppState>,
|
||||
Json(request): Json<FaceSearchRequest>,
|
||||
) -> Result<Json<FaceSearchResponse>, (StatusCode, String)> {
|
||||
let db = match PostgresDb::init().await {
|
||||
Ok(db) => db,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to connect to database: {}", e),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
// Convert embedding to PostgreSQL vector format
|
||||
let embedding_str = format!(
|
||||
"[{}]",
|
||||
request
|
||||
.embedding
|
||||
.iter()
|
||||
.map(|v| v.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(",")
|
||||
);
|
||||
|
||||
let similarity_threshold = request.similarity_threshold.unwrap_or(0.6);
|
||||
let limit = request.limit.unwrap_or(10);
|
||||
|
||||
// Search for similar faces
|
||||
let face_identities_table = schema::table_name("face_identities");
|
||||
let vector_type = if schema::SCHEMA_PREFIX.as_str().is_empty() {
|
||||
"vector".to_string()
|
||||
} else {
|
||||
"public.vector".to_string()
|
||||
};
|
||||
let query = format!(
|
||||
r#"
|
||||
SELECT
|
||||
face_id,
|
||||
name,
|
||||
1 - (embedding <=> $1::{}) as similarity,
|
||||
attributes,
|
||||
metadata
|
||||
FROM {}
|
||||
WHERE is_active = TRUE
|
||||
AND embedding IS NOT NULL
|
||||
AND 1 - (embedding <=> $1::{}) >= $2
|
||||
ORDER BY embedding <=> $1::{}
|
||||
LIMIT $3
|
||||
"#,
|
||||
vector_type, face_identities_table, vector_type, vector_type
|
||||
);
|
||||
|
||||
let results: Vec<FaceSearchResult> = match sqlx::query_as::<
|
||||
_,
|
||||
(
|
||||
String,
|
||||
Option<String>,
|
||||
f64,
|
||||
Option<serde_json::Value>,
|
||||
Option<serde_json::Value>,
|
||||
),
|
||||
>(query.as_str())
|
||||
.bind(&embedding_str)
|
||||
.bind(similarity_threshold)
|
||||
.bind(limit)
|
||||
.fetch_all(db.pool())
|
||||
.await
|
||||
{
|
||||
Ok(rows) => rows
|
||||
.into_iter()
|
||||
.map(
|
||||
|(face_id, name, similarity, attributes, metadata)| FaceSearchResult {
|
||||
face_id,
|
||||
name,
|
||||
similarity,
|
||||
attributes,
|
||||
metadata,
|
||||
},
|
||||
)
|
||||
.collect(),
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to search faces: {}", e),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Json(FaceSearchResponse {
|
||||
success: true,
|
||||
message: format!("Found {} similar faces", results.len()),
|
||||
results,
|
||||
}))
|
||||
}
|
||||
|
||||
async fn list_faces(
|
||||
State(_state): State<crate::api::server::AppState>,
|
||||
Query(query): Query<FaceListQuery>,
|
||||
) -> Result<Json<FaceListResponse>, (StatusCode, String)> {
|
||||
let db = match PostgresDb::init().await {
|
||||
Ok(db) => db,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to connect to database: {}", e),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let page = query.page.unwrap_or(1);
|
||||
let page_size = query.page_size.unwrap_or(20);
|
||||
let offset = ((page - 1) as i64) * (page_size as i64);
|
||||
let active_only = query.active_only.unwrap_or(true);
|
||||
|
||||
// Build query
|
||||
let mut where_clause = "WHERE 1=1".to_string();
|
||||
if active_only {
|
||||
where_clause.push_str(" AND is_active = TRUE");
|
||||
}
|
||||
|
||||
let face_identities_table = schema::table_name("face_identities");
|
||||
let count_query = format!(
|
||||
"SELECT COUNT(*) FROM {} {}",
|
||||
face_identities_table, where_clause
|
||||
);
|
||||
let list_query = format!(
|
||||
"SELECT face_id, name, created_at, updated_at, is_active, metadata FROM {} {} ORDER BY created_at DESC LIMIT $1 OFFSET $2",
|
||||
face_identities_table, where_clause
|
||||
);
|
||||
|
||||
// Get total count
|
||||
let total: i64 = match sqlx::query_scalar(&count_query).fetch_one(db.pool()).await {
|
||||
Ok(count) => count,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to count faces: {}", e),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
// Get face list
|
||||
let faces: Vec<FaceListItem> = match sqlx::query_as::<
|
||||
_,
|
||||
(
|
||||
String,
|
||||
Option<String>,
|
||||
chrono::DateTime<chrono::Utc>,
|
||||
chrono::DateTime<chrono::Utc>,
|
||||
bool,
|
||||
Option<serde_json::Value>,
|
||||
),
|
||||
>(&list_query)
|
||||
.bind(page_size as i32)
|
||||
.bind(offset)
|
||||
.fetch_all(db.pool())
|
||||
.await
|
||||
{
|
||||
Ok(rows) => rows
|
||||
.into_iter()
|
||||
.map(
|
||||
|(face_id, name, created_at, updated_at, is_active, metadata)| FaceListItem {
|
||||
face_id,
|
||||
name,
|
||||
created_at: created_at.to_rfc3339(),
|
||||
updated_at: updated_at.to_rfc3339(),
|
||||
is_active,
|
||||
metadata,
|
||||
},
|
||||
)
|
||||
.collect(),
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to list faces: {}", e),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Json(FaceListResponse {
|
||||
success: true,
|
||||
message: format!("Found {} faces", total),
|
||||
faces,
|
||||
count: total,
|
||||
page,
|
||||
page_size,
|
||||
}))
|
||||
}
|
||||
|
||||
async fn get_face_details(
|
||||
State(_state): State<crate::api::server::AppState>,
|
||||
Path(face_id): Path<String>,
|
||||
) -> Result<Json<serde_json::Value>, (StatusCode, String)> {
|
||||
let db = match PostgresDb::init().await {
|
||||
Ok(db) => db,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to connect to database: {}", e),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let face_identities_table = schema::table_name("face_identities");
|
||||
let query = format!(
|
||||
r#"
|
||||
SELECT
|
||||
face_id,
|
||||
name,
|
||||
embedding,
|
||||
attributes,
|
||||
metadata,
|
||||
created_at,
|
||||
updated_at,
|
||||
is_active
|
||||
FROM {}
|
||||
WHERE face_id = $1
|
||||
"#,
|
||||
face_identities_table
|
||||
);
|
||||
|
||||
let face: Option<(
|
||||
String,
|
||||
Option<String>,
|
||||
Option<String>,
|
||||
Option<serde_json::Value>,
|
||||
Option<serde_json::Value>,
|
||||
chrono::DateTime<chrono::Utc>,
|
||||
chrono::DateTime<chrono::Utc>,
|
||||
bool,
|
||||
)> = match sqlx::query_as(&query)
|
||||
.bind(&face_id)
|
||||
.fetch_optional(db.pool())
|
||||
.await
|
||||
{
|
||||
Ok(face) => face,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to fetch face details: {}", e),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
match face {
|
||||
Some((
|
||||
face_id,
|
||||
name,
|
||||
embedding,
|
||||
attributes,
|
||||
metadata,
|
||||
created_at,
|
||||
updated_at,
|
||||
is_active,
|
||||
)) => {
|
||||
let response = serde_json::json!({
|
||||
"success": true,
|
||||
"face_id": face_id,
|
||||
"name": name,
|
||||
"has_embedding": embedding.is_some(),
|
||||
"attributes": attributes,
|
||||
"metadata": metadata,
|
||||
"created_at": created_at.to_rfc3339(),
|
||||
"updated_at": updated_at.to_rfc3339(),
|
||||
"is_active": is_active
|
||||
});
|
||||
|
||||
Ok(Json(response))
|
||||
}
|
||||
None => Err((
|
||||
StatusCode::NOT_FOUND,
|
||||
format!("Face not found: {}", face_id),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn delete_face(
|
||||
State(_state): State<crate::api::server::AppState>,
|
||||
Path(face_id): Path<String>,
|
||||
) -> Result<Json<serde_json::Value>, (StatusCode, String)> {
|
||||
let db = match PostgresDb::init().await {
|
||||
Ok(db) => db,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to connect to database: {}", e),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
// Soft delete by marking as inactive
|
||||
let face_identities_table = schema::table_name("face_identities");
|
||||
let query = format!(
|
||||
r#"
|
||||
UPDATE {}
|
||||
SET is_active = FALSE, updated_at = CURRENT_TIMESTAMP
|
||||
WHERE face_id = $1 AND is_active = TRUE
|
||||
RETURNING face_id, name
|
||||
"#,
|
||||
face_identities_table
|
||||
);
|
||||
|
||||
let deleted: Option<(String, Option<String>)> = match sqlx::query_as(&query)
|
||||
.bind(&face_id)
|
||||
.fetch_optional(db.pool())
|
||||
.await
|
||||
{
|
||||
Ok(result) => result,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to delete face: {}", e),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
match deleted {
|
||||
Some((deleted_id, name)) => {
|
||||
let response = serde_json::json!({
|
||||
"success": true,
|
||||
"message": format!("Face '{}' deleted successfully", name.clone().unwrap_or_else(|| deleted_id.clone())),
|
||||
"face_id": deleted_id.clone()
|
||||
});
|
||||
|
||||
Ok(Json(response))
|
||||
}
|
||||
None => Err((
|
||||
StatusCode::NOT_FOUND,
|
||||
format!("Face not found or already deleted: {}", face_id),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_recognition_results(
|
||||
State(_state): State<crate::api::server::AppState>,
|
||||
Path(video_uuid): Path<String>,
|
||||
) -> Result<Json<serde_json::Value>, (StatusCode, String)> {
|
||||
let db = match PostgresDb::init().await {
|
||||
Ok(db) => db,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to connect to database: {}", e),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let query = r#"
|
||||
SELECT
|
||||
video_uuid,
|
||||
frame_count,
|
||||
fps,
|
||||
total_faces,
|
||||
recognized_faces,
|
||||
clusters_count,
|
||||
result_data,
|
||||
processing_time_secs,
|
||||
created_at
|
||||
FROM face_recognition_results
|
||||
WHERE video_uuid = $1
|
||||
ORDER BY created_at DESC
|
||||
LIMIT 1
|
||||
"#;
|
||||
|
||||
let result: Option<(
|
||||
String,
|
||||
i64,
|
||||
f64,
|
||||
i32,
|
||||
i32,
|
||||
i32,
|
||||
serde_json::Value,
|
||||
Option<f64>,
|
||||
chrono::DateTime<chrono::Utc>,
|
||||
)> = match sqlx::query_as(query)
|
||||
.bind(&video_uuid)
|
||||
.fetch_optional(db.pool())
|
||||
.await
|
||||
{
|
||||
Ok(result) => result,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to fetch recognition results: {}", e),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
match result {
|
||||
Some((
|
||||
video_uuid,
|
||||
frame_count,
|
||||
fps,
|
||||
total_faces,
|
||||
recognized_faces,
|
||||
clusters_count,
|
||||
result_data,
|
||||
processing_time_secs,
|
||||
created_at,
|
||||
)) => {
|
||||
let response = serde_json::json!({
|
||||
"success": true,
|
||||
"video_uuid": video_uuid,
|
||||
"frame_count": frame_count,
|
||||
"fps": fps,
|
||||
"total_faces": total_faces,
|
||||
"recognized_faces": recognized_faces,
|
||||
"clusters_count": clusters_count,
|
||||
"result_data": result_data,
|
||||
"processing_time_secs": processing_time_secs,
|
||||
"created_at": created_at.to_rfc3339()
|
||||
});
|
||||
|
||||
Ok(Json(response))
|
||||
}
|
||||
None => Err((
|
||||
StatusCode::NOT_FOUND,
|
||||
format!("No recognition results found for video: {}", video_uuid),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn store_recognition_results(
|
||||
db: &PostgresDb,
|
||||
video_uuid: &str,
|
||||
result: &FaceRecognitionResult,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
let total_faces = result.frames.iter().map(|f| f.faces.len()).sum::<usize>();
|
||||
let recognized_faces = result
|
||||
.frames
|
||||
.iter()
|
||||
.flat_map(|f| &f.faces)
|
||||
.filter(|face| face.identity.is_some())
|
||||
.count();
|
||||
|
||||
let query = r#"
|
||||
INSERT INTO face_recognition_results (
|
||||
video_uuid,
|
||||
frame_count,
|
||||
fps,
|
||||
total_faces,
|
||||
recognized_faces,
|
||||
clusters_count,
|
||||
result_data
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
ON CONFLICT (video_uuid) DO UPDATE SET
|
||||
frame_count = EXCLUDED.frame_count,
|
||||
fps = EXCLUDED.fps,
|
||||
total_faces = EXCLUDED.total_faces,
|
||||
recognized_faces = EXCLUDED.recognized_faces,
|
||||
clusters_count = EXCLUDED.clusters_count,
|
||||
result_data = EXCLUDED.result_data,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
"#;
|
||||
|
||||
sqlx::query(query)
|
||||
.bind(video_uuid)
|
||||
.bind(result.frame_count as i64)
|
||||
.bind(result.fps)
|
||||
.bind(total_faces as i32)
|
||||
.bind(recognized_faces as i32)
|
||||
.bind(result.face_clusters.len() as i32)
|
||||
.bind(serde_json::to_value(result)?)
|
||||
.execute(db.pool())
|
||||
.await?;
|
||||
|
||||
// Store individual face detections
|
||||
for frame in &result.frames {
|
||||
for face in &frame.faces {
|
||||
if let Some(embedding) = &face.embedding {
|
||||
let embedding_str = format!(
|
||||
"[{}]",
|
||||
embedding
|
||||
.iter()
|
||||
.map(|v| v.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(",")
|
||||
);
|
||||
|
||||
let insert_query = r#"
|
||||
INSERT INTO face_detections (
|
||||
video_uuid,
|
||||
frame_number,
|
||||
timestamp_secs,
|
||||
face_id,
|
||||
x,
|
||||
y,
|
||||
width,
|
||||
height,
|
||||
confidence,
|
||||
embedding,
|
||||
attributes,
|
||||
identity_confidence,
|
||||
cluster_id
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10::vector, $11, $12, $13)
|
||||
ON CONFLICT (video_uuid, frame_number, x, y, width, height) DO UPDATE SET
|
||||
face_id = EXCLUDED.face_id,
|
||||
confidence = EXCLUDED.confidence,
|
||||
embedding = EXCLUDED.embedding,
|
||||
attributes = EXCLUDED.attributes,
|
||||
identity_confidence = EXCLUDED.identity_confidence,
|
||||
cluster_id = EXCLUDED.cluster_id
|
||||
"#;
|
||||
|
||||
let identity_confidence = face.identity.as_ref().map(|id| id.confidence as f64);
|
||||
let cluster_id = result
|
||||
.face_clusters
|
||||
.iter()
|
||||
.find(|c| {
|
||||
c.face_ids
|
||||
.contains(&face.face_id.clone().unwrap_or_default())
|
||||
})
|
||||
.map(|c| c.cluster_id.clone());
|
||||
|
||||
sqlx::query(insert_query)
|
||||
.bind(video_uuid)
|
||||
.bind(frame.frame as i64)
|
||||
.bind(frame.timestamp)
|
||||
.bind(face.face_id.as_deref())
|
||||
.bind(face.x)
|
||||
.bind(face.y)
|
||||
.bind(face.width)
|
||||
.bind(face.height)
|
||||
.bind(face.confidence as f64)
|
||||
.bind(&embedding_str)
|
||||
.bind(serde_json::to_value(&face.attributes)?)
|
||||
.bind(identity_confidence)
|
||||
.bind(cluster_id)
|
||||
.execute(db.pool())
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Store face clusters
|
||||
for cluster in &result.face_clusters {
|
||||
let centroid = &cluster.centroid;
|
||||
let centroid_str = format!(
|
||||
"[{}]",
|
||||
centroid
|
||||
.iter()
|
||||
.map(|v: &f32| v.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(",")
|
||||
);
|
||||
|
||||
let cluster_query = r#"
|
||||
INSERT INTO face_clusters (
|
||||
cluster_id,
|
||||
video_uuid,
|
||||
centroid,
|
||||
size,
|
||||
representative_face_id,
|
||||
metadata
|
||||
) VALUES ($1, $2, $3::vector, $4, $5, $6)
|
||||
ON CONFLICT (cluster_id) DO UPDATE SET
|
||||
centroid = EXCLUDED.centroid,
|
||||
size = EXCLUDED.size,
|
||||
representative_face_id = EXCLUDED.representative_face_id,
|
||||
metadata = EXCLUDED.metadata
|
||||
"#;
|
||||
|
||||
sqlx::query(cluster_query)
|
||||
.bind(&cluster.cluster_id)
|
||||
.bind(video_uuid)
|
||||
.bind(¢roid_str)
|
||||
.bind(cluster.size as i32)
|
||||
.bind(cluster.representative_face_id.as_deref())
|
||||
.bind(&cluster.metadata)
|
||||
.execute(db.pool())
|
||||
.await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
288
src/api/identities.rs
Normal file
288
src/api/identities.rs
Normal file
@@ -0,0 +1,288 @@
|
||||
use axum::{
|
||||
extract::{Query, State},
|
||||
http::StatusCode,
|
||||
response::Json,
|
||||
routing::{get, post},
|
||||
Router,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::core::db::{schema, Database, PostgresDb};
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct RegisterFromPersonRequest {
|
||||
pub video_uuid: String,
|
||||
pub person_id: String,
|
||||
pub identity_name: String,
|
||||
pub metadata: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct RegisterFromPersonResponse {
|
||||
pub success: bool,
|
||||
pub message: String,
|
||||
pub identity_id: i32,
|
||||
pub identity_name: String,
|
||||
pub person_id: String,
|
||||
}
|
||||
|
||||
pub fn identity_routes() -> Router<crate::api::server::AppState> {
|
||||
Router::new()
|
||||
.route("/api/v1/identities/from-person", post(register_from_person))
|
||||
.route("/api/v1/identities", get(list_identities))
|
||||
}
|
||||
|
||||
/// Register a Global Identity from a specific Person in a video.
|
||||
/// This creates/updates the Identity record, links the Person to the Identity,
|
||||
/// and updates the Person's name to match the Identity.
|
||||
async fn register_from_person(
|
||||
State(_state): State<crate::api::server::AppState>,
|
||||
Json(req): Json<RegisterFromPersonRequest>,
|
||||
) -> Result<Json<RegisterFromPersonResponse>, (StatusCode, String)> {
|
||||
let db = match PostgresDb::init().await {
|
||||
Ok(db) => db,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("DB error: {}", e),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let mut tx = match db.pool().begin().await {
|
||||
Ok(tx) => tx,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Tx error: {}", e),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
// 1. Check if Person exists
|
||||
let person_query =
|
||||
"SELECT id, name FROM person_identities WHERE person_id = $1 AND video_uuid = $2";
|
||||
let person: Option<(i32, Option<String>)> = match sqlx::query_as(person_query)
|
||||
.bind(&req.person_id)
|
||||
.bind(&req.video_uuid)
|
||||
.fetch_optional(&mut *tx)
|
||||
.await
|
||||
{
|
||||
Ok(p) => p,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Query error: {}", e),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let (person_db_id, _old_name) = match person {
|
||||
Some(p) => p,
|
||||
None => {
|
||||
return Err((
|
||||
StatusCode::NOT_FOUND,
|
||||
format!(
|
||||
"Person '{}' not found in video '{}'",
|
||||
req.person_id, req.video_uuid
|
||||
),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
// 2. Check if Identity exists
|
||||
let identity_query = "SELECT id FROM identities WHERE name = $1";
|
||||
let identity_id: Option<i32> = match sqlx::query_scalar(identity_query)
|
||||
.bind(&req.identity_name)
|
||||
.fetch_optional(&mut *tx)
|
||||
.await
|
||||
{
|
||||
Ok(id) => id,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Query error: {}", e),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let final_identity_id = if let Some(id) = identity_id {
|
||||
id
|
||||
} else {
|
||||
// Create new Identity
|
||||
let meta_json = req.metadata.clone().unwrap_or(serde_json::json!({}));
|
||||
|
||||
let new_id: i32 = match sqlx::query_scalar(
|
||||
r#"
|
||||
INSERT INTO identities (name, embedding, metadata)
|
||||
VALUES ($1, NULLIF($2, '')::public.vector, $3)
|
||||
RETURNING id
|
||||
"#,
|
||||
)
|
||||
.bind(&req.identity_name)
|
||||
.bind("".to_string()) // No embedding for now via this API
|
||||
.bind(&meta_json)
|
||||
.fetch_one(&mut *tx)
|
||||
.await
|
||||
{
|
||||
Ok(id) => id,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Insert identity error: {}", e),
|
||||
))
|
||||
}
|
||||
};
|
||||
new_id
|
||||
};
|
||||
|
||||
// 3. Create Binding
|
||||
// Columns: id, identity_id, identity_type, identity_value, confidence, metadata, created_at
|
||||
let binding_query = r#"
|
||||
INSERT INTO identity_bindings (identity_id, identity_type, identity_value, confidence, metadata)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
ON CONFLICT DO NOTHING
|
||||
"#;
|
||||
|
||||
match sqlx::query(binding_query)
|
||||
.bind(final_identity_id)
|
||||
.bind("person_id") // identity_type
|
||||
.bind(&req.person_id) // identity_value
|
||||
.bind(1.0) // confidence
|
||||
.bind(&serde_json::json!({"auto_updated": true}))
|
||||
.execute(&mut *tx)
|
||||
.await
|
||||
{
|
||||
Ok(_) => {}
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Binding error: {}", e),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
// 4. Update Person Name
|
||||
let update_person = "UPDATE person_identities SET name = $1 WHERE id = $2";
|
||||
match sqlx::query(update_person)
|
||||
.bind(&req.identity_name)
|
||||
.bind(person_db_id)
|
||||
.execute(&mut *tx)
|
||||
.await
|
||||
{
|
||||
Ok(_) => {}
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Update person error: {}", e),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
match tx.commit().await {
|
||||
Ok(_) => {}
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Commit error: {}", e),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Json(RegisterFromPersonResponse {
|
||||
success: true,
|
||||
message: format!(
|
||||
"Successfully registered identity '{}' and linked to person '{}'",
|
||||
req.identity_name, req.person_id
|
||||
),
|
||||
identity_id: final_identity_id,
|
||||
identity_name: req.identity_name,
|
||||
person_id: req.person_id,
|
||||
}))
|
||||
}
|
||||
|
||||
/// List all global identities
|
||||
async fn list_identities(
|
||||
State(_state): State<crate::api::server::AppState>,
|
||||
Query(query): Query<ListIdentitiesQuery>,
|
||||
) -> Result<Json<IdentityListResponse>, (StatusCode, String)> {
|
||||
let db = match PostgresDb::init().await {
|
||||
Ok(db) => db,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("DB error: {}", e),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let page = query.page.unwrap_or(1);
|
||||
let page_size = query.page_size.unwrap_or(20);
|
||||
let offset = ((page - 1) as i64) * (page_size as i64);
|
||||
|
||||
// 獲取總數
|
||||
let count_sql = "SELECT COUNT(*) FROM identities";
|
||||
let total: i64 = match sqlx::query_scalar(count_sql).fetch_one(db.pool()).await {
|
||||
Ok(count) => count,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Count error: {}", e),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let sql = "SELECT id, name, metadata FROM identities ORDER BY id DESC LIMIT $1 OFFSET $2";
|
||||
|
||||
let rows: Vec<(i32, String, Option<serde_json::Value>)> = match sqlx::query_as(sql)
|
||||
.bind(page_size as i64)
|
||||
.bind(offset)
|
||||
.fetch_all(db.pool())
|
||||
.await
|
||||
{
|
||||
Ok(rows) => rows,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Query error: {}", e),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let identities: Vec<IdentityResponse> = rows
|
||||
.into_iter()
|
||||
.map(|r| IdentityResponse {
|
||||
id: r.0,
|
||||
name: r.1,
|
||||
metadata: r.2,
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(Json(IdentityListResponse {
|
||||
identities,
|
||||
count: total,
|
||||
page,
|
||||
page_size,
|
||||
}))
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ListIdentitiesQuery {
|
||||
pub page: Option<usize>,
|
||||
pub page_size: Option<usize>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct IdentityResponse {
|
||||
pub id: i32,
|
||||
pub name: String,
|
||||
pub metadata: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct IdentityListResponse {
|
||||
pub identities: Vec<IdentityResponse>,
|
||||
pub count: i64,
|
||||
pub page: usize,
|
||||
pub page_size: usize,
|
||||
}
|
||||
412
src/api/identity_binding.rs
Normal file
412
src/api/identity_binding.rs
Normal file
@@ -0,0 +1,412 @@
|
||||
use axum::{
|
||||
extract::{Path, Query},
|
||||
http::StatusCode,
|
||||
response::Json,
|
||||
routing::{get, post},
|
||||
Router,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::core::db::{Database, PostgresDb};
|
||||
use crate::core::person_identity::{BindIdentityRequest, Identity, UnbindIdentityRequest};
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct ApiResponse<T: Serialize> {
|
||||
pub success: bool,
|
||||
pub message: String,
|
||||
pub data: Option<T>,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// API Handlers
|
||||
// ============================================================================
|
||||
|
||||
async fn get_db() -> Result<PostgresDb, (StatusCode, Json<serde_json::Value>)> {
|
||||
PostgresDb::init().await.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": format!("DB init failed: {}", e) })),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
/// 獲取 Identity (人物) 列表
|
||||
pub async fn list_identities(
|
||||
Query(params): Query<ListIdentitiesParams>,
|
||||
) -> Result<Json<ApiResponse<Vec<Identity>>>, (StatusCode, Json<serde_json::Value>)> {
|
||||
let db = get_db().await?;
|
||||
let limit = params.limit.unwrap_or(100);
|
||||
let offset = params.offset.unwrap_or(0);
|
||||
let search = params.search.unwrap_or_default();
|
||||
|
||||
let identities = db
|
||||
.list_identities(&search, limit, offset)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": e.to_string() })),
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(Json(ApiResponse {
|
||||
success: true,
|
||||
message: format!("Found {} identities", identities.len()),
|
||||
data: Some(identities),
|
||||
}))
|
||||
}
|
||||
|
||||
/// 綁定身份 (Face/Speaker -> Identity)
|
||||
pub async fn bind_identity(
|
||||
Json(req): Json<BindIdentityRequest>,
|
||||
) -> Result<Json<ApiResponse<()>>, (StatusCode, Json<serde_json::Value>)> {
|
||||
let db = get_db().await?;
|
||||
|
||||
let identity = if let Some(id_id) = req.identity_id {
|
||||
db.get_identity_by_id(id_id).await.ok().flatten()
|
||||
} else if let Some(name) = &req.name {
|
||||
db.get_or_create_identity(name).await.ok()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let identity = match identity {
|
||||
Some(t) => t,
|
||||
None => {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({ "error": "Identity not found or name required" })),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let source = req.source.unwrap_or("manual".to_string());
|
||||
db.bind_identity(
|
||||
identity.id as i64,
|
||||
&req.binding_type,
|
||||
&req.binding_value,
|
||||
&source,
|
||||
1.0,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": e.to_string() })),
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(Json(ApiResponse {
|
||||
success: true,
|
||||
message: format!(
|
||||
"Bound {} '{}' to Identity '{}'",
|
||||
req.binding_type, req.binding_value, identity.name
|
||||
),
|
||||
data: None,
|
||||
}))
|
||||
}
|
||||
|
||||
/// 解綁身份
|
||||
pub async fn unbind_identity(
|
||||
Json(req): Json<UnbindIdentityRequest>,
|
||||
) -> Result<Json<ApiResponse<()>>, (StatusCode, Json<serde_json::Value>)> {
|
||||
let db = get_db().await?;
|
||||
db.unbind_identity(&req.binding_type, &req.binding_value)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": e.to_string() })),
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(Json(ApiResponse {
|
||||
success: true,
|
||||
message: format!("Unbound {} '{}'", req.binding_type, req.binding_value),
|
||||
data: None,
|
||||
}))
|
||||
}
|
||||
|
||||
/// 查詢機器 ID 對應的 Identity (人物)
|
||||
pub async fn get_identity_info(
|
||||
Path((binding_type, binding_value)): Path<(String, String)>,
|
||||
) -> Result<Json<ApiResponse<Identity>>, (StatusCode, Json<serde_json::Value>)> {
|
||||
let db = get_db().await?;
|
||||
let identity = db
|
||||
.get_identity_by_binding(&binding_type, &binding_value)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": e.to_string() })),
|
||||
)
|
||||
})?;
|
||||
|
||||
let identity = identity.ok_or_else(|| {
|
||||
(
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(serde_json::json!({ "error": "Identity not found" })),
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(Json(ApiResponse {
|
||||
success: true,
|
||||
message: "Identity info retrieved".to_string(),
|
||||
data: Some(identity),
|
||||
}))
|
||||
}
|
||||
|
||||
/// 列出未綁定的信號 (待標註列表)
|
||||
pub async fn list_unbound_signals(
|
||||
Query(params): Query<ListSignalsParams>,
|
||||
) -> Result<Json<ApiResponse<Vec<String>>>, (StatusCode, Json<serde_json::Value>)> {
|
||||
let db = get_db().await?;
|
||||
let signals = db
|
||||
.list_unbound_signals(¶ms.uuid, ¶ms.binding_type)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": e.to_string() })),
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(Json(ApiResponse {
|
||||
success: true,
|
||||
message: format!(
|
||||
"Found {} unbound {} signals",
|
||||
signals.len(),
|
||||
params.binding_type
|
||||
),
|
||||
data: Some(signals),
|
||||
}))
|
||||
}
|
||||
|
||||
/// 獲取特定信號 (Face ID 或 Speaker ID) 出現的所有 Chunk (時間軸)
|
||||
pub async fn get_signal_timeline(
|
||||
Path((uuid, binding_type, binding_value)): Path<(String, String, String)>,
|
||||
) -> Result<Json<ApiResponse<Vec<serde_json::Value>>>, (StatusCode, Json<serde_json::Value>)> {
|
||||
let db = get_db().await?;
|
||||
let chunks = db
|
||||
.get_chunks_by_signal(&uuid, &binding_type, &binding_value)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": e.to_string() })),
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(Json(ApiResponse {
|
||||
success: true,
|
||||
message: format!(
|
||||
"Found {} chunks for {} '{}'",
|
||||
chunks.len(),
|
||||
binding_type,
|
||||
binding_value
|
||||
),
|
||||
data: Some(chunks),
|
||||
}))
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct AVSuggestRequest {
|
||||
pub video_uuid: String,
|
||||
pub overlap_threshold: Option<f64>, // default 0.6
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct AVSuggestion {
|
||||
pub face_id: String,
|
||||
pub speaker_id: String,
|
||||
pub overlap_score: f64,
|
||||
pub face_talent_id: Option<i32>,
|
||||
pub speaker_talent_id: Option<i32>,
|
||||
}
|
||||
|
||||
/// Suggests Face-Speaker bindings based on temporal overlap
|
||||
pub async fn suggest_audio_visual_bindings(
|
||||
Json(req): Json<AVSuggestRequest>,
|
||||
) -> Result<Json<ApiResponse<Vec<AVSuggestion>>>, (StatusCode, Json<serde_json::Value>)> {
|
||||
let db = get_db().await?;
|
||||
let threshold = req.overlap_threshold.unwrap_or(0.6);
|
||||
|
||||
// 1. Get Face signals and their time ranges
|
||||
let face_signals = db
|
||||
.list_unbound_signals(&req.video_uuid, "face")
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": format!("Face signals: {}", e) })),
|
||||
)
|
||||
})?;
|
||||
|
||||
let speaker_signals = db
|
||||
.list_unbound_signals(&req.video_uuid, "speaker")
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": format!("Speaker signals: {}", e) })),
|
||||
)
|
||||
})?;
|
||||
|
||||
let mut suggestions = Vec::new();
|
||||
|
||||
for face_id in &face_signals {
|
||||
for speaker_id in &speaker_signals {
|
||||
// Calculate overlap
|
||||
// In a real implementation, we would query the exact timestamps from DB.
|
||||
// For now, we'll use a placeholder or simple heuristic if timestamps are available in signal timeline.
|
||||
// Let's assume we fetch timelines.
|
||||
|
||||
// Placeholder: Calculate overlap by fetching timelines
|
||||
let face_timeline = db
|
||||
.get_chunks_by_signal(&req.video_uuid, "face", face_id)
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
let speaker_timeline = db
|
||||
.get_chunks_by_signal(&req.video_uuid, "speaker", speaker_id)
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
|
||||
// Simplified overlap calculation based on chunk count/ids (assuming chunk_id contains time info or we have ranges)
|
||||
// Since chunk IDs are generic, we rely on content JSON having 'start_time'
|
||||
let overlap = calculate_overlap(&face_timeline, &speaker_timeline);
|
||||
|
||||
if overlap >= threshold {
|
||||
// Check if they are already bound to the same identity
|
||||
let face_identity = db
|
||||
.get_identity_by_binding("face", face_id)
|
||||
.await
|
||||
.ok()
|
||||
.flatten();
|
||||
let speaker_identity = db
|
||||
.get_identity_by_binding("speaker", speaker_id)
|
||||
.await
|
||||
.ok()
|
||||
.flatten();
|
||||
|
||||
// If both bound to different identities, don't suggest (conflict)
|
||||
if let (Some(fi), Some(si)) = (&face_identity, &speaker_identity) {
|
||||
if fi.id != si.id {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
suggestions.push(AVSuggestion {
|
||||
face_id: face_id.clone(),
|
||||
speaker_id: speaker_id.clone(),
|
||||
overlap_score: overlap,
|
||||
face_talent_id: face_identity.as_ref().map(|i| i.id as i32),
|
||||
speaker_talent_id: speaker_identity.as_ref().map(|i| i.id as i32),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
suggestions.sort_by(|a, b| b.overlap_score.partial_cmp(&a.overlap_score).unwrap());
|
||||
|
||||
Ok(Json(ApiResponse {
|
||||
success: true,
|
||||
message: format!("Found {} AV suggestions", suggestions.len()),
|
||||
data: Some(suggestions),
|
||||
}))
|
||||
}
|
||||
|
||||
fn calculate_overlap(
|
||||
face_chunks: &[serde_json::Value],
|
||||
speaker_chunks: &[serde_json::Value],
|
||||
) -> f64 {
|
||||
// Simplified: Extract start/end times and calculate intersection over union
|
||||
// Assuming chunks have start_frame or start_time in content JSON
|
||||
// If content is raw string, we might need to parse it.
|
||||
// In our schema, content is JSONB.
|
||||
|
||||
let mut face_ranges: Vec<(f64, f64)> = Vec::new();
|
||||
for c in face_chunks {
|
||||
if let Some(content) = c.get("content") {
|
||||
if let (Some(start), Some(end)) = (
|
||||
content.get("start_time").and_then(|v| v.as_f64()),
|
||||
content.get("end_time").and_then(|v| v.as_f64()),
|
||||
) {
|
||||
face_ranges.push((start, end));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut speaker_ranges: Vec<(f64, f64)> = Vec::new();
|
||||
for c in speaker_chunks {
|
||||
if let Some(content) = c.get("content") {
|
||||
if let (Some(start), Some(end)) = (
|
||||
content.get("start_time").and_then(|v| v.as_f64()),
|
||||
content.get("end_time").and_then(|v| v.as_f64()),
|
||||
) {
|
||||
speaker_ranges.push((start, end));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut overlap_duration = 0.0;
|
||||
for (fs, fe) in &face_ranges {
|
||||
for (ss, se) in &speaker_ranges {
|
||||
let start = fs.max(*ss);
|
||||
let end = fe.min(*se);
|
||||
if start < end {
|
||||
overlap_duration += end - start;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Return normalized overlap (0.0 to 1.0+), simple version: overlap / min_duration
|
||||
let min_duration = face_ranges
|
||||
.iter()
|
||||
.map(|(_, e)| e)
|
||||
.sum::<f64>()
|
||||
.min(speaker_ranges.iter().map(|(_, e)| e).sum::<f64>());
|
||||
|
||||
if min_duration > 0.0 {
|
||||
(overlap_duration / min_duration).min(1.0)
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Router Setup
|
||||
// ============================================================================
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ListIdentitiesParams {
|
||||
pub search: Option<String>,
|
||||
pub limit: Option<i32>,
|
||||
pub offset: Option<i32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ListSignalsParams {
|
||||
pub uuid: String,
|
||||
pub binding_type: String, // "face" or "speaker"
|
||||
}
|
||||
|
||||
pub fn identity_binding_routes() -> Router<crate::api::server::AppState> {
|
||||
Router::new()
|
||||
.route("/api/v1/identities/bind", post(bind_identity))
|
||||
.route("/api/v1/identities/unbind", post(unbind_identity))
|
||||
.route(
|
||||
"/api/v1/identity/:binding_type/:binding_value",
|
||||
get(get_identity_info),
|
||||
)
|
||||
// 信號發現 (Discovery)
|
||||
.route("/api/v1/signals/unbound", get(list_unbound_signals))
|
||||
// 信號時間軸 (Timeline)
|
||||
.route(
|
||||
"/api/v1/signals/:uuid/:binding_type/:binding_value/timeline",
|
||||
get(get_signal_timeline),
|
||||
)
|
||||
.route(
|
||||
"/api/v1/identities/suggest-av",
|
||||
post(suggest_audio_visual_bindings),
|
||||
)
|
||||
}
|
||||
264
src/api/n8n_search.rs
Normal file
264
src/api/n8n_search.rs
Normal file
@@ -0,0 +1,264 @@
|
||||
use crate::core::db::{Bm25Result, PostgresDb};
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashSet;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct SmartSearchRequest {
|
||||
pub query: String,
|
||||
pub uuid: Option<String>,
|
||||
pub limit: Option<usize>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct SmartSearchResponse {
|
||||
pub query: String,
|
||||
pub parsed_dimensions: serde_json::Value,
|
||||
pub hits: Vec<serde_json::Value>,
|
||||
pub total: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
struct LlmDimensionResponse {
|
||||
pub who: Option<String>,
|
||||
pub what: Option<String>,
|
||||
pub when: Option<String>,
|
||||
pub r#where: Option<String>,
|
||||
pub why: Option<String>,
|
||||
#[serde(default)]
|
||||
pub keywords: Vec<String>,
|
||||
}
|
||||
|
||||
/// POST /api/v1/n8n/search/smart
|
||||
pub async fn n8n_search_smart(
|
||||
db: &PostgresDb,
|
||||
req: SmartSearchRequest,
|
||||
) -> Result<SmartSearchResponse, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let limit = req.limit.unwrap_or(10);
|
||||
let video_uuid = req.uuid.clone();
|
||||
|
||||
// 1. Call LLM to extract 5W1H (Fallback to keywords if LLM fails)
|
||||
let dimensions = match parse_query_with_llm(&req.query).await {
|
||||
Some(dims) => dims,
|
||||
None => LlmDimensionResponse {
|
||||
who: None,
|
||||
what: None,
|
||||
when: None,
|
||||
r#where: None,
|
||||
why: None,
|
||||
keywords: extract_keywords(&req.query),
|
||||
},
|
||||
};
|
||||
|
||||
// Prepare search terms based on dimensions
|
||||
let keywords = dimensions.keywords.join(" ");
|
||||
|
||||
let semantic_query = [
|
||||
dimensions.who.clone(),
|
||||
dimensions.what.clone(),
|
||||
dimensions.r#where.clone(),
|
||||
dimensions.why.clone(),
|
||||
dimensions.when.clone(),
|
||||
]
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.collect::<Vec<_>>()
|
||||
.join(" ");
|
||||
|
||||
// 2. Multi-dimensional Search
|
||||
let mut hits: Vec<serde_json::Value> = Vec::new();
|
||||
let mut seen_chunk_ids: HashSet<String> = HashSet::new();
|
||||
|
||||
// Helper function
|
||||
fn add_hit(
|
||||
hits: &mut Vec<serde_json::Value>,
|
||||
seen_chunk_ids: &mut HashSet<String>,
|
||||
sr: Bm25Result,
|
||||
boost: f32,
|
||||
) {
|
||||
if seen_chunk_ids.insert(sr.chunk_id.clone()) {
|
||||
let score = sr.bm25_score * boost;
|
||||
let val = serde_json::json!({
|
||||
"id": sr.chunk_id,
|
||||
"vid": sr.uuid,
|
||||
"start": sr.start_time,
|
||||
"end": sr.end_time,
|
||||
"text": sr.text,
|
||||
"score": score,
|
||||
"chunk_type": sr.chunk_type
|
||||
});
|
||||
hits.push(val);
|
||||
}
|
||||
}
|
||||
|
||||
// A. Keyword Search (BM25)
|
||||
if !keywords.is_empty() {
|
||||
if let Ok(results) = db
|
||||
.search_bm25(&keywords, video_uuid.as_deref(), limit)
|
||||
.await
|
||||
{
|
||||
for sr in results {
|
||||
add_hit(&mut hits, &mut seen_chunk_ids, sr, 1.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// B. Who Search (Person Matching)
|
||||
if let Some(who_query) = &dimensions.who {
|
||||
// 1. Search Person
|
||||
if let Ok(persons) = db.search_person_candidates(who_query, &video_uuid, 5).await {
|
||||
if !persons.is_empty() {
|
||||
let person_id = persons[0]
|
||||
.get("candidate_id")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(String::from);
|
||||
|
||||
if let Some(pid) = person_id {
|
||||
// Heuristic: Search BM25 for the person's ID or Name
|
||||
let person_name = persons[0]
|
||||
.get("display_name")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or(who_query);
|
||||
|
||||
// Re-run BM25 with person name to find specific chunks and boost them
|
||||
if let Ok(results) = db
|
||||
.search_bm25(person_name, video_uuid.as_deref(), limit)
|
||||
.await
|
||||
{
|
||||
for sr in results {
|
||||
let id = sr.chunk_id.clone();
|
||||
if seen_chunk_ids.insert(id) {
|
||||
let score = sr.bm25_score * 1.5;
|
||||
let val = serde_json::json!({
|
||||
"id": sr.chunk_id,
|
||||
"vid": sr.uuid,
|
||||
"start": sr.start_time,
|
||||
"end": sr.end_time,
|
||||
"text": sr.text,
|
||||
"score": score,
|
||||
"matched_person": pid
|
||||
});
|
||||
hits.push(val);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by score
|
||||
hits.sort_by(|a, b| {
|
||||
let score_a = a.get("score").and_then(|v| v.as_f64()).unwrap_or(0.0);
|
||||
let score_b = b.get("score").and_then(|v| v.as_f64()).unwrap_or(0.0);
|
||||
score_b.partial_cmp(&score_a).unwrap()
|
||||
});
|
||||
|
||||
// Limit
|
||||
hits.truncate(limit);
|
||||
let total = hits.len();
|
||||
|
||||
Ok(SmartSearchResponse {
|
||||
query: req.query,
|
||||
parsed_dimensions: serde_json::json!(dimensions),
|
||||
hits,
|
||||
total,
|
||||
})
|
||||
}
|
||||
|
||||
fn extract_keywords(query: &str) -> Vec<String> {
|
||||
// Simple keyword extraction: remove common stop words and punctuation
|
||||
let stop_words = [
|
||||
"who", "what", "where", "when", "why", "how", "is", "the", "a", "an", "and", "or", "of",
|
||||
"in", "to", "for", "with", "by", "on", "at", "from", "up", "about", "into", "over",
|
||||
"after",
|
||||
];
|
||||
|
||||
query
|
||||
.to_lowercase()
|
||||
.chars()
|
||||
.map(|c| if c.is_alphanumeric() { c } else { ' ' })
|
||||
.collect::<String>()
|
||||
.split_whitespace()
|
||||
.filter(|w| !stop_words.contains(w))
|
||||
.map(String::from)
|
||||
.collect()
|
||||
}
|
||||
|
||||
async fn parse_query_with_llm(query: &str) -> Option<LlmDimensionResponse> {
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
// Test connectivity first
|
||||
if let Ok(resp) = client.get("http://127.0.0.1:8081/health").send().await {
|
||||
tracing::info!("LLM Health Check: {}", resp.status());
|
||||
} else {
|
||||
tracing::error!("LLM Server is unreachable at 127.0.0.1:8081");
|
||||
}
|
||||
|
||||
// We use the OpenAI-compatible endpoint provided by llama.cpp server (default port 8081)
|
||||
let prompt = format!(
|
||||
r#"Analyze the user query and extract the following dimensions into a JSON object.
|
||||
If a dimension is not present, use null.
|
||||
Dimensions: "who" (person/subject), "what" (action/event), "where" (location), "when" (time), "why" (reason/intent), "keywords" (array of specific keywords).
|
||||
|
||||
User Query: "{}"
|
||||
|
||||
Output ONLY the JSON object.
|
||||
"#,
|
||||
query
|
||||
);
|
||||
|
||||
let payload = serde_json::json!({
|
||||
"model": "gemma4",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt
|
||||
}
|
||||
],
|
||||
"temperature": 0.1,
|
||||
"stream": false
|
||||
});
|
||||
|
||||
if let Ok(response) = client
|
||||
.post("http://127.0.0.1:8081/v1/chat/completions")
|
||||
.json(&payload)
|
||||
.timeout(std::time::Duration::from_secs(60))
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
tracing::info!("LLM Response Status: {}", response.status());
|
||||
if let Ok(json_resp) = response.json::<serde_json::Value>().await {
|
||||
tracing::info!("LLM Response Body: {}", json_resp);
|
||||
if let Some(choices) = json_resp.get("choices").and_then(|v| v.as_array()) {
|
||||
if let Some(choice) = choices.get(0) {
|
||||
if let Some(content) = choice
|
||||
.get("message")
|
||||
.and_then(|m| m.get("content"))
|
||||
.and_then(|c| c.as_str())
|
||||
{
|
||||
if let Some(start) = content.find("{") {
|
||||
if let Some(end) = content.rfind("}") {
|
||||
let json_str = &content[start..=end];
|
||||
if let Ok(dims) =
|
||||
serde_json::from_str::<LlmDimensionResponse>(json_str)
|
||||
{
|
||||
tracing::info!("Parsed LLM Dimensions: {:?}", dims);
|
||||
return Some(dims);
|
||||
} else {
|
||||
tracing::warn!("Failed to parse LLM JSON: {}", json_str);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
tracing::warn!("Failed to parse LLM response JSON");
|
||||
}
|
||||
} else {
|
||||
tracing::warn!("LLM request failed or timed out");
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
3281
src/api/person_identity.rs
Normal file
3281
src/api/person_identity.rs
Normal file
File diff suppressed because it is too large
Load Diff
2772
src/api/person_identity.rs.bak
Normal file
2772
src/api/person_identity.rs.bak
Normal file
File diff suppressed because it is too large
Load Diff
2774
src/api/person_identity.rs.bak2
Normal file
2774
src/api/person_identity.rs.bak2
Normal file
File diff suppressed because it is too large
Load Diff
195
src/api/search.rs
Normal file
195
src/api/search.rs
Normal file
@@ -0,0 +1,195 @@
|
||||
//! Smart Search API
|
||||
//! Implements the 5W1H search capability using semantic vectors.
|
||||
|
||||
use axum::{extract::State, http::StatusCode, response::Json, routing::post, Router};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json;
|
||||
use tracing;
|
||||
|
||||
use crate::core::db::PostgresDb;
|
||||
|
||||
// --- Request / Response Structures ---
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct SmartSearchRequest {
|
||||
pub uuid: String,
|
||||
pub query: String,
|
||||
pub limit: Option<usize>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct SearchResult {
|
||||
pub id: i32,
|
||||
pub parent_id: i32,
|
||||
pub scene_order: Option<i32>,
|
||||
|
||||
// Primary: frame-accurate position (authoritative unit)
|
||||
pub start_frame: i64,
|
||||
pub end_frame: i64,
|
||||
pub fps: f64,
|
||||
|
||||
// Reference: time derived from frames (subject to FPS variation, not precise)
|
||||
pub start_time: f64,
|
||||
pub end_time: f64,
|
||||
|
||||
pub raw_text: Option<String>, // Text content of the child chunk
|
||||
pub summary: Option<String>, // Summary from parent context
|
||||
pub metadata: Option<serde_json::Value>,
|
||||
pub similarity: Option<f64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct SmartSearchResponse {
|
||||
pub query: String,
|
||||
pub results: Vec<SearchResult>,
|
||||
pub strategy: String,
|
||||
}
|
||||
|
||||
// --- API Handler ---
|
||||
|
||||
pub async fn smart_search(
|
||||
State(state): State<crate::api::server::AppState>,
|
||||
Json(req): Json<SmartSearchRequest>,
|
||||
) -> Result<Json<SmartSearchResponse>, (StatusCode, Json<serde_json::Value>)> {
|
||||
let db = &state.db;
|
||||
let limit = req.limit.unwrap_or(5);
|
||||
|
||||
// 1. Generate Embedding using Ollama
|
||||
let embedding = get_ollama_embedding(&req.query).await.map_err(
|
||||
|e| -> (StatusCode, Json<serde_json::Value>) {
|
||||
tracing::error!("Embedding failed: {}", e);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": e.to_string() })),
|
||||
)
|
||||
},
|
||||
)?;
|
||||
|
||||
// 2. Search Database (Drill-Down: Find Parents First)
|
||||
let db_parents: Vec<crate::core::db::postgres_db::SemanticSearchResult> = db
|
||||
.search_parent_chunks_semantic(&req.uuid, &embedding, limit)
|
||||
.await
|
||||
.map_err(
|
||||
|e: anyhow::Error| -> (StatusCode, Json<serde_json::Value>) {
|
||||
tracing::error!("DB search failed: {}", e);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": e.to_string() })),
|
||||
)
|
||||
},
|
||||
)?;
|
||||
|
||||
if db_parents.is_empty() {
|
||||
return Ok(Json(SmartSearchResponse {
|
||||
query: req.query,
|
||||
results: vec![],
|
||||
strategy: "semantic_vector_search".to_string(),
|
||||
}));
|
||||
}
|
||||
|
||||
// Collect Parent IDs
|
||||
let parent_ids: Vec<i32> = db_parents.iter().map(|p| p.id).collect();
|
||||
|
||||
// 3. Fetch Children for these Parents (Drill Down)
|
||||
// We fetch all children for these parents (limit can be adjusted)
|
||||
let children: Vec<crate::core::db::postgres_db::ChildChunkResult> = db
|
||||
.get_children_for_parents(&parent_ids, 10) // Fetch top 10 children per parent
|
||||
.await
|
||||
.map_err(
|
||||
|e: anyhow::Error| -> (StatusCode, Json<serde_json::Value>) {
|
||||
tracing::error!("Fetching children failed: {}", e);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": e.to_string() })),
|
||||
)
|
||||
},
|
||||
)?;
|
||||
|
||||
// 4. Map Parents to a lookup table
|
||||
let parent_map: std::collections::HashMap<
|
||||
i32,
|
||||
&crate::core::db::postgres_db::SemanticSearchResult,
|
||||
> = db_parents.iter().map(|p| (p.id, p)).collect();
|
||||
|
||||
// Map Children to API response struct
|
||||
let results: Vec<SearchResult> = children
|
||||
.into_iter()
|
||||
.map(|c| {
|
||||
let parent = parent_map.get(&c.parent_id);
|
||||
SearchResult {
|
||||
id: c.id,
|
||||
parent_id: c.parent_id,
|
||||
scene_order: parent.map(|p| p.scene_order),
|
||||
|
||||
start_frame: c.start_frame,
|
||||
end_frame: c.end_frame,
|
||||
fps: c.fps,
|
||||
|
||||
start_time: c.start_time,
|
||||
end_time: c.end_time,
|
||||
raw_text: Some(c.raw_text),
|
||||
summary: parent.map(|p| p.summary.clone()),
|
||||
metadata: parent.map(|p| p.metadata.clone()),
|
||||
similarity: parent.and_then(|p| p.similarity),
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
// 6. Sort results by similarity (descending)
|
||||
// Since all children of a parent have the same parent similarity, this groups relevant chunks together
|
||||
let mut results = results;
|
||||
results.sort_by(|a, b| {
|
||||
b.similarity
|
||||
.partial_cmp(&a.similarity)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
|
||||
// 7. Limit the final results (optional, but good for API consistency)
|
||||
let limit = req.limit.unwrap_or(5) * 5; // Allow more children per parent context
|
||||
results.truncate(limit);
|
||||
|
||||
// 8. Format Response
|
||||
let response = SmartSearchResponse {
|
||||
query: req.query,
|
||||
results,
|
||||
strategy: "drill_down_semantic_search".to_string(),
|
||||
};
|
||||
|
||||
Ok(Json(response))
|
||||
}
|
||||
|
||||
// --- Helper: Ollama Embedding ---
|
||||
|
||||
async fn get_ollama_embedding(
|
||||
text: &str,
|
||||
) -> Result<Vec<f32>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let client = reqwest::Client::new();
|
||||
let payload = serde_json::json!({
|
||||
"model": "nomic-embed-text",
|
||||
"prompt": text
|
||||
});
|
||||
|
||||
let res = client
|
||||
.post("http://localhost:11434/api/embeddings")
|
||||
.json(&payload)
|
||||
.send()
|
||||
.await?
|
||||
.json::<serde_json::Value>()
|
||||
.await?;
|
||||
|
||||
// Parse embedding array from response
|
||||
let embedding = res["embedding"]
|
||||
.as_array()
|
||||
.ok_or("No embedding found in Ollama response")?
|
||||
.iter()
|
||||
.map(|v| v.as_f64().unwrap_or(0.0) as f32)
|
||||
.collect();
|
||||
|
||||
Ok(embedding)
|
||||
}
|
||||
|
||||
// --- Router Setup ---
|
||||
|
||||
pub fn search_routes() -> Router<crate::api::server::AppState> {
|
||||
Router::new().route("/smart", post(smart_search))
|
||||
}
|
||||
195
src/api/search.rs.bak
Normal file
195
src/api/search.rs.bak
Normal file
@@ -0,0 +1,195 @@
|
||||
//! Smart Search API
|
||||
//! Implements the 5W1H search capability using semantic vectors.
|
||||
|
||||
use axum::{extract::State, http::StatusCode, response::Json, routing::post, Router};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json;
|
||||
use tracing;
|
||||
|
||||
use crate::core::db::PostgresDb;
|
||||
|
||||
// --- Request / Response Structures ---
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct SmartSearchRequest {
|
||||
pub uuid: String,
|
||||
pub query: String,
|
||||
pub limit: Option<usize>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct SearchResult {
|
||||
pub id: i32,
|
||||
pub parent_id: i32,
|
||||
pub scene_order: Option<i32>,
|
||||
|
||||
// Primary: frame-accurate position (authoritative unit)
|
||||
pub start_frame: i64,
|
||||
pub end_frame: i64,
|
||||
pub fps: f64,
|
||||
|
||||
// Reference: time derived from frames (subject to FPS variation, not precise)
|
||||
pub start_time: f64,
|
||||
pub end_time: f64,
|
||||
|
||||
pub raw_text: Option<String>, // Text content of the child chunk
|
||||
pub summary: Option<String>, // Summary from parent context
|
||||
pub metadata: Option<serde_json::Value>,
|
||||
pub similarity: Option<f64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct SmartSearchResponse {
|
||||
pub query: String,
|
||||
pub results: Vec<SearchResult>,
|
||||
pub strategy: String,
|
||||
}
|
||||
|
||||
// --- API Handler ---
|
||||
|
||||
pub async fn smart_search(
|
||||
State(state): State<crate::api::server::AppState>,
|
||||
Json(req): Json<SmartSearchRequest>,
|
||||
) -> Result<Json<SmartSearchResponse>, (StatusCode, Json<serde_json::Value>)> {
|
||||
let db = &state.db;
|
||||
let limit = req.limit.unwrap_or(5);
|
||||
|
||||
// 1. Generate Embedding using Ollama
|
||||
let embedding = get_ollama_embedding(&req.query).await.map_err(
|
||||
|e| -> (StatusCode, Json<serde_json::Value>) {
|
||||
tracing::error!("Embedding failed: {}", e);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": e.to_string() })),
|
||||
)
|
||||
},
|
||||
)?;
|
||||
|
||||
// 2. Search Database (Drill-Down: Find Parents First)
|
||||
let db_parents: Vec<crate::core::db::postgres_db::SemanticSearchResult> = db
|
||||
.search_parent_chunks_semantic(&req.uuid, &embedding, limit)
|
||||
.await
|
||||
.map_err(
|
||||
|e: anyhow::Error| -> (StatusCode, Json<serde_json::Value>) {
|
||||
tracing::error!("DB search failed: {}", e);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": e.to_string() })),
|
||||
)
|
||||
},
|
||||
)?;
|
||||
|
||||
if db_parents.is_empty() {
|
||||
return Ok(Json(SmartSearchResponse {
|
||||
query: req.query,
|
||||
results: vec![],
|
||||
strategy: "semantic_vector_search".to_string(),
|
||||
}));
|
||||
}
|
||||
|
||||
// Collect Parent IDs
|
||||
let parent_ids: Vec<i32> = db_parents.iter().map(|p| p.id).collect();
|
||||
|
||||
// 3. Fetch Children for these Parents (Drill Down)
|
||||
// We fetch all children for these parents (limit can be adjusted)
|
||||
let children: Vec<crate::core::db::postgres_db::ChildChunkResult> = db
|
||||
.get_children_for_parents(&parent_ids, 10) // Fetch top 10 children per parent
|
||||
.await
|
||||
.map_err(
|
||||
|e: anyhow::Error| -> (StatusCode, Json<serde_json::Value>) {
|
||||
tracing::error!("Fetching children failed: {}", e);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": e.to_string() })),
|
||||
)
|
||||
},
|
||||
)?;
|
||||
|
||||
// 4. Map Parents to a lookup table
|
||||
let parent_map: std::collections::HashMap<
|
||||
i32,
|
||||
&crate::core::db::postgres_db::SemanticSearchResult,
|
||||
> = db_parents.iter().map(|p| (p.id, p)).collect();
|
||||
|
||||
// Map Children to API response struct
|
||||
let results: Vec<SearchResult> = children
|
||||
.into_iter()
|
||||
.map(|c| {
|
||||
let parent = parent_map.get(&c.parent_id);
|
||||
SearchResult {
|
||||
id: c.id,
|
||||
parent_id: c.parent_id,
|
||||
scene_order: parent.map(|p| p.scene_order),
|
||||
|
||||
start_frame: c.start_frame,
|
||||
end_frame: c.end_frame,
|
||||
fps: c.fps,
|
||||
|
||||
start_time: c.start_time,
|
||||
end_time: c.end_time,
|
||||
raw_text: Some(c.raw_text),
|
||||
summary: parent.map(|p| p.summary.clone()),
|
||||
metadata: parent.map(|p| p.metadata.clone()),
|
||||
similarity: parent.and_then(|p| p.similarity),
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
// 6. Sort results by similarity (descending)
|
||||
// Since all children of a parent have the same parent similarity, this groups relevant chunks together
|
||||
let mut results = results;
|
||||
results.sort_by(|a, b| {
|
||||
b.similarity
|
||||
.partial_cmp(&a.similarity)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
|
||||
// 7. Limit the final results (optional, but good for API consistency)
|
||||
let limit = req.limit.unwrap_or(5) * 5; // Allow more children per parent context
|
||||
results.truncate(limit);
|
||||
|
||||
// 8. Format Response
|
||||
let response = SmartSearchResponse {
|
||||
query: req.query,
|
||||
results,
|
||||
strategy: "drill_down_semantic_search".to_string(),
|
||||
};
|
||||
|
||||
Ok(Json(response))
|
||||
}
|
||||
|
||||
// --- Helper: Ollama Embedding ---
|
||||
|
||||
async fn get_ollama_embedding(
|
||||
text: &str,
|
||||
) -> Result<Vec<f32>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let client = reqwest::Client::new();
|
||||
let payload = serde_json::json!({
|
||||
"model": "nomic-embed-text",
|
||||
"prompt": text
|
||||
});
|
||||
|
||||
let res = client
|
||||
.post("http://localhost:11434/api/embeddings")
|
||||
.json(&payload)
|
||||
.send()
|
||||
.await?
|
||||
.json::<serde_json::Value>()
|
||||
.await?;
|
||||
|
||||
// Parse embedding array from response
|
||||
let embedding = res["embedding"]
|
||||
.as_array()
|
||||
.ok_or("No embedding found in Ollama response")?
|
||||
.iter()
|
||||
.map(|v| v.as_f64().unwrap_or(0.0) as f32)
|
||||
.collect();
|
||||
|
||||
Ok(embedding)
|
||||
}
|
||||
|
||||
// --- Router Setup ---
|
||||
|
||||
pub fn search_routes() -> Router<crate::api::server::AppState> {
|
||||
Router::new().route("/smart", post(smart_search))
|
||||
}
|
||||
@@ -30,6 +30,33 @@ use super::universal_search;
|
||||
use super::visual_chunk_search;
|
||||
use crate::core::chunk::types::Chunk;
|
||||
|
||||
static DEMO_USER_API_KEY: &str = "muser_68600856036340bcafc01930eb4bd839_1774418104_97221b69";
|
||||
|
||||
fn hash_password(password: &str) -> String {
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(password.as_bytes());
|
||||
format!("{:x}", hasher.finalize())
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct LoginRequest {
|
||||
username: String,
|
||||
password: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct LoginResponse {
|
||||
success: bool,
|
||||
message: Option<String>,
|
||||
api_key: Option<String>,
|
||||
user: Option<UserInfo>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct UserInfo {
|
||||
username: String,
|
||||
}
|
||||
|
||||
// Global State
|
||||
static SERVER_START: OnceCell<Instant> = OnceCell::new();
|
||||
|
||||
@@ -334,6 +361,12 @@ struct VideoInfoResponse {
|
||||
duration: f64,
|
||||
width: u32,
|
||||
height: u32,
|
||||
status: String,
|
||||
processing_status: Option<String>,
|
||||
created_at: Option<String>,
|
||||
registration_time: Option<String>,
|
||||
file_size: Option<i64>,
|
||||
probe_json: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
@@ -348,6 +381,9 @@ struct VideosResponse {
|
||||
struct VideosQuery {
|
||||
page: Option<usize>,
|
||||
page_size: Option<usize>,
|
||||
q: Option<String>,
|
||||
status: Option<String>,
|
||||
uuid: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -408,7 +444,10 @@ async fn health_detailed(State(state): State<AppState>) -> Json<DetailedHealthRe
|
||||
let qdrant = check_qdrant().await;
|
||||
let mongodb = check_mongodb(&state.mongo_cache).await;
|
||||
|
||||
let overall_status = if postgres.status == "ok" && redis.status == "ok" && qdrant.status == "ok"
|
||||
let overall_status = if postgres.status == "ok"
|
||||
&& redis.status == "ok"
|
||||
&& qdrant.status == "ok"
|
||||
&& mongodb.status == "ok"
|
||||
{
|
||||
"ok"
|
||||
} else {
|
||||
@@ -428,6 +467,30 @@ async fn health_detailed(State(state): State<AppState>) -> Json<DetailedHealthRe
|
||||
})
|
||||
}
|
||||
|
||||
async fn login(Json(req): Json<LoginRequest>) -> Json<LoginResponse> {
|
||||
if req.username == "demo" && req.password == "demo" {
|
||||
Json(LoginResponse {
|
||||
success: true,
|
||||
message: Some("Login successful".to_string()),
|
||||
api_key: Some(DEMO_USER_API_KEY.to_string()),
|
||||
user: Some(UserInfo {
|
||||
username: "demo".to_string(),
|
||||
}),
|
||||
})
|
||||
} else {
|
||||
Json(LoginResponse {
|
||||
success: false,
|
||||
message: Some("Invalid username or password".to_string()),
|
||||
api_key: None,
|
||||
user: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
async fn logout() -> Json<serde_json::Value> {
|
||||
Json(serde_json::json!({ "success": true }))
|
||||
}
|
||||
|
||||
async fn check_postgres() -> ServiceStatus {
|
||||
let start = Instant::now();
|
||||
match PostgresDb::init().await {
|
||||
@@ -709,6 +772,7 @@ async fn register(
|
||||
user_id: None,
|
||||
job_id: None,
|
||||
created_at: String::new(),
|
||||
registration_time: None,
|
||||
};
|
||||
|
||||
let video_id = db
|
||||
@@ -1874,9 +1938,51 @@ async fn list_videos(
|
||||
let page = params.page.unwrap_or(1);
|
||||
let page_size = params.page_size.unwrap_or(20);
|
||||
let offset = ((page - 1) as i64) * (page_size as i64);
|
||||
let status_filter = params.status.clone();
|
||||
let query_filter = params.q.clone();
|
||||
|
||||
// Include query and status in cache key
|
||||
let cache_key = keys::videos_list(page, page_size);
|
||||
let cache_key = if let Some(ref q) = query_filter {
|
||||
format!("{}:q:{}", cache_key, q)
|
||||
} else {
|
||||
cache_key
|
||||
};
|
||||
let ttl = state.mongo_cache.ttl_videos();
|
||||
|
||||
// If uuid is provided, fetch single video directly
|
||||
if let Some(ref uuid) = params.uuid {
|
||||
let db = PostgresDb::init()
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
let video = db.get_video_by_uuid(uuid).await.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
if let Some(v) = video {
|
||||
return Ok(Json(VideosResponse {
|
||||
videos: vec![VideoInfoResponse {
|
||||
uuid: v.uuid,
|
||||
file_path: v.file_path,
|
||||
file_name: v.file_name,
|
||||
duration: v.duration,
|
||||
width: v.width,
|
||||
height: v.height,
|
||||
status: v.status.as_str().to_string(),
|
||||
processing_status: None,
|
||||
created_at: Some(v.created_at),
|
||||
registration_time: v.registration_time,
|
||||
file_size: None,
|
||||
probe_json: v.probe_json,
|
||||
}],
|
||||
count: 1,
|
||||
page,
|
||||
page_size,
|
||||
}));
|
||||
} else {
|
||||
return Err(StatusCode::NOT_FOUND);
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
"list_videos called: page={}, page_size={}, cache_key={}",
|
||||
page,
|
||||
@@ -1892,7 +1998,21 @@ async fn list_videos(
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("PG init failed: {}", e))?;
|
||||
|
||||
let (videos, count) = db.list_videos(page_size as i32, offset).await?;
|
||||
// Map status parameter to is_processed filter
|
||||
let is_processed = match status_filter.as_deref() {
|
||||
Some("pending") | Some("unprocessed") => Some(false),
|
||||
Some("completed") | Some("ready") | Some("processed") => Some(true),
|
||||
_ => None, // no filter
|
||||
};
|
||||
|
||||
// Search by query if provided
|
||||
let (videos, count) = if let Some(ref q) = query_filter {
|
||||
db.search_videos(Some(q.as_str()), is_processed, page_size as i32, offset).await?
|
||||
} else if let Some(processed) = is_processed {
|
||||
db.search_videos(None, Some(processed), page_size as i32, offset).await?
|
||||
} else {
|
||||
db.list_videos(page_size as i32, offset).await?
|
||||
};
|
||||
tracing::info!("Got {} videos from DB", videos.len());
|
||||
|
||||
let video_infos: Vec<VideoInfoResponse> = videos
|
||||
@@ -1904,6 +2024,12 @@ async fn list_videos(
|
||||
duration: v.duration,
|
||||
width: v.width,
|
||||
height: v.height,
|
||||
status: v.status.as_str().to_string(),
|
||||
processing_status: None,
|
||||
created_at: Some(v.created_at),
|
||||
registration_time: v.registration_time,
|
||||
file_size: None,
|
||||
probe_json: v.probe_json,
|
||||
})
|
||||
.collect();
|
||||
|
||||
@@ -2328,19 +2454,15 @@ pub async fn start_server(host: &str, port: u16) -> anyhow::Result<()> {
|
||||
.with_state(state.clone());
|
||||
|
||||
let cors = CorsLayer::new()
|
||||
.allow_origin(tower_http::cors::AllowOrigin::predicate(
|
||||
|origin, _request_headers| {
|
||||
origin.as_bytes().ends_with(b"localhost")
|
||||
|| origin.as_bytes().ends_with(b"momentry.ddns.net")
|
||||
|| origin.as_bytes().ends_with(b"127.0.0.1")
|
||||
},
|
||||
))
|
||||
.allow_origin(tower_http::cors::AllowOrigin::any())
|
||||
.allow_methods(Any)
|
||||
.allow_headers(Any);
|
||||
|
||||
let app = Router::new()
|
||||
.route("/health", get(health))
|
||||
.route("/health/detailed", get(health_detailed))
|
||||
.route("/api/v1/auth/login", post(login))
|
||||
.route("/api/v1/auth/logout", post(logout))
|
||||
.route("/api/v1/stats/ingest", get(get_ingest_stats))
|
||||
.route("/api/v1/stats/sftpgo", get(get_sftpgo_status))
|
||||
.route("/api/v1/stats/inference", get(get_inference_health))
|
||||
|
||||
814
src/api/universal_search.rs
Normal file
814
src/api/universal_search.rs
Normal file
@@ -0,0 +1,814 @@
|
||||
//! Universal Search API
|
||||
//! Unified search across chunks, frames, and persons.
|
||||
|
||||
use axum::{
|
||||
extract::{Query, State},
|
||||
http::StatusCode,
|
||||
response::Json,
|
||||
routing::{get, post},
|
||||
Router,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::core::db::{Database, PostgresDb};
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct UniversalSearchRequest {
|
||||
pub query: String,
|
||||
pub uuid: Option<String>,
|
||||
#[serde(default)]
|
||||
pub types: Vec<String>, // chunk, frame, person
|
||||
pub time_range: Option<[f64; 2]>,
|
||||
pub filters: Option<SearchFilters>,
|
||||
pub limit: Option<usize>,
|
||||
pub offset: Option<usize>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
pub struct SearchFilters {
|
||||
pub person_id: Option<String>,
|
||||
pub object_class: Option<Vec<String>>,
|
||||
pub ocr_text: Option<String>,
|
||||
pub has_face: Option<bool>,
|
||||
pub speaker_id: Option<String>,
|
||||
// Visual chunk filters
|
||||
pub min_confidence: Option<f32>,
|
||||
pub min_unique_classes: Option<u32>,
|
||||
pub min_spatial_density: Option<f32>,
|
||||
pub max_spatial_density: Option<f32>,
|
||||
pub required_object_classes: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct UniversalSearchResponse {
|
||||
pub query: String,
|
||||
pub results: Vec<SearchResult>,
|
||||
pub total: usize,
|
||||
pub took_ms: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Clone)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum SearchResult {
|
||||
#[serde(rename = "chunk")]
|
||||
Chunk {
|
||||
chunk_id: String,
|
||||
chunk_type: String,
|
||||
// Primary: frame-accurate position
|
||||
start_frame: i64,
|
||||
end_frame: i64,
|
||||
// Reference: time derived from frames (subject to FPS variation)
|
||||
start_time: f64,
|
||||
end_time: f64,
|
||||
score: f64,
|
||||
text: Option<String>,
|
||||
speaker_id: Option<String>,
|
||||
metadata: Option<serde_json::Value>,
|
||||
},
|
||||
#[serde(rename = "frame")]
|
||||
Frame {
|
||||
// Primary: exact frame number
|
||||
frame_number: i64,
|
||||
// Reference: time derived from frame (subject to FPS variation)
|
||||
timestamp: f64,
|
||||
score: f64,
|
||||
objects: Option<Vec<serde_json::Value>>,
|
||||
ocr_texts: Option<Vec<String>>,
|
||||
faces: Option<Vec<serde_json::Value>>,
|
||||
pose_persons: Option<Vec<serde_json::Value>>,
|
||||
},
|
||||
#[serde(rename = "person")]
|
||||
Person {
|
||||
person_id: String,
|
||||
name: Option<String>,
|
||||
speaker_id: Option<String>,
|
||||
appearance_count: i32,
|
||||
score: f64,
|
||||
first_appearance_time: Option<f64>,
|
||||
last_appearance_time: Option<f64>,
|
||||
},
|
||||
}
|
||||
|
||||
pub fn universal_search_routes() -> Router<crate::api::server::AppState> {
|
||||
Router::new()
|
||||
.route("/api/v1/search/universal", post(universal_search))
|
||||
.route("/api/v1/search/frames", post(search_frames))
|
||||
.route("/api/v1/search/persons", get(search_persons))
|
||||
}
|
||||
|
||||
/// Unified search across all data types
|
||||
pub async fn universal_search(
|
||||
State(_state): State<crate::api::server::AppState>,
|
||||
Json(req): Json<UniversalSearchRequest>,
|
||||
) -> Result<Json<UniversalSearchResponse>, (StatusCode, Json<serde_json::Value>)> {
|
||||
let start_time = std::time::Instant::now();
|
||||
let db = PostgresDb::init().await.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": format!("DB error: {}", e) })),
|
||||
)
|
||||
})?;
|
||||
|
||||
let limit = req.limit.unwrap_or(20);
|
||||
let offset = req.offset.unwrap_or(0);
|
||||
let types = if req.types.is_empty() {
|
||||
vec![
|
||||
"chunk".to_string(),
|
||||
"frame".to_string(),
|
||||
"person".to_string(),
|
||||
]
|
||||
} else {
|
||||
req.types.clone()
|
||||
};
|
||||
|
||||
let mut results = Vec::new();
|
||||
|
||||
// Search chunks
|
||||
if types.contains(&"chunk".to_string()) {
|
||||
let chunk_results = search_chunks(&db, &req).await.map_err(|e| {
|
||||
(
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({ "error": e.to_string() })),
|
||||
)
|
||||
})?;
|
||||
results.extend(chunk_results);
|
||||
}
|
||||
|
||||
// Search frames
|
||||
if types.contains(&"frame".to_string()) {
|
||||
let frame_results = search_frames_internal(&db, &req).await.unwrap_or_default();
|
||||
results.extend(frame_results);
|
||||
}
|
||||
|
||||
// Search persons
|
||||
if types.contains(&"person".to_string()) {
|
||||
let person_results = search_persons_internal(&db, &req).await.unwrap_or_default();
|
||||
results.extend(person_results);
|
||||
}
|
||||
|
||||
// Sort by score descending
|
||||
results.sort_by(|a, b| {
|
||||
let score_a = match a {
|
||||
SearchResult::Chunk { score, .. } => *score,
|
||||
SearchResult::Frame { score, .. } => *score,
|
||||
SearchResult::Person { score, .. } => *score,
|
||||
};
|
||||
let score_b = match b {
|
||||
SearchResult::Chunk { score, .. } => *score,
|
||||
SearchResult::Frame { score, .. } => *score,
|
||||
SearchResult::Person { score, .. } => *score,
|
||||
};
|
||||
score_b
|
||||
.partial_cmp(&score_a)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
|
||||
let total = results.len();
|
||||
let end = std::cmp::min(offset + limit, results.len());
|
||||
let paginated = if offset < results.len() {
|
||||
results[offset..end].to_vec()
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
|
||||
let took = start_time.elapsed().as_millis() as u64;
|
||||
|
||||
Ok(Json(UniversalSearchResponse {
|
||||
query: req.query,
|
||||
results: paginated,
|
||||
total,
|
||||
took_ms: took,
|
||||
}))
|
||||
}
|
||||
|
||||
/// Search frames by YOLO objects, OCR text, or face IDs
|
||||
pub async fn search_frames(
|
||||
State(_state): State<crate::api::server::AppState>,
|
||||
Json(req): Json<FrameSearchRequest>,
|
||||
) -> Result<Json<FrameSearchResponse>, (StatusCode, Json<serde_json::Value>)> {
|
||||
let db = PostgresDb::init().await.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": format!("DB error: {}", e) })),
|
||||
)
|
||||
})?;
|
||||
|
||||
let frames = search_frames_internal_v2(&db, &req).await.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": format!("Search error: {}", e) })),
|
||||
)
|
||||
})?;
|
||||
|
||||
let frames_count = frames.len();
|
||||
Ok(Json(FrameSearchResponse {
|
||||
frames,
|
||||
total: frames_count,
|
||||
}))
|
||||
}
|
||||
|
||||
/// Search persons by name or speaker_id
|
||||
pub async fn search_persons(
|
||||
State(_state): State<crate::api::server::AppState>,
|
||||
Query(query): Query<PersonSearchQuery>,
|
||||
) -> Result<Json<PersonSearchResponse>, (StatusCode, Json<serde_json::Value>)> {
|
||||
let db = PostgresDb::init().await.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": format!("DB error: {}", e) })),
|
||||
)
|
||||
})?;
|
||||
|
||||
let limit = query.limit.unwrap_or(20);
|
||||
let persons = search_persons_by_query(
|
||||
&db,
|
||||
&query.query,
|
||||
query.min_appearances,
|
||||
query.max_age,
|
||||
limit,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": format!("Search error: {}", e) })),
|
||||
)
|
||||
})?;
|
||||
|
||||
let persons_count = persons.len();
|
||||
Ok(Json(PersonSearchResponse {
|
||||
persons,
|
||||
total: persons_count,
|
||||
}))
|
||||
}
|
||||
|
||||
// --- Internal search functions ---
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct FrameSearchRequest {
|
||||
pub uuid: Option<String>,
|
||||
pub object_class: Option<String>,
|
||||
pub ocr_text: Option<String>,
|
||||
pub face_id: Option<String>,
|
||||
pub time_range: Option<[f64; 2]>,
|
||||
pub limit: Option<usize>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct FrameSearchResponse {
|
||||
pub frames: Vec<FrameResult>,
|
||||
pub total: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct FrameResult {
|
||||
pub frame_number: i64,
|
||||
pub timestamp: f64,
|
||||
pub uuid: String,
|
||||
pub objects: Option<Vec<serde_json::Value>>,
|
||||
pub ocr_texts: Option<Vec<String>>,
|
||||
pub faces: Option<Vec<serde_json::Value>>,
|
||||
pub pose_persons: Option<Vec<serde_json::Value>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct PersonSearchQuery {
|
||||
pub video_uuid: String,
|
||||
pub query: Option<String>,
|
||||
pub min_appearances: Option<i32>,
|
||||
pub max_age: Option<i32>, // New filter for "children"
|
||||
pub limit: Option<usize>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct PersonSearchResponse {
|
||||
pub persons: Vec<PersonResult>,
|
||||
pub total: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct PersonResult {
|
||||
pub person_id: String,
|
||||
pub name: Option<String>,
|
||||
pub character_name: Option<String>,
|
||||
pub aliases: Option<Vec<String>>,
|
||||
pub age: Option<i32>,
|
||||
pub gender: Option<String>,
|
||||
pub speaker_id: Option<String>,
|
||||
pub appearance_count: i32,
|
||||
pub first_appearance_time: Option<f64>,
|
||||
pub last_appearance_time: Option<f64>,
|
||||
}
|
||||
|
||||
async fn search_chunks(
|
||||
db: &PostgresDb,
|
||||
req: &UniversalSearchRequest,
|
||||
) -> Result<Vec<SearchResult>, anyhow::Error> {
|
||||
// uuid is required for chunk search - chunk_id is only unique within a video
|
||||
let uuid = match &req.uuid {
|
||||
Some(u) => u.replace('\'', "''"),
|
||||
None => return Err(anyhow::anyhow!("uuid is required for chunk search")),
|
||||
};
|
||||
|
||||
let mut sql = format!(
|
||||
"SELECT chunk_id, chunk_type, start_time, end_time, start_frame, end_frame, text_content, content FROM chunks WHERE uuid = '{}'",
|
||||
uuid
|
||||
);
|
||||
if let Some(tr) = &req.time_range {
|
||||
sql.push_str(&format!(
|
||||
" AND start_time >= {} AND end_time <= {}",
|
||||
tr[0], tr[1]
|
||||
));
|
||||
}
|
||||
if !req.query.is_empty() {
|
||||
let q = req.query.replace('\'', "''");
|
||||
sql.push_str(&format!(
|
||||
" AND (text_content ILIKE '%{}%' OR content::text ILIKE '%{}%')",
|
||||
q, q
|
||||
));
|
||||
}
|
||||
if let Some(ref filters) = req.filters {
|
||||
if let Some(ref speaker_id) = filters.speaker_id {
|
||||
sql.push_str(&format!(
|
||||
" AND content->>'speaker_id' = '{}'",
|
||||
speaker_id.replace('\'', "''")
|
||||
));
|
||||
}
|
||||
if let Some(ref person_id) = filters.person_id {
|
||||
sql.push_str(&format!(
|
||||
" AND content::text LIKE '%{}%'",
|
||||
person_id.replace('\'', "''")
|
||||
));
|
||||
}
|
||||
// Visual chunk filters
|
||||
if let Some(min_confidence) = filters.min_confidence {
|
||||
sql.push_str(&format!(
|
||||
" AND (content->'metadata'->>'avg_confidence')::float >= {}",
|
||||
min_confidence
|
||||
));
|
||||
}
|
||||
if let Some(min_unique_classes) = filters.min_unique_classes {
|
||||
sql.push_str(&format!(
|
||||
" AND jsonb_array_length(content->'metadata'->'unique_classes') >= {}",
|
||||
min_unique_classes
|
||||
));
|
||||
}
|
||||
if let Some(min_density) = filters.min_spatial_density {
|
||||
sql.push_str(&format!(
|
||||
" AND (content->'metadata'->>'spatial_density')::float >= {}",
|
||||
min_density
|
||||
));
|
||||
}
|
||||
if let Some(max_density) = filters.max_spatial_density {
|
||||
sql.push_str(&format!(
|
||||
" AND (content->'metadata'->>'spatial_density')::float <= {}",
|
||||
max_density
|
||||
));
|
||||
}
|
||||
if let Some(ref required_classes) = filters.required_object_classes {
|
||||
if !required_classes.is_empty() {
|
||||
let class_conditions: Vec<String> = required_classes
|
||||
.iter()
|
||||
.map(|class| {
|
||||
format!(
|
||||
"content->'keyframe_objects' @> '[{{ \"class_name\": \"{}\"}}]'",
|
||||
class.replace('\'', "''")
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
sql.push_str(&format!(" AND ({})", class_conditions.join(" OR ")));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sql.push_str(" ORDER BY start_time ASC");
|
||||
sql.push_str(&format!(" LIMIT {}", req.limit.unwrap_or(20)));
|
||||
|
||||
let rows: Vec<(
|
||||
String,
|
||||
String,
|
||||
f64,
|
||||
f64,
|
||||
i64,
|
||||
i64,
|
||||
Option<String>,
|
||||
Option<serde_json::Value>,
|
||||
)> = sqlx::query_as(&sql).fetch_all(db.pool()).await?;
|
||||
|
||||
let results: Vec<SearchResult> = rows
|
||||
.into_iter()
|
||||
.map(
|
||||
|(
|
||||
chunk_id,
|
||||
chunk_type,
|
||||
start_time,
|
||||
end_time,
|
||||
start_frame,
|
||||
end_frame,
|
||||
text_content,
|
||||
content,
|
||||
)| {
|
||||
let text = text_content.or_else(|| {
|
||||
content
|
||||
.as_ref()
|
||||
.and_then(|c| c.get("text").and_then(|v| v.as_str()).map(String::from))
|
||||
});
|
||||
let speaker_id = content.as_ref().and_then(|c| {
|
||||
c.get("speaker_id")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(String::from)
|
||||
});
|
||||
// Simple scoring: if query matches, score 0.8
|
||||
let score = if !req.query.is_empty()
|
||||
&& text.as_ref().map_or(false, |t| {
|
||||
t.to_lowercase().contains(&req.query.to_lowercase())
|
||||
}) {
|
||||
0.9
|
||||
} else {
|
||||
0.5
|
||||
};
|
||||
|
||||
SearchResult::Chunk {
|
||||
chunk_id,
|
||||
chunk_type,
|
||||
start_time,
|
||||
end_time,
|
||||
start_frame,
|
||||
end_frame,
|
||||
score,
|
||||
text,
|
||||
speaker_id,
|
||||
metadata: content,
|
||||
}
|
||||
},
|
||||
)
|
||||
.collect();
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
async fn search_frames_internal(
|
||||
db: &PostgresDb,
|
||||
req: &UniversalSearchRequest,
|
||||
) -> Result<Vec<SearchResult>, anyhow::Error> {
|
||||
let table = "frames";
|
||||
let video_table = "videos";
|
||||
|
||||
let mut sql = format!(
|
||||
"SELECT f.frame_number, f.timestamp, f.yolo_objects, f.ocr_results, f.face_results, f.pose_results, v.uuid
|
||||
FROM {} f JOIN {} v ON f.file_id = v.id WHERE 1=1",
|
||||
table, video_table
|
||||
);
|
||||
|
||||
if let Some(uuid) = &req.uuid {
|
||||
sql.push_str(&format!(" AND v.uuid = '{}'", uuid));
|
||||
}
|
||||
if let Some(tr) = &req.time_range {
|
||||
sql.push_str(&format!(
|
||||
" AND f.timestamp >= {} AND f.timestamp <= {}",
|
||||
tr[0], tr[1]
|
||||
));
|
||||
}
|
||||
if let Some(ref filters) = req.filters {
|
||||
if let Some(ref classes) = filters.object_class {
|
||||
for class in classes {
|
||||
sql.push_str(&format!(" AND f.yolo_objects::text ILIKE '%{}%'", class));
|
||||
}
|
||||
}
|
||||
if let Some(ref ocr) = filters.ocr_text {
|
||||
sql.push_str(&format!(" AND f.ocr_results::text ILIKE '%{}%'", ocr));
|
||||
}
|
||||
if let Some(true) = filters.has_face {
|
||||
sql.push_str(
|
||||
" AND f.face_results IS NOT NULL AND jsonb_array_length(f.face_results) > 0",
|
||||
);
|
||||
}
|
||||
if let Some(ref person_id) = filters.person_id {
|
||||
sql.push_str(&format!(" AND f.face_results::text LIKE '%{}%'", person_id));
|
||||
}
|
||||
}
|
||||
if !req.query.is_empty() {
|
||||
// Search across all frame data
|
||||
sql.push_str(&format!(
|
||||
" AND (f.yolo_objects::text ILIKE '%{}%' OR f.ocr_results::text ILIKE '%{}%' OR f.face_results::text ILIKE '%{}%')",
|
||||
req.query, req.query, req.query
|
||||
));
|
||||
}
|
||||
|
||||
sql.push_str(" ORDER BY f.timestamp ASC");
|
||||
sql.push_str(&format!(" LIMIT {}", req.limit.unwrap_or(20)));
|
||||
|
||||
let rows: Vec<(
|
||||
i64,
|
||||
f64,
|
||||
Option<serde_json::Value>,
|
||||
Option<serde_json::Value>,
|
||||
Option<serde_json::Value>,
|
||||
Option<serde_json::Value>,
|
||||
String,
|
||||
)> = sqlx::query_as(&sql).fetch_all(db.pool()).await?;
|
||||
|
||||
let results: Vec<SearchResult> = rows
|
||||
.into_iter()
|
||||
.map(|(frame_number, timestamp, yolo, ocr, face, pose, _uuid)| {
|
||||
let objects = yolo.as_ref().and_then(|v| {
|
||||
v.get("objects")
|
||||
.map(|o| o.as_array().cloned().unwrap_or_default())
|
||||
});
|
||||
let ocr_texts = ocr.as_ref().and_then(|v| {
|
||||
v.get("texts").and_then(|t| {
|
||||
t.as_array().map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|item| {
|
||||
item.get("text").and_then(|x| x.as_str()).map(String::from)
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
})
|
||||
});
|
||||
let faces = face.as_ref().and_then(|v| {
|
||||
v.get("faces")
|
||||
.map(|f| f.as_array().cloned().unwrap_or_default())
|
||||
});
|
||||
let pose_persons = pose.as_ref().and_then(|v| {
|
||||
v.get("persons")
|
||||
.map(|p| p.as_array().cloned().unwrap_or_default())
|
||||
});
|
||||
|
||||
SearchResult::Frame {
|
||||
frame_number,
|
||||
timestamp,
|
||||
score: 0.7,
|
||||
objects: objects.map(|arr| arr.iter().map(|v| v.clone()).collect()),
|
||||
ocr_texts,
|
||||
faces,
|
||||
pose_persons,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
async fn search_persons_internal(
|
||||
db: &PostgresDb,
|
||||
req: &UniversalSearchRequest,
|
||||
) -> Result<Vec<SearchResult>, anyhow::Error> {
|
||||
let table = "person_identities";
|
||||
let mut sql = format!(
|
||||
"SELECT person_id, name, speaker_id, appearance_count, first_appearance_time, last_appearance_time FROM {} WHERE 1=1",
|
||||
table
|
||||
);
|
||||
|
||||
if !req.query.is_empty() {
|
||||
sql.push_str(&format!(
|
||||
" AND (name ILIKE '%{}%' OR person_id ILIKE '%{}%' OR speaker_id ILIKE '%{}%')",
|
||||
req.query, req.query, req.query
|
||||
));
|
||||
}
|
||||
if let Some(ref filters) = req.filters {
|
||||
if let Some(ref speaker_id) = filters.speaker_id {
|
||||
sql.push_str(&format!(" AND speaker_id = '{}'", speaker_id));
|
||||
}
|
||||
if let Some(ref person_id) = filters.person_id {
|
||||
sql.push_str(&format!(" AND person_id = '{}'", person_id));
|
||||
}
|
||||
}
|
||||
|
||||
sql.push_str(" ORDER BY appearance_count DESC");
|
||||
sql.push_str(&format!(" LIMIT {}", req.limit.unwrap_or(20)));
|
||||
|
||||
let rows: Vec<(
|
||||
String,
|
||||
Option<String>,
|
||||
Option<String>,
|
||||
i32,
|
||||
Option<f64>,
|
||||
Option<f64>,
|
||||
)> = sqlx::query_as(&sql).fetch_all(db.pool()).await?;
|
||||
|
||||
let results: Vec<SearchResult> = rows
|
||||
.into_iter()
|
||||
.map(
|
||||
|(person_id, name, speaker_id, appearance_count, first_time, last_time)| {
|
||||
let score = if !req.query.is_empty()
|
||||
&& name.as_ref().map_or(false, |n| {
|
||||
n.to_lowercase().contains(&req.query.to_lowercase())
|
||||
}) {
|
||||
0.95
|
||||
} else {
|
||||
0.5
|
||||
};
|
||||
|
||||
SearchResult::Person {
|
||||
person_id,
|
||||
name,
|
||||
speaker_id,
|
||||
appearance_count,
|
||||
score,
|
||||
first_appearance_time: first_time,
|
||||
last_appearance_time: last_time,
|
||||
}
|
||||
},
|
||||
)
|
||||
.collect();
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
async fn search_frames_internal_v2(
|
||||
db: &PostgresDb,
|
||||
req: &FrameSearchRequest,
|
||||
) -> Result<Vec<FrameResult>, anyhow::Error> {
|
||||
let table = "frames";
|
||||
let video_table = "videos";
|
||||
|
||||
let mut sql = format!(
|
||||
"SELECT f.frame_number, f.timestamp, f.yolo_objects, f.ocr_results, f.face_results, f.pose_results, v.uuid
|
||||
FROM {} f JOIN {} v ON f.file_id = v.id WHERE 1=1",
|
||||
table, video_table
|
||||
);
|
||||
|
||||
if let Some(uuid) = &req.uuid {
|
||||
sql.push_str(&format!(" AND v.uuid = '{}'", uuid));
|
||||
}
|
||||
if let Some(tr) = &req.time_range {
|
||||
sql.push_str(&format!(
|
||||
" AND f.timestamp >= {} AND f.timestamp <= {}",
|
||||
tr[0], tr[1]
|
||||
));
|
||||
}
|
||||
if let Some(ref class) = req.object_class {
|
||||
sql.push_str(&format!(" AND f.yolo_objects::text ILIKE '%{}%'", class));
|
||||
}
|
||||
if let Some(ref ocr) = req.ocr_text {
|
||||
sql.push_str(&format!(" AND f.ocr_results::text ILIKE '%{}%'", ocr));
|
||||
}
|
||||
if let Some(ref face_id) = req.face_id {
|
||||
sql.push_str(&format!(" AND f.face_results::text LIKE '%{}%'", face_id));
|
||||
}
|
||||
|
||||
sql.push_str(" ORDER BY f.timestamp ASC");
|
||||
sql.push_str(&format!(" LIMIT {}", req.limit.unwrap_or(50)));
|
||||
|
||||
let rows: Vec<(
|
||||
i64,
|
||||
f64,
|
||||
Option<serde_json::Value>,
|
||||
Option<serde_json::Value>,
|
||||
Option<serde_json::Value>,
|
||||
Option<serde_json::Value>,
|
||||
String,
|
||||
)> = sqlx::query_as(&sql).fetch_all(db.pool()).await?;
|
||||
|
||||
let results: Vec<FrameResult> = rows
|
||||
.into_iter()
|
||||
.map(|(frame_number, timestamp, yolo, ocr, face, pose, uuid)| {
|
||||
let objects = yolo.as_ref().and_then(|v| {
|
||||
v.get("objects")
|
||||
.map(|o| o.as_array().cloned().unwrap_or_default())
|
||||
});
|
||||
let ocr_texts = ocr.as_ref().and_then(|v| {
|
||||
v.get("texts").and_then(|t| {
|
||||
t.as_array().map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|item| {
|
||||
item.get("text").and_then(|x| x.as_str()).map(String::from)
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
})
|
||||
});
|
||||
let faces = face.as_ref().and_then(|v| {
|
||||
v.get("faces")
|
||||
.map(|f| f.as_array().cloned().unwrap_or_default())
|
||||
});
|
||||
let pose_persons = pose.as_ref().and_then(|v| {
|
||||
v.get("persons")
|
||||
.map(|p| p.as_array().cloned().unwrap_or_default())
|
||||
});
|
||||
|
||||
FrameResult {
|
||||
frame_number,
|
||||
timestamp,
|
||||
uuid,
|
||||
objects: objects.map(|arr| arr.iter().map(|v| v.clone()).collect()),
|
||||
ocr_texts,
|
||||
faces,
|
||||
pose_persons,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
async fn search_persons_by_query(
|
||||
db: &PostgresDb,
|
||||
query: &Option<String>,
|
||||
min_appearances: Option<i32>,
|
||||
max_age: Option<i32>,
|
||||
limit: usize,
|
||||
) -> Result<Vec<PersonResult>, anyhow::Error> {
|
||||
let table = "person_identities";
|
||||
let mut sql = format!(
|
||||
"SELECT person_id, name, character_name, aliases, age, gender, speaker_id, appearance_count, first_appearance_time, last_appearance_time FROM {} WHERE 1=1",
|
||||
table
|
||||
);
|
||||
|
||||
if let Some(ref q) = query {
|
||||
// Search name, character_name, aliases (cast to text), person_id, speaker_id
|
||||
sql.push_str(&format!(
|
||||
" AND (name ILIKE '%{}%' OR character_name ILIKE '%{}%' OR aliases::text ILIKE '%{}%' OR person_id ILIKE '%{}%' OR speaker_id ILIKE '%{}%')",
|
||||
q, q, q, q, q
|
||||
));
|
||||
}
|
||||
|
||||
if let Some(min) = min_appearances {
|
||||
sql.push_str(&format!(" AND appearance_count >= {}", min));
|
||||
}
|
||||
if let Some(max_a) = max_age {
|
||||
// Strictly filter for age <= max_age.
|
||||
// Note: This excludes entries with NULL age.
|
||||
sql.push_str(&format!(" AND age <= {}", max_a));
|
||||
}
|
||||
|
||||
sql.push_str(" ORDER BY appearance_count DESC");
|
||||
sql.push_str(&format!(" LIMIT {}", limit));
|
||||
|
||||
let rows: Vec<(
|
||||
String,
|
||||
Option<String>,
|
||||
Option<String>,
|
||||
Option<serde_json::Value>,
|
||||
Option<i32>,
|
||||
Option<String>,
|
||||
Option<String>,
|
||||
i32,
|
||||
Option<f64>,
|
||||
Option<f64>,
|
||||
)> = sqlx::query_as(&sql).fetch_all(db.pool()).await?;
|
||||
|
||||
let results: Vec<PersonResult> = rows
|
||||
.into_iter()
|
||||
.map(
|
||||
|(
|
||||
person_id,
|
||||
name,
|
||||
character_name,
|
||||
aliases_json,
|
||||
age,
|
||||
gender,
|
||||
speaker_id,
|
||||
appearance_count,
|
||||
first_time,
|
||||
last_time,
|
||||
)| {
|
||||
let aliases = aliases_json.and_then(|v| {
|
||||
v.as_array().map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|val| val.as_str().map(String::from))
|
||||
.collect()
|
||||
})
|
||||
});
|
||||
|
||||
PersonResult {
|
||||
person_id,
|
||||
name,
|
||||
character_name,
|
||||
aliases,
|
||||
age,
|
||||
gender,
|
||||
speaker_id,
|
||||
appearance_count,
|
||||
first_appearance_time: first_time,
|
||||
last_appearance_time: last_time,
|
||||
}
|
||||
},
|
||||
)
|
||||
.collect();
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_search_filters_with_visual() {
|
||||
let filters = SearchFilters {
|
||||
person_id: None,
|
||||
object_class: None,
|
||||
ocr_text: None,
|
||||
has_face: None,
|
||||
speaker_id: None,
|
||||
min_confidence: Some(0.8),
|
||||
min_unique_classes: Some(3),
|
||||
min_spatial_density: Some(0.5),
|
||||
max_spatial_density: Some(0.9),
|
||||
required_object_classes: Some(vec!["person".to_string()]),
|
||||
};
|
||||
assert_eq!(filters.min_confidence, Some(0.8));
|
||||
}
|
||||
}
|
||||
504
src/api/visual_chunk_search.rs
Normal file
504
src/api/visual_chunk_search.rs
Normal file
@@ -0,0 +1,504 @@
|
||||
//! Visual chunk search functionality.
|
||||
//!
|
||||
//! This module provides search capabilities for visual chunks based on:
|
||||
//! - Object classes (e.g., "person", "car", "envelope")
|
||||
//! - Confidence thresholds
|
||||
//! - Object counts
|
||||
//! - Spatial density
|
||||
//! - Object relationships
|
||||
|
||||
use crate::core::chunk::types::{Chunk, ChunkRule, ChunkType};
|
||||
use crate::core::db::PostgresDb;
|
||||
use anyhow::Result;
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Criteria for searching visual chunks
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct VisualChunkSearchCriteria {
|
||||
/// Minimum average confidence across frames
|
||||
pub min_avg_confidence: Option<f32>,
|
||||
/// Minimum number of frames with objects
|
||||
pub min_frames_with_objects: Option<u32>,
|
||||
/// Minimum number of unique object classes
|
||||
pub min_unique_classes: Option<u32>,
|
||||
/// Specific object classes to include (empty means all)
|
||||
pub required_classes: Vec<String>,
|
||||
/// Object class counts to filter by
|
||||
pub class_counts: HashMap<String, (u32, u32)>,
|
||||
/// Time range (optional)
|
||||
pub time_range: Option<(f64, f64)>,
|
||||
}
|
||||
|
||||
impl Default for VisualChunkSearchCriteria {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
min_avg_confidence: None,
|
||||
min_frames_with_objects: None,
|
||||
min_unique_classes: None,
|
||||
required_classes: Vec::new(),
|
||||
class_counts: HashMap::new(),
|
||||
time_range: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Search visual chunks based on criteria
|
||||
pub async fn search_visual_chunks(
|
||||
db: &PostgresDb,
|
||||
uuid: &str,
|
||||
criteria: &VisualChunkSearchCriteria,
|
||||
) -> Result<Vec<Chunk>> {
|
||||
// First, get all visual chunks for this video
|
||||
let all_chunks = get_visual_chunks_by_uuid(db, uuid).await?;
|
||||
|
||||
// Apply filters
|
||||
let filtered_chunks: Vec<Chunk> = all_chunks
|
||||
.into_iter()
|
||||
.filter(|chunk| {
|
||||
// Check min avg confidence
|
||||
if let Some(min_avg_confidence) = criteria.min_avg_confidence {
|
||||
if let Some(content) = &chunk.content.as_object() {
|
||||
if let Some(metadata) = content.get("metadata") {
|
||||
if let Some(avg_confidence) = metadata.get("avg_confidence") {
|
||||
if let Some(conf) = avg_confidence.as_f64() {
|
||||
if conf < min_avg_confidence as f64 {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check min frames with objects
|
||||
if let Some(min_frames) = criteria.min_frames_with_objects {
|
||||
if let Some(stats) = &chunk.visual_stats {
|
||||
if let Some(frames_with_objects) = stats.get("frames_with_objects") {
|
||||
if let Some(count) = frames_with_objects.as_u64() {
|
||||
if count < min_frames as u64 {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check min unique classes
|
||||
if let Some(min_unique_classes) = criteria.min_unique_classes {
|
||||
if let Some(content) = &chunk.content.as_object() {
|
||||
if let Some(metadata) = content.get("metadata") {
|
||||
if let Some(unique_classes) = metadata.get("unique_classes") {
|
||||
if let Some(classes) = unique_classes.as_array() {
|
||||
if (classes.len() as u32) < min_unique_classes {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check required classes
|
||||
if !criteria.required_classes.is_empty() {
|
||||
if let Some(content) = &chunk.content.as_object() {
|
||||
if let Some(keyframe_objects) = content.get("keyframe_objects") {
|
||||
if let Some(objects) = keyframe_objects.as_array() {
|
||||
let mut found_all = true;
|
||||
for required_class in &criteria.required_classes {
|
||||
let mut found = false;
|
||||
for obj in objects {
|
||||
if let Some(class_name) = obj.get("class_name") {
|
||||
if let Some(class_str) = class_name.as_str() {
|
||||
if class_str == required_class {
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
found_all = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if !found_all {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check class counts
|
||||
if !criteria.class_counts.is_empty() {
|
||||
if let Some(content) = &chunk.content.as_object() {
|
||||
if let Some(metadata) = content.get("metadata") {
|
||||
if let Some(object_counts) = metadata.get("object_counts") {
|
||||
for (class, (min, max)) in &criteria.class_counts {
|
||||
if let Some(count_value) = object_counts.get(class) {
|
||||
if let Some(count) = count_value.as_u64() {
|
||||
if *min > 0 && count < *min as u64 {
|
||||
return false;
|
||||
}
|
||||
if *max < u32::MAX && count > *max as u64 {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} else if *min > 0 {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} else if criteria.class_counts.values().any(|(min, _)| *min > 0) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check time range
|
||||
if let Some((start_time, end_time)) = criteria.time_range {
|
||||
// Calculate chunk time from frames
|
||||
let chunk_start_time = chunk.start_frame as f64 / chunk.fps;
|
||||
let chunk_end_time = chunk.end_frame as f64 / chunk.fps;
|
||||
|
||||
if chunk_start_time < start_time || chunk_end_time > end_time {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(filtered_chunks)
|
||||
}
|
||||
|
||||
/// Get all visual chunks for a video UUID
|
||||
async fn get_visual_chunks_by_uuid(db: &PostgresDb, uuid: &str) -> Result<Vec<Chunk>> {
|
||||
let sql = format!(
|
||||
"SELECT file_id, uuid, chunk_id, chunk_index, chunk_type, fps, start_frame, end_frame, text_content, content, metadata, vector_id, visual_stats FROM chunks WHERE uuid = '{}' AND chunk_type = 'visual' ORDER BY start_frame ASC",
|
||||
uuid.replace('\'', "''")
|
||||
);
|
||||
|
||||
let rows: Vec<(
|
||||
i32, // file_id
|
||||
String, // uuid
|
||||
String, // chunk_id
|
||||
i32, // chunk_index
|
||||
String, // chunk_type
|
||||
f64, // fps
|
||||
i64, // start_frame
|
||||
i64, // end_frame
|
||||
Option<String>, // text_content
|
||||
Value, // content
|
||||
Option<Value>, // metadata
|
||||
Option<String>, // vector_id
|
||||
Option<Value>, // visual_stats
|
||||
)> = sqlx::query_as(&sql).fetch_all(db.pool()).await?;
|
||||
|
||||
let mut chunks = Vec::new();
|
||||
for row in rows {
|
||||
let chunk_type = match row.4.as_str() {
|
||||
"visual" => ChunkType::Visual,
|
||||
"sentence" => ChunkType::Sentence,
|
||||
"time_based" => ChunkType::TimeBased,
|
||||
"cut" => ChunkType::Cut,
|
||||
"trace" => ChunkType::Trace,
|
||||
"story" => ChunkType::Story,
|
||||
_ => ChunkType::TimeBased,
|
||||
};
|
||||
|
||||
// Calculate frame_count
|
||||
let frame_count = (row.7 - row.6) as i32;
|
||||
|
||||
chunks.push(Chunk {
|
||||
file_id: row.0,
|
||||
uuid: row.1,
|
||||
chunk_id: row.2,
|
||||
chunk_index: row.3 as u32,
|
||||
chunk_type,
|
||||
rule: ChunkRule::Rule2, // Visual chunks use Rule2
|
||||
fps: row.5,
|
||||
start_frame: row.6,
|
||||
end_frame: row.7,
|
||||
text_content: row.8,
|
||||
content: row.9,
|
||||
metadata: row.10,
|
||||
vector_id: row.11,
|
||||
frame_count,
|
||||
pre_chunk_ids: Vec::new(),
|
||||
parent_chunk_id: None,
|
||||
child_chunk_ids: Vec::new(),
|
||||
visual_stats: row.12,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(chunks)
|
||||
}
|
||||
|
||||
/// Search visual chunks by object class
|
||||
pub async fn search_visual_chunks_by_class(
|
||||
db: &PostgresDb,
|
||||
uuid: &str,
|
||||
object_class: &str,
|
||||
min_count: Option<u32>,
|
||||
max_count: Option<u32>,
|
||||
) -> Result<Vec<Chunk>> {
|
||||
let all_chunks = get_visual_chunks_by_uuid(db, uuid).await?;
|
||||
|
||||
let filtered_chunks: Vec<Chunk> = all_chunks
|
||||
.into_iter()
|
||||
.filter(|chunk| {
|
||||
// Check if chunk contains the object class
|
||||
let mut contains_class = false;
|
||||
if let Some(content) = &chunk.content.as_object() {
|
||||
if let Some(keyframe_objects) = content.get("keyframe_objects") {
|
||||
if let Some(objects) = keyframe_objects.as_array() {
|
||||
for obj in objects {
|
||||
if let Some(class_name) = obj.get("class_name") {
|
||||
if let Some(class_str) = class_name.as_str() {
|
||||
if class_str == object_class {
|
||||
contains_class = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !contains_class {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check count in visual_stats
|
||||
if let Some(stats) = &chunk.visual_stats {
|
||||
if let Some(count) = stats.get(object_class) {
|
||||
if let Some(c) = count.as_u64() {
|
||||
if let Some(min) = min_count {
|
||||
if c < min as u64 {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if let Some(max) = max_count {
|
||||
if c > max as u64 {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(filtered_chunks)
|
||||
}
|
||||
|
||||
/// Search visual chunks by spatial density
|
||||
pub async fn search_visual_chunks_by_density(
|
||||
db: &PostgresDb,
|
||||
uuid: &str,
|
||||
min_density: f32,
|
||||
max_density: Option<f32>,
|
||||
) -> Result<Vec<Chunk>> {
|
||||
let all_chunks = get_visual_chunks_by_uuid(db, uuid).await?;
|
||||
|
||||
let filtered_chunks: Vec<Chunk> = all_chunks
|
||||
.into_iter()
|
||||
.filter(|chunk| {
|
||||
if let Some(content) = &chunk.content.as_object() {
|
||||
if let Some(metadata) = content.get("metadata") {
|
||||
if let Some(density_value) = metadata.get("spatial_density") {
|
||||
if let Some(density) = density_value.as_f64() {
|
||||
if density < min_density as f64 {
|
||||
return false;
|
||||
}
|
||||
if let Some(max_dens) = max_density {
|
||||
if density > max_dens as f64 {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
false
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(filtered_chunks)
|
||||
}
|
||||
|
||||
/// Find chunks containing specific object combinations
|
||||
pub async fn search_visual_chunks_by_combination(
|
||||
db: &PostgresDb,
|
||||
uuid: &str,
|
||||
combination: &[(&str, u32)],
|
||||
) -> Result<Vec<Chunk>> {
|
||||
let all_chunks = get_visual_chunks_by_uuid(db, uuid).await?;
|
||||
|
||||
let filtered_chunks: Vec<Chunk> = all_chunks
|
||||
.into_iter()
|
||||
.filter(|chunk| {
|
||||
// Check if all required combinations are present
|
||||
for (object_class, min_count) in combination {
|
||||
let mut found = false;
|
||||
if let Some(stats) = &chunk.visual_stats {
|
||||
if let Some(object_counts) = stats.get("object_counts") {
|
||||
if let Some(count_value) = object_counts.get(*object_class) {
|
||||
if let Some(count) = count_value.as_u64() {
|
||||
if count >= *min_count as u64 {
|
||||
found = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(filtered_chunks)
|
||||
}
|
||||
|
||||
/// Get visual chunk statistics
|
||||
pub async fn get_visual_chunk_statistics(
|
||||
db: &PostgresDb,
|
||||
uuid: &str,
|
||||
) -> Result<HashMap<String, Value>> {
|
||||
let sql = format!(
|
||||
"SELECT
|
||||
COUNT(*) as total_chunks,
|
||||
AVG((content->'metadata'->>'avg_confidence')::float) as avg_confidence,
|
||||
MIN((content->'metadata'->>'avg_confidence')::float) as min_confidence,
|
||||
MAX((content->'metadata'->>'avg_confidence')::float) as max_confidence,
|
||||
SUM((content->'metadata'->>'object_count')::int) as total_objects,
|
||||
AVG((content->'metadata'->>'spatial_density')::float) as avg_density
|
||||
FROM chunks
|
||||
WHERE uuid = '{}'
|
||||
AND chunk_type = 'visual'",
|
||||
uuid.replace('\'', "''")
|
||||
);
|
||||
|
||||
let row: (i64, Option<f64>, Option<f64>, Option<f64>, i64, Option<f64>) =
|
||||
sqlx::query_as(&sql).fetch_one(db.pool()).await?;
|
||||
|
||||
let mut stats = HashMap::new();
|
||||
stats.insert("total_chunks".to_string(), Value::from(row.0));
|
||||
stats.insert(
|
||||
"avg_confidence".to_string(),
|
||||
Value::from(row.1.unwrap_or(0.0)),
|
||||
);
|
||||
stats.insert(
|
||||
"min_confidence".to_string(),
|
||||
Value::from(row.2.unwrap_or(0.0)),
|
||||
);
|
||||
stats.insert(
|
||||
"max_confidence".to_string(),
|
||||
Value::from(row.3.unwrap_or(0.0)),
|
||||
);
|
||||
stats.insert("total_objects".to_string(), Value::from(row.4));
|
||||
stats.insert("avg_density".to_string(), Value::from(row.5.unwrap_or(0.0)));
|
||||
|
||||
Ok(stats)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_visual_chunk_search_criteria_default() {
|
||||
let criteria = VisualChunkSearchCriteria::default();
|
||||
|
||||
assert_eq!(criteria.min_avg_confidence, None);
|
||||
assert_eq!(criteria.min_frames_with_objects, None);
|
||||
assert_eq!(criteria.min_unique_classes, None);
|
||||
assert!(criteria.required_classes.is_empty());
|
||||
assert!(criteria.class_counts.is_empty());
|
||||
assert_eq!(criteria.time_range, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_visual_chunk_search_criteria_with_values() {
|
||||
let mut criteria = VisualChunkSearchCriteria::default();
|
||||
criteria.min_avg_confidence = Some(0.8);
|
||||
criteria.min_frames_with_objects = Some(10);
|
||||
criteria.min_unique_classes = Some(3);
|
||||
criteria.required_classes = vec!["person".to_string(), "car".to_string()];
|
||||
criteria.time_range = Some((0.0, 60.0));
|
||||
|
||||
assert_eq!(criteria.min_avg_confidence, Some(0.8));
|
||||
assert_eq!(criteria.min_frames_with_objects, Some(10));
|
||||
assert_eq!(criteria.min_unique_classes, Some(3));
|
||||
assert_eq!(criteria.required_classes.len(), 2);
|
||||
assert_eq!(criteria.time_range, Some((0.0, 60.0)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_visual_chunk_search_criteria_serialization() {
|
||||
let criteria = VisualChunkSearchCriteria {
|
||||
min_avg_confidence: Some(0.85),
|
||||
min_frames_with_objects: Some(5),
|
||||
min_unique_classes: Some(2),
|
||||
required_classes: vec!["person".to_string()],
|
||||
class_counts: HashMap::new(),
|
||||
time_range: Some((10.0, 30.0)),
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&criteria).unwrap();
|
||||
assert!(json.contains("min_avg_confidence"));
|
||||
assert!(json.contains("required_classes"));
|
||||
|
||||
let deserialized: VisualChunkSearchCriteria = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(deserialized.min_avg_confidence, Some(0.85));
|
||||
assert_eq!(deserialized.required_classes.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_visual_chunk_search_criteria_with_class_counts() {
|
||||
let mut criteria = VisualChunkSearchCriteria::default();
|
||||
criteria.class_counts.insert("person".to_string(), (5, 20));
|
||||
criteria.class_counts.insert("car".to_string(), (1, 10));
|
||||
|
||||
assert_eq!(criteria.class_counts.len(), 2);
|
||||
assert_eq!(criteria.class_counts.get("person"), Some(&(5, 20)));
|
||||
assert_eq!(criteria.class_counts.get("car"), Some(&(1, 10)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chunk_type_conversion() {
|
||||
// Test chunk type string to enum conversion logic
|
||||
let test_cases = vec![
|
||||
("visual", ChunkType::Visual),
|
||||
("sentence", ChunkType::Sentence),
|
||||
("time_based", ChunkType::TimeBased),
|
||||
("cut", ChunkType::Cut),
|
||||
("trace", ChunkType::Trace),
|
||||
("story", ChunkType::Story),
|
||||
("unknown", ChunkType::TimeBased), // Default fallback
|
||||
];
|
||||
|
||||
for (input, expected) in test_cases {
|
||||
let chunk_type = match input {
|
||||
"visual" => ChunkType::Visual,
|
||||
"sentence" => ChunkType::Sentence,
|
||||
"time_based" => ChunkType::TimeBased,
|
||||
"cut" => ChunkType::Cut,
|
||||
"trace" => ChunkType::Trace,
|
||||
"story" => ChunkType::Story,
|
||||
_ => ChunkType::TimeBased,
|
||||
};
|
||||
assert_eq!(chunk_type, expected);
|
||||
}
|
||||
}
|
||||
}
|
||||
147
src/api/who.rs
Normal file
147
src/api/who.rs
Normal file
@@ -0,0 +1,147 @@
|
||||
//! Who API - 身份識別與 ID 映射接口 (Video-Scoped)
|
||||
|
||||
use axum::{
|
||||
extract::State,
|
||||
http::StatusCode,
|
||||
response::Json,
|
||||
routing::{get, post},
|
||||
Router,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::core::db::Database;
|
||||
|
||||
// --- Request / Response Structures ---
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct WhoQuery {
|
||||
pub face_id: Option<String>,
|
||||
pub speaker_id: Option<String>,
|
||||
pub uuid: Option<String>,
|
||||
pub chunk_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct WhoCandidatesRequest {
|
||||
pub query: String,
|
||||
pub video_uuid: Option<String>,
|
||||
pub limit: Option<i32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct DefinePersonRequest {
|
||||
pub uuid: String,
|
||||
pub identity_id: Option<i32>,
|
||||
pub name: String,
|
||||
pub face_ids: Option<Vec<String>>,
|
||||
pub speaker_ids: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct WhoIdentity {
|
||||
pub identity_id: i32,
|
||||
pub uuid: String,
|
||||
pub name: String,
|
||||
pub tags: Option<Vec<String>>,
|
||||
pub face_ids: Vec<String>,
|
||||
pub speaker_ids: Vec<String>,
|
||||
}
|
||||
|
||||
// --- API Handlers ---
|
||||
|
||||
/// GET /api/v1/who
|
||||
pub async fn get_who_identity(
|
||||
State(state): State<crate::api::server::AppState>,
|
||||
axum::extract::Query(query): axum::extract::Query<WhoQuery>,
|
||||
) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
|
||||
let db = &state.db;
|
||||
|
||||
// Priority 1: Query by Chunk (UUID + Chunk ID)
|
||||
if let (Some(uuid), Some(chunk_id)) = (&query.uuid, &query.chunk_id) {
|
||||
let info = db
|
||||
.get_who_info_by_chunk(uuid, chunk_id)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": e.to_string() })),
|
||||
)
|
||||
})?;
|
||||
return Ok(Json(info));
|
||||
}
|
||||
|
||||
// Priority 2: List all for a specific UUID
|
||||
if let Some(uuid) = &query.uuid {
|
||||
// TODO: Implement list_all_persons(uuid)
|
||||
return Ok(Json(serde_json::json!({ "message": "List all pending" })));
|
||||
}
|
||||
|
||||
Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({ "error": "Missing uuid" })),
|
||||
))
|
||||
}
|
||||
|
||||
/// POST /api/v1/who/candidates
|
||||
/// Search person_identities table for n8n workflow
|
||||
pub async fn get_who_candidates(
|
||||
State(state): State<crate::api::server::AppState>,
|
||||
Json(req): Json<WhoCandidatesRequest>,
|
||||
) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
|
||||
let db = &state.db;
|
||||
let limit = req.limit.unwrap_or(20);
|
||||
let query_str = format!("%{}%", req.query);
|
||||
|
||||
let results = db
|
||||
.search_person_candidates(&query_str, &req.video_uuid, limit)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": e.to_string() })),
|
||||
)
|
||||
})?;
|
||||
|
||||
// Format for n8n
|
||||
let response = serde_json::json!({
|
||||
"query": req.query,
|
||||
"items": results,
|
||||
"total": results.len()
|
||||
});
|
||||
|
||||
Ok(Json(response))
|
||||
}
|
||||
|
||||
/// POST /api/v1/who
|
||||
pub async fn define_person(
|
||||
State(state): State<crate::api::server::AppState>,
|
||||
Json(req): Json<DefinePersonRequest>,
|
||||
) -> Result<Json<WhoIdentity>, (StatusCode, Json<serde_json::Value>)> {
|
||||
let db = &state.db;
|
||||
|
||||
let identity = db
|
||||
.create_or_update_person(
|
||||
&req.uuid,
|
||||
req.identity_id,
|
||||
req.name.clone(),
|
||||
req.face_ids.unwrap_or_default(),
|
||||
req.speaker_ids.unwrap_or_default(),
|
||||
)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": e.to_string() })),
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(Json(identity))
|
||||
}
|
||||
|
||||
// --- Router Setup ---
|
||||
|
||||
pub fn who_routes() -> Router<crate::api::server::AppState> {
|
||||
Router::new()
|
||||
.route("/api/v1/who", get(get_who_identity).post(define_person))
|
||||
.route("/api/v1/who/candidates", post(get_who_candidates))
|
||||
}
|
||||
Reference in New Issue
Block a user