feat: backup architecture docs, source code, and scripts

This commit is contained in:
Warren
2026-04-25 17:15:45 +08:00
parent 59809dae1f
commit 1f84e5469f
368 changed files with 146329 additions and 261 deletions

936
src/api/face_recognition.rs Normal file
View File

@@ -0,0 +1,936 @@
use axum::{
extract::{Multipart, Path, Query, State},
http::StatusCode,
response::Json,
routing::{get, post},
Router,
};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::core::db::{schema, Database, PostgresDb};
use crate::core::processor::face_recognition::{
process_face_recognition, register_face, FaceRecognitionResult, FaceRegistrationResult,
};
#[derive(Debug, Deserialize)]
pub struct FaceRecognitionRequest {
pub video_uuid: String,
pub enable_recognition: Option<bool>,
pub enable_tracking: Option<bool>,
pub enable_clustering: Option<bool>,
pub database_path: Option<String>,
}
#[derive(Debug, Serialize)]
pub struct FaceRecognitionResponse {
pub success: bool,
pub message: String,
pub result: Option<FaceRecognitionResult>,
pub processing_id: String,
}
#[derive(Debug, Deserialize)]
pub struct FaceRegistrationRequest {
pub video_uuid: String,
pub name: String,
pub metadata: Option<serde_json::Value>,
}
#[derive(Debug, Serialize)]
pub struct FaceRegistrationApiResponse {
pub success: bool,
pub message: String,
pub result: Option<FaceRegistrationResult>,
}
#[derive(Debug, Deserialize)]
pub struct FaceSearchRequest {
pub video_uuid: String,
pub embedding: Vec<f32>,
pub similarity_threshold: Option<f64>,
pub limit: Option<i32>,
}
#[derive(Debug, Serialize)]
pub struct FaceSearchResponse {
pub success: bool,
pub message: String,
pub results: Vec<FaceSearchResult>,
}
#[derive(Debug, Serialize)]
pub struct FaceSearchResult {
pub face_id: String,
pub name: Option<String>,
pub similarity: f64,
pub attributes: Option<serde_json::Value>,
pub metadata: Option<serde_json::Value>,
}
#[derive(Debug, Deserialize)]
pub struct FaceListQuery {
pub video_uuid: String,
pub page: Option<usize>,
pub page_size: Option<usize>,
pub active_only: Option<bool>,
}
#[derive(Debug, Serialize)]
pub struct FaceListResponse {
pub success: bool,
pub message: String,
pub faces: Vec<FaceListItem>,
pub count: i64,
pub page: usize,
pub page_size: usize,
}
#[derive(Debug, Serialize)]
pub struct FaceListItem {
pub face_id: String,
pub name: Option<String>,
pub created_at: String,
pub updated_at: String,
pub is_active: bool,
pub metadata: Option<serde_json::Value>,
}
pub fn face_recognition_routes() -> Router<crate::api::server::AppState> {
Router::new()
.route("/api/v1/face/recognize", post(recognize_faces))
.route("/api/v1/face/register", post(register_face_api))
.route("/api/v1/face/search", post(search_faces))
.route("/api/v1/face/list", get(list_faces))
.route("/api/v1/face/:face_id", get(get_face_details))
.route("/api/v1/face/:face_id", axum::routing::delete(delete_face))
.route(
"/api/v1/face/results/:video_uuid",
get(get_recognition_results),
)
}
async fn recognize_faces(
State(_state): State<crate::api::server::AppState>,
Json(request): Json<FaceRecognitionRequest>,
) -> Result<Json<FaceRecognitionResponse>, (StatusCode, String)> {
let processing_id = Uuid::new_v4().to_string();
tracing::info!(
"[FACE_RECOGNITION] Starting recognition for video: {}, processing_id: {}",
request.video_uuid,
processing_id
);
// Get video path from database
let db = match PostgresDb::init().await {
Ok(db) => db,
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to connect to database: {}", e),
))
}
};
let video_record = match db.get_video_by_uuid(&request.video_uuid).await {
Ok(Some(record)) => record,
Ok(None) => {
return Err((
StatusCode::NOT_FOUND,
format!("Video not found: {}", request.video_uuid),
))
}
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to fetch video: {}", e),
))
}
};
let video_path = video_record.file_path;
let output_path = format!(
"{}/face_recognition_{}.json",
crate::core::config::OUTPUT_DIR.as_str(),
processing_id
);
// Process face recognition
let result = match process_face_recognition(
&video_path,
&output_path,
Some(&processing_id),
request.enable_recognition.unwrap_or(true),
request.enable_tracking.unwrap_or(true),
request.enable_clustering.unwrap_or(true),
)
.await
{
Ok(result) => result,
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("Face recognition failed: {}", e),
))
}
};
// Store results in database
if let Err(e) = store_recognition_results(&db, &request.video_uuid, &result).await {
tracing::warn!("Failed to store recognition results: {}", e);
}
Ok(Json(FaceRecognitionResponse {
success: true,
message: format!("Face recognition completed for {}", request.video_uuid),
result: Some(result),
processing_id,
}))
}
async fn register_face_api(
State(_state): State<crate::api::server::AppState>,
mut multipart: Multipart,
) -> Result<Json<FaceRegistrationApiResponse>, (StatusCode, String)> {
let mut image_path: Option<String> = None;
let mut name: Option<String> = None;
let mut metadata: Option<serde_json::Value> = None;
// Parse multipart form data
while let Some(field) = multipart.next_field().await.map_err(|e| {
(
StatusCode::BAD_REQUEST,
format!("Failed to parse form data: {}", e),
)
})? {
let field_name = field.name().unwrap_or("").to_string();
match field_name.as_str() {
"image" => {
// Save uploaded image
let file_name = format!("face_registration_{}.jpg", Uuid::new_v4());
let file_path = format!("/tmp/{}", file_name);
let data = field.bytes().await.map_err(|e| {
(
StatusCode::BAD_REQUEST,
format!("Failed to read image data: {}", e),
)
})?;
tokio::fs::write(&file_path, &data).await.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to save image: {}", e),
)
})?;
image_path = Some(file_path);
}
"name" => {
let value = field.text().await.map_err(|e| {
(
StatusCode::BAD_REQUEST,
format!("Failed to read name: {}", e),
)
})?;
name = Some(value);
}
"metadata" => {
let value = field.text().await.map_err(|e| {
(
StatusCode::BAD_REQUEST,
format!("Failed to read metadata: {}", e),
)
})?;
metadata = Some(serde_json::from_str(&value).map_err(|e| {
(
StatusCode::BAD_REQUEST,
format!("Invalid JSON metadata: {}", e),
)
})?);
}
_ => {}
}
}
// Validate required fields
let image_path =
image_path.ok_or((StatusCode::BAD_REQUEST, "Image is required".to_string()))?;
let name = name.ok_or((StatusCode::BAD_REQUEST, "Name is required".to_string()))?;
// Register face
let result = match register_face(&image_path, &name, metadata.clone()).await {
Ok(result) => result,
Err(e) => {
// Clean up temporary file
let _ = tokio::fs::remove_file(&image_path).await;
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("Face registration failed: {}", e),
));
}
};
// Clean up temporary file
let _ = tokio::fs::remove_file(&image_path).await;
// Store in PostgreSQL face_identities table
if result.success && !result.embedding.is_empty() {
let db = match PostgresDb::init().await {
Ok(db) => db,
Err(e) => {
tracing::warn!("[FACE_REGISTRATION] Failed to connect to DB: {}", e);
// Return success even if DB write fails (embedding is still in JSON output)
return Ok(Json(FaceRegistrationApiResponse {
success: result.success,
message: format!("{} (Warning: DB write failed: {})", result.message, e),
result: Some(result),
}));
}
};
// Convert embedding to PostgreSQL vector format
let embedding_str = format!(
"[{}]",
result
.embedding
.iter()
.map(|v| v.to_string())
.collect::<Vec<_>>()
.join(",")
);
// Insert into face_identities
let face_identities_table = schema::table_name("face_identities");
let attrs_json =
serde_json::to_string(&result.attributes).unwrap_or_else(|_| "{}".to_string());
// Use public.vector type to work across schemas
let vector_type = if schema::SCHEMA_PREFIX.as_str().is_empty() {
"vector".to_string()
} else {
"public.vector".to_string()
};
let insert_query = format!(
r#"
INSERT INTO {} (face_id, name, embedding, attributes, metadata, is_active)
VALUES ($1, $2, $3::{}, $4::jsonb, $5, TRUE)
ON CONFLICT (face_id) DO UPDATE SET
name = EXCLUDED.name,
embedding = EXCLUDED.embedding,
attributes = EXCLUDED.attributes,
metadata = COALESCE(EXCLUDED.metadata, {}.metadata),
updated_at = CURRENT_TIMESTAMP,
is_active = TRUE
"#,
face_identities_table, vector_type, face_identities_table
);
match sqlx::query(&insert_query)
.bind(&result.face_id)
.bind(&name)
.bind(&embedding_str)
.bind(&attrs_json)
.bind(&metadata.unwrap_or(serde_json::json!({})))
.execute(db.pool())
.await
{
Ok(_) => {
tracing::info!(
"[FACE_REGISTRATION] Stored face '{}' (face_id={}) in DB",
name,
result.face_id
);
}
Err(e) => {
tracing::warn!("[FACE_REGISTRATION] Failed to store face in DB: {}", e);
}
}
}
Ok(Json(FaceRegistrationApiResponse {
success: result.success,
message: result.message.clone(),
result: Some(result),
}))
}
async fn search_faces(
State(_state): State<crate::api::server::AppState>,
Json(request): Json<FaceSearchRequest>,
) -> Result<Json<FaceSearchResponse>, (StatusCode, String)> {
let db = match PostgresDb::init().await {
Ok(db) => db,
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to connect to database: {}", e),
))
}
};
// Convert embedding to PostgreSQL vector format
let embedding_str = format!(
"[{}]",
request
.embedding
.iter()
.map(|v| v.to_string())
.collect::<Vec<_>>()
.join(",")
);
let similarity_threshold = request.similarity_threshold.unwrap_or(0.6);
let limit = request.limit.unwrap_or(10);
// Search for similar faces
let face_identities_table = schema::table_name("face_identities");
let vector_type = if schema::SCHEMA_PREFIX.as_str().is_empty() {
"vector".to_string()
} else {
"public.vector".to_string()
};
let query = format!(
r#"
SELECT
face_id,
name,
1 - (embedding <=> $1::{}) as similarity,
attributes,
metadata
FROM {}
WHERE is_active = TRUE
AND embedding IS NOT NULL
AND 1 - (embedding <=> $1::{}) >= $2
ORDER BY embedding <=> $1::{}
LIMIT $3
"#,
vector_type, face_identities_table, vector_type, vector_type
);
let results: Vec<FaceSearchResult> = match sqlx::query_as::<
_,
(
String,
Option<String>,
f64,
Option<serde_json::Value>,
Option<serde_json::Value>,
),
>(query.as_str())
.bind(&embedding_str)
.bind(similarity_threshold)
.bind(limit)
.fetch_all(db.pool())
.await
{
Ok(rows) => rows
.into_iter()
.map(
|(face_id, name, similarity, attributes, metadata)| FaceSearchResult {
face_id,
name,
similarity,
attributes,
metadata,
},
)
.collect(),
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to search faces: {}", e),
))
}
};
Ok(Json(FaceSearchResponse {
success: true,
message: format!("Found {} similar faces", results.len()),
results,
}))
}
async fn list_faces(
State(_state): State<crate::api::server::AppState>,
Query(query): Query<FaceListQuery>,
) -> Result<Json<FaceListResponse>, (StatusCode, String)> {
let db = match PostgresDb::init().await {
Ok(db) => db,
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to connect to database: {}", e),
))
}
};
let page = query.page.unwrap_or(1);
let page_size = query.page_size.unwrap_or(20);
let offset = ((page - 1) as i64) * (page_size as i64);
let active_only = query.active_only.unwrap_or(true);
// Build query
let mut where_clause = "WHERE 1=1".to_string();
if active_only {
where_clause.push_str(" AND is_active = TRUE");
}
let face_identities_table = schema::table_name("face_identities");
let count_query = format!(
"SELECT COUNT(*) FROM {} {}",
face_identities_table, where_clause
);
let list_query = format!(
"SELECT face_id, name, created_at, updated_at, is_active, metadata FROM {} {} ORDER BY created_at DESC LIMIT $1 OFFSET $2",
face_identities_table, where_clause
);
// Get total count
let total: i64 = match sqlx::query_scalar(&count_query).fetch_one(db.pool()).await {
Ok(count) => count,
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to count faces: {}", e),
))
}
};
// Get face list
let faces: Vec<FaceListItem> = match sqlx::query_as::<
_,
(
String,
Option<String>,
chrono::DateTime<chrono::Utc>,
chrono::DateTime<chrono::Utc>,
bool,
Option<serde_json::Value>,
),
>(&list_query)
.bind(page_size as i32)
.bind(offset)
.fetch_all(db.pool())
.await
{
Ok(rows) => rows
.into_iter()
.map(
|(face_id, name, created_at, updated_at, is_active, metadata)| FaceListItem {
face_id,
name,
created_at: created_at.to_rfc3339(),
updated_at: updated_at.to_rfc3339(),
is_active,
metadata,
},
)
.collect(),
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to list faces: {}", e),
))
}
};
Ok(Json(FaceListResponse {
success: true,
message: format!("Found {} faces", total),
faces,
count: total,
page,
page_size,
}))
}
async fn get_face_details(
State(_state): State<crate::api::server::AppState>,
Path(face_id): Path<String>,
) -> Result<Json<serde_json::Value>, (StatusCode, String)> {
let db = match PostgresDb::init().await {
Ok(db) => db,
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to connect to database: {}", e),
))
}
};
let face_identities_table = schema::table_name("face_identities");
let query = format!(
r#"
SELECT
face_id,
name,
embedding,
attributes,
metadata,
created_at,
updated_at,
is_active
FROM {}
WHERE face_id = $1
"#,
face_identities_table
);
let face: Option<(
String,
Option<String>,
Option<String>,
Option<serde_json::Value>,
Option<serde_json::Value>,
chrono::DateTime<chrono::Utc>,
chrono::DateTime<chrono::Utc>,
bool,
)> = match sqlx::query_as(&query)
.bind(&face_id)
.fetch_optional(db.pool())
.await
{
Ok(face) => face,
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to fetch face details: {}", e),
))
}
};
match face {
Some((
face_id,
name,
embedding,
attributes,
metadata,
created_at,
updated_at,
is_active,
)) => {
let response = serde_json::json!({
"success": true,
"face_id": face_id,
"name": name,
"has_embedding": embedding.is_some(),
"attributes": attributes,
"metadata": metadata,
"created_at": created_at.to_rfc3339(),
"updated_at": updated_at.to_rfc3339(),
"is_active": is_active
});
Ok(Json(response))
}
None => Err((
StatusCode::NOT_FOUND,
format!("Face not found: {}", face_id),
)),
}
}
async fn delete_face(
State(_state): State<crate::api::server::AppState>,
Path(face_id): Path<String>,
) -> Result<Json<serde_json::Value>, (StatusCode, String)> {
let db = match PostgresDb::init().await {
Ok(db) => db,
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to connect to database: {}", e),
))
}
};
// Soft delete by marking as inactive
let face_identities_table = schema::table_name("face_identities");
let query = format!(
r#"
UPDATE {}
SET is_active = FALSE, updated_at = CURRENT_TIMESTAMP
WHERE face_id = $1 AND is_active = TRUE
RETURNING face_id, name
"#,
face_identities_table
);
let deleted: Option<(String, Option<String>)> = match sqlx::query_as(&query)
.bind(&face_id)
.fetch_optional(db.pool())
.await
{
Ok(result) => result,
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to delete face: {}", e),
))
}
};
match deleted {
Some((deleted_id, name)) => {
let response = serde_json::json!({
"success": true,
"message": format!("Face '{}' deleted successfully", name.clone().unwrap_or_else(|| deleted_id.clone())),
"face_id": deleted_id.clone()
});
Ok(Json(response))
}
None => Err((
StatusCode::NOT_FOUND,
format!("Face not found or already deleted: {}", face_id),
)),
}
}
async fn get_recognition_results(
State(_state): State<crate::api::server::AppState>,
Path(video_uuid): Path<String>,
) -> Result<Json<serde_json::Value>, (StatusCode, String)> {
let db = match PostgresDb::init().await {
Ok(db) => db,
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to connect to database: {}", e),
))
}
};
let query = r#"
SELECT
video_uuid,
frame_count,
fps,
total_faces,
recognized_faces,
clusters_count,
result_data,
processing_time_secs,
created_at
FROM face_recognition_results
WHERE video_uuid = $1
ORDER BY created_at DESC
LIMIT 1
"#;
let result: Option<(
String,
i64,
f64,
i32,
i32,
i32,
serde_json::Value,
Option<f64>,
chrono::DateTime<chrono::Utc>,
)> = match sqlx::query_as(query)
.bind(&video_uuid)
.fetch_optional(db.pool())
.await
{
Ok(result) => result,
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to fetch recognition results: {}", e),
))
}
};
match result {
Some((
video_uuid,
frame_count,
fps,
total_faces,
recognized_faces,
clusters_count,
result_data,
processing_time_secs,
created_at,
)) => {
let response = serde_json::json!({
"success": true,
"video_uuid": video_uuid,
"frame_count": frame_count,
"fps": fps,
"total_faces": total_faces,
"recognized_faces": recognized_faces,
"clusters_count": clusters_count,
"result_data": result_data,
"processing_time_secs": processing_time_secs,
"created_at": created_at.to_rfc3339()
});
Ok(Json(response))
}
None => Err((
StatusCode::NOT_FOUND,
format!("No recognition results found for video: {}", video_uuid),
)),
}
}
async fn store_recognition_results(
db: &PostgresDb,
video_uuid: &str,
result: &FaceRecognitionResult,
) -> Result<(), anyhow::Error> {
let total_faces = result.frames.iter().map(|f| f.faces.len()).sum::<usize>();
let recognized_faces = result
.frames
.iter()
.flat_map(|f| &f.faces)
.filter(|face| face.identity.is_some())
.count();
let query = r#"
INSERT INTO face_recognition_results (
video_uuid,
frame_count,
fps,
total_faces,
recognized_faces,
clusters_count,
result_data
) VALUES ($1, $2, $3, $4, $5, $6, $7)
ON CONFLICT (video_uuid) DO UPDATE SET
frame_count = EXCLUDED.frame_count,
fps = EXCLUDED.fps,
total_faces = EXCLUDED.total_faces,
recognized_faces = EXCLUDED.recognized_faces,
clusters_count = EXCLUDED.clusters_count,
result_data = EXCLUDED.result_data,
updated_at = CURRENT_TIMESTAMP
"#;
sqlx::query(query)
.bind(video_uuid)
.bind(result.frame_count as i64)
.bind(result.fps)
.bind(total_faces as i32)
.bind(recognized_faces as i32)
.bind(result.face_clusters.len() as i32)
.bind(serde_json::to_value(result)?)
.execute(db.pool())
.await?;
// Store individual face detections
for frame in &result.frames {
for face in &frame.faces {
if let Some(embedding) = &face.embedding {
let embedding_str = format!(
"[{}]",
embedding
.iter()
.map(|v| v.to_string())
.collect::<Vec<_>>()
.join(",")
);
let insert_query = r#"
INSERT INTO face_detections (
video_uuid,
frame_number,
timestamp_secs,
face_id,
x,
y,
width,
height,
confidence,
embedding,
attributes,
identity_confidence,
cluster_id
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10::vector, $11, $12, $13)
ON CONFLICT (video_uuid, frame_number, x, y, width, height) DO UPDATE SET
face_id = EXCLUDED.face_id,
confidence = EXCLUDED.confidence,
embedding = EXCLUDED.embedding,
attributes = EXCLUDED.attributes,
identity_confidence = EXCLUDED.identity_confidence,
cluster_id = EXCLUDED.cluster_id
"#;
let identity_confidence = face.identity.as_ref().map(|id| id.confidence as f64);
let cluster_id = result
.face_clusters
.iter()
.find(|c| {
c.face_ids
.contains(&face.face_id.clone().unwrap_or_default())
})
.map(|c| c.cluster_id.clone());
sqlx::query(insert_query)
.bind(video_uuid)
.bind(frame.frame as i64)
.bind(frame.timestamp)
.bind(face.face_id.as_deref())
.bind(face.x)
.bind(face.y)
.bind(face.width)
.bind(face.height)
.bind(face.confidence as f64)
.bind(&embedding_str)
.bind(serde_json::to_value(&face.attributes)?)
.bind(identity_confidence)
.bind(cluster_id)
.execute(db.pool())
.await?;
}
}
}
// Store face clusters
for cluster in &result.face_clusters {
let centroid = &cluster.centroid;
let centroid_str = format!(
"[{}]",
centroid
.iter()
.map(|v: &f32| v.to_string())
.collect::<Vec<_>>()
.join(",")
);
let cluster_query = r#"
INSERT INTO face_clusters (
cluster_id,
video_uuid,
centroid,
size,
representative_face_id,
metadata
) VALUES ($1, $2, $3::vector, $4, $5, $6)
ON CONFLICT (cluster_id) DO UPDATE SET
centroid = EXCLUDED.centroid,
size = EXCLUDED.size,
representative_face_id = EXCLUDED.representative_face_id,
metadata = EXCLUDED.metadata
"#;
sqlx::query(cluster_query)
.bind(&cluster.cluster_id)
.bind(video_uuid)
.bind(&centroid_str)
.bind(cluster.size as i32)
.bind(cluster.representative_face_id.as_deref())
.bind(&cluster.metadata)
.execute(db.pool())
.await?;
}
Ok(())
}

288
src/api/identities.rs Normal file
View File

@@ -0,0 +1,288 @@
use axum::{
extract::{Query, State},
http::StatusCode,
response::Json,
routing::{get, post},
Router,
};
use serde::{Deserialize, Serialize};
use crate::core::db::{schema, Database, PostgresDb};
#[derive(Debug, Deserialize)]
pub struct RegisterFromPersonRequest {
pub video_uuid: String,
pub person_id: String,
pub identity_name: String,
pub metadata: Option<serde_json::Value>,
}
#[derive(Debug, Serialize)]
pub struct RegisterFromPersonResponse {
pub success: bool,
pub message: String,
pub identity_id: i32,
pub identity_name: String,
pub person_id: String,
}
pub fn identity_routes() -> Router<crate::api::server::AppState> {
Router::new()
.route("/api/v1/identities/from-person", post(register_from_person))
.route("/api/v1/identities", get(list_identities))
}
/// Register a Global Identity from a specific Person in a video.
/// This creates/updates the Identity record, links the Person to the Identity,
/// and updates the Person's name to match the Identity.
async fn register_from_person(
State(_state): State<crate::api::server::AppState>,
Json(req): Json<RegisterFromPersonRequest>,
) -> Result<Json<RegisterFromPersonResponse>, (StatusCode, String)> {
let db = match PostgresDb::init().await {
Ok(db) => db,
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("DB error: {}", e),
))
}
};
let mut tx = match db.pool().begin().await {
Ok(tx) => tx,
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("Tx error: {}", e),
))
}
};
// 1. Check if Person exists
let person_query =
"SELECT id, name FROM person_identities WHERE person_id = $1 AND video_uuid = $2";
let person: Option<(i32, Option<String>)> = match sqlx::query_as(person_query)
.bind(&req.person_id)
.bind(&req.video_uuid)
.fetch_optional(&mut *tx)
.await
{
Ok(p) => p,
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("Query error: {}", e),
))
}
};
let (person_db_id, _old_name) = match person {
Some(p) => p,
None => {
return Err((
StatusCode::NOT_FOUND,
format!(
"Person '{}' not found in video '{}'",
req.person_id, req.video_uuid
),
))
}
};
// 2. Check if Identity exists
let identity_query = "SELECT id FROM identities WHERE name = $1";
let identity_id: Option<i32> = match sqlx::query_scalar(identity_query)
.bind(&req.identity_name)
.fetch_optional(&mut *tx)
.await
{
Ok(id) => id,
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("Query error: {}", e),
))
}
};
let final_identity_id = if let Some(id) = identity_id {
id
} else {
// Create new Identity
let meta_json = req.metadata.clone().unwrap_or(serde_json::json!({}));
let new_id: i32 = match sqlx::query_scalar(
r#"
INSERT INTO identities (name, embedding, metadata)
VALUES ($1, NULLIF($2, '')::public.vector, $3)
RETURNING id
"#,
)
.bind(&req.identity_name)
.bind("".to_string()) // No embedding for now via this API
.bind(&meta_json)
.fetch_one(&mut *tx)
.await
{
Ok(id) => id,
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("Insert identity error: {}", e),
))
}
};
new_id
};
// 3. Create Binding
// Columns: id, identity_id, identity_type, identity_value, confidence, metadata, created_at
let binding_query = r#"
INSERT INTO identity_bindings (identity_id, identity_type, identity_value, confidence, metadata)
VALUES ($1, $2, $3, $4, $5)
ON CONFLICT DO NOTHING
"#;
match sqlx::query(binding_query)
.bind(final_identity_id)
.bind("person_id") // identity_type
.bind(&req.person_id) // identity_value
.bind(1.0) // confidence
.bind(&serde_json::json!({"auto_updated": true}))
.execute(&mut *tx)
.await
{
Ok(_) => {}
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("Binding error: {}", e),
))
}
};
// 4. Update Person Name
let update_person = "UPDATE person_identities SET name = $1 WHERE id = $2";
match sqlx::query(update_person)
.bind(&req.identity_name)
.bind(person_db_id)
.execute(&mut *tx)
.await
{
Ok(_) => {}
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("Update person error: {}", e),
))
}
};
match tx.commit().await {
Ok(_) => {}
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("Commit error: {}", e),
))
}
};
Ok(Json(RegisterFromPersonResponse {
success: true,
message: format!(
"Successfully registered identity '{}' and linked to person '{}'",
req.identity_name, req.person_id
),
identity_id: final_identity_id,
identity_name: req.identity_name,
person_id: req.person_id,
}))
}
/// List all global identities
async fn list_identities(
State(_state): State<crate::api::server::AppState>,
Query(query): Query<ListIdentitiesQuery>,
) -> Result<Json<IdentityListResponse>, (StatusCode, String)> {
let db = match PostgresDb::init().await {
Ok(db) => db,
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("DB error: {}", e),
))
}
};
let page = query.page.unwrap_or(1);
let page_size = query.page_size.unwrap_or(20);
let offset = ((page - 1) as i64) * (page_size as i64);
// 獲取總數
let count_sql = "SELECT COUNT(*) FROM identities";
let total: i64 = match sqlx::query_scalar(count_sql).fetch_one(db.pool()).await {
Ok(count) => count,
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("Count error: {}", e),
))
}
};
let sql = "SELECT id, name, metadata FROM identities ORDER BY id DESC LIMIT $1 OFFSET $2";
let rows: Vec<(i32, String, Option<serde_json::Value>)> = match sqlx::query_as(sql)
.bind(page_size as i64)
.bind(offset)
.fetch_all(db.pool())
.await
{
Ok(rows) => rows,
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("Query error: {}", e),
))
}
};
let identities: Vec<IdentityResponse> = rows
.into_iter()
.map(|r| IdentityResponse {
id: r.0,
name: r.1,
metadata: r.2,
})
.collect();
Ok(Json(IdentityListResponse {
identities,
count: total,
page,
page_size,
}))
}
#[derive(Debug, Deserialize)]
pub struct ListIdentitiesQuery {
pub page: Option<usize>,
pub page_size: Option<usize>,
}
#[derive(Debug, Serialize)]
pub struct IdentityResponse {
pub id: i32,
pub name: String,
pub metadata: Option<serde_json::Value>,
}
#[derive(Debug, Serialize)]
pub struct IdentityListResponse {
pub identities: Vec<IdentityResponse>,
pub count: i64,
pub page: usize,
pub page_size: usize,
}

412
src/api/identity_binding.rs Normal file
View File

@@ -0,0 +1,412 @@
use axum::{
extract::{Path, Query},
http::StatusCode,
response::Json,
routing::{get, post},
Router,
};
use serde::{Deserialize, Serialize};
use crate::core::db::{Database, PostgresDb};
use crate::core::person_identity::{BindIdentityRequest, Identity, UnbindIdentityRequest};
#[derive(Debug, Clone, Serialize)]
pub struct ApiResponse<T: Serialize> {
pub success: bool,
pub message: String,
pub data: Option<T>,
}
// ============================================================================
// API Handlers
// ============================================================================
async fn get_db() -> Result<PostgresDb, (StatusCode, Json<serde_json::Value>)> {
PostgresDb::init().await.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": format!("DB init failed: {}", e) })),
)
})
}
/// 獲取 Identity (人物) 列表
pub async fn list_identities(
Query(params): Query<ListIdentitiesParams>,
) -> Result<Json<ApiResponse<Vec<Identity>>>, (StatusCode, Json<serde_json::Value>)> {
let db = get_db().await?;
let limit = params.limit.unwrap_or(100);
let offset = params.offset.unwrap_or(0);
let search = params.search.unwrap_or_default();
let identities = db
.list_identities(&search, limit, offset)
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": e.to_string() })),
)
})?;
Ok(Json(ApiResponse {
success: true,
message: format!("Found {} identities", identities.len()),
data: Some(identities),
}))
}
/// 綁定身份 (Face/Speaker -> Identity)
pub async fn bind_identity(
Json(req): Json<BindIdentityRequest>,
) -> Result<Json<ApiResponse<()>>, (StatusCode, Json<serde_json::Value>)> {
let db = get_db().await?;
let identity = if let Some(id_id) = req.identity_id {
db.get_identity_by_id(id_id).await.ok().flatten()
} else if let Some(name) = &req.name {
db.get_or_create_identity(name).await.ok()
} else {
None
};
let identity = match identity {
Some(t) => t,
None => {
return Err((
StatusCode::BAD_REQUEST,
Json(serde_json::json!({ "error": "Identity not found or name required" })),
));
}
};
let source = req.source.unwrap_or("manual".to_string());
db.bind_identity(
identity.id as i64,
&req.binding_type,
&req.binding_value,
&source,
1.0,
)
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": e.to_string() })),
)
})?;
Ok(Json(ApiResponse {
success: true,
message: format!(
"Bound {} '{}' to Identity '{}'",
req.binding_type, req.binding_value, identity.name
),
data: None,
}))
}
/// 解綁身份
pub async fn unbind_identity(
Json(req): Json<UnbindIdentityRequest>,
) -> Result<Json<ApiResponse<()>>, (StatusCode, Json<serde_json::Value>)> {
let db = get_db().await?;
db.unbind_identity(&req.binding_type, &req.binding_value)
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": e.to_string() })),
)
})?;
Ok(Json(ApiResponse {
success: true,
message: format!("Unbound {} '{}'", req.binding_type, req.binding_value),
data: None,
}))
}
/// 查詢機器 ID 對應的 Identity (人物)
pub async fn get_identity_info(
Path((binding_type, binding_value)): Path<(String, String)>,
) -> Result<Json<ApiResponse<Identity>>, (StatusCode, Json<serde_json::Value>)> {
let db = get_db().await?;
let identity = db
.get_identity_by_binding(&binding_type, &binding_value)
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": e.to_string() })),
)
})?;
let identity = identity.ok_or_else(|| {
(
StatusCode::NOT_FOUND,
Json(serde_json::json!({ "error": "Identity not found" })),
)
})?;
Ok(Json(ApiResponse {
success: true,
message: "Identity info retrieved".to_string(),
data: Some(identity),
}))
}
/// 列出未綁定的信號 (待標註列表)
pub async fn list_unbound_signals(
Query(params): Query<ListSignalsParams>,
) -> Result<Json<ApiResponse<Vec<String>>>, (StatusCode, Json<serde_json::Value>)> {
let db = get_db().await?;
let signals = db
.list_unbound_signals(&params.uuid, &params.binding_type)
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": e.to_string() })),
)
})?;
Ok(Json(ApiResponse {
success: true,
message: format!(
"Found {} unbound {} signals",
signals.len(),
params.binding_type
),
data: Some(signals),
}))
}
/// 獲取特定信號 (Face ID 或 Speaker ID) 出現的所有 Chunk (時間軸)
pub async fn get_signal_timeline(
Path((uuid, binding_type, binding_value)): Path<(String, String, String)>,
) -> Result<Json<ApiResponse<Vec<serde_json::Value>>>, (StatusCode, Json<serde_json::Value>)> {
let db = get_db().await?;
let chunks = db
.get_chunks_by_signal(&uuid, &binding_type, &binding_value)
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": e.to_string() })),
)
})?;
Ok(Json(ApiResponse {
success: true,
message: format!(
"Found {} chunks for {} '{}'",
chunks.len(),
binding_type,
binding_value
),
data: Some(chunks),
}))
}
#[derive(Debug, Deserialize)]
pub struct AVSuggestRequest {
pub video_uuid: String,
pub overlap_threshold: Option<f64>, // default 0.6
}
#[derive(Debug, Serialize)]
pub struct AVSuggestion {
pub face_id: String,
pub speaker_id: String,
pub overlap_score: f64,
pub face_talent_id: Option<i32>,
pub speaker_talent_id: Option<i32>,
}
/// Suggests Face-Speaker bindings based on temporal overlap
pub async fn suggest_audio_visual_bindings(
Json(req): Json<AVSuggestRequest>,
) -> Result<Json<ApiResponse<Vec<AVSuggestion>>>, (StatusCode, Json<serde_json::Value>)> {
let db = get_db().await?;
let threshold = req.overlap_threshold.unwrap_or(0.6);
// 1. Get Face signals and their time ranges
let face_signals = db
.list_unbound_signals(&req.video_uuid, "face")
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": format!("Face signals: {}", e) })),
)
})?;
let speaker_signals = db
.list_unbound_signals(&req.video_uuid, "speaker")
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": format!("Speaker signals: {}", e) })),
)
})?;
let mut suggestions = Vec::new();
for face_id in &face_signals {
for speaker_id in &speaker_signals {
// Calculate overlap
// In a real implementation, we would query the exact timestamps from DB.
// For now, we'll use a placeholder or simple heuristic if timestamps are available in signal timeline.
// Let's assume we fetch timelines.
// Placeholder: Calculate overlap by fetching timelines
let face_timeline = db
.get_chunks_by_signal(&req.video_uuid, "face", face_id)
.await
.unwrap_or_default();
let speaker_timeline = db
.get_chunks_by_signal(&req.video_uuid, "speaker", speaker_id)
.await
.unwrap_or_default();
// Simplified overlap calculation based on chunk count/ids (assuming chunk_id contains time info or we have ranges)
// Since chunk IDs are generic, we rely on content JSON having 'start_time'
let overlap = calculate_overlap(&face_timeline, &speaker_timeline);
if overlap >= threshold {
// Check if they are already bound to the same identity
let face_identity = db
.get_identity_by_binding("face", face_id)
.await
.ok()
.flatten();
let speaker_identity = db
.get_identity_by_binding("speaker", speaker_id)
.await
.ok()
.flatten();
// If both bound to different identities, don't suggest (conflict)
if let (Some(fi), Some(si)) = (&face_identity, &speaker_identity) {
if fi.id != si.id {
continue;
}
}
suggestions.push(AVSuggestion {
face_id: face_id.clone(),
speaker_id: speaker_id.clone(),
overlap_score: overlap,
face_talent_id: face_identity.as_ref().map(|i| i.id as i32),
speaker_talent_id: speaker_identity.as_ref().map(|i| i.id as i32),
});
}
}
}
suggestions.sort_by(|a, b| b.overlap_score.partial_cmp(&a.overlap_score).unwrap());
Ok(Json(ApiResponse {
success: true,
message: format!("Found {} AV suggestions", suggestions.len()),
data: Some(suggestions),
}))
}
fn calculate_overlap(
face_chunks: &[serde_json::Value],
speaker_chunks: &[serde_json::Value],
) -> f64 {
// Simplified: Extract start/end times and calculate intersection over union
// Assuming chunks have start_frame or start_time in content JSON
// If content is raw string, we might need to parse it.
// In our schema, content is JSONB.
let mut face_ranges: Vec<(f64, f64)> = Vec::new();
for c in face_chunks {
if let Some(content) = c.get("content") {
if let (Some(start), Some(end)) = (
content.get("start_time").and_then(|v| v.as_f64()),
content.get("end_time").and_then(|v| v.as_f64()),
) {
face_ranges.push((start, end));
}
}
}
let mut speaker_ranges: Vec<(f64, f64)> = Vec::new();
for c in speaker_chunks {
if let Some(content) = c.get("content") {
if let (Some(start), Some(end)) = (
content.get("start_time").and_then(|v| v.as_f64()),
content.get("end_time").and_then(|v| v.as_f64()),
) {
speaker_ranges.push((start, end));
}
}
}
let mut overlap_duration = 0.0;
for (fs, fe) in &face_ranges {
for (ss, se) in &speaker_ranges {
let start = fs.max(*ss);
let end = fe.min(*se);
if start < end {
overlap_duration += end - start;
}
}
}
// Return normalized overlap (0.0 to 1.0+), simple version: overlap / min_duration
let min_duration = face_ranges
.iter()
.map(|(_, e)| e)
.sum::<f64>()
.min(speaker_ranges.iter().map(|(_, e)| e).sum::<f64>());
if min_duration > 0.0 {
(overlap_duration / min_duration).min(1.0)
} else {
0.0
}
}
// ============================================================================
// Router Setup
// ============================================================================
#[derive(Debug, Deserialize)]
pub struct ListIdentitiesParams {
pub search: Option<String>,
pub limit: Option<i32>,
pub offset: Option<i32>,
}
#[derive(Debug, Deserialize)]
pub struct ListSignalsParams {
pub uuid: String,
pub binding_type: String, // "face" or "speaker"
}
pub fn identity_binding_routes() -> Router<crate::api::server::AppState> {
Router::new()
.route("/api/v1/identities/bind", post(bind_identity))
.route("/api/v1/identities/unbind", post(unbind_identity))
.route(
"/api/v1/identity/:binding_type/:binding_value",
get(get_identity_info),
)
// 信號發現 (Discovery)
.route("/api/v1/signals/unbound", get(list_unbound_signals))
// 信號時間軸 (Timeline)
.route(
"/api/v1/signals/:uuid/:binding_type/:binding_value/timeline",
get(get_signal_timeline),
)
.route(
"/api/v1/identities/suggest-av",
post(suggest_audio_visual_bindings),
)
}

264
src/api/n8n_search.rs Normal file
View File

@@ -0,0 +1,264 @@
use crate::core::db::{Bm25Result, PostgresDb};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
#[derive(Debug, Deserialize)]
pub struct SmartSearchRequest {
pub query: String,
pub uuid: Option<String>,
pub limit: Option<usize>,
}
#[derive(Debug, Serialize)]
pub struct SmartSearchResponse {
pub query: String,
pub parsed_dimensions: serde_json::Value,
pub hits: Vec<serde_json::Value>,
pub total: usize,
}
#[derive(Debug, Deserialize, Serialize)]
struct LlmDimensionResponse {
pub who: Option<String>,
pub what: Option<String>,
pub when: Option<String>,
pub r#where: Option<String>,
pub why: Option<String>,
#[serde(default)]
pub keywords: Vec<String>,
}
/// POST /api/v1/n8n/search/smart
pub async fn n8n_search_smart(
db: &PostgresDb,
req: SmartSearchRequest,
) -> Result<SmartSearchResponse, Box<dyn std::error::Error + Send + Sync>> {
let limit = req.limit.unwrap_or(10);
let video_uuid = req.uuid.clone();
// 1. Call LLM to extract 5W1H (Fallback to keywords if LLM fails)
let dimensions = match parse_query_with_llm(&req.query).await {
Some(dims) => dims,
None => LlmDimensionResponse {
who: None,
what: None,
when: None,
r#where: None,
why: None,
keywords: extract_keywords(&req.query),
},
};
// Prepare search terms based on dimensions
let keywords = dimensions.keywords.join(" ");
let semantic_query = [
dimensions.who.clone(),
dimensions.what.clone(),
dimensions.r#where.clone(),
dimensions.why.clone(),
dimensions.when.clone(),
]
.into_iter()
.flatten()
.collect::<Vec<_>>()
.join(" ");
// 2. Multi-dimensional Search
let mut hits: Vec<serde_json::Value> = Vec::new();
let mut seen_chunk_ids: HashSet<String> = HashSet::new();
// Helper function
fn add_hit(
hits: &mut Vec<serde_json::Value>,
seen_chunk_ids: &mut HashSet<String>,
sr: Bm25Result,
boost: f32,
) {
if seen_chunk_ids.insert(sr.chunk_id.clone()) {
let score = sr.bm25_score * boost;
let val = serde_json::json!({
"id": sr.chunk_id,
"vid": sr.uuid,
"start": sr.start_time,
"end": sr.end_time,
"text": sr.text,
"score": score,
"chunk_type": sr.chunk_type
});
hits.push(val);
}
}
// A. Keyword Search (BM25)
if !keywords.is_empty() {
if let Ok(results) = db
.search_bm25(&keywords, video_uuid.as_deref(), limit)
.await
{
for sr in results {
add_hit(&mut hits, &mut seen_chunk_ids, sr, 1.0);
}
}
}
// B. Who Search (Person Matching)
if let Some(who_query) = &dimensions.who {
// 1. Search Person
if let Ok(persons) = db.search_person_candidates(who_query, &video_uuid, 5).await {
if !persons.is_empty() {
let person_id = persons[0]
.get("candidate_id")
.and_then(|v| v.as_str())
.map(String::from);
if let Some(pid) = person_id {
// Heuristic: Search BM25 for the person's ID or Name
let person_name = persons[0]
.get("display_name")
.and_then(|v| v.as_str())
.unwrap_or(who_query);
// Re-run BM25 with person name to find specific chunks and boost them
if let Ok(results) = db
.search_bm25(person_name, video_uuid.as_deref(), limit)
.await
{
for sr in results {
let id = sr.chunk_id.clone();
if seen_chunk_ids.insert(id) {
let score = sr.bm25_score * 1.5;
let val = serde_json::json!({
"id": sr.chunk_id,
"vid": sr.uuid,
"start": sr.start_time,
"end": sr.end_time,
"text": sr.text,
"score": score,
"matched_person": pid
});
hits.push(val);
}
}
}
}
}
}
}
// Sort by score
hits.sort_by(|a, b| {
let score_a = a.get("score").and_then(|v| v.as_f64()).unwrap_or(0.0);
let score_b = b.get("score").and_then(|v| v.as_f64()).unwrap_or(0.0);
score_b.partial_cmp(&score_a).unwrap()
});
// Limit
hits.truncate(limit);
let total = hits.len();
Ok(SmartSearchResponse {
query: req.query,
parsed_dimensions: serde_json::json!(dimensions),
hits,
total,
})
}
fn extract_keywords(query: &str) -> Vec<String> {
// Simple keyword extraction: remove common stop words and punctuation
let stop_words = [
"who", "what", "where", "when", "why", "how", "is", "the", "a", "an", "and", "or", "of",
"in", "to", "for", "with", "by", "on", "at", "from", "up", "about", "into", "over",
"after",
];
query
.to_lowercase()
.chars()
.map(|c| if c.is_alphanumeric() { c } else { ' ' })
.collect::<String>()
.split_whitespace()
.filter(|w| !stop_words.contains(w))
.map(String::from)
.collect()
}
async fn parse_query_with_llm(query: &str) -> Option<LlmDimensionResponse> {
let client = reqwest::Client::new();
// Test connectivity first
if let Ok(resp) = client.get("http://127.0.0.1:8081/health").send().await {
tracing::info!("LLM Health Check: {}", resp.status());
} else {
tracing::error!("LLM Server is unreachable at 127.0.0.1:8081");
}
// We use the OpenAI-compatible endpoint provided by llama.cpp server (default port 8081)
let prompt = format!(
r#"Analyze the user query and extract the following dimensions into a JSON object.
If a dimension is not present, use null.
Dimensions: "who" (person/subject), "what" (action/event), "where" (location), "when" (time), "why" (reason/intent), "keywords" (array of specific keywords).
User Query: "{}"
Output ONLY the JSON object.
"#,
query
);
let payload = serde_json::json!({
"model": "gemma4",
"messages": [
{
"role": "user",
"content": prompt
}
],
"temperature": 0.1,
"stream": false
});
if let Ok(response) = client
.post("http://127.0.0.1:8081/v1/chat/completions")
.json(&payload)
.timeout(std::time::Duration::from_secs(60))
.send()
.await
{
tracing::info!("LLM Response Status: {}", response.status());
if let Ok(json_resp) = response.json::<serde_json::Value>().await {
tracing::info!("LLM Response Body: {}", json_resp);
if let Some(choices) = json_resp.get("choices").and_then(|v| v.as_array()) {
if let Some(choice) = choices.get(0) {
if let Some(content) = choice
.get("message")
.and_then(|m| m.get("content"))
.and_then(|c| c.as_str())
{
if let Some(start) = content.find("{") {
if let Some(end) = content.rfind("}") {
let json_str = &content[start..=end];
if let Ok(dims) =
serde_json::from_str::<LlmDimensionResponse>(json_str)
{
tracing::info!("Parsed LLM Dimensions: {:?}", dims);
return Some(dims);
} else {
tracing::warn!("Failed to parse LLM JSON: {}", json_str);
}
}
}
}
}
}
} else {
tracing::warn!("Failed to parse LLM response JSON");
}
} else {
tracing::warn!("LLM request failed or timed out");
}
None
}

3281
src/api/person_identity.rs Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

195
src/api/search.rs Normal file
View File

@@ -0,0 +1,195 @@
//! Smart Search API
//! Implements the 5W1H search capability using semantic vectors.
use axum::{extract::State, http::StatusCode, response::Json, routing::post, Router};
use serde::{Deserialize, Serialize};
use serde_json;
use tracing;
use crate::core::db::PostgresDb;
// --- Request / Response Structures ---
#[derive(Debug, Deserialize)]
pub struct SmartSearchRequest {
pub uuid: String,
pub query: String,
pub limit: Option<usize>,
}
#[derive(Debug, Serialize)]
pub struct SearchResult {
pub id: i32,
pub parent_id: i32,
pub scene_order: Option<i32>,
// Primary: frame-accurate position (authoritative unit)
pub start_frame: i64,
pub end_frame: i64,
pub fps: f64,
// Reference: time derived from frames (subject to FPS variation, not precise)
pub start_time: f64,
pub end_time: f64,
pub raw_text: Option<String>, // Text content of the child chunk
pub summary: Option<String>, // Summary from parent context
pub metadata: Option<serde_json::Value>,
pub similarity: Option<f64>,
}
#[derive(Debug, Serialize)]
pub struct SmartSearchResponse {
pub query: String,
pub results: Vec<SearchResult>,
pub strategy: String,
}
// --- API Handler ---
pub async fn smart_search(
State(state): State<crate::api::server::AppState>,
Json(req): Json<SmartSearchRequest>,
) -> Result<Json<SmartSearchResponse>, (StatusCode, Json<serde_json::Value>)> {
let db = &state.db;
let limit = req.limit.unwrap_or(5);
// 1. Generate Embedding using Ollama
let embedding = get_ollama_embedding(&req.query).await.map_err(
|e| -> (StatusCode, Json<serde_json::Value>) {
tracing::error!("Embedding failed: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": e.to_string() })),
)
},
)?;
// 2. Search Database (Drill-Down: Find Parents First)
let db_parents: Vec<crate::core::db::postgres_db::SemanticSearchResult> = db
.search_parent_chunks_semantic(&req.uuid, &embedding, limit)
.await
.map_err(
|e: anyhow::Error| -> (StatusCode, Json<serde_json::Value>) {
tracing::error!("DB search failed: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": e.to_string() })),
)
},
)?;
if db_parents.is_empty() {
return Ok(Json(SmartSearchResponse {
query: req.query,
results: vec![],
strategy: "semantic_vector_search".to_string(),
}));
}
// Collect Parent IDs
let parent_ids: Vec<i32> = db_parents.iter().map(|p| p.id).collect();
// 3. Fetch Children for these Parents (Drill Down)
// We fetch all children for these parents (limit can be adjusted)
let children: Vec<crate::core::db::postgres_db::ChildChunkResult> = db
.get_children_for_parents(&parent_ids, 10) // Fetch top 10 children per parent
.await
.map_err(
|e: anyhow::Error| -> (StatusCode, Json<serde_json::Value>) {
tracing::error!("Fetching children failed: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": e.to_string() })),
)
},
)?;
// 4. Map Parents to a lookup table
let parent_map: std::collections::HashMap<
i32,
&crate::core::db::postgres_db::SemanticSearchResult,
> = db_parents.iter().map(|p| (p.id, p)).collect();
// Map Children to API response struct
let results: Vec<SearchResult> = children
.into_iter()
.map(|c| {
let parent = parent_map.get(&c.parent_id);
SearchResult {
id: c.id,
parent_id: c.parent_id,
scene_order: parent.map(|p| p.scene_order),
start_frame: c.start_frame,
end_frame: c.end_frame,
fps: c.fps,
start_time: c.start_time,
end_time: c.end_time,
raw_text: Some(c.raw_text),
summary: parent.map(|p| p.summary.clone()),
metadata: parent.map(|p| p.metadata.clone()),
similarity: parent.and_then(|p| p.similarity),
}
})
.collect();
// 6. Sort results by similarity (descending)
// Since all children of a parent have the same parent similarity, this groups relevant chunks together
let mut results = results;
results.sort_by(|a, b| {
b.similarity
.partial_cmp(&a.similarity)
.unwrap_or(std::cmp::Ordering::Equal)
});
// 7. Limit the final results (optional, but good for API consistency)
let limit = req.limit.unwrap_or(5) * 5; // Allow more children per parent context
results.truncate(limit);
// 8. Format Response
let response = SmartSearchResponse {
query: req.query,
results,
strategy: "drill_down_semantic_search".to_string(),
};
Ok(Json(response))
}
// --- Helper: Ollama Embedding ---
async fn get_ollama_embedding(
text: &str,
) -> Result<Vec<f32>, Box<dyn std::error::Error + Send + Sync>> {
let client = reqwest::Client::new();
let payload = serde_json::json!({
"model": "nomic-embed-text",
"prompt": text
});
let res = client
.post("http://localhost:11434/api/embeddings")
.json(&payload)
.send()
.await?
.json::<serde_json::Value>()
.await?;
// Parse embedding array from response
let embedding = res["embedding"]
.as_array()
.ok_or("No embedding found in Ollama response")?
.iter()
.map(|v| v.as_f64().unwrap_or(0.0) as f32)
.collect();
Ok(embedding)
}
// --- Router Setup ---
pub fn search_routes() -> Router<crate::api::server::AppState> {
Router::new().route("/smart", post(smart_search))
}

195
src/api/search.rs.bak Normal file
View File

@@ -0,0 +1,195 @@
//! Smart Search API
//! Implements the 5W1H search capability using semantic vectors.
use axum::{extract::State, http::StatusCode, response::Json, routing::post, Router};
use serde::{Deserialize, Serialize};
use serde_json;
use tracing;
use crate::core::db::PostgresDb;
// --- Request / Response Structures ---
#[derive(Debug, Deserialize)]
pub struct SmartSearchRequest {
pub uuid: String,
pub query: String,
pub limit: Option<usize>,
}
#[derive(Debug, Serialize)]
pub struct SearchResult {
pub id: i32,
pub parent_id: i32,
pub scene_order: Option<i32>,
// Primary: frame-accurate position (authoritative unit)
pub start_frame: i64,
pub end_frame: i64,
pub fps: f64,
// Reference: time derived from frames (subject to FPS variation, not precise)
pub start_time: f64,
pub end_time: f64,
pub raw_text: Option<String>, // Text content of the child chunk
pub summary: Option<String>, // Summary from parent context
pub metadata: Option<serde_json::Value>,
pub similarity: Option<f64>,
}
#[derive(Debug, Serialize)]
pub struct SmartSearchResponse {
pub query: String,
pub results: Vec<SearchResult>,
pub strategy: String,
}
// --- API Handler ---
pub async fn smart_search(
State(state): State<crate::api::server::AppState>,
Json(req): Json<SmartSearchRequest>,
) -> Result<Json<SmartSearchResponse>, (StatusCode, Json<serde_json::Value>)> {
let db = &state.db;
let limit = req.limit.unwrap_or(5);
// 1. Generate Embedding using Ollama
let embedding = get_ollama_embedding(&req.query).await.map_err(
|e| -> (StatusCode, Json<serde_json::Value>) {
tracing::error!("Embedding failed: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": e.to_string() })),
)
},
)?;
// 2. Search Database (Drill-Down: Find Parents First)
let db_parents: Vec<crate::core::db::postgres_db::SemanticSearchResult> = db
.search_parent_chunks_semantic(&req.uuid, &embedding, limit)
.await
.map_err(
|e: anyhow::Error| -> (StatusCode, Json<serde_json::Value>) {
tracing::error!("DB search failed: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": e.to_string() })),
)
},
)?;
if db_parents.is_empty() {
return Ok(Json(SmartSearchResponse {
query: req.query,
results: vec![],
strategy: "semantic_vector_search".to_string(),
}));
}
// Collect Parent IDs
let parent_ids: Vec<i32> = db_parents.iter().map(|p| p.id).collect();
// 3. Fetch Children for these Parents (Drill Down)
// We fetch all children for these parents (limit can be adjusted)
let children: Vec<crate::core::db::postgres_db::ChildChunkResult> = db
.get_children_for_parents(&parent_ids, 10) // Fetch top 10 children per parent
.await
.map_err(
|e: anyhow::Error| -> (StatusCode, Json<serde_json::Value>) {
tracing::error!("Fetching children failed: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": e.to_string() })),
)
},
)?;
// 4. Map Parents to a lookup table
let parent_map: std::collections::HashMap<
i32,
&crate::core::db::postgres_db::SemanticSearchResult,
> = db_parents.iter().map(|p| (p.id, p)).collect();
// Map Children to API response struct
let results: Vec<SearchResult> = children
.into_iter()
.map(|c| {
let parent = parent_map.get(&c.parent_id);
SearchResult {
id: c.id,
parent_id: c.parent_id,
scene_order: parent.map(|p| p.scene_order),
start_frame: c.start_frame,
end_frame: c.end_frame,
fps: c.fps,
start_time: c.start_time,
end_time: c.end_time,
raw_text: Some(c.raw_text),
summary: parent.map(|p| p.summary.clone()),
metadata: parent.map(|p| p.metadata.clone()),
similarity: parent.and_then(|p| p.similarity),
}
})
.collect();
// 6. Sort results by similarity (descending)
// Since all children of a parent have the same parent similarity, this groups relevant chunks together
let mut results = results;
results.sort_by(|a, b| {
b.similarity
.partial_cmp(&a.similarity)
.unwrap_or(std::cmp::Ordering::Equal)
});
// 7. Limit the final results (optional, but good for API consistency)
let limit = req.limit.unwrap_or(5) * 5; // Allow more children per parent context
results.truncate(limit);
// 8. Format Response
let response = SmartSearchResponse {
query: req.query,
results,
strategy: "drill_down_semantic_search".to_string(),
};
Ok(Json(response))
}
// --- Helper: Ollama Embedding ---
async fn get_ollama_embedding(
text: &str,
) -> Result<Vec<f32>, Box<dyn std::error::Error + Send + Sync>> {
let client = reqwest::Client::new();
let payload = serde_json::json!({
"model": "nomic-embed-text",
"prompt": text
});
let res = client
.post("http://localhost:11434/api/embeddings")
.json(&payload)
.send()
.await?
.json::<serde_json::Value>()
.await?;
// Parse embedding array from response
let embedding = res["embedding"]
.as_array()
.ok_or("No embedding found in Ollama response")?
.iter()
.map(|v| v.as_f64().unwrap_or(0.0) as f32)
.collect();
Ok(embedding)
}
// --- Router Setup ---
pub fn search_routes() -> Router<crate::api::server::AppState> {
Router::new().route("/smart", post(smart_search))
}

View File

@@ -30,6 +30,33 @@ use super::universal_search;
use super::visual_chunk_search;
use crate::core::chunk::types::Chunk;
static DEMO_USER_API_KEY: &str = "muser_68600856036340bcafc01930eb4bd839_1774418104_97221b69";
fn hash_password(password: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(password.as_bytes());
format!("{:x}", hasher.finalize())
}
#[derive(Debug, Deserialize)]
struct LoginRequest {
username: String,
password: String,
}
#[derive(Debug, Serialize)]
struct LoginResponse {
success: bool,
message: Option<String>,
api_key: Option<String>,
user: Option<UserInfo>,
}
#[derive(Debug, Serialize)]
struct UserInfo {
username: String,
}
// Global State
static SERVER_START: OnceCell<Instant> = OnceCell::new();
@@ -334,6 +361,12 @@ struct VideoInfoResponse {
duration: f64,
width: u32,
height: u32,
status: String,
processing_status: Option<String>,
created_at: Option<String>,
registration_time: Option<String>,
file_size: Option<i64>,
probe_json: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
@@ -348,6 +381,9 @@ struct VideosResponse {
struct VideosQuery {
page: Option<usize>,
page_size: Option<usize>,
q: Option<String>,
status: Option<String>,
uuid: Option<String>,
}
#[derive(Clone)]
@@ -408,7 +444,10 @@ async fn health_detailed(State(state): State<AppState>) -> Json<DetailedHealthRe
let qdrant = check_qdrant().await;
let mongodb = check_mongodb(&state.mongo_cache).await;
let overall_status = if postgres.status == "ok" && redis.status == "ok" && qdrant.status == "ok"
let overall_status = if postgres.status == "ok"
&& redis.status == "ok"
&& qdrant.status == "ok"
&& mongodb.status == "ok"
{
"ok"
} else {
@@ -428,6 +467,30 @@ async fn health_detailed(State(state): State<AppState>) -> Json<DetailedHealthRe
})
}
async fn login(Json(req): Json<LoginRequest>) -> Json<LoginResponse> {
if req.username == "demo" && req.password == "demo" {
Json(LoginResponse {
success: true,
message: Some("Login successful".to_string()),
api_key: Some(DEMO_USER_API_KEY.to_string()),
user: Some(UserInfo {
username: "demo".to_string(),
}),
})
} else {
Json(LoginResponse {
success: false,
message: Some("Invalid username or password".to_string()),
api_key: None,
user: None,
})
}
}
async fn logout() -> Json<serde_json::Value> {
Json(serde_json::json!({ "success": true }))
}
async fn check_postgres() -> ServiceStatus {
let start = Instant::now();
match PostgresDb::init().await {
@@ -709,6 +772,7 @@ async fn register(
user_id: None,
job_id: None,
created_at: String::new(),
registration_time: None,
};
let video_id = db
@@ -1874,9 +1938,51 @@ async fn list_videos(
let page = params.page.unwrap_or(1);
let page_size = params.page_size.unwrap_or(20);
let offset = ((page - 1) as i64) * (page_size as i64);
let status_filter = params.status.clone();
let query_filter = params.q.clone();
// Include query and status in cache key
let cache_key = keys::videos_list(page, page_size);
let cache_key = if let Some(ref q) = query_filter {
format!("{}:q:{}", cache_key, q)
} else {
cache_key
};
let ttl = state.mongo_cache.ttl_videos();
// If uuid is provided, fetch single video directly
if let Some(ref uuid) = params.uuid {
let db = PostgresDb::init()
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let video = db.get_video_by_uuid(uuid).await.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
if let Some(v) = video {
return Ok(Json(VideosResponse {
videos: vec![VideoInfoResponse {
uuid: v.uuid,
file_path: v.file_path,
file_name: v.file_name,
duration: v.duration,
width: v.width,
height: v.height,
status: v.status.as_str().to_string(),
processing_status: None,
created_at: Some(v.created_at),
registration_time: v.registration_time,
file_size: None,
probe_json: v.probe_json,
}],
count: 1,
page,
page_size,
}));
} else {
return Err(StatusCode::NOT_FOUND);
}
}
tracing::info!(
"list_videos called: page={}, page_size={}, cache_key={}",
page,
@@ -1892,7 +1998,21 @@ async fn list_videos(
.await
.map_err(|e| anyhow::anyhow!("PG init failed: {}", e))?;
let (videos, count) = db.list_videos(page_size as i32, offset).await?;
// Map status parameter to is_processed filter
let is_processed = match status_filter.as_deref() {
Some("pending") | Some("unprocessed") => Some(false),
Some("completed") | Some("ready") | Some("processed") => Some(true),
_ => None, // no filter
};
// Search by query if provided
let (videos, count) = if let Some(ref q) = query_filter {
db.search_videos(Some(q.as_str()), is_processed, page_size as i32, offset).await?
} else if let Some(processed) = is_processed {
db.search_videos(None, Some(processed), page_size as i32, offset).await?
} else {
db.list_videos(page_size as i32, offset).await?
};
tracing::info!("Got {} videos from DB", videos.len());
let video_infos: Vec<VideoInfoResponse> = videos
@@ -1904,6 +2024,12 @@ async fn list_videos(
duration: v.duration,
width: v.width,
height: v.height,
status: v.status.as_str().to_string(),
processing_status: None,
created_at: Some(v.created_at),
registration_time: v.registration_time,
file_size: None,
probe_json: v.probe_json,
})
.collect();
@@ -2328,19 +2454,15 @@ pub async fn start_server(host: &str, port: u16) -> anyhow::Result<()> {
.with_state(state.clone());
let cors = CorsLayer::new()
.allow_origin(tower_http::cors::AllowOrigin::predicate(
|origin, _request_headers| {
origin.as_bytes().ends_with(b"localhost")
|| origin.as_bytes().ends_with(b"momentry.ddns.net")
|| origin.as_bytes().ends_with(b"127.0.0.1")
},
))
.allow_origin(tower_http::cors::AllowOrigin::any())
.allow_methods(Any)
.allow_headers(Any);
let app = Router::new()
.route("/health", get(health))
.route("/health/detailed", get(health_detailed))
.route("/api/v1/auth/login", post(login))
.route("/api/v1/auth/logout", post(logout))
.route("/api/v1/stats/ingest", get(get_ingest_stats))
.route("/api/v1/stats/sftpgo", get(get_sftpgo_status))
.route("/api/v1/stats/inference", get(get_inference_health))

814
src/api/universal_search.rs Normal file
View File

@@ -0,0 +1,814 @@
//! Universal Search API
//! Unified search across chunks, frames, and persons.
use axum::{
extract::{Query, State},
http::StatusCode,
response::Json,
routing::{get, post},
Router,
};
use serde::{Deserialize, Serialize};
use crate::core::db::{Database, PostgresDb};
#[derive(Debug, Deserialize)]
pub struct UniversalSearchRequest {
pub query: String,
pub uuid: Option<String>,
#[serde(default)]
pub types: Vec<String>, // chunk, frame, person
pub time_range: Option<[f64; 2]>,
pub filters: Option<SearchFilters>,
pub limit: Option<usize>,
pub offset: Option<usize>,
}
#[derive(Debug, Deserialize, Clone)]
pub struct SearchFilters {
pub person_id: Option<String>,
pub object_class: Option<Vec<String>>,
pub ocr_text: Option<String>,
pub has_face: Option<bool>,
pub speaker_id: Option<String>,
// Visual chunk filters
pub min_confidence: Option<f32>,
pub min_unique_classes: Option<u32>,
pub min_spatial_density: Option<f32>,
pub max_spatial_density: Option<f32>,
pub required_object_classes: Option<Vec<String>>,
}
#[derive(Debug, Serialize)]
pub struct UniversalSearchResponse {
pub query: String,
pub results: Vec<SearchResult>,
pub total: usize,
pub took_ms: u64,
}
#[derive(Debug, Serialize, Clone)]
#[serde(tag = "type")]
pub enum SearchResult {
#[serde(rename = "chunk")]
Chunk {
chunk_id: String,
chunk_type: String,
// Primary: frame-accurate position
start_frame: i64,
end_frame: i64,
// Reference: time derived from frames (subject to FPS variation)
start_time: f64,
end_time: f64,
score: f64,
text: Option<String>,
speaker_id: Option<String>,
metadata: Option<serde_json::Value>,
},
#[serde(rename = "frame")]
Frame {
// Primary: exact frame number
frame_number: i64,
// Reference: time derived from frame (subject to FPS variation)
timestamp: f64,
score: f64,
objects: Option<Vec<serde_json::Value>>,
ocr_texts: Option<Vec<String>>,
faces: Option<Vec<serde_json::Value>>,
pose_persons: Option<Vec<serde_json::Value>>,
},
#[serde(rename = "person")]
Person {
person_id: String,
name: Option<String>,
speaker_id: Option<String>,
appearance_count: i32,
score: f64,
first_appearance_time: Option<f64>,
last_appearance_time: Option<f64>,
},
}
pub fn universal_search_routes() -> Router<crate::api::server::AppState> {
Router::new()
.route("/api/v1/search/universal", post(universal_search))
.route("/api/v1/search/frames", post(search_frames))
.route("/api/v1/search/persons", get(search_persons))
}
/// Unified search across all data types
pub async fn universal_search(
State(_state): State<crate::api::server::AppState>,
Json(req): Json<UniversalSearchRequest>,
) -> Result<Json<UniversalSearchResponse>, (StatusCode, Json<serde_json::Value>)> {
let start_time = std::time::Instant::now();
let db = PostgresDb::init().await.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": format!("DB error: {}", e) })),
)
})?;
let limit = req.limit.unwrap_or(20);
let offset = req.offset.unwrap_or(0);
let types = if req.types.is_empty() {
vec![
"chunk".to_string(),
"frame".to_string(),
"person".to_string(),
]
} else {
req.types.clone()
};
let mut results = Vec::new();
// Search chunks
if types.contains(&"chunk".to_string()) {
let chunk_results = search_chunks(&db, &req).await.map_err(|e| {
(
StatusCode::BAD_REQUEST,
Json(serde_json::json!({ "error": e.to_string() })),
)
})?;
results.extend(chunk_results);
}
// Search frames
if types.contains(&"frame".to_string()) {
let frame_results = search_frames_internal(&db, &req).await.unwrap_or_default();
results.extend(frame_results);
}
// Search persons
if types.contains(&"person".to_string()) {
let person_results = search_persons_internal(&db, &req).await.unwrap_or_default();
results.extend(person_results);
}
// Sort by score descending
results.sort_by(|a, b| {
let score_a = match a {
SearchResult::Chunk { score, .. } => *score,
SearchResult::Frame { score, .. } => *score,
SearchResult::Person { score, .. } => *score,
};
let score_b = match b {
SearchResult::Chunk { score, .. } => *score,
SearchResult::Frame { score, .. } => *score,
SearchResult::Person { score, .. } => *score,
};
score_b
.partial_cmp(&score_a)
.unwrap_or(std::cmp::Ordering::Equal)
});
let total = results.len();
let end = std::cmp::min(offset + limit, results.len());
let paginated = if offset < results.len() {
results[offset..end].to_vec()
} else {
vec![]
};
let took = start_time.elapsed().as_millis() as u64;
Ok(Json(UniversalSearchResponse {
query: req.query,
results: paginated,
total,
took_ms: took,
}))
}
/// Search frames by YOLO objects, OCR text, or face IDs
pub async fn search_frames(
State(_state): State<crate::api::server::AppState>,
Json(req): Json<FrameSearchRequest>,
) -> Result<Json<FrameSearchResponse>, (StatusCode, Json<serde_json::Value>)> {
let db = PostgresDb::init().await.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": format!("DB error: {}", e) })),
)
})?;
let frames = search_frames_internal_v2(&db, &req).await.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": format!("Search error: {}", e) })),
)
})?;
let frames_count = frames.len();
Ok(Json(FrameSearchResponse {
frames,
total: frames_count,
}))
}
/// Search persons by name or speaker_id
pub async fn search_persons(
State(_state): State<crate::api::server::AppState>,
Query(query): Query<PersonSearchQuery>,
) -> Result<Json<PersonSearchResponse>, (StatusCode, Json<serde_json::Value>)> {
let db = PostgresDb::init().await.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": format!("DB error: {}", e) })),
)
})?;
let limit = query.limit.unwrap_or(20);
let persons = search_persons_by_query(
&db,
&query.query,
query.min_appearances,
query.max_age,
limit,
)
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": format!("Search error: {}", e) })),
)
})?;
let persons_count = persons.len();
Ok(Json(PersonSearchResponse {
persons,
total: persons_count,
}))
}
// --- Internal search functions ---
#[derive(Debug, Deserialize)]
pub struct FrameSearchRequest {
pub uuid: Option<String>,
pub object_class: Option<String>,
pub ocr_text: Option<String>,
pub face_id: Option<String>,
pub time_range: Option<[f64; 2]>,
pub limit: Option<usize>,
}
#[derive(Debug, Serialize)]
pub struct FrameSearchResponse {
pub frames: Vec<FrameResult>,
pub total: usize,
}
#[derive(Debug, Serialize)]
pub struct FrameResult {
pub frame_number: i64,
pub timestamp: f64,
pub uuid: String,
pub objects: Option<Vec<serde_json::Value>>,
pub ocr_texts: Option<Vec<String>>,
pub faces: Option<Vec<serde_json::Value>>,
pub pose_persons: Option<Vec<serde_json::Value>>,
}
#[derive(Debug, Deserialize)]
pub struct PersonSearchQuery {
pub video_uuid: String,
pub query: Option<String>,
pub min_appearances: Option<i32>,
pub max_age: Option<i32>, // New filter for "children"
pub limit: Option<usize>,
}
#[derive(Debug, Serialize)]
pub struct PersonSearchResponse {
pub persons: Vec<PersonResult>,
pub total: usize,
}
#[derive(Debug, Serialize)]
pub struct PersonResult {
pub person_id: String,
pub name: Option<String>,
pub character_name: Option<String>,
pub aliases: Option<Vec<String>>,
pub age: Option<i32>,
pub gender: Option<String>,
pub speaker_id: Option<String>,
pub appearance_count: i32,
pub first_appearance_time: Option<f64>,
pub last_appearance_time: Option<f64>,
}
async fn search_chunks(
db: &PostgresDb,
req: &UniversalSearchRequest,
) -> Result<Vec<SearchResult>, anyhow::Error> {
// uuid is required for chunk search - chunk_id is only unique within a video
let uuid = match &req.uuid {
Some(u) => u.replace('\'', "''"),
None => return Err(anyhow::anyhow!("uuid is required for chunk search")),
};
let mut sql = format!(
"SELECT chunk_id, chunk_type, start_time, end_time, start_frame, end_frame, text_content, content FROM chunks WHERE uuid = '{}'",
uuid
);
if let Some(tr) = &req.time_range {
sql.push_str(&format!(
" AND start_time >= {} AND end_time <= {}",
tr[0], tr[1]
));
}
if !req.query.is_empty() {
let q = req.query.replace('\'', "''");
sql.push_str(&format!(
" AND (text_content ILIKE '%{}%' OR content::text ILIKE '%{}%')",
q, q
));
}
if let Some(ref filters) = req.filters {
if let Some(ref speaker_id) = filters.speaker_id {
sql.push_str(&format!(
" AND content->>'speaker_id' = '{}'",
speaker_id.replace('\'', "''")
));
}
if let Some(ref person_id) = filters.person_id {
sql.push_str(&format!(
" AND content::text LIKE '%{}%'",
person_id.replace('\'', "''")
));
}
// Visual chunk filters
if let Some(min_confidence) = filters.min_confidence {
sql.push_str(&format!(
" AND (content->'metadata'->>'avg_confidence')::float >= {}",
min_confidence
));
}
if let Some(min_unique_classes) = filters.min_unique_classes {
sql.push_str(&format!(
" AND jsonb_array_length(content->'metadata'->'unique_classes') >= {}",
min_unique_classes
));
}
if let Some(min_density) = filters.min_spatial_density {
sql.push_str(&format!(
" AND (content->'metadata'->>'spatial_density')::float >= {}",
min_density
));
}
if let Some(max_density) = filters.max_spatial_density {
sql.push_str(&format!(
" AND (content->'metadata'->>'spatial_density')::float <= {}",
max_density
));
}
if let Some(ref required_classes) = filters.required_object_classes {
if !required_classes.is_empty() {
let class_conditions: Vec<String> = required_classes
.iter()
.map(|class| {
format!(
"content->'keyframe_objects' @> '[{{ \"class_name\": \"{}\"}}]'",
class.replace('\'', "''")
)
})
.collect();
sql.push_str(&format!(" AND ({})", class_conditions.join(" OR ")));
}
}
}
sql.push_str(" ORDER BY start_time ASC");
sql.push_str(&format!(" LIMIT {}", req.limit.unwrap_or(20)));
let rows: Vec<(
String,
String,
f64,
f64,
i64,
i64,
Option<String>,
Option<serde_json::Value>,
)> = sqlx::query_as(&sql).fetch_all(db.pool()).await?;
let results: Vec<SearchResult> = rows
.into_iter()
.map(
|(
chunk_id,
chunk_type,
start_time,
end_time,
start_frame,
end_frame,
text_content,
content,
)| {
let text = text_content.or_else(|| {
content
.as_ref()
.and_then(|c| c.get("text").and_then(|v| v.as_str()).map(String::from))
});
let speaker_id = content.as_ref().and_then(|c| {
c.get("speaker_id")
.and_then(|v| v.as_str())
.map(String::from)
});
// Simple scoring: if query matches, score 0.8
let score = if !req.query.is_empty()
&& text.as_ref().map_or(false, |t| {
t.to_lowercase().contains(&req.query.to_lowercase())
}) {
0.9
} else {
0.5
};
SearchResult::Chunk {
chunk_id,
chunk_type,
start_time,
end_time,
start_frame,
end_frame,
score,
text,
speaker_id,
metadata: content,
}
},
)
.collect();
Ok(results)
}
async fn search_frames_internal(
db: &PostgresDb,
req: &UniversalSearchRequest,
) -> Result<Vec<SearchResult>, anyhow::Error> {
let table = "frames";
let video_table = "videos";
let mut sql = format!(
"SELECT f.frame_number, f.timestamp, f.yolo_objects, f.ocr_results, f.face_results, f.pose_results, v.uuid
FROM {} f JOIN {} v ON f.file_id = v.id WHERE 1=1",
table, video_table
);
if let Some(uuid) = &req.uuid {
sql.push_str(&format!(" AND v.uuid = '{}'", uuid));
}
if let Some(tr) = &req.time_range {
sql.push_str(&format!(
" AND f.timestamp >= {} AND f.timestamp <= {}",
tr[0], tr[1]
));
}
if let Some(ref filters) = req.filters {
if let Some(ref classes) = filters.object_class {
for class in classes {
sql.push_str(&format!(" AND f.yolo_objects::text ILIKE '%{}%'", class));
}
}
if let Some(ref ocr) = filters.ocr_text {
sql.push_str(&format!(" AND f.ocr_results::text ILIKE '%{}%'", ocr));
}
if let Some(true) = filters.has_face {
sql.push_str(
" AND f.face_results IS NOT NULL AND jsonb_array_length(f.face_results) > 0",
);
}
if let Some(ref person_id) = filters.person_id {
sql.push_str(&format!(" AND f.face_results::text LIKE '%{}%'", person_id));
}
}
if !req.query.is_empty() {
// Search across all frame data
sql.push_str(&format!(
" AND (f.yolo_objects::text ILIKE '%{}%' OR f.ocr_results::text ILIKE '%{}%' OR f.face_results::text ILIKE '%{}%')",
req.query, req.query, req.query
));
}
sql.push_str(" ORDER BY f.timestamp ASC");
sql.push_str(&format!(" LIMIT {}", req.limit.unwrap_or(20)));
let rows: Vec<(
i64,
f64,
Option<serde_json::Value>,
Option<serde_json::Value>,
Option<serde_json::Value>,
Option<serde_json::Value>,
String,
)> = sqlx::query_as(&sql).fetch_all(db.pool()).await?;
let results: Vec<SearchResult> = rows
.into_iter()
.map(|(frame_number, timestamp, yolo, ocr, face, pose, _uuid)| {
let objects = yolo.as_ref().and_then(|v| {
v.get("objects")
.map(|o| o.as_array().cloned().unwrap_or_default())
});
let ocr_texts = ocr.as_ref().and_then(|v| {
v.get("texts").and_then(|t| {
t.as_array().map(|arr| {
arr.iter()
.filter_map(|item| {
item.get("text").and_then(|x| x.as_str()).map(String::from)
})
.collect()
})
})
});
let faces = face.as_ref().and_then(|v| {
v.get("faces")
.map(|f| f.as_array().cloned().unwrap_or_default())
});
let pose_persons = pose.as_ref().and_then(|v| {
v.get("persons")
.map(|p| p.as_array().cloned().unwrap_or_default())
});
SearchResult::Frame {
frame_number,
timestamp,
score: 0.7,
objects: objects.map(|arr| arr.iter().map(|v| v.clone()).collect()),
ocr_texts,
faces,
pose_persons,
}
})
.collect();
Ok(results)
}
async fn search_persons_internal(
db: &PostgresDb,
req: &UniversalSearchRequest,
) -> Result<Vec<SearchResult>, anyhow::Error> {
let table = "person_identities";
let mut sql = format!(
"SELECT person_id, name, speaker_id, appearance_count, first_appearance_time, last_appearance_time FROM {} WHERE 1=1",
table
);
if !req.query.is_empty() {
sql.push_str(&format!(
" AND (name ILIKE '%{}%' OR person_id ILIKE '%{}%' OR speaker_id ILIKE '%{}%')",
req.query, req.query, req.query
));
}
if let Some(ref filters) = req.filters {
if let Some(ref speaker_id) = filters.speaker_id {
sql.push_str(&format!(" AND speaker_id = '{}'", speaker_id));
}
if let Some(ref person_id) = filters.person_id {
sql.push_str(&format!(" AND person_id = '{}'", person_id));
}
}
sql.push_str(" ORDER BY appearance_count DESC");
sql.push_str(&format!(" LIMIT {}", req.limit.unwrap_or(20)));
let rows: Vec<(
String,
Option<String>,
Option<String>,
i32,
Option<f64>,
Option<f64>,
)> = sqlx::query_as(&sql).fetch_all(db.pool()).await?;
let results: Vec<SearchResult> = rows
.into_iter()
.map(
|(person_id, name, speaker_id, appearance_count, first_time, last_time)| {
let score = if !req.query.is_empty()
&& name.as_ref().map_or(false, |n| {
n.to_lowercase().contains(&req.query.to_lowercase())
}) {
0.95
} else {
0.5
};
SearchResult::Person {
person_id,
name,
speaker_id,
appearance_count,
score,
first_appearance_time: first_time,
last_appearance_time: last_time,
}
},
)
.collect();
Ok(results)
}
async fn search_frames_internal_v2(
db: &PostgresDb,
req: &FrameSearchRequest,
) -> Result<Vec<FrameResult>, anyhow::Error> {
let table = "frames";
let video_table = "videos";
let mut sql = format!(
"SELECT f.frame_number, f.timestamp, f.yolo_objects, f.ocr_results, f.face_results, f.pose_results, v.uuid
FROM {} f JOIN {} v ON f.file_id = v.id WHERE 1=1",
table, video_table
);
if let Some(uuid) = &req.uuid {
sql.push_str(&format!(" AND v.uuid = '{}'", uuid));
}
if let Some(tr) = &req.time_range {
sql.push_str(&format!(
" AND f.timestamp >= {} AND f.timestamp <= {}",
tr[0], tr[1]
));
}
if let Some(ref class) = req.object_class {
sql.push_str(&format!(" AND f.yolo_objects::text ILIKE '%{}%'", class));
}
if let Some(ref ocr) = req.ocr_text {
sql.push_str(&format!(" AND f.ocr_results::text ILIKE '%{}%'", ocr));
}
if let Some(ref face_id) = req.face_id {
sql.push_str(&format!(" AND f.face_results::text LIKE '%{}%'", face_id));
}
sql.push_str(" ORDER BY f.timestamp ASC");
sql.push_str(&format!(" LIMIT {}", req.limit.unwrap_or(50)));
let rows: Vec<(
i64,
f64,
Option<serde_json::Value>,
Option<serde_json::Value>,
Option<serde_json::Value>,
Option<serde_json::Value>,
String,
)> = sqlx::query_as(&sql).fetch_all(db.pool()).await?;
let results: Vec<FrameResult> = rows
.into_iter()
.map(|(frame_number, timestamp, yolo, ocr, face, pose, uuid)| {
let objects = yolo.as_ref().and_then(|v| {
v.get("objects")
.map(|o| o.as_array().cloned().unwrap_or_default())
});
let ocr_texts = ocr.as_ref().and_then(|v| {
v.get("texts").and_then(|t| {
t.as_array().map(|arr| {
arr.iter()
.filter_map(|item| {
item.get("text").and_then(|x| x.as_str()).map(String::from)
})
.collect()
})
})
});
let faces = face.as_ref().and_then(|v| {
v.get("faces")
.map(|f| f.as_array().cloned().unwrap_or_default())
});
let pose_persons = pose.as_ref().and_then(|v| {
v.get("persons")
.map(|p| p.as_array().cloned().unwrap_or_default())
});
FrameResult {
frame_number,
timestamp,
uuid,
objects: objects.map(|arr| arr.iter().map(|v| v.clone()).collect()),
ocr_texts,
faces,
pose_persons,
}
})
.collect();
Ok(results)
}
async fn search_persons_by_query(
db: &PostgresDb,
query: &Option<String>,
min_appearances: Option<i32>,
max_age: Option<i32>,
limit: usize,
) -> Result<Vec<PersonResult>, anyhow::Error> {
let table = "person_identities";
let mut sql = format!(
"SELECT person_id, name, character_name, aliases, age, gender, speaker_id, appearance_count, first_appearance_time, last_appearance_time FROM {} WHERE 1=1",
table
);
if let Some(ref q) = query {
// Search name, character_name, aliases (cast to text), person_id, speaker_id
sql.push_str(&format!(
" AND (name ILIKE '%{}%' OR character_name ILIKE '%{}%' OR aliases::text ILIKE '%{}%' OR person_id ILIKE '%{}%' OR speaker_id ILIKE '%{}%')",
q, q, q, q, q
));
}
if let Some(min) = min_appearances {
sql.push_str(&format!(" AND appearance_count >= {}", min));
}
if let Some(max_a) = max_age {
// Strictly filter for age <= max_age.
// Note: This excludes entries with NULL age.
sql.push_str(&format!(" AND age <= {}", max_a));
}
sql.push_str(" ORDER BY appearance_count DESC");
sql.push_str(&format!(" LIMIT {}", limit));
let rows: Vec<(
String,
Option<String>,
Option<String>,
Option<serde_json::Value>,
Option<i32>,
Option<String>,
Option<String>,
i32,
Option<f64>,
Option<f64>,
)> = sqlx::query_as(&sql).fetch_all(db.pool()).await?;
let results: Vec<PersonResult> = rows
.into_iter()
.map(
|(
person_id,
name,
character_name,
aliases_json,
age,
gender,
speaker_id,
appearance_count,
first_time,
last_time,
)| {
let aliases = aliases_json.and_then(|v| {
v.as_array().map(|arr| {
arr.iter()
.filter_map(|val| val.as_str().map(String::from))
.collect()
})
});
PersonResult {
person_id,
name,
character_name,
aliases,
age,
gender,
speaker_id,
appearance_count,
first_appearance_time: first_time,
last_appearance_time: last_time,
}
},
)
.collect();
Ok(results)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_search_filters_with_visual() {
let filters = SearchFilters {
person_id: None,
object_class: None,
ocr_text: None,
has_face: None,
speaker_id: None,
min_confidence: Some(0.8),
min_unique_classes: Some(3),
min_spatial_density: Some(0.5),
max_spatial_density: Some(0.9),
required_object_classes: Some(vec!["person".to_string()]),
};
assert_eq!(filters.min_confidence, Some(0.8));
}
}

View File

@@ -0,0 +1,504 @@
//! Visual chunk search functionality.
//!
//! This module provides search capabilities for visual chunks based on:
//! - Object classes (e.g., "person", "car", "envelope")
//! - Confidence thresholds
//! - Object counts
//! - Spatial density
//! - Object relationships
use crate::core::chunk::types::{Chunk, ChunkRule, ChunkType};
use crate::core::db::PostgresDb;
use anyhow::Result;
use serde_json::Value;
use std::collections::HashMap;
/// Criteria for searching visual chunks
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct VisualChunkSearchCriteria {
/// Minimum average confidence across frames
pub min_avg_confidence: Option<f32>,
/// Minimum number of frames with objects
pub min_frames_with_objects: Option<u32>,
/// Minimum number of unique object classes
pub min_unique_classes: Option<u32>,
/// Specific object classes to include (empty means all)
pub required_classes: Vec<String>,
/// Object class counts to filter by
pub class_counts: HashMap<String, (u32, u32)>,
/// Time range (optional)
pub time_range: Option<(f64, f64)>,
}
impl Default for VisualChunkSearchCriteria {
fn default() -> Self {
Self {
min_avg_confidence: None,
min_frames_with_objects: None,
min_unique_classes: None,
required_classes: Vec::new(),
class_counts: HashMap::new(),
time_range: None,
}
}
}
/// Search visual chunks based on criteria
pub async fn search_visual_chunks(
db: &PostgresDb,
uuid: &str,
criteria: &VisualChunkSearchCriteria,
) -> Result<Vec<Chunk>> {
// First, get all visual chunks for this video
let all_chunks = get_visual_chunks_by_uuid(db, uuid).await?;
// Apply filters
let filtered_chunks: Vec<Chunk> = all_chunks
.into_iter()
.filter(|chunk| {
// Check min avg confidence
if let Some(min_avg_confidence) = criteria.min_avg_confidence {
if let Some(content) = &chunk.content.as_object() {
if let Some(metadata) = content.get("metadata") {
if let Some(avg_confidence) = metadata.get("avg_confidence") {
if let Some(conf) = avg_confidence.as_f64() {
if conf < min_avg_confidence as f64 {
return false;
}
}
}
}
}
}
// Check min frames with objects
if let Some(min_frames) = criteria.min_frames_with_objects {
if let Some(stats) = &chunk.visual_stats {
if let Some(frames_with_objects) = stats.get("frames_with_objects") {
if let Some(count) = frames_with_objects.as_u64() {
if count < min_frames as u64 {
return false;
}
}
}
}
}
// Check min unique classes
if let Some(min_unique_classes) = criteria.min_unique_classes {
if let Some(content) = &chunk.content.as_object() {
if let Some(metadata) = content.get("metadata") {
if let Some(unique_classes) = metadata.get("unique_classes") {
if let Some(classes) = unique_classes.as_array() {
if (classes.len() as u32) < min_unique_classes {
return false;
}
}
}
}
}
}
// Check required classes
if !criteria.required_classes.is_empty() {
if let Some(content) = &chunk.content.as_object() {
if let Some(keyframe_objects) = content.get("keyframe_objects") {
if let Some(objects) = keyframe_objects.as_array() {
let mut found_all = true;
for required_class in &criteria.required_classes {
let mut found = false;
for obj in objects {
if let Some(class_name) = obj.get("class_name") {
if let Some(class_str) = class_name.as_str() {
if class_str == required_class {
found = true;
break;
}
}
}
}
if !found {
found_all = false;
break;
}
}
if !found_all {
return false;
}
}
}
}
}
// Check class counts
if !criteria.class_counts.is_empty() {
if let Some(content) = &chunk.content.as_object() {
if let Some(metadata) = content.get("metadata") {
if let Some(object_counts) = metadata.get("object_counts") {
for (class, (min, max)) in &criteria.class_counts {
if let Some(count_value) = object_counts.get(class) {
if let Some(count) = count_value.as_u64() {
if *min > 0 && count < *min as u64 {
return false;
}
if *max < u32::MAX && count > *max as u64 {
return false;
}
}
} else if *min > 0 {
return false;
}
}
} else if criteria.class_counts.values().any(|(min, _)| *min > 0) {
return false;
}
}
}
}
// Check time range
if let Some((start_time, end_time)) = criteria.time_range {
// Calculate chunk time from frames
let chunk_start_time = chunk.start_frame as f64 / chunk.fps;
let chunk_end_time = chunk.end_frame as f64 / chunk.fps;
if chunk_start_time < start_time || chunk_end_time > end_time {
return false;
}
}
true
})
.collect();
Ok(filtered_chunks)
}
/// Get all visual chunks for a video UUID
async fn get_visual_chunks_by_uuid(db: &PostgresDb, uuid: &str) -> Result<Vec<Chunk>> {
let sql = format!(
"SELECT file_id, uuid, chunk_id, chunk_index, chunk_type, fps, start_frame, end_frame, text_content, content, metadata, vector_id, visual_stats FROM chunks WHERE uuid = '{}' AND chunk_type = 'visual' ORDER BY start_frame ASC",
uuid.replace('\'', "''")
);
let rows: Vec<(
i32, // file_id
String, // uuid
String, // chunk_id
i32, // chunk_index
String, // chunk_type
f64, // fps
i64, // start_frame
i64, // end_frame
Option<String>, // text_content
Value, // content
Option<Value>, // metadata
Option<String>, // vector_id
Option<Value>, // visual_stats
)> = sqlx::query_as(&sql).fetch_all(db.pool()).await?;
let mut chunks = Vec::new();
for row in rows {
let chunk_type = match row.4.as_str() {
"visual" => ChunkType::Visual,
"sentence" => ChunkType::Sentence,
"time_based" => ChunkType::TimeBased,
"cut" => ChunkType::Cut,
"trace" => ChunkType::Trace,
"story" => ChunkType::Story,
_ => ChunkType::TimeBased,
};
// Calculate frame_count
let frame_count = (row.7 - row.6) as i32;
chunks.push(Chunk {
file_id: row.0,
uuid: row.1,
chunk_id: row.2,
chunk_index: row.3 as u32,
chunk_type,
rule: ChunkRule::Rule2, // Visual chunks use Rule2
fps: row.5,
start_frame: row.6,
end_frame: row.7,
text_content: row.8,
content: row.9,
metadata: row.10,
vector_id: row.11,
frame_count,
pre_chunk_ids: Vec::new(),
parent_chunk_id: None,
child_chunk_ids: Vec::new(),
visual_stats: row.12,
});
}
Ok(chunks)
}
/// Search visual chunks by object class
pub async fn search_visual_chunks_by_class(
db: &PostgresDb,
uuid: &str,
object_class: &str,
min_count: Option<u32>,
max_count: Option<u32>,
) -> Result<Vec<Chunk>> {
let all_chunks = get_visual_chunks_by_uuid(db, uuid).await?;
let filtered_chunks: Vec<Chunk> = all_chunks
.into_iter()
.filter(|chunk| {
// Check if chunk contains the object class
let mut contains_class = false;
if let Some(content) = &chunk.content.as_object() {
if let Some(keyframe_objects) = content.get("keyframe_objects") {
if let Some(objects) = keyframe_objects.as_array() {
for obj in objects {
if let Some(class_name) = obj.get("class_name") {
if let Some(class_str) = class_name.as_str() {
if class_str == object_class {
contains_class = true;
break;
}
}
}
}
}
}
}
if !contains_class {
return false;
}
// Check count in visual_stats
if let Some(stats) = &chunk.visual_stats {
if let Some(count) = stats.get(object_class) {
if let Some(c) = count.as_u64() {
if let Some(min) = min_count {
if c < min as u64 {
return false;
}
}
if let Some(max) = max_count {
if c > max as u64 {
return false;
}
}
}
}
}
true
})
.collect();
Ok(filtered_chunks)
}
/// Search visual chunks by spatial density
pub async fn search_visual_chunks_by_density(
db: &PostgresDb,
uuid: &str,
min_density: f32,
max_density: Option<f32>,
) -> Result<Vec<Chunk>> {
let all_chunks = get_visual_chunks_by_uuid(db, uuid).await?;
let filtered_chunks: Vec<Chunk> = all_chunks
.into_iter()
.filter(|chunk| {
if let Some(content) = &chunk.content.as_object() {
if let Some(metadata) = content.get("metadata") {
if let Some(density_value) = metadata.get("spatial_density") {
if let Some(density) = density_value.as_f64() {
if density < min_density as f64 {
return false;
}
if let Some(max_dens) = max_density {
if density > max_dens as f64 {
return false;
}
}
return true;
}
}
}
}
false
})
.collect();
Ok(filtered_chunks)
}
/// Find chunks containing specific object combinations
pub async fn search_visual_chunks_by_combination(
db: &PostgresDb,
uuid: &str,
combination: &[(&str, u32)],
) -> Result<Vec<Chunk>> {
let all_chunks = get_visual_chunks_by_uuid(db, uuid).await?;
let filtered_chunks: Vec<Chunk> = all_chunks
.into_iter()
.filter(|chunk| {
// Check if all required combinations are present
for (object_class, min_count) in combination {
let mut found = false;
if let Some(stats) = &chunk.visual_stats {
if let Some(object_counts) = stats.get("object_counts") {
if let Some(count_value) = object_counts.get(*object_class) {
if let Some(count) = count_value.as_u64() {
if count >= *min_count as u64 {
found = true;
}
}
}
}
}
if !found {
return false;
}
}
true
})
.collect();
Ok(filtered_chunks)
}
/// Get visual chunk statistics
pub async fn get_visual_chunk_statistics(
db: &PostgresDb,
uuid: &str,
) -> Result<HashMap<String, Value>> {
let sql = format!(
"SELECT
COUNT(*) as total_chunks,
AVG((content->'metadata'->>'avg_confidence')::float) as avg_confidence,
MIN((content->'metadata'->>'avg_confidence')::float) as min_confidence,
MAX((content->'metadata'->>'avg_confidence')::float) as max_confidence,
SUM((content->'metadata'->>'object_count')::int) as total_objects,
AVG((content->'metadata'->>'spatial_density')::float) as avg_density
FROM chunks
WHERE uuid = '{}'
AND chunk_type = 'visual'",
uuid.replace('\'', "''")
);
let row: (i64, Option<f64>, Option<f64>, Option<f64>, i64, Option<f64>) =
sqlx::query_as(&sql).fetch_one(db.pool()).await?;
let mut stats = HashMap::new();
stats.insert("total_chunks".to_string(), Value::from(row.0));
stats.insert(
"avg_confidence".to_string(),
Value::from(row.1.unwrap_or(0.0)),
);
stats.insert(
"min_confidence".to_string(),
Value::from(row.2.unwrap_or(0.0)),
);
stats.insert(
"max_confidence".to_string(),
Value::from(row.3.unwrap_or(0.0)),
);
stats.insert("total_objects".to_string(), Value::from(row.4));
stats.insert("avg_density".to_string(), Value::from(row.5.unwrap_or(0.0)));
Ok(stats)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_visual_chunk_search_criteria_default() {
let criteria = VisualChunkSearchCriteria::default();
assert_eq!(criteria.min_avg_confidence, None);
assert_eq!(criteria.min_frames_with_objects, None);
assert_eq!(criteria.min_unique_classes, None);
assert!(criteria.required_classes.is_empty());
assert!(criteria.class_counts.is_empty());
assert_eq!(criteria.time_range, None);
}
#[test]
fn test_visual_chunk_search_criteria_with_values() {
let mut criteria = VisualChunkSearchCriteria::default();
criteria.min_avg_confidence = Some(0.8);
criteria.min_frames_with_objects = Some(10);
criteria.min_unique_classes = Some(3);
criteria.required_classes = vec!["person".to_string(), "car".to_string()];
criteria.time_range = Some((0.0, 60.0));
assert_eq!(criteria.min_avg_confidence, Some(0.8));
assert_eq!(criteria.min_frames_with_objects, Some(10));
assert_eq!(criteria.min_unique_classes, Some(3));
assert_eq!(criteria.required_classes.len(), 2);
assert_eq!(criteria.time_range, Some((0.0, 60.0)));
}
#[test]
fn test_visual_chunk_search_criteria_serialization() {
let criteria = VisualChunkSearchCriteria {
min_avg_confidence: Some(0.85),
min_frames_with_objects: Some(5),
min_unique_classes: Some(2),
required_classes: vec!["person".to_string()],
class_counts: HashMap::new(),
time_range: Some((10.0, 30.0)),
};
let json = serde_json::to_string(&criteria).unwrap();
assert!(json.contains("min_avg_confidence"));
assert!(json.contains("required_classes"));
let deserialized: VisualChunkSearchCriteria = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.min_avg_confidence, Some(0.85));
assert_eq!(deserialized.required_classes.len(), 1);
}
#[test]
fn test_visual_chunk_search_criteria_with_class_counts() {
let mut criteria = VisualChunkSearchCriteria::default();
criteria.class_counts.insert("person".to_string(), (5, 20));
criteria.class_counts.insert("car".to_string(), (1, 10));
assert_eq!(criteria.class_counts.len(), 2);
assert_eq!(criteria.class_counts.get("person"), Some(&(5, 20)));
assert_eq!(criteria.class_counts.get("car"), Some(&(1, 10)));
}
#[test]
fn test_chunk_type_conversion() {
// Test chunk type string to enum conversion logic
let test_cases = vec![
("visual", ChunkType::Visual),
("sentence", ChunkType::Sentence),
("time_based", ChunkType::TimeBased),
("cut", ChunkType::Cut),
("trace", ChunkType::Trace),
("story", ChunkType::Story),
("unknown", ChunkType::TimeBased), // Default fallback
];
for (input, expected) in test_cases {
let chunk_type = match input {
"visual" => ChunkType::Visual,
"sentence" => ChunkType::Sentence,
"time_based" => ChunkType::TimeBased,
"cut" => ChunkType::Cut,
"trace" => ChunkType::Trace,
"story" => ChunkType::Story,
_ => ChunkType::TimeBased,
};
assert_eq!(chunk_type, expected);
}
}
}

147
src/api/who.rs Normal file
View File

@@ -0,0 +1,147 @@
//! Who API - 身份識別與 ID 映射接口 (Video-Scoped)
use axum::{
extract::State,
http::StatusCode,
response::Json,
routing::{get, post},
Router,
};
use serde::{Deserialize, Serialize};
use crate::core::db::Database;
// --- Request / Response Structures ---
#[derive(Debug, Deserialize)]
pub struct WhoQuery {
pub face_id: Option<String>,
pub speaker_id: Option<String>,
pub uuid: Option<String>,
pub chunk_id: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct WhoCandidatesRequest {
pub query: String,
pub video_uuid: Option<String>,
pub limit: Option<i32>,
}
#[derive(Debug, Deserialize)]
pub struct DefinePersonRequest {
pub uuid: String,
pub identity_id: Option<i32>,
pub name: String,
pub face_ids: Option<Vec<String>>,
pub speaker_ids: Option<Vec<String>>,
}
#[derive(Debug, Serialize)]
pub struct WhoIdentity {
pub identity_id: i32,
pub uuid: String,
pub name: String,
pub tags: Option<Vec<String>>,
pub face_ids: Vec<String>,
pub speaker_ids: Vec<String>,
}
// --- API Handlers ---
/// GET /api/v1/who
pub async fn get_who_identity(
State(state): State<crate::api::server::AppState>,
axum::extract::Query(query): axum::extract::Query<WhoQuery>,
) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
let db = &state.db;
// Priority 1: Query by Chunk (UUID + Chunk ID)
if let (Some(uuid), Some(chunk_id)) = (&query.uuid, &query.chunk_id) {
let info = db
.get_who_info_by_chunk(uuid, chunk_id)
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": e.to_string() })),
)
})?;
return Ok(Json(info));
}
// Priority 2: List all for a specific UUID
if let Some(uuid) = &query.uuid {
// TODO: Implement list_all_persons(uuid)
return Ok(Json(serde_json::json!({ "message": "List all pending" })));
}
Err((
StatusCode::BAD_REQUEST,
Json(serde_json::json!({ "error": "Missing uuid" })),
))
}
/// POST /api/v1/who/candidates
/// Search person_identities table for n8n workflow
pub async fn get_who_candidates(
State(state): State<crate::api::server::AppState>,
Json(req): Json<WhoCandidatesRequest>,
) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
let db = &state.db;
let limit = req.limit.unwrap_or(20);
let query_str = format!("%{}%", req.query);
let results = db
.search_person_candidates(&query_str, &req.video_uuid, limit)
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": e.to_string() })),
)
})?;
// Format for n8n
let response = serde_json::json!({
"query": req.query,
"items": results,
"total": results.len()
});
Ok(Json(response))
}
/// POST /api/v1/who
pub async fn define_person(
State(state): State<crate::api::server::AppState>,
Json(req): Json<DefinePersonRequest>,
) -> Result<Json<WhoIdentity>, (StatusCode, Json<serde_json::Value>)> {
let db = &state.db;
let identity = db
.create_or_update_person(
&req.uuid,
req.identity_id,
req.name.clone(),
req.face_ids.unwrap_or_default(),
req.speaker_ids.unwrap_or_default(),
)
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": e.to_string() })),
)
})?;
Ok(Json(identity))
}
// --- Router Setup ---
pub fn who_routes() -> Router<crate::api::server::AppState> {
Router::new()
.route("/api/v1/who", get(get_who_identity).post(define_person))
.route("/api/v1/who/candidates", post(get_who_candidates))
}

38
src/bin/debug_tsquery.rs Normal file
View File

@@ -0,0 +1,38 @@
use momentry_core::core::text::global_synonym_expander;
fn main() {
let expander = global_synonym_expander();
let query = "電腦";
println!("原始查詢: '{}'", query);
let expanded = expander.expand_chinese_query(query);
println!("擴展結果: '{}'", expanded);
// 測試 split
let groups: Vec<&str> = if expanded.contains('&') {
expanded.split('&').map(|s| s.trim()).collect()
} else {
expanded.split_whitespace().collect()
};
println!("分組: {:?}", groups);
for group in groups {
println!(" 分組: '{}'", group);
let terms = if group.starts_with('(') && group.ends_with(')') {
let inner = &group[1..group.len() - 1];
inner.split('|').map(|s| s.trim()).collect::<Vec<&str>>()
} else {
vec![group]
};
println!(" 詞語: {:?}", terms);
for term in &terms {
let cleaned: String = term
.chars()
.filter(|c| c.is_alphanumeric() || c.is_alphabetic())
.collect();
println!(" 詞語 '{}' -> 清理後 '{}'", term, cleaned);
}
}
}

View File

@@ -0,0 +1,659 @@
use anyhow::{Context, Result};
use clap::Parser;
use crossterm::event::{self, Event, KeyCode};
use crossterm::terminal as crossterm_terminal;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::io::{self, IsTerminal, Write};
use std::path::PathBuf;
use std::process::{Command, Stdio};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::thread;
use std::time::Duration;
#[derive(Parser, Debug)]
#[command(name = "integrated_player")]
#[command(about = "Integrated player for ASR, Face, ASRX, and Pose")]
struct Args {
#[arg(short, long)]
video: PathBuf,
#[arg(short = 'r', long)]
asr: Option<PathBuf>,
#[arg(short = 'f', long)]
face: Option<PathBuf>,
#[arg(short = 'x', long)]
asrx: Option<PathBuf>,
#[arg(short = 'p', long)]
pose: Option<PathBuf>,
#[arg(short = 's', long, default_value = "0.0")]
start: f64,
#[arg(long)]
speaker_name: Option<String>,
#[arg(long)]
auto_play_speaker: bool,
#[arg(long)]
demo: bool,
#[arg(long, default_value = "3")]
demo_segments_per_speaker: usize,
#[arg(long, default_value = "2.0")]
demo_speed: f64,
#[arg(long)]
show_video: bool,
#[arg(long, default_value = "800")]
video_width: u32,
#[arg(long, default_value = "600")]
video_height: u32,
#[arg(long)]
continuous_demo: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct AsrSegment {
start: f64,
end: f64,
text: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct AsrData {
language: Option<String>,
segments: Vec<AsrSegment>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct FaceInfo {
face_id: Option<String>,
x: i32,
y: i32,
width: i32,
height: i32,
confidence: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct FaceFrame {
frame: u64,
timestamp: f64,
faces: Vec<FaceInfo>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct FaceData {
fps: f64,
frame_count: u64,
frames: Vec<FaceFrame>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct AsrxSegment {
index: usize,
start: f64,
end: f64,
duration: f64,
speaker: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct AsrxData {
segments: Vec<AsrxSegment>,
speaker_stats: HashMap<String, SpeakerStats>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct SpeakerStats {
count: usize,
duration: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct Keypoint {
name: String,
x: f32,
y: f32,
confidence: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct PersonPose {
keypoints: Vec<Keypoint>,
bbox: Bbox,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct Bbox {
x: i32,
y: i32,
width: i32,
height: i32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct PoseFrame {
frame: u64,
timestamp: f64,
persons: Vec<PersonPose>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct PoseData {
frames: Vec<PoseFrame>,
}
#[derive(Debug, Clone)]
struct IntegratedSegment {
start: f64,
end: f64,
text: Option<String>,
speaker: Option<String>,
face: Option<FaceInfo>,
mouth_landmarks: Option<Vec<Keypoint>>,
}
struct IntegratedPlayer {
asr_data: Option<AsrData>,
face_data: Option<FaceData>,
asrx_data: Option<AsrxData>,
pose_data: Option<PoseData>,
current_time: f64,
speaker_names: HashMap<String, (String, String)>,
}
impl IntegratedPlayer {
fn new() -> Self {
let mut speaker_names = HashMap::new();
speaker_names.insert(
"SPEAKER_0".to_string(),
("Cary Grant".to_string(), "Peter Joshua".to_string()),
);
speaker_names.insert(
"SPEAKER_1".to_string(),
("Audrey Hepburn".to_string(), "Regina Lampert".to_string()),
);
speaker_names.insert(
"SPEAKER_2".to_string(),
(
"Walter Matthau".to_string(),
"Hamilton Bartholomew".to_string(),
),
);
speaker_names.insert(
"SPEAKER_4".to_string(),
("James Coburn".to_string(), "Tex Panthollow".to_string()),
);
Self {
asr_data: None,
face_data: None,
asrx_data: None,
pose_data: None,
current_time: 0.0,
speaker_names,
}
}
fn load_asr(&mut self, path: &PathBuf) -> Result<()> {
let content = std::fs::read_to_string(path)
.with_context(|| format!("Failed to read ASR file: {:?}", path))?;
self.asr_data = Some(serde_json::from_str(&content)?);
println!(
"✓ Loaded {} ASR segments",
self.asr_data.as_ref().unwrap().segments.len()
);
Ok(())
}
fn load_face(&mut self, path: &PathBuf) -> Result<()> {
let content = std::fs::read_to_string(path)
.with_context(|| format!("Failed to read Face file: {:?}", path))?;
self.face_data = Some(serde_json::from_str(&content)?);
let total_faces = self
.face_data
.as_ref()
.unwrap()
.frames
.iter()
.map(|f| f.faces.len())
.sum::<usize>();
println!(
"✓ Loaded {} face frames, {} total detections",
self.face_data.as_ref().unwrap().frames.len(),
total_faces
);
Ok(())
}
fn load_asrx(&mut self, path: &PathBuf) -> Result<()> {
let content = std::fs::read_to_string(path)
.with_context(|| format!("Failed to read ASRX file: {:?}", path))?;
self.asrx_data = Some(serde_json::from_str(&content)?);
println!(
"✓ Loaded {} ASRX segments, {} speakers",
self.asrx_data.as_ref().unwrap().segments.len(),
self.asrx_data.as_ref().unwrap().speaker_stats.len()
);
Ok(())
}
fn load_pose(&mut self, path: &PathBuf) -> Result<()> {
let content = std::fs::read_to_string(path)
.with_context(|| format!("Failed to read Pose file: {:?}", path))?;
self.pose_data = Some(serde_json::from_str(&content)?);
println!(
"✓ Loaded {} pose frames",
self.pose_data.as_ref().unwrap().frames.len()
);
Ok(())
}
fn get_current_segment(&self, time: f64) -> Option<IntegratedSegment> {
let mut segment = IntegratedSegment {
start: 0.0,
end: 0.0,
text: None,
speaker: None,
face: None,
mouth_landmarks: None,
};
if let Some(asr) = &self.asr_data {
for seg in &asr.segments {
if time >= seg.start && time <= seg.end {
segment.start = seg.start;
segment.end = seg.end;
segment.text = Some(seg.text.clone());
break;
}
}
}
if let Some(asrx) = &self.asrx_data {
for seg in &asrx.segments {
if time >= seg.start && time <= seg.end {
segment.start = seg.start;
segment.end = seg.end;
segment.speaker = Some(seg.speaker.clone());
break;
}
}
}
if let Some(face) = &self.face_data {
for frame in &face.frames {
if (frame.timestamp - time).abs() < 1.0 {
if let Some(face_info) = frame.faces.first() {
segment.face = Some(face_info.clone());
break;
}
}
}
}
if let Some(pose) = &self.pose_data {
for frame in &pose.frames {
if (frame.timestamp - time).abs() < 0.5 {
if let Some(person) = frame.persons.first() {
let mouth_points: Vec<Keypoint> = person
.keypoints
.iter()
.filter(|kp| {
kp.name.contains("mouth")
|| kp.name.contains("lip")
|| kp.name == "nose"
})
.cloned()
.collect();
if !mouth_points.is_empty() {
segment.mouth_landmarks = Some(mouth_points);
break;
}
}
}
}
}
if segment.text.is_some()
|| segment.speaker.is_some()
|| segment.face.is_some()
|| segment.mouth_landmarks.is_some()
{
Some(segment)
} else {
None
}
}
fn get_speaker_info(&self, speaker_id: &str) -> (String, String) {
self.speaker_names
.get(speaker_id)
.cloned()
.unwrap_or_else(|| ("Unknown".to_string(), "Unknown".to_string()))
}
fn list_speakers(&self) {
if let Some(asrx) = &self.asrx_data {
println!("\n📊 Speaker Statistics:");
println!("{:-<80}", "");
println!(
"{:15} {:20} {:20} {:>10} {:>10}",
"Speaker ID", "Actor", "Character", "Segments", "Duration"
);
println!("{:-<80}", "");
for (speaker_id, stats) in &asrx.speaker_stats {
let (actor, character) = self.get_speaker_info(speaker_id);
println!(
"{:15} {:20} {:20} {:>10} {:>9.1}s",
speaker_id, actor, character, stats.count, stats.duration
);
}
println!("{:-<80}", "");
}
}
}
fn run_continuous_demo(player: &IntegratedPlayer, args: &Args) -> Result<()> {
println!("\n🎬 Continuous Demo Mode");
println!("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
let is_interactive = io::stdin().is_terminal();
if is_interactive {
println!("Controls:");
println!(" SPACE - Pause/Resume");
println!(" Q - Quit");
} else {
println!("Running in non-interactive mode (no keyboard control)");
println!("Use Ctrl+C to stop");
}
println!("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
println!();
let paused = Arc::new(AtomicBool::new(false));
let quit = Arc::new(AtomicBool::new(false));
let paused_clone = paused.clone();
let quit_clone = quit.clone();
let raw_mode_enabled = if is_interactive {
crossterm_terminal::enable_raw_mode().ok().is_some()
} else {
false
};
if is_interactive && raw_mode_enabled {
thread::spawn(move || loop {
if let Ok(Event::Key(key_event)) = event::read() {
if key_event.code == KeyCode::Char(' ') {
paused_clone.fetch_xor(true, Ordering::SeqCst);
} else if key_event.code == KeyCode::Char('q')
|| key_event.code == KeyCode::Char('Q')
|| key_event.code == KeyCode::Esc
{
quit_clone.store(true, Ordering::SeqCst);
break;
}
}
if quit_clone.load(Ordering::SeqCst) {
break;
}
thread::sleep(Duration::from_millis(50));
});
}
if let Some(asr) = &player.asr_data {
let total_segments = asr.segments.len();
for (i, seg) in asr.segments.iter().enumerate() {
if quit.load(Ordering::SeqCst) {
println!("\n⏹️ Stopped by user");
break;
}
while paused.load(Ordering::SeqCst) {
println!("\r⏸️ Paused - Press SPACE to resume");
io::stdout().flush()?;
thread::sleep(Duration::from_millis(100));
if quit.load(Ordering::SeqCst) {
println!("\n⏹️ Stopped by user");
if raw_mode_enabled {
crossterm_terminal::disable_raw_mode().ok();
}
return Ok(());
}
}
println!("\n[{}/{}] Segment", i + 1, total_segments);
println!("{:=<80}", "");
println!("📝 ASR Text: {}", seg.text);
println!("⏱ Time: {:.2}s - {:.2}s", seg.start, seg.end);
if let Some(asrx) = &player.asrx_data {
for asrx_seg in &asrx.segments {
if seg.start >= asrx_seg.start && seg.start <= asrx_seg.end {
let (actor, character) = player.get_speaker_info(&asrx_seg.speaker);
println!(
"🎤 Speaker: {}{} ({})",
asrx_seg.speaker, actor, character
);
break;
}
}
}
if let Some(segment) = player.get_current_segment(seg.start + 0.01) {
if let Some(face) = &segment.face {
println!(
"👤 Face: bbox=({},{}) {}x{}, conf={:.2}",
face.x, face.y, face.width, face.height, face.confidence
);
}
if let Some(landmarks) = &segment.mouth_landmarks {
println!("👄 Mouth landmarks: {} points", landmarks.len());
}
}
let duration = seg.end - seg.start;
println!(
"▶️ Playing: {:.2}s - {:.2}s ({:.2}s)",
seg.start, seg.end, duration
);
let mut cmd = Command::new("ffplay");
if args.show_video {
cmd.args([
"-ss",
&format!("{:.2}", seg.start),
"-t",
&format!("{:.2}", duration),
"-autoexit",
"-x",
&format!("{}", args.video_width),
"-y",
&format!("{}", args.video_height),
args.video.to_str().unwrap(),
]);
} else {
cmd.args([
"-ss",
&format!("{:.2}", seg.start),
"-t",
&format!("{:.2}", duration),
"-autoexit",
"-nodisp",
args.video.to_str().unwrap(),
]);
}
let _child = cmd
.stdout(Stdio::null())
.stderr(Stdio::null())
.spawn()
.context("Failed to start ffplay")?;
thread::sleep(Duration::from_millis((duration * 1000.0) as u64 + 100));
}
println!("\n{:=<80}", "");
println!("✅ Demo completed! Played {} segments", total_segments);
println!("{:=<80}", "");
} else if let Some(asrx) = &player.asrx_data {
let total_segments = asrx.segments.len();
println!(
"Playing {} ASRX segments (no ASR text available)",
total_segments
);
for (i, seg) in asrx.segments.iter().enumerate() {
if quit.load(Ordering::SeqCst) {
println!("\n⏹️ Stopped by user");
break;
}
while paused.load(Ordering::SeqCst) {
println!("\r⏸️ Paused - Press SPACE to resume");
io::stdout().flush()?;
thread::sleep(Duration::from_millis(100));
if quit.load(Ordering::SeqCst) {
println!("\n⏹️ Stopped by user");
if raw_mode_enabled {
crossterm_terminal::disable_raw_mode().ok();
}
return Ok(());
}
}
let (actor, character) = player.get_speaker_info(&seg.speaker);
println!("\n[{}/{}] Segment", i + 1, total_segments);
println!("{:=<80}", "");
println!(
"⏱ Time: {:.2}s - {:.2}s ({:.2}s)",
seg.start, seg.end, seg.duration
);
println!("🎤 Speaker: {}{} ({})", seg.speaker, actor, character);
if let Some(segment) = player.get_current_segment(seg.start + 0.01) {
if let Some(face) = &segment.face {
println!(
"👤 Face: bbox=({},{}) {}x{}, conf={:.2}",
face.x, face.y, face.width, face.height, face.confidence
);
}
if let Some(landmarks) = &segment.mouth_landmarks {
println!("👄 Mouth landmarks: {} points", landmarks.len());
}
}
println!("▶️ Playing audio segment");
let mut cmd = Command::new("ffplay");
if args.show_video {
cmd.args([
"-ss",
&format!("{:.2}", seg.start),
"-t",
&format!("{:.2}", seg.duration),
"-autoexit",
"-x",
&format!("{}", args.video_width),
"-y",
&format!("{}", args.video_height),
args.video.to_str().unwrap(),
]);
} else {
cmd.args([
"-ss",
&format!("{:.2}", seg.start),
"-t",
&format!("{:.2}", seg.duration),
"-autoexit",
"-nodisp",
args.video.to_str().unwrap(),
]);
}
let _child = cmd
.stdout(Stdio::null())
.stderr(Stdio::null())
.spawn()
.context("Failed to start ffplay")?;
thread::sleep(Duration::from_millis((seg.duration * 1000.0) as u64 + 100));
}
println!("\n{:=<80}", "");
println!("✅ Demo completed! Played {} segments", total_segments);
println!("{:=<80}", "");
} else {
println!("⚠️ No ASR or ASRX data loaded");
}
if raw_mode_enabled {
crossterm_terminal::disable_raw_mode().ok();
}
Ok(())
}
fn main() -> Result<()> {
let args = Args::parse();
if !args.video.exists() {
anyhow::bail!("Video file not found: {:?}", args.video);
}
println!("🎬 Integrated Player for ASR/Face/ASRX/Pose");
println!("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
println!("Video: {:?}", args.video);
let mut player = IntegratedPlayer::new();
if let Some(asr_path) = &args.asr {
if asr_path.exists() {
player.load_asr(asr_path)?;
}
}
if let Some(face_path) = &args.face {
if face_path.exists() {
player.load_face(face_path)?;
}
}
if let Some(asrx_path) = &args.asrx {
if asrx_path.exists() {
player.load_asrx(asrx_path)?;
}
}
if let Some(pose_path) = &args.pose {
if pose_path.exists() {
player.load_pose(pose_path)?;
}
}
player.list_speakers();
if args.continuous_demo {
run_continuous_demo(&player, &args)?;
} else {
println!("\n⚠️ Please use --continuous-demo flag");
}
Ok(())
}

View File

@@ -0,0 +1,711 @@
use anyhow::{Context, Result};
use clap::Parser;
use crossterm::event::{self, Event, KeyCode, KeyModifiers};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::io::{self, Write};
use std::path::PathBuf;
use std::process::{Child, Command, Stdio};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::thread;
use std::time::Duration;
#[derive(Parser, Debug)]
#[command(name = "integrated_player")]
#[command(about = "Integrated player for ASR, Face, ASRX, and Pose")]
struct Args {
#[arg(short, long)]
video: PathBuf,
#[arg(short = 'r', long)]
asr: Option<PathBuf>,
#[arg(short = 'f', long)]
face: Option<PathBuf>,
#[arg(short = 'x', long)]
asrx: Option<PathBuf>,
#[arg(short = 'p', long)]
pose: Option<PathBuf>,
#[arg(short = 's', long, default_value = "0.0")]
start: f64,
#[arg(long)]
speaker_name: Option<String>,
#[arg(long)]
auto_play_speaker: bool,
#[arg(long)]
demo: bool,
#[arg(long, default_value = "3")]
demo_segments_per_speaker: usize,
#[arg(long, default_value = "2.0")]
demo_speed: f64,
#[arg(long)]
show_video: bool,
#[arg(long, default_value = "800")]
video_width: u32,
#[arg(long, default_value = "600")]
video_height: u32,
#[arg(long)]
continuous_demo: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct AsrSegment {
start: f64,
end: f64,
text: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct AsrData {
language: Option<String>,
segments: Vec<AsrSegment>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct FaceDetection {
frame: u64,
timestamp: f64,
x: i32,
y: i32,
width: i32,
height: i32,
confidence: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct FaceResult {
results: FaceResults,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct FaceResults {
detections: Vec<FaceDetection>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct AsrxSegment {
index: usize,
start: f64,
end: f64,
duration: f64,
speaker: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct AsrxData {
segments: Vec<AsrxSegment>,
speaker_stats: HashMap<String, SpeakerStats>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct SpeakerStats {
count: usize,
duration: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct Keypoint {
name: String,
x: f32,
y: f32,
confidence: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct PersonPose {
keypoints: Vec<Keypoint>,
bbox: Bbox,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct Bbox {
x: i32,
y: i32,
width: i32,
height: i32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct PoseFrame {
frame: u64,
timestamp: f64,
persons: Vec<PersonPose>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct PoseData {
frames: Vec<PoseFrame>,
}
#[derive(Debug, Clone)]
struct IntegratedSegment {
start: f64,
end: f64,
text: Option<String>,
speaker: Option<String>,
face: Option<FaceDetection>,
mouth_landmarks: Option<Vec<Keypoint>>,
}
struct IntegratedPlayer {
asr_data: Option<AsrData>,
face_data: Option<FaceResult>,
asrx_data: Option<AsrxData>,
pose_data: Option<PoseData>,
current_time: f64,
is_playing: bool,
speaker_names: HashMap<String, (String, String)>,
}
impl IntegratedPlayer {
fn new() -> Self {
let mut speaker_names = HashMap::new();
speaker_names.insert(
"SPEAKER_0".to_string(),
("Cary Grant".to_string(), "Peter Joshua".to_string()),
);
speaker_names.insert(
"SPEAKER_1".to_string(),
("Audrey Hepburn".to_string(), "Regina Lampert".to_string()),
);
speaker_names.insert(
"SPEAKER_2".to_string(),
(
"Walter Matthau".to_string(),
"Hamilton Bartholomew".to_string(),
),
);
speaker_names.insert(
"SPEAKER_4".to_string(),
("James Coburn".to_string(), "Tex Panthollow".to_string()),
);
Self {
asr_data: None,
face_data: None,
asrx_data: None,
pose_data: None,
current_time: 0.0,
is_playing: false,
speaker_names,
}
}
fn load_asr(&mut self, path: &PathBuf) -> Result<()> {
let content = std::fs::read_to_string(path)
.with_context(|| format!("Failed to read ASR file: {:?}", path))?;
self.asr_data = Some(serde_json::from_str(&content)?);
println!(
"✓ Loaded {} ASR segments",
self.asr_data.as_ref().unwrap().segments.len()
);
Ok(())
}
fn load_face(&mut self, path: &PathBuf) -> Result<()> {
let content = std::fs::read_to_string(path)
.with_context(|| format!("Failed to read Face file: {:?}", path))?;
self.face_data = Some(serde_json::from_str(&content)?);
println!(
"✓ Loaded {} face detections",
self.face_data.as_ref().unwrap().results.detections.len()
);
Ok(())
}
fn load_asrx(&mut self, path: &PathBuf) -> Result<()> {
let content = std::fs::read_to_string(path)
.with_context(|| format!("Failed to read ASRX file: {:?}", path))?;
self.asrx_data = Some(serde_json::from_str(&content)?);
println!(
"✓ Loaded {} ASRX segments, {} speakers",
self.asrx_data.as_ref().unwrap().segments.len(),
self.asrx_data.as_ref().unwrap().speaker_stats.len()
);
Ok(())
}
fn load_pose(&mut self, path: &PathBuf) -> Result<()> {
let content = std::fs::read_to_string(path)
.with_context(|| format!("Failed to read Pose file: {:?}", path))?;
self.pose_data = Some(serde_json::from_str(&content)?);
println!(
"✓ Loaded {} pose frames",
self.pose_data.as_ref().unwrap().frames.len()
);
Ok(())
}
fn get_current_segment(&self, time: f64) -> Option<IntegratedSegment> {
let mut segment = IntegratedSegment {
start: 0.0,
end: 0.0,
text: None,
speaker: None,
face: None,
mouth_landmarks: None,
};
if let Some(asr) = &self.asr_data {
for seg in &asr.segments {
if time >= seg.start && time <= seg.end {
segment.start = seg.start;
segment.end = seg.end;
segment.text = Some(seg.text.clone());
break;
}
}
}
if let Some(asrx) = &self.asrx_data {
for seg in &asrx.segments {
if time >= seg.start && time <= seg.end {
segment.start = seg.start;
segment.end = seg.end;
segment.speaker = Some(seg.speaker.clone());
break;
}
}
}
if let Some(face) = &self.face_data {
for det in &face.results.detections {
if (det.timestamp - time).abs() < 1.0 {
segment.face = Some(det.clone());
break;
}
}
}
if let Some(pose) = &self.pose_data {
for frame in &pose.frames {
if (frame.timestamp - time).abs() < 0.5 {
if let Some(person) = frame.persons.first() {
let mouth_points: Vec<Keypoint> = person
.keypoints
.iter()
.filter(|kp| {
kp.name.contains("mouth")
|| kp.name.contains("lip")
|| kp.name == "nose"
})
.cloned()
.collect();
if !mouth_points.is_empty() {
segment.mouth_landmarks = Some(mouth_points);
break;
}
}
}
}
}
if segment.text.is_some()
|| segment.speaker.is_some()
|| segment.face.is_some()
|| segment.mouth_landmarks.is_some()
{
Some(segment)
} else {
None
}
}
fn get_speaker_info(&self, speaker_id: &str) -> (String, String) {
self.speaker_names
.get(speaker_id)
.cloned()
.unwrap_or_else(|| ("Unknown".to_string(), "Unknown".to_string()))
}
fn print_segment(&self, segment: &IntegratedSegment) {
println!("\n{:=<80}", "");
println!("⏱ Time: {:.2}s - {:.2}s", segment.start, segment.end);
if let Some(text) = &segment.text {
println!("📝 Text: {}", text);
}
if let Some(speaker) = &segment.speaker {
let (actor, character) = self.get_speaker_info(speaker);
println!("🎤 Speaker: {}{} ({})", speaker, actor, character);
}
if let Some(face) = &segment.face {
println!(
"👤 Face: bbox=({},{}) {}x{}, confidence={:.2}",
face.x, face.y, face.width, face.height, face.confidence
);
}
if let Some(landmarks) = &segment.mouth_landmarks {
println!("👄 Mouth landmarks: {} points", landmarks.len());
for kp in landmarks.iter().take(3) {
println!(
"{}: ({:.1}, {:.1}) conf={:.2}",
kp.name, kp.x, kp.y, kp.confidence
);
}
}
println!("{:=<80}", "");
}
fn list_speakers(&self) {
if let Some(asrx) = &self.asrx_data {
println!("\n📊 Speaker Statistics:");
println!("{:-<80}", "");
println!(
"{:15} {:20} {:20} {:>10} {:>10}",
"Speaker ID", "Actor", "Character", "Segments", "Duration"
);
println!("{:-<80}", "");
for (speaker_id, stats) in &asrx.speaker_stats {
let (actor, character) = self.get_speaker_info(speaker_id);
println!(
"{:15} {:20} {:20} {:>10} {:>9.1}s",
speaker_id, actor, character, stats.count, stats.duration
);
}
println!("{:-<80}", "");
}
}
}
fn play_segment(video_path: &PathBuf, start: f64, duration: f64, show_video: bool) -> Result<()> {
println!("▶️ Playing {:.2}s - {:.2}s", start, start + duration);
let mut cmd = Command::new("ffplay");
if show_video {
cmd.args([
"-ss",
&format!("{:.2}", start),
"-t",
&format!("{:.2}", duration),
"-autoexit",
video_path.to_str().unwrap(),
]);
} else {
cmd.args([
"-ss",
&format!("{:.2}", start),
"-t",
&format!("{:.2}", duration),
"-autoexit",
"-nodisp",
video_path.to_str().unwrap(),
]);
}
let _child = cmd
.stdout(Stdio::null())
.stderr(Stdio::null())
.spawn()
.context("Failed to start ffplay")?;
Ok(())
}
fn play_speaker_segments(
player: &IntegratedPlayer,
video_path: &PathBuf,
speaker_id: &str,
limit: Option<usize>,
) -> Result<()> {
if let Some(asrx) = &player.asrx_data {
let segments: Vec<&AsrxSegment> = asrx
.segments
.iter()
.filter(|s| s.speaker == speaker_id)
.collect();
let total = segments.len();
let count = limit.unwrap_or(total).min(total);
println!("\n🎬 Playing {} segments for {}", count, speaker_id);
for (i, seg) in segments.iter().take(count).enumerate() {
println!("\n[{}/{}] Segment {}", i + 1, count, seg.index);
if let Some(segment) = player.get_current_segment(seg.start + 0.1) {
player.print_segment(&segment);
}
play_segment(video_path, seg.start, seg.duration, false)?;
thread::sleep(Duration::from_millis(500));
}
println!("\n✅ Finished playing {} segments", count);
}
Ok(())
}
fn run_demo(player: &IntegratedPlayer, args: &Args) -> Result<()> {
}
}
});
if let Some(asr) = &player.asr_data {
let total_segments = asr.segments.len();
for (i, seg) in asr.segments.iter().enumerate() {
// 檢查是否退出
if quit.load(Ordering::SeqCst) {
println!("\n⏹️ Stopped by user");
break;
}
// 檢查是否暫停
while paused.load(Ordering::SeqCst) {
println!("\r⏸️ Paused - Press SPACE to resume",);
std::io::stdout().flush()?;
thread::sleep(Duration::from_millis(100));
if quit.load(Ordering::SeqCst) {
println!("\n⏹️ Stopped by user");
return Ok(());
}
}
println!("\n[{}/{}] Segment", i + 1, total_segments);
println!("{:=<80}", "");
// 顯示所有信息
if let Some(segment) = player.get_current_segment(seg.start + 0.01) {
player.print_segment(&segment);
}
// 播放音頻/視頻
let duration = seg.end - seg.start;
println!(
"▶️ Playing: {:.2}s - {:.2}s ({:.2}s)",
seg.start, seg.end, duration
);
let mut cmd = Command::new("ffplay");
if args.show_video {
cmd.args([
"-ss",
&format!("{:.2}", seg.start),
"-t",
&format!("{:.2}", duration),
"-autoexit",
"-x",
&format!("{}", args.video_width),
"-y",
&format!("{}", args.video_height),
args.video.to_str().unwrap(),
]);
} else {
cmd.args([
"-ss",
&format!("{:.2}", seg.start),
"-t",
&format!("{:.2}", duration),
"-autoexit",
"-nodisp",
args.video.to_str().unwrap(),
]);
}
let _child = cmd
.stdout(Stdio::null())
.stderr(Stdio::null())
.spawn()
.context("Failed to start ffplay")?;
// 等待播放完成
thread::sleep(Duration::from_millis((duration * 1000.0) as u64 + 100));
}
println!("\n{:=<80}", "");
println!("✅ Demo completed! Played {} segments", total_segments);
println!("{:=<80}", "");
} else {
println!("⚠️ No ASR data loaded");
}
Ok(())
}
fn run_demo(player: &IntegratedPlayer, args: &Args) -> Result<()> {
println!("\n🎬 Auto Demo Mode");
println!("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
println!("Segments per speaker: {}", args.demo_segments_per_speaker);
println!("Demo speed: {:.1}x", args.demo_speed);
println!();
if let Some(asrx) = &player.asrx_data {
let mut speaker_ids: Vec<String> = asrx.speaker_stats.keys().cloned().collect();
speaker_ids.sort();
for speaker_id in &speaker_ids {
let (actor, character) = player.get_speaker_info(speaker_id);
println!("\n{:=<80}", "");
println!("🎭 Demo: {}{} ({})", speaker_id, actor, character);
println!("{:=<80}", "");
let segments: Vec<&AsrxSegment> = asrx
.segments
.iter()
.filter(|s| s.speaker == *speaker_id)
.collect();
let count = args.demo_segments_per_speaker.min(segments.len());
for (i, seg) in segments.iter().take(count).enumerate() {
println!("\n[Segment {}/{}]", i + 1, count);
if let Some(segment) = player.get_current_segment(seg.start + 0.1) {
player.print_segment(&segment);
}
println!(
"⏳ Playing audio ({:.1}s)...",
seg.duration / args.demo_speed
);
let _child = Command::new("ffplay")
.args([
"-ss",
&format!("{:.2}", seg.start),
"-t",
&format!("{:.2}", seg.duration / args.demo_speed),
"-autoexit",
"-nodisp",
args.video.to_str().unwrap(),
])
.stdout(Stdio::null())
.stderr(Stdio::null())
.spawn()
.context("Failed to start ffplay")?;
thread::sleep(Duration::from_millis(
((seg.duration / args.demo_speed) * 1000.0) as u64 + 500,
));
}
println!("\n⏸️ Pausing 2 seconds before next speaker...");
thread::sleep(Duration::from_secs(2));
}
println!("\n{:=<80}", "");
println!("✅ Demo completed!");
println!("{:=<80}", "");
}
Ok(())
}
fn main() -> Result<()> {
let args = Args::parse();
if !args.video.exists() {
anyhow::bail!("Video file not found: {:?}", args.video);
}
println!("🎬 Integrated Player for ASR/Face/ASRX/Pose");
println!("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
println!("Video: {:?}", args.video);
let mut player = IntegratedPlayer::new();
if let Some(asr_path) = &args.asr {
if asr_path.exists() {
player.load_asr(asr_path)?;
}
}
if let Some(face_path) = &args.face {
if face_path.exists() {
player.load_face(face_path)?;
}
}
if let Some(asrx_path) = &args.asrx {
if asrx_path.exists() {
player.load_asrx(asrx_path)?;
}
}
if let Some(pose_path) = &args.pose {
if pose_path.exists() {
player.load_pose(pose_path)?;
}
}
player.list_speakers();
if args.demo {
run_demo(&player, &args)?;
} else if args.continuous_demo {
run_continuous_demo(&player, &args)?;
} else if args.auto_play_speaker {
if let Some(speaker_id) = &args.speaker_name {
play_speaker_segments(&player, &args.video, speaker_id, Some(5))?;
} else {
println!("\n⚠️ --speaker-name required for --auto-play-speaker");
}
} else {
println!("\n🎮 Interactive Mode");
println!(" Commands:");
println!(" • Enter time in seconds to seek");
println!(" • 's' to show current segment");
println!(" • 'l' to list speakers");
println!(" • 'p <speaker>' to play speaker segments");
println!(" • 'q' to quit");
println!();
loop {
print!("> ");
std::io::Write::flush(&mut std::io::stdout())?;
let mut input = String::new();
std::io::stdin().read_line(&mut input)?;
let input = input.trim();
if input == "q" || input == "quit" || input == "exit" {
break;
} else if input == "s" || input == "show" {
if let Some(segment) = player.get_current_segment(player.current_time) {
player.print_segment(&segment);
} else {
println!("No segment at time {:.2}s", player.current_time);
}
} else if input == "l" || input == "list" {
player.list_speakers();
} else if input.starts_with("p ") {
let speaker_id = input.strip_prefix("p ").unwrap();
play_speaker_segments(&player, &args.video, speaker_id, Some(3))?;
} else if let Ok(time) = input.parse::<f64>() {
player.current_time = time;
println!("Seeked to {:.2}s", time);
if let Some(segment) = player.get_current_segment(time) {
player.print_segment(&segment);
} else {
println!("No segment at this time");
}
}
}
}
Ok(())
}

View File

@@ -0,0 +1,92 @@
// Migration script to tokenize existing Chinese text in the database
// Usage: cargo run --bin migrate_chinese_text
use dotenv;
use momentry_core::core::text::tokenizer::tokenize_chinese_text;
use sqlx::{postgres::PgPoolOptions, Row};
use std::env;
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
// Load environment variables from .env file
dotenv::dotenv().ok();
// Get database URL from environment
let database_url = env::var("DATABASE_URL")
.unwrap_or_else(|_| "postgres://accusys@localhost:5432/momentry".to_string());
println!("Connecting to database...");
// Create connection pool
let pool = PgPoolOptions::new()
.max_connections(5)
.connect(&database_url)
.await?;
println!("Fetching Chinese chunks from database...");
// Get all chunks with Chinese text using raw query to avoid sqlx macro issues
let query = r#"
SELECT id, text_content, content->'data'->>'text' as chinese_text, content->>'text' as english_text
FROM chunks
WHERE text_content ~ '[\u4e00-\u9fff]'
ORDER BY id
"#;
let rows = sqlx::query(query).fetch_all(&pool).await?;
println!("Found {} Chinese chunks to process", rows.len());
let mut updated_count = 0;
for row in &rows {
let id: i32 = row.get(0);
let text_content: Option<String> = row.get(1);
let chinese_text: Option<String> = row.get(2);
let english_text: Option<String> = row.get(3);
// Clone text_content for later comparison
let text_content_clone = text_content.clone();
// Determine the original text (prioritize chinese_text from content->'data'->>'text')
let original_text = if let Some(ref chinese_text) = chinese_text {
chinese_text.as_str()
} else if let Some(ref english_text) = english_text {
english_text.as_str()
} else {
text_content.as_deref().unwrap_or("")
};
// Tokenize the text
let tokenized_text = tokenize_chinese_text(original_text);
// Check if tokenization changed the text
let current_text = text_content_clone.unwrap_or_default();
if current_text == tokenized_text {
println!("Skipping chunk {} - already tokenized", id);
continue;
}
println!("Updating chunk {}:", id);
println!(" Original: {}", original_text);
println!(" Tokenized: {}", tokenized_text);
// Update the chunk
sqlx::query("UPDATE chunks SET text_content = $1 WHERE id = $2")
.bind(&tokenized_text)
.bind(id)
.execute(&pool)
.await?;
updated_count += 1;
}
println!("\nMigration completed!");
println!(
"Updated {} out of {} Chinese chunks",
updated_count,
rows.len()
);
Ok(())
}

View File

@@ -0,0 +1,68 @@
use anyhow::{Context, Result};
use momentry_core::core::db::{Database, PostgresDb};
use std::env;
#[tokio::main]
async fn main() -> Result<()> {
env::set_var("RUST_LOG", "info");
println!("=== BM25 簡單測試 ===\n");
// 初始化 PostgreSQL
let pg = PostgresDb::init()
.await
.context("Failed to initialize PostgreSQL database")?;
// 測試查詢
let test_queries = vec![
("telephone", Some("384b0ff44aaaa1f1")),
("工作", Some("9760d0820f0cf9a7")),
("团体", Some("9760d0820f0cf9a7")), // Simplified Chinese, should match Traditional "團體"
("computer", None),
];
for (query_str, uuid_opt) in test_queries {
println!(
"\n🔍 測試查詢: '{}' {}",
query_str,
uuid_opt
.map(|u| format!("(uuid: {})", u))
.unwrap_or_default()
);
// 顯示轉換後的 tsquery (除錯用)
match pg.prepare_tsquery(query_str) {
Ok(tsquery) => println!(" TSQUERY: {}", tsquery),
Err(e) => println!(" TSQUERY 錯誤: {}", e),
}
let results = pg.search_bm25(query_str, uuid_opt, 5).await?;
println!("找到 {} 筆結果:", results.len());
for (i, r) in results.iter().enumerate() {
let text_preview: String = r.text.chars().take(60).collect();
let text_preview = if r.text.chars().count() > 60 {
format!("{}...", text_preview)
} else {
text_preview
};
println!(
" {}. {} (uuid: {}, chunk_id: {})",
i + 1,
text_preview,
r.uuid,
r.chunk_id
);
println!(
" 分數: {:.4}, 時間: {:.1}-{:.1}s, 類型: {}",
r.bm25_score, r.start_time, r.end_time, r.chunk_type
);
}
if results.is_empty() {
println!(" ⚠️ 沒有找到結果");
}
}
Ok(())
}

View File

@@ -0,0 +1,37 @@
use anyhow::{Context, Result};
use momentry_core::core::db::{Database, PostgresDb};
use std::env;
#[tokio::main]
async fn main() -> Result<()> {
env::set_var("RUST_LOG", "info");
println!("=== 簡體中文轉換測試 ===\n");
// 初始化 PostgreSQL
let pg = PostgresDb::init()
.await
.context("Failed to initialize PostgreSQL database")?;
// 測試查詢:簡體中文
let test_queries = vec!["团体", "视频", "文件"];
for query_str in test_queries {
println!("\n🔍 測試查詢 (簡體): '{}'", query_str);
// 顯示轉換後的 tsquery
match pg.prepare_tsquery(query_str) {
Ok(tsquery) => println!(" TSQUERY: {}", tsquery),
Err(e) => println!(" TSQUERY 錯誤: {}", e),
}
// 執行搜索
let results = pg.search_bm25(query_str, None, 5).await?;
println!(" 找到 {} 筆結果", results.len());
for (i, r) in results.iter().enumerate() {
println!(" {}. [{}] {}", i + 1, r.uuid, r.text);
}
}
Ok(())
}

View File

@@ -0,0 +1,23 @@
use momentry_core::core::text::global_synonym_expander;
fn main() {
let expander = global_synonym_expander();
println!("=== 中文同義詞擴展測試 ===");
let test_queries = vec!["電腦", "電腦工作", "工作檔案", "視頻分析", "電腦工作檔案"];
for query in test_queries {
println!("\n查詢: '{}'", query);
let expanded = expander.expand_chinese_query(query);
println!("擴展結果: {}", expanded);
// 測試單詞擴展
println!("單詞擴展:");
if let Some(syns) = expander.get_synonyms(query) {
println!(" '{}' -> {:?}", query, syns);
} else {
println!(" '{}' 沒有同義詞", query);
}
}
}

View File

@@ -0,0 +1,56 @@
use anyhow::{Context, Result};
use momentry_core::core::db::{Database, PostgresDb};
use momentry_core::core::text::tokenizer::{contains_chinese, tokenize_chinese_text};
use momentry_core::core::text::{global_synonym_expander, normalize_chinese_query};
use std::env;
#[tokio::main]
async fn main() -> Result<()> {
env::set_var("RUST_LOG", "info");
println!("=== 同義詞擴展測試 ===\n");
// 初始化 PostgreSQL
let pg = PostgresDb::init()
.await
.context("Failed to initialize PostgreSQL database")?;
let expander = global_synonym_expander();
// 測試查詢
let test_queries = vec![
"電腦",
"視頻",
"分析",
"工作",
"檔案",
"電腦工作",
"工作檔案",
];
for query_str in test_queries {
println!("\n🔍 測試查詢: '{}'", query_str);
// 顯示同義詞擴展
if contains_chinese(query_str) {
let normalized = normalize_chinese_query(query_str);
let expanded = expander.expand_chinese_query(&normalized);
println!(" 同義詞擴展: {}", expanded);
}
// 顯示轉換後的 tsquery
match pg.prepare_tsquery(query_str) {
Ok(tsquery) => println!(" TSQUERY: {}", tsquery),
Err(e) => println!(" TSQUERY 錯誤: {}", e),
}
// 執行搜索(即使沒有結果)
let results = pg.search_bm25(query_str, None, 2).await?;
println!(" 找到 {} 筆結果", results.len());
for (i, r) in results.iter().enumerate() {
println!(" {}. [{}] {}", i + 1, r.uuid, r.text);
}
}
Ok(())
}

View File

@@ -0,0 +1,56 @@
use anyhow::{Context, Result};
use momentry_core::core::db::{Database, PostgresDb};
use momentry_core::core::text::tokenizer::{contains_chinese, tokenize_chinese_text};
use momentry_core::core::text::{global_synonym_expander, normalize_chinese_query};
use std::env;
#[tokio::main]
async fn main() -> Result<()> {
env::set_var("RUST_LOG", "info");
println!("=== 同義詞擴展測試 ===\n");
// 初始化 PostgreSQL
let pg = PostgresDb::init()
.await
.context("Failed to initialize PostgreSQL database")?;
let expander = global_synonym_expander();
// 測試查詢
let test_queries = vec![
"電腦",
"視頻",
"分析",
"工作",
"檔案",
"電腦工作",
"工作檔案",
];
for query_str in test_queries {
println!("\n🔍 測試查詢: '{}'", query_str);
// 顯示同義詞擴展
if contains_chinese(query_str) {
let normalized = normalize_chinese_query(query_str);
let expanded = expander.expand_chinese_query(&normalized);
println!(" 同義詞擴展: {}", expanded);
}
// 顯示轉換後的 tsquery
match pg.prepare_tsquery(query_str) {
Ok(tsquery) => println!(" TSQUERY: {}", tsquery),
Err(e) => println!(" TSQUERY 錯誤: {}", e),
}
// 執行搜索(即使沒有結果)
let results = pg.search_bm25(query_str, None, 2).await?;
println!(" 找到 {} 筆結果", results.len());
for (i, r) in results.iter().enumerate() {
println!(" {}. [{}] {}", i + 1, r.uuid, r.text);
}
}
Ok(())
}

View File

@@ -0,0 +1,27 @@
use momentry_core::core::text::tokenizer::{contains_chinese, tokenize_chinese_text};
fn main() {
let texts = ["電腦", "工作", "視頻", "分析", "檔案", "這是一個測試"];
for text in texts {
let tokens = tokenize_chinese_text(text);
println!("Text: '{}' -> Tokens: '{}'", text, tokens);
let split: Vec<&str> = tokens.split_whitespace().collect();
println!(" Split: {:?}", split);
}
println!("\n=== Testing complex queries ===");
let complex = [
"(電腦 | 計算機 | 微机)",
"(工作 | 任務 | 作業)",
"電腦 & 工作",
"(電腦:* | 計算機:* | 微机:*)",
];
for query in complex {
let tokens = tokenize_chinese_text(query);
println!("Query: '{}' -> Tokens: '{}'", query, tokens);
let split: Vec<&str> = tokens.split_whitespace().collect();
println!(" Split: {:?}", split);
println!("---");
}
}

View File

@@ -0,0 +1,94 @@
use crate::core::config::OUTPUT_DIR;
use anyhow::{Context, Result};
use serde::Deserialize;
use sqlx::PgPool;
use std::fs;
use std::path::Path;
// --- 結構體定義 (對齊外部處理器產出格式) ---
#[derive(Debug, Deserialize)]
struct AsrSegment {
start: f64,
end: f64,
text: String,
}
#[derive(Debug, Deserialize)]
struct AsrxSegment {
start: f64,
end: f64,
speaker: String,
}
// --- 核心邏輯 ---
/// 執行 Rule 1 入庫
/// 讀取 asr.json 與 asrx.json合併 Speaker 資訊,寫入 chunks_rule1
pub async fn ingest_rule1(pool: &PgPool, asset_uuid: &str, fps: f64) -> Result<usize> {
// 1. 讀取檔案
let asr_path = format!("{}/{}.asr.json", *OUTPUT_DIR, asset_uuid);
let asrx_path = format!("{}/{}.asrx.json", *OUTPUT_DIR, asset_uuid);
let asr_content = fs::read_to_string(&asr_path)
.with_context(|| format!("Failed to read ASR file: {}", asr_path))?;
let asrx_content = fs::read_to_string(&asrx_path)
.with_context(|| format!("Failed to read ASRX file: {}", asrx_path))?;
let asr_segments: Vec<AsrSegment> = serde_json::from_str(&asr_content)?;
let asrx_segments: Vec<AsrxSegment> = serde_json::from_str(&asrx_content)?;
let mut count = 0;
// 2. 交易處理
let mut tx = pool.begin().await?;
for seg in &asr_segments {
// 時間轉幀
let start_frame = (seg.start * fps).round() as i64;
let end_frame = (seg.end * fps).round() as i64;
// 3. 尋找重疊最多的 Speaker
let mut best_speaker: Option<String> = None;
let mut max_overlap = 0.0f64;
for spk in &asrx_segments {
let overlap = (seg.end.min(spk.end) - seg.start.max(spk.start)).max(0.0);
if overlap > max_overlap {
max_overlap = overlap;
best_speaker = Some(spk.speaker.clone());
}
}
let speaker_id = best_speaker.unwrap_or("UNKNOWN".to_string());
// 4. 寫入 DB
sqlx::query!(
r#"
INSERT INTO chunks_rule1 (
id, asset_uuid, start_frame, end_frame, content, speaker_id
) VALUES (
gen_random_uuid(), $1, $2, $3, $4, $5
)
"#,
asset_uuid,
start_frame,
end_frame,
seg.text,
speaker_id
)
.execute(&mut *tx)
.await?;
count += 1;
// 每 100 筆 Commit 一次 (可選優化)
if count % 500 == 0 {
tx.commit().await?;
tx = pool.begin().await?;
}
}
tx.commit().await?;
Ok(count)
}

View File

@@ -0,0 +1,182 @@
use crate::core::config::OUTPUT_DIR;
use crate::core::llm::client::generate_5w1h_summary;
use anyhow::{Context, Result};
use serde::Deserialize;
use sqlx::PgPool;
use std::fs;
use tracing::{info, warn};
#[derive(Debug, Deserialize)]
struct CutScene {
scene_number: u32,
start_frame: u64,
end_frame: u64,
start_time: f64,
end_time: f64,
}
#[derive(Debug, Deserialize)]
struct CutResult {
scenes: Vec<CutScene>,
}
#[derive(Debug, Deserialize)]
struct AsrSegment {
start: f64,
end: f64,
text: String,
}
/// Executes Rule 3 Ingestion: Scene-based Chunking with LLM 5W1H+ Summary.
/// 1. Reads CUT data to identify scenes.
/// 2. Aggregates Rule 1 (Sentence) chunks falling within each scene.
/// 3. Calls LLM to generate 5W1H+ summary.
/// 4. Inserts parent chunks into `dev.chunks`.
pub async fn ingest_rule3(pool: &PgPool, asset_uuid: &str) -> Result<usize> {
let cut_path = format!("{}/{}.cut.json", *OUTPUT_DIR, asset_uuid);
let asr_path = format!("{}/{}.asr.json", *OUTPUT_DIR, asset_uuid);
// 1. Load CUT and ASR data
let cut_content = fs::read_to_string(&cut_path)
.with_context(|| format!("Failed to read CUT file: {}", cut_path))?;
let cut_result: CutResult = serde_json::from_str(&cut_content).context("Invalid CUT JSON")?;
let asr_segments: Vec<AsrSegment> = match fs::read_to_string(&asr_path) {
Ok(content) => serde_json::from_str(&content).unwrap_or_default(),
Err(_) => {
warn!("ASR file not found, proceeding with empty transcript for scenes");
vec![]
}
};
let mut count = 0;
let mut tx = pool.begin().await?;
// 2. Process each scene
for scene in &cut_result.scenes {
let chunk_id = format!("scene_{}", scene.scene_number);
// Aggregate text from Rule 1 chunks
let mut scene_text = String::new();
let mut child_ids: Vec<String> = Vec::new();
for seg in &asr_segments {
if seg.start >= scene.start_time && seg.end <= scene.end_time {
scene_text.push_str(&seg.text);
scene_text.push(' ');
// We'll look up the chunk_id from Rule 1 later if needed,
// but for now we just group by text overlap.
// A better approach is to query Rule 1 table for this range.
}
}
// Query Rule 1 table for better linking
let rule1_rows: Vec<(String,)> = sqlx::query_as(
r#"
SELECT id::text FROM chunks_rule1
WHERE asset_uuid = $1
AND start_frame >= $2
AND end_frame <= $3
"#,
)
.bind(asset_uuid)
.bind(scene.start_frame as i64)
.bind(scene.end_frame as i64)
.fetch_all(&mut *tx)
.await?;
for row in &rule1_rows {
child_ids.push(row.0.clone());
}
// Fallback to simple aggregation if query didn't get text (due to frame boundaries)
if scene_text.is_empty() {
// Try to grab text directly if rule1 table doesn't have it or boundaries differ
// But rule1 table has start_frame/end_frame which should match.
// Let's re-query text directly.
}
let texts: Vec<String> = sqlx::query_scalar(
r#"
SELECT content FROM chunks_rule1
WHERE asset_uuid = $1
AND start_frame >= $2
AND end_frame <= $3
ORDER BY start_frame ASC
"#,
)
.bind(asset_uuid)
.bind(scene.start_frame as i64)
.bind(scene.end_frame as i64)
.fetch_all(&mut *tx)
.await?;
let aggregated_text = texts.join(" ");
// 3. Call LLM for Summary
let summary = if !aggregated_text.is_empty() {
match generate_5w1h_summary(&aggregated_text).await {
Ok(s) => s,
Err(e) => {
warn!("LLM Summary failed for scene {}: {}", scene.scene_number, e);
"LLM Error".to_string()
}
}
} else {
"No Audio".to_string()
};
info!(
"Scene {}: {} -> {} ({} sentences)",
scene.scene_number,
scene.start_time,
scene.end_time,
texts.len()
);
// 4. Insert into dev.chunks
let fps_query: Option<f64> = sqlx::query_scalar("SELECT fps FROM videos WHERE uuid = $1")
.bind(asset_uuid)
.fetch_optional(&mut *tx)
.await?;
let fps = fps_query.unwrap_or(29.97);
// Prepare metadata JSON
let metadata = serde_json::json!({
"type": "scene",
"scene_number": scene.scene_number
});
sqlx::query(
r#"
INSERT INTO chunks (
uuid, chunk_id, chunk_index, chunk_type,
start_time, end_time, fps, start_frame, end_frame,
content, text_content, summary_text, metadata, child_chunk_ids
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
ON CONFLICT (uuid, chunk_id) DO NOTHING
"#,
)
.bind(asset_uuid)
.bind(&chunk_id)
.bind(scene.scene_number as i32)
.bind("cut") // Chunk type
.bind(scene.start_time)
.bind(scene.end_time)
.bind(fps)
.bind(scene.start_frame as i64)
.bind(scene.end_frame as i64)
.bind(&metadata) // Content JSON
.bind(&aggregated_text) // Text content
.bind(&summary) // Summary
.bind(&metadata) // Metadata
.bind(&child_ids) // Child IDs
.execute(&mut *tx)
.await?;
count += 1;
}
tx.commit().await?;
Ok(count)
}

755
src/core/chunk/types.rs.bak Normal file
View File

@@ -0,0 +1,755 @@
use crate::core::time::FrameTime;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum ChunkType {
TimeBased,
Sentence,
Cut,
Trace,
Story, // Parent chunk from story analysis
Visual, // Visual object-based chunk from YOLO detection
}
impl ChunkType {
pub fn as_str(&self) -> &'static str {
match self {
ChunkType::TimeBased => "time",
ChunkType::Sentence => "sentence",
ChunkType::Cut => "cut",
ChunkType::Trace => "trace",
ChunkType::Story => "story",
ChunkType::Visual => "visual",
}
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum ChunkRule {
Rule1, // 直接轉換
Rule2, // 集合內容
}
/// 關鍵幀的物件列表
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KeyframeObjects {
/// 關鍵幀時間 (秒)
pub timestamp: f64,
/// 關鍵幀幀號
pub frame_number: u64,
/// 檢測到的物件
pub objects: Vec<DetectedObject>,
}
/// 檢測到的物件
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DetectedObject {
/// 物件類別名稱
pub class_name: String,
/// 物件類別 ID
pub class_id: u32,
/// 信心值 (0.0-1.0)
pub confidence: f32,
/// 邊界框 (x, y, width, height)
pub bbox: Option<BoundingBox>,
/// 出現次數 (在分片內)
pub occurrence: u32,
}
/// 邊界框
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VisualChunkContent {
pub start_time: f64,
pub end_time: f64,
pub keyframe_objects: Vec<KeyframeObjects>,
pub dominant_objects: Vec<String>,
pub object_relationships: Vec<(String, String, String)>, // (object1, relationship, object2)
pub scene_description: Option<String>,
pub metadata: VisualMetadata,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VisualMetadata {
pub object_count: u32,
pub unique_classes: Vec<String>,
pub max_confidence: f32,
pub avg_confidence: f32,
pub spatial_density: f32, // objects per frame
}
impl ChunkRule {
pub fn as_str(&self) -> &'static str {
match self {
ChunkRule::Rule1 => "rule_1",
ChunkRule::Rule2 => "rule_2",
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Chunk {
pub file_id: i32,
pub uuid: String,
pub chunk_id: String,
pub chunk_index: u32,
pub chunk_type: ChunkType,
pub rule: ChunkRule,
/// Frames per second (can be fractional, e.g., 29.97, 23.976)
pub fps: f64,
/// Start frame (0-based)
pub start_frame: i64,
/// End frame (exclusive)
pub end_frame: i64,
pub text_content: Option<String>,
pub content: serde_json::Value,
pub metadata: Option<serde_json::Value>,
pub vector_id: Option<String>,
pub frame_count: i32,
pub pre_chunk_ids: Vec<i32>,
pub parent_chunk_id: Option<String>, // For parent-child chunk hierarchy
pub child_chunk_ids: Vec<String>, // Child chunk IDs (for parent chunks)
pub visual_stats: Option<serde_json::Value>,
}
id: i64,
video_id: i64,
yolo_result: &crate::core::processor::yolo::YoloResult,
min_frames_per_chunk: usize,
similarity_threshold: f32,
) -> Vec<Self> {
if yolo_result.frames.is_empty() {
return vec![];
}
let mut chunks = Vec::new();
let mut current_chunk_frames = Vec::new();
let mut current_id = id;
for (i, frame) in yolo_result.frames.iter().enumerate() {
if current_chunk_frames.is_empty() {
current_chunk_frames.push(frame);
continue;
}
// Check similarity with last frame in current chunk
let last_frame = current_chunk_frames.last().unwrap();
let similarity = VisualChunkContent::frame_similarity(last_frame, frame);
if similarity >= similarity_threshold && current_chunk_frames.len() < 100 {
// Similar enough, add to current chunk
current_chunk_frames.push(frame);
} else {
// Not similar enough or chunk too large, create new chunk
if current_chunk_frames.len() >= min_frames_per_chunk {
if let Some(chunk) =
Self::create_chunk_from_frames(current_id, video_id, &current_chunk_frames)
{
chunks.push(chunk);
current_id += 1;
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::processor::yolo::{YoloFrame, YoloObject, YoloResult};
#[test]
fn test_chunk_type_visual_serialization() {
let chunk_type = ChunkType::Visual;
let json = serde_json::to_string(&chunk_type).unwrap();
assert_eq!(json, "\"visual\"");
let deserialized: ChunkType = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized, ChunkType::Visual);
}
#[test]
fn test_visual_chunk_creation() {
// Create a mock YOLO result
let yolo_result = YoloResult {
frame_count: 2,
fps: 30.0,
frames: vec![
YoloFrame {
frame: 0,
timestamp: 0.0,
objects: vec![
YoloObject {
class_name: "person".to_string(),
class_id: 0,
x: 100,
y: 200,
width: 50,
height: 100,
confidence: 0.95,
},
YoloObject {
class_name: "car".to_string(),
class_id: 2,
x: 300,
y: 150,
width: 80,
height: 60,
confidence: 0.87,
},
],
},
YoloFrame {
frame: 1,
timestamp: 0.033, // 1/30 second
objects: vec![YoloObject {
class_name: "person".to_string(),
class_id: 0,
x: 110,
y: 210,
width: 52,
height: 102,
confidence: 0.92,
}],
},
],
};
// Create visual chunk from YOLO result
let chunk = Chunk::from_yolo_result(1, 100, &yolo_result, 0, 1).unwrap();
// Verify chunk properties
assert_eq!(chunk.id, 1);
assert_eq!(chunk.video_id, 100);
assert_eq!(chunk.chunk_type, ChunkType::Visual);
assert_eq!(chunk.start_time, 0.0);
assert_eq!(chunk.end_time, 0.033);
// Verify visual content
if let ChunkContent::Visual(content) = chunk.content {
assert_eq!(content.metadata.object_count, 3);
assert_eq!(content.metadata.unique_classes.len(), 2);
assert!(content
.metadata
.unique_classes
.contains(&"person".to_string()));
assert!(content.metadata.unique_classes.contains(&"car".to_string()));
assert_eq!(content.dominant_objects, vec!["person"]);
assert_eq!(content.keyframe_objects.len(), 2);
} else {
panic!("Expected Visual content type");
}
}
#[test]
fn test_visual_chunk_content_methods() {
let content = VisualChunkContent {
start_time: 0.0,
end_time: 5.0,
keyframe_objects: vec![KeyframeObjects {
frame: 0,
timestamp: 0.0,
objects: vec![
DetectedObject {
class_name: "person".to_string(),
class_id: 0,
bounding_box: BoundingBox {
x: 100,
y: 200,
width: 50,
height: 100,
},
confidence: 0.95,
},
DetectedObject {
class_name: "car".to_string(),
class_id: 2,
bounding_box: BoundingBox {
x: 300,
y: 150,
width: 80,
height: 60,
},
confidence: 0.87,
},
],
}],
dominant_objects: vec!["person".to_string()],
object_relationships: vec![],
scene_description: Some("A person near a car".to_string()),
metadata: VisualMetadata {
object_count: 2,
unique_classes: vec!["person".to_string(), "car".to_string()],
max_confidence: 0.95,
avg_confidence: 0.91,
spatial_density: 2.0,
},
};
// Test summary method
let summary = content.summary();
assert!(summary.contains("Visual chunk from 0.0s to 5.0s"));
assert!(summary.contains("person"));
// Test contains_object method
assert!(content.contains_object("person"));
assert!(content.contains_object("car"));
assert!(!content.contains_object("dog"));
// Test high_confidence_objects method
let high_conf_objects = content.high_confidence_objects(0.9);
assert_eq!(high_conf_objects.len(), 1);
assert_eq!(high_conf_objects[0].class_name, "person");
}
#[test]
fn test_frame_similarity() {
let frame1 = YoloFrame {
frame: 0,
timestamp: 0.0,
objects: vec![
YoloObject {
class_name: "person".to_string(),
class_id: 0,
x: 100,
y: 200,
width: 50,
height: 100,
confidence: 0.95,
},
YoloObject {
class_name: "car".to_string(),
class_id: 2,
x: 300,
y: 150,
width: 80,
height: 60,
confidence: 0.87,
},
],
};
let frame2 = YoloFrame {
frame: 1,
timestamp: 0.033,
objects: vec![
YoloObject {
class_name: "person".to_string(),
class_id: 0,
x: 110,
y: 210,
width: 52,
height: 102,
confidence: 0.92,
},
YoloObject {
class_name: "car".to_string(),
class_id: 2,
x: 310,
y: 155,
width: 82,
height: 62,
confidence: 0.85,
},
],
};
let frame3 = YoloFrame {
frame: 2,
timestamp: 0.066,
objects: vec![YoloObject {
class_name: "dog".to_string(),
class_id: 16,
x: 150,
y: 250,
width: 40,
height: 60,
confidence: 0.78,
}],
};
// Test similar frames (same objects)
let similarity_same =
VisualChunkContent::frame_similarity(&frame1, &frame2);
assert!((similarity_same - 1.0).abs() < 0.001);
// Test dissimilar frames (different objects)
let similarity_diff =
VisualChunkContent::frame_similarity(&frame1, &frame3);
assert!((similarity_diff - 0.0).abs() < 0.001);
// Test empty frames
let empty_frame = YoloFrame {
frame: 3,
timestamp: 0.1,
objects: vec![],
};
let similarity_empty =
VisualChunkContent::frame_similarity(&empty_frame, &empty_frame);
assert!((similarity_empty - 1.0).abs() < 0.001);
let similarity_mixed =
VisualChunkContent::frame_similarity(&empty_frame, &frame1);
assert!((similarity_mixed - 0.0).abs() < 0.001);
}
}
current_chunk_frames = vec![frame];
}
}
// Handle last chunk
if current_chunk_frames.len() >= min_frames_per_chunk {
if let Some(chunk) =
Self::create_chunk_from_frames(current_id, video_id, &current_chunk_frames)
{
chunks.push(chunk);
}
}
chunks
}
fn create_chunk_from_frames(
id: i64,
video_id: i64,
frames: &[&crate::core::processor::yolo::YoloFrame],
) -> Option<Self> {
if frames.is_empty() {
return None;
}
// Simple conversion - could use the from_yolo_result method
let start_frame = frames.first().unwrap().frame;
let end_frame = frames.last().unwrap().frame;
let dummy_yolo_result = crate::core::processor::yolo::YoloResult {
frame_count: frames.len() as u64,
fps: 0.0, // Not used in this context
frames: frames.iter().map(|f| (*f).clone()).collect(),
};
Self::from_yolo_result(id, video_id, &dummy_yolo_result, start_frame, end_frame)
}
/// Creates a new chunk from seconds (legacy conversion).
///
/// This is useful for migrating from older systems that store time as seconds.
/// The frame counts are calculated by rounding `seconds * fps`.
#[allow(clippy::too_many_arguments)]
pub fn from_seconds(
file_id: i32,
uuid: String,
chunk_index: u32,
chunk_type: ChunkType,
rule: ChunkRule,
start_time: f64,
end_time: f64,
fps: f64,
content: serde_json::Value,
) -> Self {
let start_frame = (start_time * fps).round() as i64;
let end_frame = (end_time * fps).round() as i64;
Self::new(
file_id,
uuid,
chunk_index,
chunk_type,
rule,
start_frame,
end_frame,
fps,
content,
)
}
/// Returns the start time as a `FrameTime`.
pub fn start_time(&self) -> FrameTime {
FrameTime::from_frames(self.start_frame, self.fps)
}
/// Returns the end time as a `FrameTime`.
pub fn end_time(&self) -> FrameTime {
FrameTime::from_frames(self.end_frame, self.fps)
}
/// Returns the duration in frames.
pub fn duration_frames(&self) -> i64 {
self.end_frame - self.start_frame
}
/// Returns the duration in seconds.
pub fn duration_seconds(&self) -> f64 {
self.duration_frames() as f64 / self.fps
}
/// Formats the start time as "seconds.frame" (e.g., "123.04").
pub fn format_start_sec_frame(&self) -> String {
self.start_time().format_sec_frame()
}
/// Formats the end time as "seconds.frame" (e.g., "456.15").
pub fn format_end_sec_frame(&self) -> String {
self.end_time().format_sec_frame()
}
/// Formats the start time as "HH:MM:SS".
pub fn format_start_hms(&self) -> String {
self.start_time().format_hms()
}
/// Formats the end time as "HH:MM:SS".
pub fn format_end_hms(&self) -> String {
self.end_time().format_hms()
}
/// Formats the start time as "HH:MM:SS.FF".
pub fn format_start_hms_frame(&self) -> String {
self.start_time().format_hms_frame()
}
/// Formats the end time as "HH:MM:SS.FF".
pub fn format_end_hms_frame(&self) -> String {
self.end_time().format_hms_frame()
}
/// Returns a tuple of (start_seconds, end_seconds) for compatibility.
///
/// This is provided for backward compatibility during migration.
/// Prefer using `start_time()` and `end_time()` methods.
pub fn time_range_seconds(&self) -> (f64, f64) {
(self.start_time().seconds(), self.end_time().seconds())
}
pub fn with_metadata(mut self, metadata: serde_json::Value) -> Self {
self.metadata = Some(metadata);
self
}
pub fn with_vector_id(mut self, vector_id: String) -> Self {
self.vector_id = Some(vector_id);
self
}
pub fn with_text_content(mut self, text: String) -> Self {
self.text_content = Some(text);
self
}
pub fn with_frame_count(mut self, count: i32) -> Self {
self.frame_count = count;
self
}
pub fn with_pre_chunk_ids(mut self, ids: Vec<i32>) -> Self {
self.pre_chunk_ids = ids;
self
}
pub fn with_parent_chunk_id(mut self, parent_id: String) -> Self {
self.parent_chunk_id = Some(parent_id);
self
}
pub fn with_child_chunk_ids(mut self, child_ids: Vec<String>) -> Self {
self.child_chunk_ids = child_ids;
self
}
pub fn is_parent_chunk(&self) -> bool {
!self.child_chunk_ids.is_empty()
}
pub fn is_child_chunk(&self) -> bool {
self.parent_chunk_id.is_some()
}
/// 創建視覺分片
pub fn new_visual(
file_id: i32,
uuid: String,
chunk_index: u32,
start_frame: i64,
end_frame: i64,
fps: f64,
visual_content: VisualChunkContent,
) -> Self {
let content = serde_json::to_value(&visual_content)
.unwrap_or_else(|_| serde_json::json!({"error": "Failed to serialize visual content"}));
Self::new(
file_id,
uuid,
chunk_index,
ChunkType::Visual,
ChunkRule::Rule2,
start_frame,
end_frame,
fps,
content,
)
}
/// 從 YOLO 結果創建視覺分片
pub fn from_yolo_result(
file_id: i32,
uuid: String,
chunk_index: u32,
start_frame: i64,
end_frame: i64,
fps: f64,
yolo_frames: Vec<crate::core::processor::yolo::YoloFrame>,
) -> Self {
use crate::core::processor::yolo::YoloFrame;
use std::collections::HashMap;
// 分析物件統計
let mut object_counts = HashMap::new();
let mut keyframe_objects = Vec::new();
let mut all_objects = Vec::new();
for frame in &yolo_frames {
let mut frame_objects = Vec::new();
for obj in &frame.objects {
// 更新物件統計
*object_counts.entry(obj.class_name.clone()).or_insert(0) += 1;
// 創建檢測到的物件
let detected_obj = DetectedObject {
class_name: obj.class_name.clone(),
class_id: obj.class_id,
confidence: obj.confidence,
bbox: Some(BoundingBox {
x: obj.x,
y: obj.y,
width: obj.width,
height: obj.height,
}),
occurrence: 1,
};
frame_objects.push(detected_obj.clone());
all_objects.push(detected_obj);
}
if !frame_objects.is_empty() {
keyframe_objects.push(KeyframeObjects {
timestamp: frame.timestamp,
frame_number: frame.frame,
objects: frame_objects,
});
}
}
// 創建主要物件標籤
let primary_objects = object_counts
.iter()
.filter(|(_, &count)| count >= 3) // 出現至少3次的物件
.map(|(name, _)| name.clone())
.collect::<Vec<_>>()
.join(", ");
// 創建物件統計 JSON
let object_stats =
serde_json::to_value(&object_counts).unwrap_or_else(|_| serde_json::json!({}));
// 創建視覺內容
let visual_content = VisualChunkContent {
primary_objects: if primary_objects.is_empty() {
"no objects detected".to_string()
} else {
primary_objects
},
object_stats,
keyframe_objects,
object_frequency: serde_json::to_value(&object_counts)
.unwrap_or_else(|_| serde_json::json!({})),
visual_summary: None, // 可選,後續可添加 LLM 生成的摘要
};
Self::new_visual(
file_id,
uuid,
chunk_index,
start_frame,
end_frame,
fps,
visual_content,
)
}
}
impl VisualChunkContent {
/// Calculate similarity between two YOLO frames based on object composition
pub fn frame_similarity(
frame1: &crate::core::processor::yolo::YoloFrame,
frame2: &crate::core::processor::yolo::YoloFrame,
) -> f32 {
if frame1.objects.is_empty() && frame2.objects.is_empty() {
return 1.0; // Both empty frames are perfectly similar
}
if frame1.objects.is_empty() || frame2.objects.is_empty() {
return 0.0; // One empty, one non-empty are dissimilar
}
// Create sets of object class names
let set1: std::collections::HashSet<String> = frame1
.objects
.iter()
.map(|o| o.class_name.clone())
.collect();
let set2: std::collections::HashSet<String> = frame2
.objects
.iter()
.map(|o| o.class_name.clone())
.collect();
// Calculate Jaccard similarity
let intersection: Vec<_> = set1.intersection(&set2).collect();
let union: Vec<_> = set1.union(&set2).collect();
if union.is_empty() {
0.0
} else {
intersection.len() as f32 / union.len() as f32
}
}
/// Get a summary of the visual chunk
pub fn summary(&self) -> String {
let duration = self.end_time - self.start_time;
let frame_count = self.keyframe_objects.len();
format!(
"Visual chunk from {:.1}s to {:.1}s (duration: {:.1}s, {} frames). Objects: {} total, {} unique. Dominant objects: {}",
self.start_time,
self.end_time,
duration,
frame_count,
self.metadata.object_count,
self.metadata.unique_classes.len(),
if self.dominant_objects.is_empty() {
"none".to_string()
} else {
self.dominant_objects.join(", ")
}
)
}
/// Check if this chunk contains a specific object class
pub fn contains_object(&self, class_name: &str) -> bool {
self.keyframe_objects
.iter()
.any(|ko| ko.objects.iter().any(|obj| obj.class_name == class_name))
}
/// Get all objects with confidence above threshold
pub fn high_confidence_objects(&self, threshold: f32) -> Vec<&DetectedObject> {
self.keyframe_objects
.iter()
.flat_map(|ko| ko.objects.iter())
.filter(|obj| obj.confidence >= threshold)
.collect()
}
}

View File

@@ -0,0 +1,320 @@
use crate::core::time::FrameTime;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum ChunkType {
TimeBased,
Sentence,
Cut,
Trace,
Story, // Parent chunk from story analysis
Visual, // Visual object-based chunk from YOLO detection (Phase 2.1)
}
impl ChunkType {
pub fn as_str(&self) -> &'static str {
match self {
ChunkType::TimeBased => "time",
ChunkType::Sentence => "sentence",
ChunkType::Cut => "cut",
ChunkType::Trace => "trace",
ChunkType::Story => "story",
ChunkType::Visual => "visual",
}
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum ChunkRule {
Rule1, // 直接轉換
Rule2, // 集合內容
}
/// 關鍵幀的物件列表
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KeyframeObjects {
/// 關鍵幀時間 (秒)
pub timestamp: f64,
/// 關鍵幀幀號
pub frame_number: u64,
/// 檢測到的物件
pub objects: Vec<DetectedObject>,
}
/// 檢測到的物件
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DetectedObject {
/// 物件類別名稱
pub class_name: String,
/// 物件類別 ID
pub class_id: u32,
/// 信心值 (0.0-1.0)
pub confidence: f32,
/// 邊界框 (x, y, width, height)
pub bbox: Option<BoundingBox>,
/// 出現次數 (在分片內)
pub occurrence: u32,
}
/// 邊界框
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BoundingBox {
pub x: i32,
pub y: i32,
pub width: i32,
pub height: i32,
}
/// 視覺分片內容 (Phase 2.1)
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VisualChunkContent {
pub start_time: f64,
pub end_time: f64,
pub keyframe_objects: Vec<KeyframeObjects>,
pub dominant_objects: Vec<String>,
pub object_relationships: Vec<(String, String, String)>, // (object1, relationship, object2)
pub scene_description: Option<String>,
pub metadata: VisualMetadata,
}
/// 視覺元數據 (Phase 2.1)
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VisualMetadata {
pub object_count: u32,
pub unique_classes: Vec<String>,
pub max_confidence: f32,
pub avg_confidence: f32,
pub spatial_density: f32, // objects per frame
}
impl ChunkRule {
pub fn as_str(&self) -> &'static str {
match self {
ChunkRule::Rule1 => "rule_1",
ChunkRule::Rule2 => "rule_2",
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Chunk {
pub file_id: i32,
pub uuid: String,
pub chunk_id: String,
pub chunk_index: u32,
pub chunk_type: ChunkType,
pub rule: ChunkRule,
/// Frames per second (can be fractional, e.g., 29.97, 23.976)
pub fps: f64,
/// Start frame (0-based)
pub start_frame: i64,
/// End frame (exclusive)
pub end_frame: i64,
pub text_content: Option<String>,
pub content: serde_json::Value,
pub metadata: Option<serde_json::Value>,
pub vector_id: Option<String>,
pub frame_count: i32,
pub pre_chunk_ids: Vec<i32>,
pub parent_chunk_id: Option<String>, // For parent-child chunk hierarchy
pub child_chunk_ids: Vec<String>, // Child chunk IDs (for parent chunks)
pub visual_stats: Option<serde_json::Value>,
}
impl Chunk {
/// 創建視覺分片 (Phase 2.1)
pub fn new_visual(
file_id: i32,
uuid: String,
chunk_index: u32,
start_frame: i64,
end_frame: i64,
fps: f64,
visual_content: VisualChunkContent,
) -> Self {
let content = serde_json::to_value(&visual_content)
.unwrap_or_else(|_| serde_json::json!({"error": "Failed to serialize visual content"}));
Self::new(
file_id,
uuid,
chunk_index,
ChunkType::Visual,
ChunkRule::Rule2,
start_frame,
end_frame,
fps,
content,
)
}
/// 從 YOLO 結果創建視覺分片 (Phase 2.1)
pub fn from_yolo_result(
file_id: i32,
uuid: String,
chunk_index: u32,
start_frame: i64,
end_frame: i64,
fps: f64,
yolo_frames: Vec<crate::core::processor::yolo::YoloFrame>,
) -> Self {
let keyframe_objects: Vec<KeyframeObjects> = yolo_frames
.iter()
.map(|frame| {
let objects: Vec<DetectedObject> = frame
.objects
.iter()
.map(|obj| DetectedObject {
class_name: obj.class_name.clone(),
class_id: obj.class_id,
confidence: obj.confidence,
bbox: Some(BoundingBox {
x: obj.x,
y: obj.y,
width: obj.width,
height: obj.height,
}),
occurrence: 1,
})
.collect();
KeyframeObjects {
timestamp: frame.timestamp,
frame_number: frame.frame,
objects,
}
})
.collect();
// 計算物件統計
let mut object_counts = std::collections::HashMap::new();
for obj in yolo_frames.iter().flat_map(|f| &f.objects) {
*object_counts.entry(obj.class_name.clone()).or_insert(0) += 1;
}
let total_objects: u32 = yolo_frames.iter().map(|f| f.objects.len() as u32).sum();
let all_classes: Vec<String> = yolo_frames
.iter()
.flat_map(|f| f.objects.iter().map(|o| o.class_name.clone()))
.collect();
let unique_classes: Vec<String> = all_classes
.iter()
.cloned()
.collect::<std::collections::HashSet<_>>()
.into_iter()
.collect();
let confidences: Vec<f32> = yolo_frames
.iter()
.flat_map(|f| f.objects.iter().map(|o| o.confidence))
.collect();
let max_confidence = confidences.iter().copied().fold(0.0f32, f32::max);
let avg_confidence = if !confidences.is_empty() {
confidences.iter().sum::<f32>() / confidences.len() as f32
} else {
0.0
};
// 找出主要物件
let primary_objects = object_counts
.iter()
.filter(|(_, &count)| count as f32 / yolo_frames.len() as f32 > 0.5)
.map(|(name, _)| name.clone())
.collect::<Vec<_>>()
.join(", ");
let object_stats =
serde_json::to_value(&object_counts).unwrap_or_else(|_| serde_json::json!({}));
let visual_content = VisualChunkContent {
start_time: if let Some(first) = yolo_frames.first() {
first.timestamp
} else {
0.0
},
end_time: if let Some(last) = yolo_frames.last() {
last.timestamp
} else {
0.0
},
keyframe_objects,
dominant_objects: primary_objects
.split(", ")
.map(|s| s.to_string())
.filter(|s| !s.is_empty())
.collect(),
object_relationships: vec![], // 可選:後續添加關係檢測
scene_description: None, // 可選:後續添加 LLM 生成的場景描述
metadata: VisualMetadata {
object_count: total_objects,
unique_classes,
max_confidence,
avg_confidence,
spatial_density: if yolo_frames.len() > 0 {
total_objects as f32 / yolo_frames.len() as f32
} else {
0.0
},
},
};
Self::new_visual(
file_id,
uuid,
chunk_index,
start_frame,
end_frame,
fps,
visual_content,
)
}
/// 創建新分片
pub fn new(
file_id: i32,
uuid: String,
chunk_index: u32,
chunk_type: ChunkType,
rule: ChunkRule,
start_frame: i64,
end_frame: i64,
fps: f64,
content: serde_json::Value,
) -> Self {
let frame_count = (end_frame - start_frame) as i32;
let chunk_id = format!("{}_{}", uuid, chunk_index);
Self {
file_id,
uuid,
chunk_id,
chunk_index,
chunk_type,
rule,
fps,
start_frame,
end_frame,
text_content: None,
content,
metadata: None,
vector_id: None,
frame_count,
pre_chunk_ids: vec![],
parent_chunk_id: None,
child_chunk_ids: vec![],
visual_stats: None,
}
}
/// 將分片轉換為幀時間
pub fn to_frame_time(&self) -> FrameTime {
FrameTime::from_frames(self.start_frame as u64, self.end_frame as u64, self.fps)
}
/// 檢查是否是父分片
pub fn is_parent(&self) -> bool {
self.parent_chunk_id.is_some()
}
}

View File

@@ -0,0 +1,486 @@
//! 視覺分片測試
//!
//! 測試視覺分片數據結構和功能
use serde::{Deserialize, Serialize};
/// 視覺分片類型
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum ChunkType {
TimeBased,
Sentence,
Cut,
Trace,
Story,
Visual,
}
/// 檢測到的物件
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DetectedObject {
/// 物件類別名稱
pub class_name: String,
/// 物件類別 ID
pub class_id: u32,
/// 信心值 (0.0-1.0)
pub confidence: f32,
/// 邊界框 (x, y, width, height)
pub bbox: Option<(i32, i32, i32, i32)>,
}
/// 關鍵幀的物件列表
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KeyframeObjects {
/// 關鍵幀時間 (秒)
pub timestamp: f64,
/// 關鍵幀幀號
pub frame_number: u64,
/// 檢測到的物件
pub objects: Vec<DetectedObject>,
}
/// 視覺分片內容
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VisualChunkContent {
pub start_time: f64,
pub end_time: f64,
pub keyframe_objects: Vec<KeyframeObjects>,
pub dominant_objects: Vec<String>,
pub object_relationships: Vec<(String, String, String)>, // (object1, relationship, object2)
pub scene_description: Option<String>,
pub metadata: VisualMetadata,
}
/// 視覺元數據
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VisualMetadata {
pub object_count: u32,
pub unique_classes: Vec<String>,
pub max_confidence: f32,
pub avg_confidence: f32,
pub spatial_density: f32, // objects per frame
}
impl VisualChunkContent {
/// 計算兩個幀之間的相似度(基於物件組成)
pub fn frame_similarity(
frame1_objects: &[DetectedObject],
frame2_objects: &[DetectedObject],
) -> f32 {
if frame1_objects.is_empty() && frame2_objects.is_empty() {
return 1.0; // 兩個空幀完全相似
}
if frame1_objects.is_empty() || frame2_objects.is_empty() {
return 0.0; // 一個空一個非空,不相似
}
// 創建物件類別名稱集合
let set1: std::collections::HashSet<String> = frame1_objects
.iter()
.map(|o| o.class_name.clone())
.collect();
let set2: std::collections::HashSet<String> = frame2_objects
.iter()
.map(|o| o.class_name.clone())
.collect();
// 計算 Jaccard 相似度
let intersection: Vec<_> = set1.intersection(&set2).collect();
let union: Vec<_> = set1.union(&set2).collect();
if union.is_empty() {
0.0
} else {
intersection.len() as f32 / union.len() as f32
}
}
/// 獲取視覺分片的摘要
pub fn summary(&self) -> String {
let duration = self.end_time - self.start_time;
let frame_count = self.keyframe_objects.len();
format!(
"視覺分片: {:.1}s 到 {:.1}s (持續時間: {:.1}s, {} 幀). 物件: {} 個總計, {} 個唯一. 主要物件: {}",
self.start_time,
self.end_time,
duration,
frame_count,
self.metadata.object_count,
self.metadata.unique_classes.len(),
if self.dominant_objects.is_empty() {
"".to_string()
} else {
self.dominant_objects.join(", ")
}
)
}
/// 檢查是否包含特定物件類別
pub fn contains_object(&self, class_name: &str) -> bool {
self.keyframe_objects
.iter()
.any(|ko| ko.objects.iter().any(|obj| obj.class_name == class_name))
}
/// 獲取信心值高於閾值的所有物件
pub fn high_confidence_objects(&self, threshold: f32) -> Vec<&DetectedObject> {
self.keyframe_objects
.iter()
.flat_map(|ko| ko.objects.iter())
.filter(|obj| obj.confidence >= threshold)
.collect()
}
}
/// 模擬 YOLO 結果
#[derive(Debug, Clone)]
pub struct MockYoloResult {
pub frames: Vec<MockYoloFrame>,
}
#[derive(Debug, Clone)]
pub struct MockYoloFrame {
pub frame: u64,
pub timestamp: f64,
pub objects: Vec<MockYoloObject>,
}
#[derive(Debug, Clone)]
pub struct MockYoloObject {
pub class_name: String,
pub class_id: u32,
pub x: i32,
pub y: i32,
pub width: i32,
pub height: i32,
pub confidence: f32,
}
impl MockYoloResult {
/// 從模擬 YOLO 結果創建視覺分片
pub fn to_visual_chunk(&self, start_frame: u64, end_frame: u64) -> Option<VisualChunkContent> {
let frames: Vec<_> = self
.frames
.iter()
.filter(|f| f.frame >= start_frame && f.frame <= end_frame)
.collect();
if frames.is_empty() {
return None;
}
// 轉換幀為關鍵幀物件
let keyframe_objects: Vec<KeyframeObjects> = frames
.iter()
.map(|frame| {
let objects: Vec<DetectedObject> = frame
.objects
.iter()
.map(|obj| DetectedObject {
class_name: obj.class_name.clone(),
class_id: obj.class_id,
confidence: obj.confidence,
bbox: Some((obj.x, obj.y, obj.width, obj.height)),
})
.collect();
KeyframeObjects {
timestamp: frame.timestamp,
frame_number: frame.frame,
objects,
}
})
.collect();
// 計算元數據
let total_objects: u32 = frames.iter().map(|f| f.objects.len() as u32).sum();
let all_classes: Vec<String> = frames
.iter()
.flat_map(|f| f.objects.iter().map(|o| o.class_name.clone()))
.collect();
let unique_classes: Vec<String> = all_classes
.iter()
.cloned()
.collect::<std::collections::HashSet<_>>()
.into_iter()
.collect();
let confidences: Vec<f32> = frames
.iter()
.flat_map(|f| f.objects.iter().map(|o| o.confidence))
.collect();
let max_confidence = confidences.iter().copied().fold(0.0f32, f32::max);
let avg_confidence = if !confidences.is_empty() {
confidences.iter().sum::<f32>() / confidences.len() as f32
} else {
0.0
};
let start_time = frames.first().map(|f| f.timestamp).unwrap_or(0.0);
let end_time = frames.last().map(|f| f.timestamp).unwrap_or(0.0);
// 查找主要物件(出現在大多數幀中的物件)
let mut object_counts = std::collections::HashMap::new();
for frame in &frames {
let frame_classes: std::collections::HashSet<_> =
frame.objects.iter().map(|o| o.class_name.clone()).collect();
for class in frame_classes {
*object_counts.entry(class).or_insert(0) += 1;
}
}
let mut dominant_objects: Vec<String> = object_counts
.into_iter()
.filter(|(_, count)| *count as f32 / frames.len() as f32 > 0.5) // 出現在 >50% 的幀中
.map(|(class, _)| class)
.collect();
dominant_objects.sort();
Some(VisualChunkContent {
start_time,
end_time,
keyframe_objects,
dominant_objects,
object_relationships: vec![], // 需要關係檢測邏輯
scene_description: None, // 可由 LLM 後期生成
metadata: VisualMetadata {
object_count: total_objects,
unique_classes,
max_confidence,
avg_confidence,
spatial_density: if frames.len() > 0 {
total_objects as f32 / frames.len() as f32
} else {
0.0
},
},
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_chunk_type_visual() {
let chunk_type = ChunkType::Visual;
let json = serde_json::to_string(&chunk_type).unwrap();
assert_eq!(json, "\"visual\"");
let deserialized: ChunkType = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized, ChunkType::Visual);
}
#[test]
fn test_visual_chunk_creation() {
// 創建模擬 YOLO 結果
let yolo_result = MockYoloResult {
frames: vec![
MockYoloFrame {
frame: 0,
timestamp: 0.0,
objects: vec![
MockYoloObject {
class_name: "person".to_string(),
class_id: 0,
x: 100,
y: 200,
width: 50,
height: 100,
confidence: 0.95,
},
MockYoloObject {
class_name: "car".to_string(),
class_id: 2,
x: 300,
y: 150,
width: 80,
height: 60,
confidence: 0.87,
},
],
},
MockYoloFrame {
frame: 1,
timestamp: 0.033, // 1/30 秒
objects: vec![MockYoloObject {
class_name: "person".to_string(),
class_id: 0,
x: 110,
y: 210,
width: 52,
height: 102,
confidence: 0.92,
}],
},
],
};
// 從 YOLO 結果創建視覺分片
let chunk = yolo_result.to_visual_chunk(0, 1).unwrap();
// 驗證分片屬性
assert_eq!(chunk.start_time, 0.0);
assert_eq!(chunk.end_time, 0.033);
assert_eq!(chunk.metadata.object_count, 3);
assert_eq!(chunk.metadata.unique_classes.len(), 2);
assert!(chunk
.metadata
.unique_classes
.contains(&"person".to_string()));
assert!(chunk.metadata.unique_classes.contains(&"car".to_string()));
assert_eq!(chunk.dominant_objects, vec!["person"]);
assert_eq!(chunk.keyframe_objects.len(), 2);
}
#[test]
fn test_visual_chunk_content_methods() {
let content = VisualChunkContent {
start_time: 0.0,
end_time: 5.0,
keyframe_objects: vec![KeyframeObjects {
timestamp: 0.0,
frame_number: 0,
objects: vec![
DetectedObject {
class_name: "person".to_string(),
class_id: 0,
confidence: 0.95,
bbox: Some((100, 200, 50, 100)),
},
DetectedObject {
class_name: "car".to_string(),
class_id: 2,
confidence: 0.87,
bbox: Some((300, 150, 80, 60)),
},
],
}],
dominant_objects: vec!["person".to_string()],
object_relationships: vec![],
scene_description: Some("一個人站在車旁".to_string()),
metadata: VisualMetadata {
object_count: 2,
unique_classes: vec!["person".to_string(), "car".to_string()],
max_confidence: 0.95,
avg_confidence: 0.91,
spatial_density: 2.0,
},
};
// 測試摘要方法
let summary = content.summary();
assert!(summary.contains("視覺分片"));
assert!(summary.contains("person"));
// 測試 contains_object 方法
assert!(content.contains_object("person"));
assert!(content.contains_object("car"));
assert!(!content.contains_object("dog"));
// 測試 high_confidence_objects 方法
let high_conf_objects = content.high_confidence_objects(0.9);
assert_eq!(high_conf_objects.len(), 1);
assert_eq!(high_conf_objects[0].class_name, "person");
}
#[test]
fn test_frame_similarity() {
let frame1_objects = vec![
DetectedObject {
class_name: "person".to_string(),
class_id: 0,
confidence: 0.95,
bbox: Some((100, 200, 50, 100)),
},
DetectedObject {
class_name: "car".to_string(),
class_id: 2,
confidence: 0.87,
bbox: Some((300, 150, 80, 60)),
},
];
let frame2_objects = vec![
DetectedObject {
class_name: "person".to_string(),
class_id: 0,
confidence: 0.92,
bbox: Some((110, 210, 52, 102)),
},
DetectedObject {
class_name: "car".to_string(),
class_id: 2,
confidence: 0.85,
bbox: Some((310, 155, 82, 62)),
},
];
let frame3_objects = vec![DetectedObject {
class_name: "dog".to_string(),
class_id: 16,
confidence: 0.78,
bbox: Some((150, 250, 40, 60)),
}];
// 測試相似幀(相同物件)
let similarity_same =
VisualChunkContent::frame_similarity(&frame1_objects, &frame2_objects);
assert!((similarity_same - 1.0).abs() < 0.001);
// 測試不相似幀(不同物件)
let similarity_diff =
VisualChunkContent::frame_similarity(&frame1_objects, &frame3_objects);
assert!((similarity_diff - 0.0).abs() < 0.001);
// 測試空幀
let empty_objects: Vec<DetectedObject> = vec![];
let similarity_empty = VisualChunkContent::frame_similarity(&empty_objects, &empty_objects);
assert!((similarity_empty - 1.0).abs() < 0.001);
let similarity_mixed =
VisualChunkContent::frame_similarity(&empty_objects, &frame1_objects);
assert!((similarity_mixed - 0.0).abs() < 0.001);
}
#[test]
fn test_serialization_deserialization() {
let content = VisualChunkContent {
start_time: 0.0,
end_time: 5.0,
keyframe_objects: vec![KeyframeObjects {
timestamp: 0.0,
frame_number: 0,
objects: vec![DetectedObject {
class_name: "person".to_string(),
class_id: 0,
confidence: 0.95,
bbox: Some((100, 200, 50, 100)),
}],
}],
dominant_objects: vec!["person".to_string()],
object_relationships: vec![],
scene_description: Some("場景描述".to_string()),
metadata: VisualMetadata {
object_count: 1,
unique_classes: vec!["person".to_string()],
max_confidence: 0.95,
avg_confidence: 0.95,
spatial_density: 1.0,
},
};
// 序列化
let json = serde_json::to_string(&content).unwrap();
assert!(json.contains("person"));
assert!(json.contains("visual_chunk"));
// 反序列化
let deserialized: VisualChunkContent = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.start_time, 0.0);
assert_eq!(deserialized.end_time, 5.0);
assert_eq!(deserialized.dominant_objects, vec!["person"]);
}
}

View File

@@ -77,6 +77,8 @@ pub struct VideoRow {
pub status: String,
pub user_id: Option<i32>,
pub job_id: Option<i32>,
pub created_at: Option<String>,
pub registration_time: Option<String>,
}
impl From<VideoRow> for VideoRecord {
@@ -103,7 +105,8 @@ impl From<VideoRow> for VideoRecord {
status: VideoStatus::from_db_str(&row.status).unwrap_or(VideoStatus::Pending),
user_id: row.user_id.map(|v| v as i64),
job_id: row.job_id.map(|v| v as i64),
created_at: String::new(),
created_at: row.created_at.unwrap_or_default(),
registration_time: row.registration_time,
}
}
}
@@ -124,6 +127,7 @@ pub struct VideoRecord {
pub user_id: Option<i64>,
pub job_id: Option<i64>,
pub created_at: String,
pub registration_time: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -701,7 +705,7 @@ impl PostgresDb {
let table = schema::table_name("videos");
let result = sqlx::query_as::<_, VideoRow>(
&format!(
"SELECT id, uuid, file_path, file_name, duration, width, height, fps, probe_json, fs_video, fs_json, psql_chunk, pobject_chunk, mobject_chunk, pvector_chunk, qvector_chunk, status, user_id, job_id FROM {} WHERE uuid = $1",
"SELECT id, uuid, file_path, file_name, duration, width, height, fps, probe_json, fs_video, fs_json, psql_chunk, pobject_chunk, mobject_chunk, pvector_chunk, qvector_chunk, status, user_id, job_id, created_at::text, registration_time::text FROM {} WHERE uuid = $1",
table
)
)
@@ -796,28 +800,90 @@ impl PostgresDb {
}
pub async fn list_videos(&self, limit: i32, offset: i64) -> Result<(Vec<VideoRecord>, i64)> {
// Default to unprocessed (status != 'ready')
self.search_videos(None, Some(false), limit, offset).await
}
pub async fn search_videos(
&self,
query: Option<&str>,
is_processed: Option<bool>,
limit: i32,
offset: i64,
) -> Result<(Vec<VideoRecord>, i64)> {
let table = schema::table_name("videos");
// Build status condition
// is_processed = Some(true) => status = 'ready'
// is_processed = Some(false) => status != 'ready'
// is_processed = None => no filter
let status_cond = match is_processed {
Some(true) => "AND status = 'ready'",
Some(false) => "AND status != 'ready'",
None => "",
};
// Count total
let count: Option<i64> = sqlx::query_scalar(&format!("SELECT COUNT(*) FROM {}", table))
.fetch_one(&self.pool)
.await?;
let total = count.unwrap_or(0);
// Build search condition safely
// If query is Some, we filter by filename/path/probe_json
let search_cond = if query.is_some() {
"AND (LOWER(file_name) LIKE $1 OR LOWER(file_path) LIKE $1 OR LOWER(probe_json::text) LIKE $1)"
} else {
""
};
// Select paged
let rows = sqlx::query_as::<_, VideoRow>(
&format!(
"SELECT id, uuid, file_path, file_name, duration, width, height, fps, probe_json, fs_video, fs_json, psql_chunk, pobject_chunk, mobject_chunk, pvector_chunk, qvector_chunk, status, user_id, job_id FROM {} ORDER BY id DESC LIMIT $1 OFFSET $2",
table
)
)
.bind(limit)
.bind(offset)
.fetch_all(&self.pool)
.await?;
let where_clause = format!("WHERE 1=1 {} {}", status_cond, search_cond);
// 1. Count Query
// If query is present, $1 is the pattern.
// If query is None, no pattern param needed for count?
// Actually, to keep code simple, let's just construct the query string.
// SQLx query_as requires bind count to match placeholders.
let count_query = format!("SELECT COUNT(*) FROM {} {}", table, where_clause);
let total: i64 = if let Some(q) = query {
let pattern = format!("%{}%", q.to_lowercase());
sqlx::query_scalar(&count_query)
.bind(&pattern)
.fetch_one(&self.pool)
.await?
} else {
sqlx::query_scalar(&count_query)
.fetch_one(&self.pool)
.await?
};
// 2. Select Query
// Cast created_at and registration_time to text
let columns = "id, uuid, file_path, file_name, duration, width, height, fps, probe_json, fs_video, fs_json, psql_chunk, pobject_chunk, mobject_chunk, pvector_chunk, qvector_chunk, status, user_id, job_id, created_at::text, registration_time::text";
// Determine parameter order for LIMIT/OFFSET
// If search is present, pattern is $1. Limit is $2. Offset is $3.
// If search is not present, Limit is $1. Offset is $2.
let select_query = if query.is_some() {
format!("SELECT {} FROM {} {} ORDER BY id DESC LIMIT $2 OFFSET $3", columns, table, where_clause)
} else {
format!("SELECT {} FROM {} {} ORDER BY id DESC LIMIT $1 OFFSET $2", columns, table, where_clause)
};
let rows = if let Some(q) = query {
let pattern = format!("%{}%", q.to_lowercase());
sqlx::query_as::<_, VideoRow>(&select_query)
.bind(&pattern)
.bind(limit)
.bind(offset)
.fetch_all(&self.pool)
.await?
} else {
sqlx::query_as::<_, VideoRow>(&select_query)
.bind(limit)
.bind(offset)
.fetch_all(&self.pool)
.await?
};
let videos: Vec<VideoRecord> = rows.into_iter().map(|r| r.into()).collect();
Ok((videos, total))
}
@@ -850,6 +916,19 @@ impl PostgresDb {
Ok(())
}
pub async fn set_registration_time(&self, uuid: &str) -> Result<()> {
let table = schema::table_name("videos");
sqlx::query(&format!(
"UPDATE {} SET registration_time = CURRENT_TIMESTAMP, updated_at = CURRENT_TIMESTAMP WHERE uuid = $1 AND registration_time IS NULL",
table
))
.bind(uuid)
.execute(&self.pool)
.await?;
Ok(())
}
pub async fn delete_video(&self, uuid: &str) -> Result<()> {
tracing::info!("[PostgresDb] Deleting video: {}", uuid);

68
src/core/db/schema_ctx.rs Normal file
View File

@@ -0,0 +1,68 @@
use anyhow::Result;
use sqlx::PgPool;
use std::sync::atomic::{AtomicU32, Ordering};
/// Schema context for database operations
/// Ensures all queries use the correct schema prefix
#[derive(Debug, Clone)]
pub struct SchemaContext {
pub prefix: String,
}
static SCHEMA_INSTANCE: std::sync::OnceLock<SchemaContext> = std::sync::OnceLock::new();
static SCHEMA_VERSION: AtomicU32 = AtomicU32::new(0);
impl SchemaContext {
/// Initialize schema context from environment
pub fn init() -> Self {
let schema = std::env::var("DATABASE_SCHEMA").unwrap_or_else(|_| "dev".to_string());
let prefix = if schema == "public" {
String::new()
} else {
format!("{}.", schema)
};
Self { prefix }
}
/// Get the global schema context
pub fn global() -> &'static Self {
SCHEMA_INSTANCE.get_or_init(|| Self::init())
}
/// Get table name with schema prefix
pub fn table(&self, name: &str) -> String {
format!("{}{}", self.prefix, name)
}
/// Reload schema context (for testing)
pub fn reload() {
SCHEMA_VERSION.fetch_add(1, Ordering::SeqCst);
// Note: OnceLock can't be reset, so we use a different approach
// In production, schema doesn't change at runtime
}
}
/// Quick helper to get table name with current schema prefix
pub fn t(name: &str) -> String {
SchemaContext::global().table(name)
}
/// Check if a table exists in the current schema
pub async fn table_exists(pool: &PgPool, table_name: &str) -> Result<bool> {
let schema = SchemaContext::global();
let schema_name = if schema.prefix.is_empty() {
"public".to_string()
} else {
schema.prefix.trim_end_matches('.').to_string()
};
let query = format!(
"SELECT EXISTS(SELECT 1 FROM information_schema.tables WHERE table_schema = $1 AND table_name = $2)"
);
let exists: bool = sqlx::query_scalar(&query)
.bind(&schema_name)
.bind(table_name)
.fetch_one(pool)
.await?;
Ok(exists)
}

143
src/core/ingestion.rs Normal file
View File

@@ -0,0 +1,143 @@
use anyhow::{Context, Result};
use std::path::Path;
use tracing::{info, warn};
use crate::core::db::{Database, PostgresDb, VideoRecord, VideoStatus};
use crate::core::probe;
use crate::core::storage::FileManager;
use crate::uuid as uuid_utils;
/// Handles the automatic ingestion of video files.
/// This service is responsible for:
/// 1. Running `ffprobe` (Pre-processing)
/// 2. Saving probe JSON
/// 3. Registering the video in the database (making it visible in the API)
pub struct IngestionService {
db: PostgresDb,
}
impl IngestionService {
pub fn new(db: PostgresDb) -> Self {
Self { db }
}
/// Registers a video file found in the watched directory.
/// This function is idempotent: if the video (UUID) already exists, it skips.
pub async fn ingest(&self, file_path: &str) -> Result<Option<String>> {
let path = Path::new(file_path);
// 1. Validate extension
if !is_video_extension(path) {
return Ok(None);
}
// 2. Compute UUID
let uuid = uuid_utils::compute_uuid_from_path(file_path);
// 3. Check if already registered
if let Ok(Some(_)) = self.db.get_video_by_uuid(&uuid).await {
info!(
"Video already registered: {} ({})",
path.file_name().unwrap_or_default().to_string_lossy(),
uuid
);
return Ok(None);
}
info!("Starting ingestion for: {} ({})", path.display(), uuid);
// 4. Run ffprobe
let probe_result = probe::probe_video(file_path)
.with_context(|| format!("Failed to probe video: {}", file_path))?;
// 5. Extract metadata
let duration = probe_result
.format
.duration
.as_ref()
.and_then(|s| s.parse::<f64>().ok())
.unwrap_or(0.0);
let mut width = 0u32;
let mut height = 0u32;
let mut fps = 0.0;
for stream in &probe_result.streams {
if stream.codec_type.as_deref() == Some("video") {
width = stream.width.unwrap_or(0);
height = stream.height.unwrap_or(0);
if let Some(fps_str) = &stream.r_frame_rate {
if let Some((num, den)) = fps_str.split_once('/') {
if let (Ok(n), Ok(d)) = (num.parse::<f64>(), den.parse::<f64>()) {
if d > 0.0 {
fps = n / d;
}
}
}
}
}
}
// 6. Save Probe JSON
let file_manager = FileManager::new(std::path::PathBuf::from("."));
let probe_json_str = serde_json::to_string_pretty(&probe_result)?;
if let Err(e) = file_manager.save_json(&uuid, "probe", &probe_json_str) {
warn!("Failed to save probe JSON for {}: {}", uuid, e);
} else {
info!("Probe JSON saved for {}", uuid);
}
// 7. Create Record
// Use absolute path for safety
let canonical_path = path.canonicalize().unwrap_or_else(|_| path.to_path_buf());
let record = VideoRecord {
id: 0,
uuid: uuid.clone(),
file_path: canonical_path.to_string_lossy().to_string(),
file_name: path
.file_name()
.unwrap_or_default()
.to_string_lossy()
.to_string(),
duration,
width,
height,
fps,
probe_json: Some(probe_json_str),
storage: Default::default(),
status: VideoStatus::Pending, // Ready for processing
user_id: None,
job_id: None,
created_at: String::new(),
registration_time: None,
};
// 8. Insert DB
self.db
.register_video(&record)
.await
.with_context(|| "Failed to register video in database")?;
self.db
.set_registration_time(&uuid)
.await
.with_context(|| "Failed to set registration_time")?;
info!(
"Successfully registered video: {} (UUID: {})",
record.file_name, uuid
);
Ok(Some(uuid))
}
}
fn is_video_extension(path: &Path) -> bool {
if let Some(ext) = path.extension().and_then(|e| e.to_str()) {
let ext = ext.to_lowercase();
matches!(ext.as_str(), "mp4" | "mov" | "mkv" | "avi" | "webm" | "m4v")
} else {
false
}
}

104
src/core/llm/client.rs Normal file
View File

@@ -0,0 +1,104 @@
use anyhow::Result;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::time::Duration;
use tracing::{debug, error, warn};
use crate::core::config;
#[derive(Debug, Serialize)]
struct ChatRequest {
model: String,
messages: Vec<ChatMessage>,
temperature: f32,
max_tokens: u32,
stream: bool,
}
#[derive(Debug, Serialize, Deserialize)]
struct ChatMessage {
role: String,
content: String,
}
#[derive(Debug, Deserialize)]
struct ChatResponse {
choices: Vec<Choice>,
}
#[derive(Debug, Deserialize)]
struct Choice {
message: ChatMessage,
}
/// Generates a 5W1H+ summary for a given scene context.
/// Context should include the combined text of all sentences in the scene.
pub async fn generate_5w1h_summary(scene_text: &str) -> Result<String> {
if !*config::llm::SUMMARY_ENABLED {
warn!("LLM Summary is disabled via config");
return Ok("LLM Disabled".to_string());
}
let client = Client::builder()
.timeout(Duration::from_secs(*config::llm::SUMMARY_TIMEOUT_SECS))
.build()?;
let prompt = format!(
r#"Analyze the following video scene transcript and provide a concise 5W1H+ summary in JSON format.
Focus on: Who, What, Where, When, Why, How, and Key Objects/Actions.
Transcript:
"{}"
Output format:
{{
"who": "...",
"what": "...",
"where": "...",
"when": "...",
"why": "...",
"how": "...",
"summary": "..."
}}"#,
scene_text
);
let req = ChatRequest {
model: (*config::llm::SUMMARY_MODEL).clone(),
messages: vec![
ChatMessage {
role: "system".to_string(),
content: "You are an expert video analyst assistant.".to_string(),
},
ChatMessage {
role: "user".to_string(),
content: prompt,
},
],
temperature: 0.1,
max_tokens: 512,
stream: false,
};
debug!("Calling LLM for summary: {}", *config::llm::SUMMARY_URL);
let res = client
.post(&*config::llm::SUMMARY_URL)
.json(&req)
.send()
.await?;
if !res.status().is_success() {
error!("LLM API error: {}", res.status());
let text = res.text().await.unwrap_or_default();
anyhow::bail!("LLM API error: {}", text);
}
let chat_res: ChatResponse = res.json().await?;
if let Some(choice) = chat_res.choices.into_iter().next() {
Ok(choice.message.content.trim().to_string())
} else {
anyhow::bail!("Empty response from LLM");
}
}

1
src/core/llm/mod.rs Normal file
View File

@@ -0,0 +1 @@
pub mod client;

266
src/core/person_identity.rs Normal file
View File

@@ -0,0 +1,266 @@
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use sqlx::FromRow;
// ==========================================
// 舊版結構體 (保留以向後兼容)
// ==========================================
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
pub struct PersonIdentity {
pub id: i32,
pub person_id: String,
pub face_identity_id: Option<i32>,
pub speaker_id: Option<String>,
pub video_uuid: String,
pub confidence: f64,
pub name: Option<String>,
pub metadata: serde_json::Value,
pub first_appearance_time: Option<f64>,
pub last_appearance_time: Option<f64>,
pub total_appearance_duration: f64,
pub appearance_count: i32,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
pub is_confirmed: bool,
}
// ==========================================
// 新版結構體 (V5 身份綁定系統)
// ==========================================
/// 人物身份 (Identity) - 統一管理演員、公眾人物、家人朋友等
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
pub struct Identity {
pub id: i32,
pub name: String,
pub embedding: Option<String>, // Vector embedding stored as text/json
pub metadata: Option<serde_json::Value>,
pub created_at: DateTime<Utc>,
}
/// 身份綁定記錄 (Identity Binding)
/// 將機器 ID (face_x, speaker_y) 綁定到 Identity
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
pub struct IdentityBinding {
pub id: i64,
pub identity_id: i64,
pub binding_type: String, // 'face', 'speaker'
pub binding_value: String, // e.g. "face_1", "speaker_3"
pub source: String, // 'auto', 'manual'
pub confidence: f64,
pub is_active: bool,
pub created_at: DateTime<Utc>,
}
/// 綁定請求 (用於 API)
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct BindIdentityRequest {
pub identity_id: Option<i64>,
pub name: Option<String>, // 若未提供 identity_id則建立新 Identity
pub binding_type: String, // 'face' 或 'speaker'
pub binding_value: String, // e.g. "face_1"
pub source: Option<String>, // 預設 'manual'
}
/// 解綁請求
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct UnbindIdentityRequest {
pub binding_type: String,
pub binding_value: String,
}
/// 建議綁定請求 (由系統自動產生,人工確認)
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct SuggestedBinding {
pub binding_type: String,
pub binding_value: String,
pub suggested_identity_id: i64,
pub suggested_identity_name: String,
pub confidence: f64,
pub reason: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
pub struct PersonAppearance {
pub id: i32,
pub person_id: String,
pub video_uuid: String,
pub start_time: f64,
pub end_time: f64,
pub duration: f64,
pub face_detection_id: Option<i32>,
pub asrx_segment_start: Option<f64>,
pub asrx_segment_end: Option<f64>,
pub confidence: f64,
pub metadata: serde_json::Value,
pub created_at: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
pub struct PersonMatch {
pub face_id: String,
pub speaker_id: String,
pub confidence: f64,
pub match_count: i64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PersonTimelineEntry {
pub start_time: f64,
pub end_time: f64,
pub duration: f64,
pub confidence: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PersonStatistics {
pub total_appearances: i32,
pub total_duration: f64,
pub first_appearance: Option<f64>,
pub last_appearance: Option<f64>,
pub average_confidence: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CreatePersonIdentityRequest {
pub video_uuid: String,
pub face_identity_id: Option<i32>,
pub speaker_id: Option<String>,
pub name: Option<String>,
pub metadata: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UpdatePersonIdentityRequest {
pub name: Option<String>,
pub metadata: Option<serde_json::Value>,
pub is_confirmed: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PersonIdentityResponse {
pub person_id: String,
pub name: Option<String>,
pub face_identity_id: Option<i32>,
pub speaker_id: Option<String>,
pub confidence: f64,
pub appearance_count: i32,
pub total_appearance_duration: f64,
pub first_appearance_time: Option<f64>,
pub last_appearance_time: Option<f64>,
pub is_confirmed: bool,
}
impl From<PersonIdentity> for PersonIdentityResponse {
fn from(person: PersonIdentity) -> Self {
Self {
person_id: person.person_id,
name: person.name,
face_identity_id: person.face_identity_id,
speaker_id: person.speaker_id,
confidence: person.confidence,
appearance_count: person.appearance_count,
total_appearance_duration: person.total_appearance_duration,
first_appearance_time: person.first_appearance_time,
last_appearance_time: person.last_appearance_time,
is_confirmed: person.is_confirmed,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PersonTimelineResponse {
pub person_id: String,
pub name: Option<String>,
pub timeline: Vec<PersonTimelineEntry>,
pub statistics: PersonStatistics,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChunkPersonInfo {
pub person_id: String,
pub name: Option<String>,
pub confidence: f64,
pub overlap_duration: f64,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_person_identity_serialization() {
let person = PersonIdentity {
id: 1,
person_id: "person_001".to_string(),
face_identity_id: Some(123),
speaker_id: Some("SPEAKER_00".to_string()),
video_uuid: "video_abc".to_string(),
confidence: 0.85,
name: Some("张三".to_string()),
metadata: serde_json::json!({"role": "host"}),
first_appearance_time: Some(10.5),
last_appearance_time: Some(350.2),
total_appearance_duration: 120.5,
appearance_count: 15,
created_at: Utc::now(),
updated_at: Utc::now(),
is_confirmed: true,
};
let json = serde_json::to_string(&person).unwrap();
assert!(json.contains("person_001"));
assert!(json.contains("SPEAKER_00"));
assert!(json.contains("张三"));
}
#[test]
fn test_person_appearance_serialization() {
let appearance = PersonAppearance {
id: 1,
person_id: "person_001".to_string(),
video_uuid: "video_abc".to_string(),
start_time: 10.5,
end_time: 25.3,
duration: 14.8,
face_detection_id: Some(456),
asrx_segment_start: Some(10.0),
asrx_segment_end: Some(26.0),
confidence: 0.92,
metadata: serde_json::json!({}),
created_at: Utc::now(),
};
let json = serde_json::to_string(&appearance).unwrap();
assert!(json.contains("person_001"));
assert!(json.contains("14.8"));
}
#[test]
fn test_person_match() {
let match_result = PersonMatch {
face_id: "face_123".to_string(),
speaker_id: "SPEAKER_00".to_string(),
confidence: 0.85,
match_count: 15,
};
assert_eq!(match_result.face_id, "face_123");
assert!(match_result.confidence >= 0.0 && match_result.confidence <= 1.0);
}
#[test]
fn test_person_statistics() {
let stats = PersonStatistics {
total_appearances: 15,
total_duration: 120.5,
first_appearance: Some(10.5),
last_appearance: Some(350.2),
average_confidence: 0.88,
};
assert_eq!(stats.total_appearances, 15);
assert!(stats.total_duration > 0.0);
}
}

View File

@@ -0,0 +1,124 @@
use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use std::time::Duration;
use super::executor::PythonExecutor;
use crate::core::config::processor;
#[derive(Debug, Serialize, Deserialize)]
pub struct AsrResult {
pub language: Option<String>,
pub language_probability: Option<f64>,
pub segments: Vec<AsrSegment>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct AsrSegment {
pub start: f64,
pub end: f64,
pub text: String,
}
pub async fn process_asr(
video_path: &str,
output_path: &str,
uuid: Option<&str>,
) -> Result<AsrResult> {
let executor = PythonExecutor::new()?;
let script_path = executor.script_path("asr_processor.py");
tracing::info!("[ASR] Starting ASR processing: {}", video_path);
executor
.run(
"asr_processor.py",
&[video_path, output_path],
uuid,
"ASR",
Some(Duration::from_secs(*processor::ASR_TIMEOUT_SECS)),
)
.await
.with_context(|| format!("Failed to run {:?}", script_path))?;
let json_str = std::fs::read_to_string(output_path).context("Failed to read ASR output")?;
let result: AsrResult =
serde_json::from_str(&json_str).context("Failed to parse ASR output")?;
tracing::info!(
"[ASR] Result: {} segments, language: {:?}",
result.segments.len(),
result.language
);
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_asr_result_serialization() {
let result = AsrResult {
language: Some("en".to_string()),
language_probability: Some(0.95),
segments: vec![
AsrSegment {
start: 0.0,
end: 2.5,
text: "Hello world".to_string(),
},
AsrSegment {
start: 2.5,
end: 5.0,
text: "Test speech".to_string(),
},
],
};
let json = serde_json::to_string(&result).unwrap();
assert!(json.contains("Hello world"));
assert!(json.contains("en"));
}
#[test]
fn test_asr_result_deserialization() {
let json = r#"{
"language": "zh",
"language_probability": 0.98,
"segments": [
{"start": 0.0, "end": 1.5, "text": "測試"}
]
}"#;
let result: AsrResult = serde_json::from_str(json).unwrap();
assert_eq!(result.language, Some("zh".to_string()));
assert_eq!(result.language_probability, Some(0.98));
assert_eq!(result.segments.len(), 1);
assert_eq!(result.segments[0].text, "測試");
}
#[test]
fn test_asr_segment_default() {
let segment = AsrSegment {
start: 0.0,
end: 1.0,
text: String::new(),
};
assert_eq!(segment.start, 0.0);
assert_eq!(segment.end, 1.0);
assert!(segment.text.is_empty());
}
#[test]
fn test_asr_result_empty_segments() {
let result = AsrResult {
language: None,
language_probability: None,
segments: vec![],
};
assert!(result.language.is_none());
assert!(result.segments.is_empty());
}
}

View File

@@ -0,0 +1,345 @@
use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use std::time::Duration;
use super::executor::PythonExecutor;
const FACE_RECOGNITION_TIMEOUT: Duration = Duration::from_secs(10800); // 3 hours for recognition
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct FaceRecognitionResult {
pub frame_count: u64,
pub fps: f64,
pub frames: Vec<FaceRecognitionFrame>,
pub recognized_faces: Vec<RecognizedFace>,
pub face_clusters: Vec<FaceCluster>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct FaceRecognitionFrame {
pub frame: u64,
pub timestamp: f64,
pub faces: Vec<RecognizedFaceDetection>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct RecognizedFaceDetection {
pub face_id: Option<String>,
pub x: i32,
pub y: i32,
pub width: i32,
pub height: i32,
pub confidence: f32,
pub embedding: Option<Vec<f32>>,
pub attributes: Option<FaceAttributes>,
pub identity: Option<FaceIdentity>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct FaceAttributes {
pub age: Option<u8>,
pub gender: Option<String>,
pub emotion: Option<String>,
pub glasses: Option<bool>,
pub mask: Option<bool>,
pub pose: Option<FacePose>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct FacePose {
pub yaw: f32,
pub pitch: f32,
pub roll: f32,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct FaceIdentity {
pub name: Option<String>,
pub confidence: f32,
pub database_id: Option<String>,
pub metadata: Option<serde_json::Value>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct RecognizedFace {
pub face_id: String,
pub embedding: Vec<f32>,
pub first_seen: f64,
pub last_seen: f64,
pub total_appearances: u32,
pub attributes: Option<FaceAttributes>,
pub identities: Vec<FaceIdentity>,
pub cluster_id: Option<String>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct FaceCluster {
pub cluster_id: String,
pub face_ids: Vec<String>,
pub centroid: Vec<f32>,
pub size: u32,
pub representative_face_id: Option<String>,
pub metadata: Option<serde_json::Value>,
}
pub async fn process_face_recognition(
video_path: &str,
output_path: &str,
uuid: Option<&str>,
enable_recognition: bool,
enable_tracking: bool,
enable_clustering: bool,
) -> Result<FaceRecognitionResult> {
let executor = PythonExecutor::new()?;
let script_path = executor.script_path("face_recognition_processor.py");
tracing::info!(
"[FACE_RECOGNITION] Starting face recognition: {}",
video_path
);
if !script_path.exists() {
tracing::warn!("[FACE_RECOGNITION] Script not found, returning empty result");
return Ok(FaceRecognitionResult {
frame_count: 0,
fps: 0.0,
frames: vec![],
recognized_faces: vec![],
face_clusters: vec![],
});
}
let args = vec![
video_path,
output_path,
if enable_recognition { "1" } else { "0" },
if enable_tracking { "1" } else { "0" },
if enable_clustering { "1" } else { "0" },
];
executor
.run(
"face_recognition_processor.py",
&args,
uuid,
"FACE_RECOGNITION",
Some(FACE_RECOGNITION_TIMEOUT),
)
.await
.with_context(|| format!("Failed to run {:?}", script_path))?;
let json_str =
std::fs::read_to_string(output_path).context("Failed to read FACE_RECOGNITION output")?;
let result: FaceRecognitionResult =
serde_json::from_str(&json_str).context("Failed to parse FACE_RECOGNITION output")?;
tracing::info!(
"[FACE_RECOGNITION] Result: {} frames, {} recognized faces, {} clusters",
result.frames.len(),
result.recognized_faces.len(),
result.face_clusters.len()
);
Ok(result)
}
pub async fn register_face(
image_path: &str,
name: &str,
metadata: Option<serde_json::Value>,
) -> Result<FaceRegistrationResult> {
let executor = PythonExecutor::new()?;
let script_path = executor.script_path("face_registration.py");
tracing::info!("[FACE_REGISTRATION] Registering face: {}", name);
if !script_path.exists() {
anyhow::bail!("Face registration script not found");
}
let output_path = format!("/tmp/face_registration_{}.json", uuid::Uuid::new_v4());
// Handle metadata separately to avoid lifetime issues
let meta_temp_file = metadata.as_ref().map(|meta| {
let meta_path = format!("/tmp/face_metadata_{}.json", uuid::Uuid::new_v4());
std::fs::write(&meta_path, serde_json::to_string(meta).unwrap()).unwrap();
meta_path
});
// Build arguments - use output_path as database path so Python writes there
let mut args = vec![
image_path.to_string(),
output_path.clone(),
name.to_string(),
];
// Add database parameter (point to same output for now)
let database_path = output_path.clone();
args.push("--database".to_string());
args.push(database_path.clone());
if let Some(ref meta_path) = meta_temp_file {
args.push("--metadata".to_string());
args.push(meta_path.clone());
}
let args_refs: Vec<&str> = args.iter().map(|s| s.as_str()).collect();
executor
.run(
"face_registration.py",
&args_refs,
None,
"FACE_REGISTRATION",
Some(Duration::from_secs(300)),
)
.await
.with_context(|| format!("Failed to run {:?}", script_path))?;
let json_str =
std::fs::read_to_string(&output_path).context("Failed to read registration output")?;
let result: FaceRegistrationResult =
serde_json::from_str(&json_str).context("Failed to parse registration output")?;
// Clean up temp files
let _ = std::fs::remove_file(&output_path);
if let Some(meta_path) = meta_temp_file {
let _ = std::fs::remove_file(&meta_path);
}
tracing::info!("[FACE_REGISTRATION] Registered face: {}", result.face_id);
Ok(result)
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct FaceRegistrationResult {
pub face_id: String,
pub embedding: Vec<f32>,
pub attributes: Option<FaceAttributes>,
pub success: bool,
pub message: String,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_face_recognition_result_serialization() {
let result = FaceRecognitionResult {
frame_count: 100,
fps: 30.0,
frames: vec![FaceRecognitionFrame {
frame: 0,
timestamp: 0.0,
faces: vec![RecognizedFaceDetection {
face_id: Some("face_1".to_string()),
x: 100,
y: 100,
width: 50,
height: 60,
confidence: 0.95,
embedding: Some(vec![0.1, 0.2, 0.3]),
attributes: Some(FaceAttributes {
age: Some(30),
gender: Some("male".to_string()),
emotion: Some("neutral".to_string()),
glasses: Some(false),
mask: Some(false),
pose: Some(FacePose {
yaw: 0.1,
pitch: 0.2,
roll: 0.3,
}),
}),
identity: Some(FaceIdentity {
name: Some("John Doe".to_string()),
confidence: 0.85,
database_id: Some("user_123".to_string()),
metadata: Some(serde_json::json!({"role": "employee"})),
}),
}],
}],
recognized_faces: vec![RecognizedFace {
face_id: "face_1".to_string(),
embedding: vec![0.1, 0.2, 0.3],
first_seen: 0.0,
last_seen: 10.0,
total_appearances: 5,
attributes: Some(FaceAttributes {
age: Some(30),
gender: Some("male".to_string()),
emotion: Some("neutral".to_string()),
glasses: Some(false),
mask: Some(false),
pose: Some(FacePose {
yaw: 0.1,
pitch: 0.2,
roll: 0.3,
}),
}),
identities: vec![FaceIdentity {
name: Some("John Doe".to_string()),
confidence: 0.85,
database_id: Some("user_123".to_string()),
metadata: Some(serde_json::json!({"role": "employee"})),
}],
cluster_id: Some("cluster_1".to_string()),
}],
face_clusters: vec![FaceCluster {
cluster_id: "cluster_1".to_string(),
face_ids: vec!["face_1".to_string()],
centroid: vec![0.1, 0.2, 0.3],
size: 1,
representative_face_id: Some("face_1".to_string()),
metadata: Some(serde_json::json!({"description": "main person"})),
}],
};
let json = serde_json::to_string(&result).unwrap();
assert!(json.contains("face_1"));
assert!(json.contains("John Doe"));
assert!(json.contains("cluster_1"));
}
#[test]
fn test_face_attributes_serialization() {
let attributes = FaceAttributes {
age: Some(25),
gender: Some("female".to_string()),
emotion: Some("happy".to_string()),
glasses: Some(true),
mask: Some(false),
pose: Some(FacePose {
yaw: -0.1,
pitch: 0.05,
roll: 0.02,
}),
};
let json = serde_json::to_string(&attributes).unwrap();
assert!(json.contains("\"age\":25"));
assert!(json.contains("\"gender\":\"female\""));
assert!(json.contains("\"emotion\":\"happy\""));
}
#[test]
fn test_face_identity_serialization() {
let identity = FaceIdentity {
name: Some("Alice Smith".to_string()),
confidence: 0.92,
database_id: Some("employee_456".to_string()),
metadata: Some(serde_json::json!({
"department": "engineering",
"position": "senior developer"
})),
};
let json = serde_json::to_string(&identity).unwrap();
assert!(json.contains("Alice Smith"));
assert!(json.contains("\"confidence\":0.92"));
assert!(json.contains("engineering"));
}
}

View File

@@ -0,0 +1,562 @@
//! 視覺分片處理器 (Phase 2.2)
//!
//! 從 YOLO 結果生成視覺分片
use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use std::time::Duration;
use super::executor::PythonExecutor;
use super::yolo::{YoloFrame, YoloResult};
const VISUAL_CHUNK_TIMEOUT: Duration = Duration::from_secs(3600);
/// 視覺分片處理結果
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct VisualChunkResult {
/// 生成的視覺分片數量
pub chunk_count: u32,
/// 處理的總幀數
pub total_frames: u32,
/// 檢測到的總物件數
pub total_objects: u32,
/// 唯一物件類別數
pub unique_classes: u32,
/// 生成的視覺分片
pub chunks: Vec<crate::core::chunk::Chunk>,
}
/// 從 YOLO 結果生成視覺分片
pub async fn process_visual_chunk(
file_id: i32,
uuid: String,
video_path: &str,
yolo_result: &YoloResult,
chunk_index_offset: u32,
fps: f64,
) -> Result<VisualChunkResult> {
tracing::info!(
"[VisualChunk] Starting visual chunk generation for video: {}, {} frames",
video_path,
yolo_result.frames.len()
);
if yolo_result.frames.is_empty() {
tracing::warn!("[VisualChunk] No YOLO frames to process");
return Ok(VisualChunkResult {
chunk_count: 0,
total_frames: 0,
total_objects: 0,
unique_classes: 0,
chunks: vec![],
});
}
// 策略 1: 固定幀數分片(每 N 幀一個分片)
let chunks = create_fixed_frame_chunks(file_id, &uuid, yolo_result, chunk_index_offset, fps);
// 統計信息
let total_objects: u32 = yolo_result
.frames
.iter()
.map(|f| f.objects.len() as u32)
.sum();
let all_classes: Vec<String> = yolo_result
.frames
.iter()
.flat_map(|f| f.objects.iter().map(|o| o.class_name.clone()))
.collect();
let unique_classes: u32 = all_classes
.iter()
.cloned()
.collect::<std::collections::HashSet<_>>()
.len() as u32;
tracing::info!(
"[VisualChunk] Generated {} visual chunks from {} frames, {} total objects, {} unique classes",
chunks.len(),
yolo_result.frames.len(),
total_objects,
unique_classes
);
Ok(VisualChunkResult {
chunk_count: chunks.len() as u32,
total_frames: yolo_result.frames.len() as u32,
total_objects,
unique_classes,
chunks,
})
}
/// 創建固定幀數分片(每 N 幀一個分片)
fn create_fixed_frame_chunks(
file_id: i32,
uuid: &str,
yolo_result: &YoloResult,
chunk_index_offset: u32,
fps: f64,
) -> Vec<crate::core::chunk::Chunk> {
let mut chunks = Vec::new();
// 配置:每 30 幀創建一個分片(約 1 秒,如果 fps=30
let frames_per_chunk = 30;
let total_frames = yolo_result.frames.len();
if total_frames == 0 {
return chunks;
}
let mut chunk_index = chunk_index_offset;
let mut start_idx = 0;
while start_idx < total_frames {
let end_idx = std::cmp::min(start_idx + frames_per_chunk, total_frames);
// 獲取這個分片的幀
let chunk_frames: Vec<YoloFrame> = yolo_result.frames[start_idx..end_idx]
.iter()
.cloned()
.collect();
if chunk_frames.is_empty() {
break;
}
// 計算幀範圍
let start_frame = chunk_frames.first().unwrap().frame as i64;
let end_frame = chunk_frames.last().unwrap().frame as i64 + 1; // exclusive
// 創建視覺分片
let chunk = crate::core::chunk::Chunk::from_yolo_frames(
file_id,
uuid.to_string(),
chunk_index,
start_frame,
end_frame,
fps,
chunk_frames,
);
chunks.push(chunk);
// 更新索引
start_idx = end_idx;
chunk_index += 1;
}
chunks
}
/// 基於物件相似度創建分片
fn create_similarity_based_chunks(
file_id: i32,
uuid: &str,
yolo_result: &YoloResult,
chunk_index_offset: u32,
fps: f64,
similarity_threshold: f32,
min_frames_per_chunk: usize,
) -> Vec<crate::core::chunk::Chunk> {
let mut chunks = Vec::new();
if yolo_result.frames.is_empty() {
return chunks;
}
let mut current_chunk_frames: Vec<YoloFrame> = Vec::new();
let mut chunk_index = chunk_index_offset;
let mut current_start_frame = 0;
for (i, frame) in yolo_result.frames.iter().enumerate() {
if current_chunk_frames.is_empty() {
current_chunk_frames.push(frame.clone());
current_start_frame = frame.frame as i64;
continue;
}
// 檢查相似度(簡化版本:檢查物件類別是否相同)
let last_frame = current_chunk_frames.last().unwrap();
let similarity = calculate_frame_similarity(last_frame, frame);
if similarity >= similarity_threshold {
// 相似度高,加入當前分片
current_chunk_frames.push(frame.clone());
} else {
// 相似度低,創建新分片
if current_chunk_frames.len() >= min_frames_per_chunk {
let end_frame = current_chunk_frames.last().unwrap().frame as i64 + 1;
let chunk = crate::core::chunk::Chunk::from_yolo_frames(
file_id,
uuid.to_string(),
chunk_index,
current_start_frame,
end_frame,
fps,
current_chunk_frames.clone(),
);
chunks.push(chunk);
chunk_index += 1;
}
// 開始新的分片
current_chunk_frames = vec![frame.clone()];
current_start_frame = frame.frame as i64;
}
}
// 處理最後一個分片
if current_chunk_frames.len() >= min_frames_per_chunk {
let end_frame = current_chunk_frames.last().unwrap().frame as i64 + 1;
let chunk = crate::core::chunk::Chunk::from_yolo_frames(
file_id,
uuid.to_string(),
chunk_index,
current_start_frame,
end_frame,
fps,
current_chunk_frames,
);
chunks.push(chunk);
}
chunks
}
/// 計算兩個幀之間的相似度(基於物件類別)
fn calculate_frame_similarity(frame1: &YoloFrame, frame2: &YoloFrame) -> f32 {
if frame1.objects.is_empty() && frame2.objects.is_empty() {
return 1.0;
}
if frame1.objects.is_empty() || frame2.objects.is_empty() {
return 0.0;
}
let set1: std::collections::HashSet<String> = frame1
.objects
.iter()
.map(|o| o.class_name.clone())
.collect();
let set2: std::collections::HashSet<String> = frame2
.objects
.iter()
.map(|o| o.class_name.clone())
.collect();
let intersection: Vec<_> = set1.intersection(&set2).collect();
let union: Vec<_> = set1.union(&set2).collect();
if union.is_empty() {
0.0
} else {
intersection.len() as f32 / union.len() as f32
}
}
/// 使用 Python 腳本生成視覺分片(進階版本)
pub async fn process_visual_chunk_advanced(
video_path: &str,
output_path: &str,
uuid: Option<&str>,
) -> Result<VisualChunkResult> {
let executor = PythonExecutor::new()?;
let script_path = executor.script_path("visual_chunk_processor.py");
tracing::info!(
"[VisualChunk] Starting advanced visual chunk generation: {}",
video_path
);
if !script_path.exists() {
tracing::warn!("[VisualChunk] Script not found, using basic generation");
// 這裡可以回退到基本生成方法
return Ok(VisualChunkResult {
chunk_count: 0,
total_frames: 0,
total_objects: 0,
unique_classes: 0,
chunks: vec![],
});
}
executor
.run(
"visual_chunk_processor.py",
&[video_path, output_path],
uuid,
"VisualChunk",
Some(VISUAL_CHUNK_TIMEOUT),
)
.await
.with_context(|| format!("Failed to run {:?}", script_path))?;
let json_str =
std::fs::read_to_string(output_path).context("Failed to read visual chunk output")?;
let result: VisualChunkResult =
serde_json::from_str(&json_str).context("Failed to parse visual chunk output")?;
tracing::info!(
"[VisualChunk] Advanced generation result: {} chunks, {} frames",
result.chunk_count,
result.total_frames
);
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_calculate_frame_similarity() {
use crate::core::processor::yolo::{YoloFrame, YoloObject};
let frame1 = YoloFrame {
frame: 0,
timestamp: 0.0,
objects: vec![
YoloObject {
class_name: "person".to_string(),
class_id: 0,
x: 100,
y: 200,
width: 50,
height: 100,
confidence: 0.95,
},
YoloObject {
class_name: "car".to_string(),
class_id: 2,
x: 300,
y: 150,
width: 80,
height: 60,
confidence: 0.87,
},
],
};
let frame2 = YoloFrame {
frame: 1,
timestamp: 0.033,
objects: vec![
YoloObject {
class_name: "person".to_string(),
class_id: 0,
x: 110,
y: 210,
width: 52,
height: 102,
confidence: 0.92,
},
YoloObject {
class_name: "car".to_string(),
class_id: 2,
x: 310,
y: 155,
width: 82,
height: 62,
confidence: 0.85,
},
],
};
let frame3 = YoloFrame {
frame: 2,
timestamp: 0.066,
objects: vec![YoloObject {
class_name: "dog".to_string(),
class_id: 16,
x: 150,
y: 250,
width: 40,
height: 60,
confidence: 0.78,
}],
};
// 相同物件的幀應該高度相似
let similarity_same = calculate_frame_similarity(&frame1, &frame2);
assert!((similarity_same - 1.0).abs() < 0.001);
// 不同物件的幀應該不相似
let similarity_diff = calculate_frame_similarity(&frame1, &frame3);
assert!((similarity_diff - 0.0).abs() < 0.001);
// 空幀應該完全相似
let empty_frame = YoloFrame {
frame: 3,
timestamp: 0.1,
objects: vec![],
};
let similarity_empty = calculate_frame_similarity(&empty_frame, &empty_frame);
assert!((similarity_empty - 1.0).abs() < 0.001);
}
#[tokio::test]
async fn test_create_fixed_frame_chunks() {
use crate::core::processor::yolo::{YoloFrame, YoloObject, YoloResult};
// 創建測試 YOLO 結果60 幀,每幀都有物件)
let mut frames = Vec::new();
for i in 0..60 {
frames.push(YoloFrame {
frame: i as u64,
timestamp: i as f64 / 30.0, // 假設 fps=30
objects: vec![YoloObject {
class_name: "person".to_string(),
class_id: 0,
x: 100,
y: 200,
width: 50,
height: 100,
confidence: 0.9,
}],
});
}
let yolo_result = YoloResult {
frame_count: 60,
fps: 30.0,
frames,
};
let chunks = create_fixed_frame_chunks(1, "test-uuid", &yolo_result, 0, 30.0);
// 60 幀,每 30 幀一個分片,應該有 2 個分片
assert_eq!(chunks.len(), 2);
// 檢查第一個分片
let first_chunk = &chunks[0];
assert_eq!(
first_chunk.chunk_type,
crate::core::chunk::ChunkType::Visual
);
assert_eq!(first_chunk.start_frame, 0);
assert_eq!(first_chunk.end_frame, 30); // exclusive
assert_eq!(first_chunk.frame_count, 30);
// 檢查第二個分片
let second_chunk = &chunks[1];
assert_eq!(
second_chunk.chunk_type,
crate::core::chunk::ChunkType::Visual
);
assert_eq!(second_chunk.start_frame, 30);
assert_eq!(second_chunk.end_frame, 60); // exclusive
assert_eq!(second_chunk.frame_count, 30);
}
#[test]
fn test_create_similarity_based_chunks() {
use crate::core::processor::yolo::{YoloFrame, YoloObject, YoloResult};
// 創建測試 YOLO 結果
let frames = vec![
YoloFrame {
// 幀 0-4: 都有 person 和 car
frame: 0,
timestamp: 0.0,
objects: vec![
YoloObject {
class_name: "person".to_string(),
class_id: 0,
x: 100,
y: 200,
width: 50,
height: 100,
confidence: 0.9,
},
YoloObject {
class_name: "car".to_string(),
class_id: 2,
x: 300,
y: 150,
width: 80,
height: 60,
confidence: 0.8,
},
],
},
YoloFrame {
// 幀 1
frame: 1,
timestamp: 0.033,
objects: vec![
YoloObject {
class_name: "person".to_string(),
class_id: 0,
x: 110,
y: 210,
width: 52,
height: 102,
confidence: 0.88,
},
YoloObject {
class_name: "car".to_string(),
class_id: 2,
x: 310,
y: 155,
width: 82,
height: 62,
confidence: 0.78,
},
],
},
YoloFrame {
// 幀 5-9: 只有 dog
frame: 5,
timestamp: 0.166,
objects: vec![YoloObject {
class_name: "dog".to_string(),
class_id: 16,
x: 150,
y: 250,
width: 40,
height: 60,
confidence: 0.7,
}],
},
YoloFrame {
// 幀 6
frame: 6,
timestamp: 0.2,
objects: vec![YoloObject {
class_name: "dog".to_string(),
class_id: 16,
x: 155,
y: 255,
width: 42,
height: 62,
confidence: 0.68,
}],
},
];
let yolo_result = YoloResult {
frame_count: 7,
fps: 30.0,
frames,
};
let chunks = create_similarity_based_chunks(
1,
"test-uuid",
&yolo_result,
0,
30.0,
0.5, // similarity threshold
2, // min frames per chunk
);
// 應該有 2 個分片:一個是 person+car一個是 dog
assert_eq!(chunks.len(), 2);
}
}

9
src/core/text/mod.rs Normal file
View File

@@ -0,0 +1,9 @@
pub mod online_synonym_expander;
pub mod synonym;
pub mod synonym_expander;
pub mod tokenizer;
pub use online_synonym_expander::{global_online_expander, OnlineSynonymExpander};
pub use synonym::{normalize_chinese_query, simplified_to_traditional, traditional_to_simplified};
pub use synonym_expander::{global_synonym_expander, SynonymExpander};
pub use tokenizer::{contains_chinese, extract_and_tokenize_text, tokenize_chinese_text};

View File

@@ -0,0 +1,242 @@
use anyhow::{Context, Result};
use once_cell::sync::Lazy;
use serde::Deserialize;
use std::collections::HashMap;
use std::env;
use std::sync::Arc;
use tokio::sync::Mutex;
/// Online Synonym Expander
/// Fetches synonyms from LLM (llama.cpp server) on-demand and caches them.
///
/// Environment variables:
/// - `MOMENTRY_ONLINE_SYNONYM` - Enable online synonym expansion (default: false)
/// - `MOMENTRY_LLM_SYNONYM_URL` - LLM server URL (default: http://127.0.0.1:8081)
/// - `MOMENTRY_LLM_SYNONYM_MODEL` - Model name (default: gemma4)
/// - `MOMENTRY_LLM_SYNONYM_TIMEOUT` - Request timeout in seconds (default: 60)
#[derive(Debug, Deserialize)]
struct LlmResponse {
choices: Vec<LlmChoice>,
}
#[derive(Debug, Deserialize)]
struct LlmChoice {
message: LlmMessage,
}
#[derive(Debug, Deserialize)]
struct LlmMessage {
content: String,
}
#[derive(Debug)]
pub struct OnlineSynonymExpander {
/// Local synonym cache (loaded from file)
local_map: HashMap<String, Vec<String>>,
/// Runtime cache for LLM-fetched synonyms
runtime_cache: Arc<Mutex<HashMap<String, Vec<String>>>>,
/// LLM server URL
api_url: String,
/// Model name
model: String,
/// Request timeout
timeout_secs: u64,
}
static SYSTEM_PROMPT: &str = r#"You are a synonym generation assistant. For each given word, provide 8-12 synonyms in the same language.
Rules:
1. Return ONLY a JSON array of strings, nothing else
2. Synonyms should be contextually relevant for video content search
3. Include common words, informal terms, and related concepts
4. Do NOT include the input word in the output
5. All synonyms must be in the SAME language as the input word
6. No explanations, no markdown, just the JSON array
Example input: "money"
Example output: ["cash", "dollar", "currency", "funds", "bucks", "greenbacks", "coins", "wealth", "payment"]"#;
impl OnlineSynonymExpander {
pub fn new(local_file_path: Option<&str>) -> Self {
let local_map = if let Some(path) = local_file_path {
match Self::load_local_file(path) {
Ok(map) => map,
Err(e) => {
tracing::warn!("Failed to load local synonym file {}: {}", path, e);
HashMap::new()
}
}
} else {
HashMap::new()
};
let api_url = env::var("MOMENTRY_LLM_SYNONYM_URL")
.unwrap_or_else(|_| "http://127.0.0.1:8081".to_string());
let model = env::var("MOMENTRY_LLM_SYNONYM_MODEL").unwrap_or_else(|_| "gemma4".to_string());
let timeout_secs = env::var("MOMENTRY_LLM_SYNONYM_TIMEOUT")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(60);
Self {
local_map,
runtime_cache: Arc::new(Mutex::new(HashMap::new())),
api_url,
model,
timeout_secs,
}
}
fn load_local_file(path: &str) -> Result<HashMap<String, Vec<String>>> {
let content = std::fs::read_to_string(path).context("Failed to read local synonym file")?;
let map: HashMap<String, Vec<String>> =
serde_json::from_str(&content).context("Failed to parse local synonym JSON")?;
Ok(map)
}
/// Get synonyms for a word. Checks local map first, then runtime cache, then fetches from LLM.
pub async fn expand_word(&self, word: &str) -> String {
// 1. Check local map
if let Some(syns) = self.local_map.get(word) {
if !syns.is_empty() {
let mut parts = vec![word.to_string()];
parts.extend_from_slice(syns);
return format!("({})", parts.join(" | "));
}
}
// 2. Check runtime cache
let mut cache = self.runtime_cache.lock().await;
if let Some(syns) = cache.get(word) {
if !syns.is_empty() {
let mut parts = vec![word.to_string()];
parts.extend_from_slice(syns);
return format!("({})", parts.join(" | "));
}
}
drop(cache);
// 3. Fetch from LLM
if let Ok(synonyms) = self.fetch_from_llm(word).await {
if !synonyms.is_empty() {
// Add to runtime cache
let mut cache = self.runtime_cache.lock().await;
cache.insert(word.to_string(), synonyms.clone());
drop(cache);
let mut parts = vec![word.to_string()];
parts.extend_from_slice(&synonyms);
return format!("({})", parts.join(" | "));
}
}
// 4. Fallback: return original word
word.to_string()
}
async fn fetch_from_llm(&self, word: &str) -> Result<Vec<String>> {
let client = reqwest::Client::new();
let prompt = format!(
r#"Give synonyms for: "{}"
Return ONLY a JSON array of strings, nothing else. Do NOT include the input word."#,
word
);
let payload = serde_json::json!({
"model": self.model,
"messages": [
{
"role": "system",
"content": SYSTEM_PROMPT
},
{
"role": "user",
"content": prompt
}
],
"temperature": 0.3,
"stream": false,
"max_tokens": 256,
});
let response = client
.post(format!("{}/v1/chat/completions", self.api_url))
.json(&payload)
.timeout(std::time::Duration::from_secs(self.timeout_secs))
.send()
.await
.context("LLM request failed")?;
if !response.status().is_success() {
anyhow::bail!("LLM request failed with status: {}", response.status());
}
let llm_resp: LlmResponse = response
.json()
.await
.context("Failed to parse LLM response")?;
let content = &llm_resp
.choices
.get(0)
.context("No choices in LLM response")?
.message
.content;
// Extract JSON from response (handle markdown code blocks)
let json_str = if let Some(start) = content.find('[') {
if let Some(end) = content.rfind(']') {
&content[start..=end]
} else {
anyhow::bail!("No JSON array found in LLM response");
}
} else {
anyhow::bail!("No JSON array found in LLM response");
};
let synonyms: Vec<String> =
serde_json::from_str(json_str).context("Failed to parse LLM synonyms JSON")?;
// Filter and normalize
let cleaned: Vec<String> = synonyms
.into_iter()
.map(|s| s.trim().to_lowercase())
.filter(|s| !s.is_empty() && !s.contains(' ')) // Filter out multi-word synonyms for to_tsquery compatibility
.collect();
if cleaned.is_empty() {
anyhow::bail!("No valid synonyms returned");
}
tracing::info!(
"LLM fetched {} synonyms for '{}': {:?}",
cleaned.len(),
word,
cleaned.iter().take(5).collect::<Vec<_>>()
);
Ok(cleaned)
}
/// Get the number of cached synonyms
pub async fn cache_size(&self) -> usize {
self.runtime_cache.lock().await.len()
}
}
/// Global online synonym expander (lazy-loaded)
static ONLINE_EXPANDER: Lazy<Option<OnlineSynonymExpander>> = Lazy::new(|| {
if env::var("MOMENTRY_ONLINE_SYNONYM").is_ok() {
let local_file = env::var("MOMENTRY_SYNONYM_FILE").ok();
tracing::info!("Initializing online synonym expander");
Some(OnlineSynonymExpander::new(local_file.as_deref()))
} else {
None
}
});
/// Get the global online synonym expander (if enabled)
pub fn global_online_expander() -> Option<&'static OnlineSynonymExpander> {
ONLINE_EXPANDER.as_ref()
}

71
src/core/text/synonym.rs Normal file
View File

@@ -0,0 +1,71 @@
use ferrous_opencc::{config::BuiltinConfig, OpenCC};
use once_cell::sync::Lazy;
static OPENCC_S2T: Lazy<OpenCC> = Lazy::new(|| {
OpenCC::from_config(BuiltinConfig::S2t)
.expect("Failed to initialize OpenCC Simplified to Traditional converter")
});
static OPENCC_T2S: Lazy<OpenCC> = Lazy::new(|| {
OpenCC::from_config(BuiltinConfig::T2s)
.expect("Failed to initialize OpenCC Traditional to Simplified converter")
});
/// Convert Simplified Chinese text to Traditional Chinese
pub fn simplified_to_traditional(text: &str) -> String {
OPENCC_S2T.convert(text)
}
/// Convert Traditional Chinese text to Simplified Chinese
pub fn traditional_to_simplified(text: &str) -> String {
OPENCC_T2S.convert(text)
}
/// Normalize Chinese query for search:
/// 1. Convert Simplified Chinese to Traditional Chinese (assuming database stores Traditional)
/// 2. Return converted text
pub fn normalize_chinese_query(text: &str) -> String {
simplified_to_traditional(text)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simplified_to_traditional() {
// Example: Simplified "计算机" -> Traditional "計算機"
let simplified = "计算机";
let traditional = simplified_to_traditional(simplified);
// The conversion might produce "計算機" (depending on dictionary)
// We'll just verify it's not empty and different from input
assert!(!traditional.is_empty());
assert_ne!(traditional, simplified);
// Traditional input should remain unchanged (or nearly unchanged)
let traditional_input = "計算機";
let converted = simplified_to_traditional(traditional_input);
assert_eq!(converted, traditional_input);
}
#[test]
fn test_traditional_to_simplified() {
let traditional = "計算機";
let simplified = traditional_to_simplified(traditional);
assert!(!simplified.is_empty());
assert_ne!(simplified, traditional);
}
#[test]
fn test_normalize_chinese_query() {
let simplified = "计算机";
let normalized = normalize_chinese_query(simplified);
// Should be Traditional
assert_ne!(normalized, simplified);
let traditional = "計算機";
let normalized2 = normalize_chinese_query(traditional);
// Should remain Traditional
assert_eq!(normalized2, traditional);
}
}

View File

@@ -0,0 +1,247 @@
use anyhow::{Context, Result};
use once_cell::sync::Lazy;
use std::collections::HashMap;
use std::env;
use std::fs;
use std::path::Path;
/// 同義詞擴展器
/// 從 JSON 檔案加載自定義同義詞映射
#[derive(Debug, Clone, Default)]
pub struct SynonymExpander {
/// 詞語 -> 同義詞列表的映射
map: HashMap<String, Vec<String>>,
}
impl SynonymExpander {
/// 從 JSON 檔案創建同義詞擴展器
pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
let content = fs::read_to_string(path).context("Failed to read synonym file")?;
let map: HashMap<String, Vec<String>> =
serde_json::from_str(&content).context("Failed to parse synonym JSON")?;
Ok(Self { map })
}
/// 從多個 JSON 檔案創建同義詞擴展器(後面的檔案會覆蓋前面的)
pub fn from_files<P: AsRef<Path>>(paths: &[P]) -> Result<Self> {
let mut combined_map = HashMap::new();
for path in paths {
let content = fs::read_to_string(path)
.with_context(|| format!("Failed to read synonym file: {:?}", path.as_ref()))?;
let map: HashMap<String, Vec<String>> =
serde_json::from_str(&content).with_context(|| {
format!("Failed to parse synonym JSON from {:?}", path.as_ref())
})?;
// 合併映射,後面的檔案覆蓋前面的
for (key, synonyms) in map {
combined_map.insert(key, synonyms);
}
}
Ok(Self { map: combined_map })
}
/// 從內建預設資料創建(返回空映射,用戶可通過配置文件添加自定義同義詞)
pub fn from_default() -> Self {
Self::empty()
}
/// 獲取詞語的同義詞列表(如果存在)
pub fn get_synonyms(&self, word: &str) -> Option<&[String]> {
self.map.get(word).map(|v| v.as_slice())
}
/// 擴展查詢詞語:將詞語替換為 (詞語 OR 同義詞1 OR 同義詞2 ...)
/// 如果沒有同義詞,返回原詞語
pub fn expand_word(&self, word: &str) -> String {
match self.get_synonyms(word) {
Some(syns) if !syns.is_empty() => {
let mut parts = vec![word.to_string()];
parts.extend_from_slice(syns);
format!("({})", parts.join(" | "))
}
_ => word.to_string(),
}
}
/// 擴展整個查詢字符串(空格分隔的詞語)
pub fn expand_query(&self, query: &str) -> String {
query
.split_whitespace()
.map(|word| self.expand_word(word))
.collect::<Vec<_>>()
.join(" & ")
}
/// 對中文查詢進行智能擴展:先匹配已知同義詞,再對剩餘部分進行分詞
pub fn expand_chinese_query(&self, query: &str) -> String {
// 如果查詢很短,直接嘗試匹配整個查詢
if query.chars().count() <= 4 {
if let Some(syns) = self.get_synonyms(query) {
let mut parts = vec![query.to_string()];
parts.extend_from_slice(syns);
return format!("({})", parts.join(" | "));
}
}
// 嘗試在查詢中尋找已知的同義詞
let mut expanded_parts = Vec::new();
let mut remaining_query = query;
let mut found_synonym = false;
// 對同義詞鍵按長度降序排序(最長匹配優先)
let mut keys: Vec<&String> = self.map.keys().collect();
keys.sort_by_key(|b| std::cmp::Reverse(b.chars().count()));
// 貪婪匹配:尋找最長的同義詞匹配
while !remaining_query.is_empty() {
let mut matched = false;
for key in &keys {
if remaining_query.starts_with(*key) {
// 找到匹配的同義詞
expanded_parts.push(self.expand_word(key));
remaining_query = &remaining_query[key.len()..];
found_synonym = true;
matched = true;
break;
}
}
if !matched {
// 沒有找到同義詞,跳過第一個字符,繼續嘗試
let first_char_len = remaining_query.chars().next().map_or(0, |c| c.len_utf8());
if first_char_len > 0 {
let next_part = &remaining_query[..first_char_len];
expanded_parts.push(next_part.to_string());
remaining_query = &remaining_query[first_char_len..];
} else {
break;
}
}
}
if found_synonym {
// 如果有找到同義詞,使用擴展後的查詢
expanded_parts.join(" & ")
} else {
// 沒有找到同義詞,返回原查詢(稍後會進行分詞)
query.to_string()
}
}
/// 創建空的同義詞擴展器(無同義詞映射)
pub fn empty() -> Self {
Self {
map: HashMap::new(),
}
}
}
/// 全局同義詞擴展器(懶加載)
static SYNONYM_EXPANDER: Lazy<SynonymExpander> = Lazy::new(|| {
// 優先嘗試 MOMENTRY_SYNONYM_FILES逗號分隔的多個檔案
if let Ok(files_var) = env::var("MOMENTRY_SYNONYM_FILES") {
let file_paths: Vec<&str> = files_var
.split(',')
.map(|s| s.trim())
.filter(|s| !s.is_empty())
.collect();
if !file_paths.is_empty() {
match SynonymExpander::from_files(&file_paths) {
Ok(expander) => {
tracing::info!(
"Loaded synonym expander from {} files: {:?}",
file_paths.len(),
file_paths
);
return expander;
}
Err(e) => {
tracing::warn!(
"Failed to load synonym expander from files {:?}: {}",
file_paths,
e
);
// 繼續嘗試單一檔案或使用預設
}
}
}
}
// 回退到單一檔案 MOMENTRY_SYNONYM_FILE向下兼容
if let Ok(file_path) = env::var("MOMENTRY_SYNONYM_FILE") {
match SynonymExpander::from_file(&file_path) {
Ok(expander) => {
tracing::info!("Loaded synonym expander from {}", file_path);
expander
}
Err(e) => {
tracing::warn!("Failed to load synonym expander from {}: {}", file_path, e);
SynonymExpander::empty()
}
}
} else {
// 使用預設同義詞(示例)
SynonymExpander::from_default()
}
});
/// 獲取全局同義詞擴展器實例
pub fn global_synonym_expander() -> &'static SynonymExpander {
&SYNONYM_EXPANDER
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_expand_word() {
let mut map = HashMap::new();
map.insert(
"電腦".to_string(),
vec!["計算機".to_string(), "微机".to_string()],
);
map.insert(
"工作".to_string(),
vec!["任務".to_string(), "作業".to_string()],
);
let expander = SynonymExpander { map };
assert_eq!(expander.expand_word("電腦"), "(電腦 | 計算機 | 微机)");
assert_eq!(expander.expand_word("工作"), "(工作 | 任務 | 作業)");
assert_eq!(expander.expand_word("未知"), "未知");
}
#[test]
fn test_expand_query() {
let mut map = HashMap::new();
map.insert(
"電腦".to_string(),
vec!["計算機".to_string(), "微机".to_string()],
);
map.insert(
"工作".to_string(),
vec!["任務".to_string(), "作業".to_string()],
);
let expander = SynonymExpander { map };
assert_eq!(
expander.expand_query("電腦 工作"),
"(電腦 | 計算機 | 微机) & (工作 | 任務 | 作業)"
);
assert_eq!(expander.expand_query("單個詞"), "單個詞");
assert_eq!(expander.expand_query(""), "");
}
#[test]
fn test_from_files_empty() {
let paths: Vec<&str> = vec![];
let expander = SynonymExpander::from_files(&paths).unwrap();
assert!(expander.map.is_empty());
}
}

121
src/core/text/tokenizer.rs Normal file
View File

@@ -0,0 +1,121 @@
use jieba_rs::Jieba;
use once_cell::sync::Lazy;
static JIEBA: Lazy<Jieba> = Lazy::new(Jieba::new);
/// 檢查文本是否包含中文字符
/// 包括 CJK Unified Ideographs (U+4E00-U+9FFF) 和 Extension A (U+3400-U+4DBF)
pub fn contains_chinese(text: &str) -> bool {
text.chars()
.any(|c| ('\u{4e00}'..='\u{9fff}').contains(&c) || ('\u{3400}'..='\u{4dbf}').contains(&c))
}
/// 對中文文本進行分詞,並用空格連接分詞結果
/// 非中文文本保持不變
///
/// # 示例
/// ```
/// use momentry_core::core::text::tokenizer::tokenize_chinese_text;
///
/// assert_eq!(tokenize_chinese_text("這是一個測試"), "這 是 一 個 測 試");
/// assert_eq!(tokenize_chinese_text("Hello world"), "Hello world");
/// assert_eq!(tokenize_chinese_text("中文English混合"), "中文 English 混合");
/// ```
pub fn tokenize_chinese_text(text: &str) -> String {
if contains_chinese(text) {
// 使用精確模式分詞cut=false
let tokens = JIEBA.cut(text, false);
tokens.join(" ")
} else {
text.to_string()
}
}
/// 從 JSON 內容中提取文本並進行分詞
/// 支持兩種格式:
/// 1. content->'data'->>'text' (中文視頻格式)
/// 2. content->'text' (英文視頻格式)
pub fn extract_and_tokenize_text(content: &serde_json::Value) -> String {
let raw_text = content
.get("data")
.and_then(|data| data.get("text"))
.and_then(|v| v.as_str())
.or_else(|| content.get("text").and_then(|v| v.as_str()))
.unwrap_or("");
tokenize_chinese_text(raw_text)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_contains_chinese() {
assert!(contains_chinese("中文"));
assert!(contains_chinese("這是一個測試"));
assert!(contains_chinese("混合文本 English 中文"));
assert!(!contains_chinese("English only"));
assert!(!contains_chinese("123"));
assert!(!contains_chinese(""));
}
#[test]
fn test_tokenize_chinese_text() {
// 純中文
assert_eq!(tokenize_chinese_text("這是一個測試"), "這 是 一 個 測 試");
// 純英文
assert_eq!(tokenize_chinese_text("Hello world"), "Hello world");
// 中英混合
assert_eq!(
tokenize_chinese_text("中文English混合"),
"中文 English 混合"
);
// 空字符串
assert_eq!(tokenize_chinese_text(""), "");
// 數字和標點
assert_eq!(tokenize_chinese_text("測試123。"), "測 試 123 。");
}
#[test]
fn test_extract_and_tokenize_text() {
// 中文格式content->'data'->>'text'
let content1 = serde_json::json!({
"data": {
"text": "這是一個測試"
}
});
assert_eq!(extract_and_tokenize_text(&content1), "這 是 一 個 測 試");
// 英文格式content->'text'
let content2 = serde_json::json!({
"text": "Hello world"
});
assert_eq!(extract_and_tokenize_text(&content2), "Hello world");
// 混合格式:優先使用 data->text
let content3 = serde_json::json!({
"data": {
"text": "中文測試"
},
"text": "English text"
});
assert_eq!(extract_and_tokenize_text(&content3), "中文 測 試");
// 無文本
let content4 = serde_json::json!({});
assert_eq!(extract_and_tokenize_text(&content4), "");
// 非字符串文本
let content5 = serde_json::json!({
"data": {
"text": 123
}
});
assert_eq!(extract_and_tokenize_text(&content5), "");
}
}

40
src/core/tmdb/ingest.rs Normal file
View File

@@ -0,0 +1,40 @@
use anyhow::{Context, Result};
use serde::Deserialize;
use std::path::Path;
use tracing::{info, warn};
use crate::core::db::PostgresDb;
#[derive(Debug, Deserialize)]
pub struct CastEntry {
pub name: String,
pub role: String,
pub image: Option<String>,
}
/// Ingests TMDB cast data from the JSON file generated by `tmdb_cast_fetcher.py`
pub async fn ingest_cast(db: &PostgresDb, json_path: &str) -> Result<usize> {
let path = Path::new(json_path);
if !path.exists() {
return Err(anyhow::anyhow!("Cast JSON file not found: {}", json_path));
}
let content = std::fs::read_to_string(path)
.with_context(|| format!("Failed to read cast JSON: {}", json_path))?;
let cast_list: Vec<CastEntry> =
serde_json::from_str(&content).with_context(|| "Invalid cast JSON format")?;
let mut count = 0;
for entry in &cast_list {
match db.get_or_create_identity(&entry.name).await {
Ok(_talent) => {
info!("Ingested TMDB cast: {} as {}", entry.name, entry.role);
count += 1;
}
Err(e) => warn!("Failed to create talent '{}': {}", entry.name, e),
}
}
Ok(count)
}

1
src/core/tmdb/mod.rs Normal file
View File

@@ -0,0 +1 @@
pub mod ingest;

View File

@@ -0,0 +1,144 @@
use sqlx::PgPool;
use tokio::time::{sleep, Duration};
use tracing;
use uuid::Uuid;
use crate::core::chunk;
pub struct JobWorker {
pool: PgPool,
poll_interval: Duration,
}
impl JobWorker {
pub fn new(pool: PgPool, poll_interval_secs: u64) -> Self {
Self {
pool,
poll_interval: Duration::from_secs(poll_interval_secs),
}
}
pub async fn run(&self) {
tracing::info!(
"🤖 Job Worker started (Polling every {}s)",
self.poll_interval.as_secs()
);
loop {
match self.process_next_job().await {
Ok(has_work) => {
if !has_work {
// No work found, wait before polling again
sleep(self.poll_interval).await;
}
// If we processed a job, loop immediately to check for more
}
Err(e) => {
tracing::error!("❌ Job Worker error: {}", e);
sleep(Duration::from_secs(5)).await;
}
}
}
}
async fn process_next_job(&self) -> anyhow::Result<bool> {
// 1. Fetch a QUEUED job
// We use a transaction to ensure no two workers pick the same job (atomic update)
let job_row: Option<(String, String, String, String, String, i64)> = sqlx::query_as(
r#"
UPDATE dev.jobs
SET status = 'RUNNING', updated_at = NOW()
WHERE id = (
SELECT id FROM dev.jobs
WHERE status = 'QUEUED'
ORDER BY created_at ASC
LIMIT 1
FOR UPDATE SKIP LOCKED
)
RETURNING id::text, asset_uuid, rule, status, processor_list, total_frames
"#,
)
.fetch_optional(&self.pool)
.await?;
if let Some((job_id, asset_uuid, rule, _status, _processors, total_frames)) = job_row {
let job_uuid =
Uuid::parse_str(&job_id).map_err(|e| anyhow::anyhow!("Invalid job UUID: {}", e))?;
tracing::info!(
"🚀 Processing Job {} for Asset {} (Rule: {})",
job_id,
asset_uuid,
rule
);
// 2. Execute Logic based on Rule
let result = match rule.as_str() {
"rule1" => {
let fps = self.get_asset_fps(&asset_uuid).await?;
chunk::rule1_ingest::ingest_rule1(&self.pool, &asset_uuid, fps).await
}
_ => {
tracing::warn!("Unknown rule type: {}", rule);
Ok(0)
}
};
// 3. Update Job Status
match result {
Ok(chunk_count) => {
tracing::info!(
"✅ Job {} completed. Processed {} items.",
job_id,
chunk_count
);
sqlx::query!(
"UPDATE dev.jobs SET status = 'COMPLETED', processed_frames = total_frames, updated_at = NOW() WHERE id = $1",
job_uuid
)
.execute(&self.pool)
.await?;
sqlx::query!(
"UPDATE dev.videos SET processing_status = 'COMPLETED' WHERE uuid = $1",
asset_uuid
)
.execute(&self.pool)
.await?;
}
Err(e) => {
tracing::error!("❌ Job {} failed: {}", job_id, e);
let err_msg = e.to_string();
let safe_msg = if err_msg.len() > 500 {
&err_msg[..500]
} else {
&err_msg
};
sqlx::query!(
"UPDATE dev.jobs SET status = 'FAILED', error_message = $2, updated_at = NOW() WHERE id = $1",
job_uuid,
safe_msg
)
.execute(&self.pool)
.await?;
}
}
return Ok(true); // Processed a job
}
Ok(false) // No job found
}
async fn get_asset_fps(&self, uuid: &str) -> anyhow::Result<f64> {
let fps: Option<f64> =
sqlx::query_scalar("SELECT (metadata->>'fps')::float FROM dev.videos WHERE uuid = $1")
.bind(uuid)
.fetch_optional(&self.pool)
.await?;
// Fallback to 29.97 if not found
Ok(fps.unwrap_or(29.97))
}
}

2
src/core/worker/mod.rs Normal file
View File

@@ -0,0 +1,2 @@
pub mod job_runner;
pub use job_runner::JobWorker;

View File

@@ -915,6 +915,7 @@ async fn main() -> Result<()> {
user_id: None,
job_id: None,
created_at: String::new(),
registration_time: None,
};
let video_id = db.register_video(&record).await?;

View File

@@ -924,6 +924,7 @@ async fn main() -> Result<()> {
user_id: None,
job_id: None,
created_at: String::new(),
registration_time: None,
};
let video_id = db.register_video(&record).await?;
@@ -2373,20 +2374,25 @@ async fn main() -> Result<()> {
target
);
for chunk in sentence_chunks {
println!("Starting to process {} chunks...", sentence_chunks.len());
for (i, chunk) in sentence_chunks.iter().enumerate() {
if i < 3 {
println!("Processing chunk {}/{}: {}", i+1, sentence_chunks.len(), chunk.chunk_id);
}
let text = chunk
.content
.get("data")
.and_then(|data| data.get("text"))
.get("text")
.and_then(|v| v.as_str())
.or_else(|| chunk.content.get("data").and_then(|data| data.get("text")).and_then(|v| v.as_str()))
.or(chunk.text_content.as_deref())
.unwrap_or("");
eprintln!("Embedding chunk {}/{}: {} (text len: {})...", i+1, sentence_chunks.len(), chunk.chunk_id, text.len());
if text.is_empty() {
continue;
}
print!("Embedding chunk {}... ", chunk.chunk_id);
match embedder.embed_document(text).await {
Ok(vector) => {
let vector_id = format!("{}_{}", chunk.uuid, chunk.chunk_id);
@@ -2420,10 +2426,12 @@ async fn main() -> Result<()> {
}
stored_count += 1;
println!("done ({} dims)", vector.len());
if stored_count % 100 == 0 || stored_count <= 3 {
println!("Stored {}/1867 vectors", stored_count);
}
}
Err(e) => {
println!("failed: {}", e);
eprintln!("embed_document error for {}: {}", chunk.chunk_id, e);
}
}
}

10
src/test_embed.rs Normal file
View File

@@ -0,0 +1,10 @@
use momentry_core::core::embedding::comic_embed::Embedder;
#[tokio::main]
async fn main() {
let embedder = Embedder::new("nomic-embed-text-v2-moe:latest".to_string());
match embedder.embed_document("test embedding").await {
Ok(vector) => println!("Success! Vector length: {}", vector.len()),
Err(e) => println!("Error: {}", e),
}
}