feat: backup architecture docs, source code, and scripts
This commit is contained in:
936
src/api/face_recognition.rs
Normal file
936
src/api/face_recognition.rs
Normal file
@@ -0,0 +1,936 @@
|
||||
use axum::{
|
||||
extract::{Multipart, Path, Query, State},
|
||||
http::StatusCode,
|
||||
response::Json,
|
||||
routing::{get, post},
|
||||
Router,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::core::db::{schema, Database, PostgresDb};
|
||||
use crate::core::processor::face_recognition::{
|
||||
process_face_recognition, register_face, FaceRecognitionResult, FaceRegistrationResult,
|
||||
};
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct FaceRecognitionRequest {
|
||||
pub video_uuid: String,
|
||||
pub enable_recognition: Option<bool>,
|
||||
pub enable_tracking: Option<bool>,
|
||||
pub enable_clustering: Option<bool>,
|
||||
pub database_path: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct FaceRecognitionResponse {
|
||||
pub success: bool,
|
||||
pub message: String,
|
||||
pub result: Option<FaceRecognitionResult>,
|
||||
pub processing_id: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct FaceRegistrationRequest {
|
||||
pub video_uuid: String,
|
||||
pub name: String,
|
||||
pub metadata: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct FaceRegistrationApiResponse {
|
||||
pub success: bool,
|
||||
pub message: String,
|
||||
pub result: Option<FaceRegistrationResult>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct FaceSearchRequest {
|
||||
pub video_uuid: String,
|
||||
pub embedding: Vec<f32>,
|
||||
pub similarity_threshold: Option<f64>,
|
||||
pub limit: Option<i32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct FaceSearchResponse {
|
||||
pub success: bool,
|
||||
pub message: String,
|
||||
pub results: Vec<FaceSearchResult>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct FaceSearchResult {
|
||||
pub face_id: String,
|
||||
pub name: Option<String>,
|
||||
pub similarity: f64,
|
||||
pub attributes: Option<serde_json::Value>,
|
||||
pub metadata: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct FaceListQuery {
|
||||
pub video_uuid: String,
|
||||
pub page: Option<usize>,
|
||||
pub page_size: Option<usize>,
|
||||
pub active_only: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct FaceListResponse {
|
||||
pub success: bool,
|
||||
pub message: String,
|
||||
pub faces: Vec<FaceListItem>,
|
||||
pub count: i64,
|
||||
pub page: usize,
|
||||
pub page_size: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct FaceListItem {
|
||||
pub face_id: String,
|
||||
pub name: Option<String>,
|
||||
pub created_at: String,
|
||||
pub updated_at: String,
|
||||
pub is_active: bool,
|
||||
pub metadata: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
pub fn face_recognition_routes() -> Router<crate::api::server::AppState> {
|
||||
Router::new()
|
||||
.route("/api/v1/face/recognize", post(recognize_faces))
|
||||
.route("/api/v1/face/register", post(register_face_api))
|
||||
.route("/api/v1/face/search", post(search_faces))
|
||||
.route("/api/v1/face/list", get(list_faces))
|
||||
.route("/api/v1/face/:face_id", get(get_face_details))
|
||||
.route("/api/v1/face/:face_id", axum::routing::delete(delete_face))
|
||||
.route(
|
||||
"/api/v1/face/results/:video_uuid",
|
||||
get(get_recognition_results),
|
||||
)
|
||||
}
|
||||
|
||||
async fn recognize_faces(
|
||||
State(_state): State<crate::api::server::AppState>,
|
||||
Json(request): Json<FaceRecognitionRequest>,
|
||||
) -> Result<Json<FaceRecognitionResponse>, (StatusCode, String)> {
|
||||
let processing_id = Uuid::new_v4().to_string();
|
||||
|
||||
tracing::info!(
|
||||
"[FACE_RECOGNITION] Starting recognition for video: {}, processing_id: {}",
|
||||
request.video_uuid,
|
||||
processing_id
|
||||
);
|
||||
|
||||
// Get video path from database
|
||||
let db = match PostgresDb::init().await {
|
||||
Ok(db) => db,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to connect to database: {}", e),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let video_record = match db.get_video_by_uuid(&request.video_uuid).await {
|
||||
Ok(Some(record)) => record,
|
||||
Ok(None) => {
|
||||
return Err((
|
||||
StatusCode::NOT_FOUND,
|
||||
format!("Video not found: {}", request.video_uuid),
|
||||
))
|
||||
}
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to fetch video: {}", e),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let video_path = video_record.file_path;
|
||||
let output_path = format!(
|
||||
"{}/face_recognition_{}.json",
|
||||
crate::core::config::OUTPUT_DIR.as_str(),
|
||||
processing_id
|
||||
);
|
||||
|
||||
// Process face recognition
|
||||
let result = match process_face_recognition(
|
||||
&video_path,
|
||||
&output_path,
|
||||
Some(&processing_id),
|
||||
request.enable_recognition.unwrap_or(true),
|
||||
request.enable_tracking.unwrap_or(true),
|
||||
request.enable_clustering.unwrap_or(true),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(result) => result,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Face recognition failed: {}", e),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
// Store results in database
|
||||
if let Err(e) = store_recognition_results(&db, &request.video_uuid, &result).await {
|
||||
tracing::warn!("Failed to store recognition results: {}", e);
|
||||
}
|
||||
|
||||
Ok(Json(FaceRecognitionResponse {
|
||||
success: true,
|
||||
message: format!("Face recognition completed for {}", request.video_uuid),
|
||||
result: Some(result),
|
||||
processing_id,
|
||||
}))
|
||||
}
|
||||
|
||||
async fn register_face_api(
|
||||
State(_state): State<crate::api::server::AppState>,
|
||||
mut multipart: Multipart,
|
||||
) -> Result<Json<FaceRegistrationApiResponse>, (StatusCode, String)> {
|
||||
let mut image_path: Option<String> = None;
|
||||
let mut name: Option<String> = None;
|
||||
let mut metadata: Option<serde_json::Value> = None;
|
||||
|
||||
// Parse multipart form data
|
||||
while let Some(field) = multipart.next_field().await.map_err(|e| {
|
||||
(
|
||||
StatusCode::BAD_REQUEST,
|
||||
format!("Failed to parse form data: {}", e),
|
||||
)
|
||||
})? {
|
||||
let field_name = field.name().unwrap_or("").to_string();
|
||||
|
||||
match field_name.as_str() {
|
||||
"image" => {
|
||||
// Save uploaded image
|
||||
let file_name = format!("face_registration_{}.jpg", Uuid::new_v4());
|
||||
let file_path = format!("/tmp/{}", file_name);
|
||||
|
||||
let data = field.bytes().await.map_err(|e| {
|
||||
(
|
||||
StatusCode::BAD_REQUEST,
|
||||
format!("Failed to read image data: {}", e),
|
||||
)
|
||||
})?;
|
||||
|
||||
tokio::fs::write(&file_path, &data).await.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to save image: {}", e),
|
||||
)
|
||||
})?;
|
||||
|
||||
image_path = Some(file_path);
|
||||
}
|
||||
"name" => {
|
||||
let value = field.text().await.map_err(|e| {
|
||||
(
|
||||
StatusCode::BAD_REQUEST,
|
||||
format!("Failed to read name: {}", e),
|
||||
)
|
||||
})?;
|
||||
name = Some(value);
|
||||
}
|
||||
"metadata" => {
|
||||
let value = field.text().await.map_err(|e| {
|
||||
(
|
||||
StatusCode::BAD_REQUEST,
|
||||
format!("Failed to read metadata: {}", e),
|
||||
)
|
||||
})?;
|
||||
metadata = Some(serde_json::from_str(&value).map_err(|e| {
|
||||
(
|
||||
StatusCode::BAD_REQUEST,
|
||||
format!("Invalid JSON metadata: {}", e),
|
||||
)
|
||||
})?);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
// Validate required fields
|
||||
let image_path =
|
||||
image_path.ok_or((StatusCode::BAD_REQUEST, "Image is required".to_string()))?;
|
||||
|
||||
let name = name.ok_or((StatusCode::BAD_REQUEST, "Name is required".to_string()))?;
|
||||
|
||||
// Register face
|
||||
let result = match register_face(&image_path, &name, metadata.clone()).await {
|
||||
Ok(result) => result,
|
||||
Err(e) => {
|
||||
// Clean up temporary file
|
||||
let _ = tokio::fs::remove_file(&image_path).await;
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Face registration failed: {}", e),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
// Clean up temporary file
|
||||
let _ = tokio::fs::remove_file(&image_path).await;
|
||||
|
||||
// Store in PostgreSQL face_identities table
|
||||
if result.success && !result.embedding.is_empty() {
|
||||
let db = match PostgresDb::init().await {
|
||||
Ok(db) => db,
|
||||
Err(e) => {
|
||||
tracing::warn!("[FACE_REGISTRATION] Failed to connect to DB: {}", e);
|
||||
// Return success even if DB write fails (embedding is still in JSON output)
|
||||
return Ok(Json(FaceRegistrationApiResponse {
|
||||
success: result.success,
|
||||
message: format!("{} (Warning: DB write failed: {})", result.message, e),
|
||||
result: Some(result),
|
||||
}));
|
||||
}
|
||||
};
|
||||
|
||||
// Convert embedding to PostgreSQL vector format
|
||||
let embedding_str = format!(
|
||||
"[{}]",
|
||||
result
|
||||
.embedding
|
||||
.iter()
|
||||
.map(|v| v.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(",")
|
||||
);
|
||||
|
||||
// Insert into face_identities
|
||||
let face_identities_table = schema::table_name("face_identities");
|
||||
let attrs_json =
|
||||
serde_json::to_string(&result.attributes).unwrap_or_else(|_| "{}".to_string());
|
||||
// Use public.vector type to work across schemas
|
||||
let vector_type = if schema::SCHEMA_PREFIX.as_str().is_empty() {
|
||||
"vector".to_string()
|
||||
} else {
|
||||
"public.vector".to_string()
|
||||
};
|
||||
let insert_query = format!(
|
||||
r#"
|
||||
INSERT INTO {} (face_id, name, embedding, attributes, metadata, is_active)
|
||||
VALUES ($1, $2, $3::{}, $4::jsonb, $5, TRUE)
|
||||
ON CONFLICT (face_id) DO UPDATE SET
|
||||
name = EXCLUDED.name,
|
||||
embedding = EXCLUDED.embedding,
|
||||
attributes = EXCLUDED.attributes,
|
||||
metadata = COALESCE(EXCLUDED.metadata, {}.metadata),
|
||||
updated_at = CURRENT_TIMESTAMP,
|
||||
is_active = TRUE
|
||||
"#,
|
||||
face_identities_table, vector_type, face_identities_table
|
||||
);
|
||||
|
||||
match sqlx::query(&insert_query)
|
||||
.bind(&result.face_id)
|
||||
.bind(&name)
|
||||
.bind(&embedding_str)
|
||||
.bind(&attrs_json)
|
||||
.bind(&metadata.unwrap_or(serde_json::json!({})))
|
||||
.execute(db.pool())
|
||||
.await
|
||||
{
|
||||
Ok(_) => {
|
||||
tracing::info!(
|
||||
"[FACE_REGISTRATION] Stored face '{}' (face_id={}) in DB",
|
||||
name,
|
||||
result.face_id
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("[FACE_REGISTRATION] Failed to store face in DB: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Json(FaceRegistrationApiResponse {
|
||||
success: result.success,
|
||||
message: result.message.clone(),
|
||||
result: Some(result),
|
||||
}))
|
||||
}
|
||||
|
||||
async fn search_faces(
|
||||
State(_state): State<crate::api::server::AppState>,
|
||||
Json(request): Json<FaceSearchRequest>,
|
||||
) -> Result<Json<FaceSearchResponse>, (StatusCode, String)> {
|
||||
let db = match PostgresDb::init().await {
|
||||
Ok(db) => db,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to connect to database: {}", e),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
// Convert embedding to PostgreSQL vector format
|
||||
let embedding_str = format!(
|
||||
"[{}]",
|
||||
request
|
||||
.embedding
|
||||
.iter()
|
||||
.map(|v| v.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(",")
|
||||
);
|
||||
|
||||
let similarity_threshold = request.similarity_threshold.unwrap_or(0.6);
|
||||
let limit = request.limit.unwrap_or(10);
|
||||
|
||||
// Search for similar faces
|
||||
let face_identities_table = schema::table_name("face_identities");
|
||||
let vector_type = if schema::SCHEMA_PREFIX.as_str().is_empty() {
|
||||
"vector".to_string()
|
||||
} else {
|
||||
"public.vector".to_string()
|
||||
};
|
||||
let query = format!(
|
||||
r#"
|
||||
SELECT
|
||||
face_id,
|
||||
name,
|
||||
1 - (embedding <=> $1::{}) as similarity,
|
||||
attributes,
|
||||
metadata
|
||||
FROM {}
|
||||
WHERE is_active = TRUE
|
||||
AND embedding IS NOT NULL
|
||||
AND 1 - (embedding <=> $1::{}) >= $2
|
||||
ORDER BY embedding <=> $1::{}
|
||||
LIMIT $3
|
||||
"#,
|
||||
vector_type, face_identities_table, vector_type, vector_type
|
||||
);
|
||||
|
||||
let results: Vec<FaceSearchResult> = match sqlx::query_as::<
|
||||
_,
|
||||
(
|
||||
String,
|
||||
Option<String>,
|
||||
f64,
|
||||
Option<serde_json::Value>,
|
||||
Option<serde_json::Value>,
|
||||
),
|
||||
>(query.as_str())
|
||||
.bind(&embedding_str)
|
||||
.bind(similarity_threshold)
|
||||
.bind(limit)
|
||||
.fetch_all(db.pool())
|
||||
.await
|
||||
{
|
||||
Ok(rows) => rows
|
||||
.into_iter()
|
||||
.map(
|
||||
|(face_id, name, similarity, attributes, metadata)| FaceSearchResult {
|
||||
face_id,
|
||||
name,
|
||||
similarity,
|
||||
attributes,
|
||||
metadata,
|
||||
},
|
||||
)
|
||||
.collect(),
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to search faces: {}", e),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Json(FaceSearchResponse {
|
||||
success: true,
|
||||
message: format!("Found {} similar faces", results.len()),
|
||||
results,
|
||||
}))
|
||||
}
|
||||
|
||||
async fn list_faces(
|
||||
State(_state): State<crate::api::server::AppState>,
|
||||
Query(query): Query<FaceListQuery>,
|
||||
) -> Result<Json<FaceListResponse>, (StatusCode, String)> {
|
||||
let db = match PostgresDb::init().await {
|
||||
Ok(db) => db,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to connect to database: {}", e),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let page = query.page.unwrap_or(1);
|
||||
let page_size = query.page_size.unwrap_or(20);
|
||||
let offset = ((page - 1) as i64) * (page_size as i64);
|
||||
let active_only = query.active_only.unwrap_or(true);
|
||||
|
||||
// Build query
|
||||
let mut where_clause = "WHERE 1=1".to_string();
|
||||
if active_only {
|
||||
where_clause.push_str(" AND is_active = TRUE");
|
||||
}
|
||||
|
||||
let face_identities_table = schema::table_name("face_identities");
|
||||
let count_query = format!(
|
||||
"SELECT COUNT(*) FROM {} {}",
|
||||
face_identities_table, where_clause
|
||||
);
|
||||
let list_query = format!(
|
||||
"SELECT face_id, name, created_at, updated_at, is_active, metadata FROM {} {} ORDER BY created_at DESC LIMIT $1 OFFSET $2",
|
||||
face_identities_table, where_clause
|
||||
);
|
||||
|
||||
// Get total count
|
||||
let total: i64 = match sqlx::query_scalar(&count_query).fetch_one(db.pool()).await {
|
||||
Ok(count) => count,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to count faces: {}", e),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
// Get face list
|
||||
let faces: Vec<FaceListItem> = match sqlx::query_as::<
|
||||
_,
|
||||
(
|
||||
String,
|
||||
Option<String>,
|
||||
chrono::DateTime<chrono::Utc>,
|
||||
chrono::DateTime<chrono::Utc>,
|
||||
bool,
|
||||
Option<serde_json::Value>,
|
||||
),
|
||||
>(&list_query)
|
||||
.bind(page_size as i32)
|
||||
.bind(offset)
|
||||
.fetch_all(db.pool())
|
||||
.await
|
||||
{
|
||||
Ok(rows) => rows
|
||||
.into_iter()
|
||||
.map(
|
||||
|(face_id, name, created_at, updated_at, is_active, metadata)| FaceListItem {
|
||||
face_id,
|
||||
name,
|
||||
created_at: created_at.to_rfc3339(),
|
||||
updated_at: updated_at.to_rfc3339(),
|
||||
is_active,
|
||||
metadata,
|
||||
},
|
||||
)
|
||||
.collect(),
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to list faces: {}", e),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Json(FaceListResponse {
|
||||
success: true,
|
||||
message: format!("Found {} faces", total),
|
||||
faces,
|
||||
count: total,
|
||||
page,
|
||||
page_size,
|
||||
}))
|
||||
}
|
||||
|
||||
async fn get_face_details(
|
||||
State(_state): State<crate::api::server::AppState>,
|
||||
Path(face_id): Path<String>,
|
||||
) -> Result<Json<serde_json::Value>, (StatusCode, String)> {
|
||||
let db = match PostgresDb::init().await {
|
||||
Ok(db) => db,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to connect to database: {}", e),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let face_identities_table = schema::table_name("face_identities");
|
||||
let query = format!(
|
||||
r#"
|
||||
SELECT
|
||||
face_id,
|
||||
name,
|
||||
embedding,
|
||||
attributes,
|
||||
metadata,
|
||||
created_at,
|
||||
updated_at,
|
||||
is_active
|
||||
FROM {}
|
||||
WHERE face_id = $1
|
||||
"#,
|
||||
face_identities_table
|
||||
);
|
||||
|
||||
let face: Option<(
|
||||
String,
|
||||
Option<String>,
|
||||
Option<String>,
|
||||
Option<serde_json::Value>,
|
||||
Option<serde_json::Value>,
|
||||
chrono::DateTime<chrono::Utc>,
|
||||
chrono::DateTime<chrono::Utc>,
|
||||
bool,
|
||||
)> = match sqlx::query_as(&query)
|
||||
.bind(&face_id)
|
||||
.fetch_optional(db.pool())
|
||||
.await
|
||||
{
|
||||
Ok(face) => face,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to fetch face details: {}", e),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
match face {
|
||||
Some((
|
||||
face_id,
|
||||
name,
|
||||
embedding,
|
||||
attributes,
|
||||
metadata,
|
||||
created_at,
|
||||
updated_at,
|
||||
is_active,
|
||||
)) => {
|
||||
let response = serde_json::json!({
|
||||
"success": true,
|
||||
"face_id": face_id,
|
||||
"name": name,
|
||||
"has_embedding": embedding.is_some(),
|
||||
"attributes": attributes,
|
||||
"metadata": metadata,
|
||||
"created_at": created_at.to_rfc3339(),
|
||||
"updated_at": updated_at.to_rfc3339(),
|
||||
"is_active": is_active
|
||||
});
|
||||
|
||||
Ok(Json(response))
|
||||
}
|
||||
None => Err((
|
||||
StatusCode::NOT_FOUND,
|
||||
format!("Face not found: {}", face_id),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn delete_face(
|
||||
State(_state): State<crate::api::server::AppState>,
|
||||
Path(face_id): Path<String>,
|
||||
) -> Result<Json<serde_json::Value>, (StatusCode, String)> {
|
||||
let db = match PostgresDb::init().await {
|
||||
Ok(db) => db,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to connect to database: {}", e),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
// Soft delete by marking as inactive
|
||||
let face_identities_table = schema::table_name("face_identities");
|
||||
let query = format!(
|
||||
r#"
|
||||
UPDATE {}
|
||||
SET is_active = FALSE, updated_at = CURRENT_TIMESTAMP
|
||||
WHERE face_id = $1 AND is_active = TRUE
|
||||
RETURNING face_id, name
|
||||
"#,
|
||||
face_identities_table
|
||||
);
|
||||
|
||||
let deleted: Option<(String, Option<String>)> = match sqlx::query_as(&query)
|
||||
.bind(&face_id)
|
||||
.fetch_optional(db.pool())
|
||||
.await
|
||||
{
|
||||
Ok(result) => result,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to delete face: {}", e),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
match deleted {
|
||||
Some((deleted_id, name)) => {
|
||||
let response = serde_json::json!({
|
||||
"success": true,
|
||||
"message": format!("Face '{}' deleted successfully", name.clone().unwrap_or_else(|| deleted_id.clone())),
|
||||
"face_id": deleted_id.clone()
|
||||
});
|
||||
|
||||
Ok(Json(response))
|
||||
}
|
||||
None => Err((
|
||||
StatusCode::NOT_FOUND,
|
||||
format!("Face not found or already deleted: {}", face_id),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_recognition_results(
|
||||
State(_state): State<crate::api::server::AppState>,
|
||||
Path(video_uuid): Path<String>,
|
||||
) -> Result<Json<serde_json::Value>, (StatusCode, String)> {
|
||||
let db = match PostgresDb::init().await {
|
||||
Ok(db) => db,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to connect to database: {}", e),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let query = r#"
|
||||
SELECT
|
||||
video_uuid,
|
||||
frame_count,
|
||||
fps,
|
||||
total_faces,
|
||||
recognized_faces,
|
||||
clusters_count,
|
||||
result_data,
|
||||
processing_time_secs,
|
||||
created_at
|
||||
FROM face_recognition_results
|
||||
WHERE video_uuid = $1
|
||||
ORDER BY created_at DESC
|
||||
LIMIT 1
|
||||
"#;
|
||||
|
||||
let result: Option<(
|
||||
String,
|
||||
i64,
|
||||
f64,
|
||||
i32,
|
||||
i32,
|
||||
i32,
|
||||
serde_json::Value,
|
||||
Option<f64>,
|
||||
chrono::DateTime<chrono::Utc>,
|
||||
)> = match sqlx::query_as(query)
|
||||
.bind(&video_uuid)
|
||||
.fetch_optional(db.pool())
|
||||
.await
|
||||
{
|
||||
Ok(result) => result,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to fetch recognition results: {}", e),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
match result {
|
||||
Some((
|
||||
video_uuid,
|
||||
frame_count,
|
||||
fps,
|
||||
total_faces,
|
||||
recognized_faces,
|
||||
clusters_count,
|
||||
result_data,
|
||||
processing_time_secs,
|
||||
created_at,
|
||||
)) => {
|
||||
let response = serde_json::json!({
|
||||
"success": true,
|
||||
"video_uuid": video_uuid,
|
||||
"frame_count": frame_count,
|
||||
"fps": fps,
|
||||
"total_faces": total_faces,
|
||||
"recognized_faces": recognized_faces,
|
||||
"clusters_count": clusters_count,
|
||||
"result_data": result_data,
|
||||
"processing_time_secs": processing_time_secs,
|
||||
"created_at": created_at.to_rfc3339()
|
||||
});
|
||||
|
||||
Ok(Json(response))
|
||||
}
|
||||
None => Err((
|
||||
StatusCode::NOT_FOUND,
|
||||
format!("No recognition results found for video: {}", video_uuid),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn store_recognition_results(
|
||||
db: &PostgresDb,
|
||||
video_uuid: &str,
|
||||
result: &FaceRecognitionResult,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
let total_faces = result.frames.iter().map(|f| f.faces.len()).sum::<usize>();
|
||||
let recognized_faces = result
|
||||
.frames
|
||||
.iter()
|
||||
.flat_map(|f| &f.faces)
|
||||
.filter(|face| face.identity.is_some())
|
||||
.count();
|
||||
|
||||
let query = r#"
|
||||
INSERT INTO face_recognition_results (
|
||||
video_uuid,
|
||||
frame_count,
|
||||
fps,
|
||||
total_faces,
|
||||
recognized_faces,
|
||||
clusters_count,
|
||||
result_data
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
ON CONFLICT (video_uuid) DO UPDATE SET
|
||||
frame_count = EXCLUDED.frame_count,
|
||||
fps = EXCLUDED.fps,
|
||||
total_faces = EXCLUDED.total_faces,
|
||||
recognized_faces = EXCLUDED.recognized_faces,
|
||||
clusters_count = EXCLUDED.clusters_count,
|
||||
result_data = EXCLUDED.result_data,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
"#;
|
||||
|
||||
sqlx::query(query)
|
||||
.bind(video_uuid)
|
||||
.bind(result.frame_count as i64)
|
||||
.bind(result.fps)
|
||||
.bind(total_faces as i32)
|
||||
.bind(recognized_faces as i32)
|
||||
.bind(result.face_clusters.len() as i32)
|
||||
.bind(serde_json::to_value(result)?)
|
||||
.execute(db.pool())
|
||||
.await?;
|
||||
|
||||
// Store individual face detections
|
||||
for frame in &result.frames {
|
||||
for face in &frame.faces {
|
||||
if let Some(embedding) = &face.embedding {
|
||||
let embedding_str = format!(
|
||||
"[{}]",
|
||||
embedding
|
||||
.iter()
|
||||
.map(|v| v.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(",")
|
||||
);
|
||||
|
||||
let insert_query = r#"
|
||||
INSERT INTO face_detections (
|
||||
video_uuid,
|
||||
frame_number,
|
||||
timestamp_secs,
|
||||
face_id,
|
||||
x,
|
||||
y,
|
||||
width,
|
||||
height,
|
||||
confidence,
|
||||
embedding,
|
||||
attributes,
|
||||
identity_confidence,
|
||||
cluster_id
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10::vector, $11, $12, $13)
|
||||
ON CONFLICT (video_uuid, frame_number, x, y, width, height) DO UPDATE SET
|
||||
face_id = EXCLUDED.face_id,
|
||||
confidence = EXCLUDED.confidence,
|
||||
embedding = EXCLUDED.embedding,
|
||||
attributes = EXCLUDED.attributes,
|
||||
identity_confidence = EXCLUDED.identity_confidence,
|
||||
cluster_id = EXCLUDED.cluster_id
|
||||
"#;
|
||||
|
||||
let identity_confidence = face.identity.as_ref().map(|id| id.confidence as f64);
|
||||
let cluster_id = result
|
||||
.face_clusters
|
||||
.iter()
|
||||
.find(|c| {
|
||||
c.face_ids
|
||||
.contains(&face.face_id.clone().unwrap_or_default())
|
||||
})
|
||||
.map(|c| c.cluster_id.clone());
|
||||
|
||||
sqlx::query(insert_query)
|
||||
.bind(video_uuid)
|
||||
.bind(frame.frame as i64)
|
||||
.bind(frame.timestamp)
|
||||
.bind(face.face_id.as_deref())
|
||||
.bind(face.x)
|
||||
.bind(face.y)
|
||||
.bind(face.width)
|
||||
.bind(face.height)
|
||||
.bind(face.confidence as f64)
|
||||
.bind(&embedding_str)
|
||||
.bind(serde_json::to_value(&face.attributes)?)
|
||||
.bind(identity_confidence)
|
||||
.bind(cluster_id)
|
||||
.execute(db.pool())
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Store face clusters
|
||||
for cluster in &result.face_clusters {
|
||||
let centroid = &cluster.centroid;
|
||||
let centroid_str = format!(
|
||||
"[{}]",
|
||||
centroid
|
||||
.iter()
|
||||
.map(|v: &f32| v.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(",")
|
||||
);
|
||||
|
||||
let cluster_query = r#"
|
||||
INSERT INTO face_clusters (
|
||||
cluster_id,
|
||||
video_uuid,
|
||||
centroid,
|
||||
size,
|
||||
representative_face_id,
|
||||
metadata
|
||||
) VALUES ($1, $2, $3::vector, $4, $5, $6)
|
||||
ON CONFLICT (cluster_id) DO UPDATE SET
|
||||
centroid = EXCLUDED.centroid,
|
||||
size = EXCLUDED.size,
|
||||
representative_face_id = EXCLUDED.representative_face_id,
|
||||
metadata = EXCLUDED.metadata
|
||||
"#;
|
||||
|
||||
sqlx::query(cluster_query)
|
||||
.bind(&cluster.cluster_id)
|
||||
.bind(video_uuid)
|
||||
.bind(¢roid_str)
|
||||
.bind(cluster.size as i32)
|
||||
.bind(cluster.representative_face_id.as_deref())
|
||||
.bind(&cluster.metadata)
|
||||
.execute(db.pool())
|
||||
.await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
288
src/api/identities.rs
Normal file
288
src/api/identities.rs
Normal file
@@ -0,0 +1,288 @@
|
||||
use axum::{
|
||||
extract::{Query, State},
|
||||
http::StatusCode,
|
||||
response::Json,
|
||||
routing::{get, post},
|
||||
Router,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::core::db::{schema, Database, PostgresDb};
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct RegisterFromPersonRequest {
|
||||
pub video_uuid: String,
|
||||
pub person_id: String,
|
||||
pub identity_name: String,
|
||||
pub metadata: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct RegisterFromPersonResponse {
|
||||
pub success: bool,
|
||||
pub message: String,
|
||||
pub identity_id: i32,
|
||||
pub identity_name: String,
|
||||
pub person_id: String,
|
||||
}
|
||||
|
||||
pub fn identity_routes() -> Router<crate::api::server::AppState> {
|
||||
Router::new()
|
||||
.route("/api/v1/identities/from-person", post(register_from_person))
|
||||
.route("/api/v1/identities", get(list_identities))
|
||||
}
|
||||
|
||||
/// Register a Global Identity from a specific Person in a video.
|
||||
/// This creates/updates the Identity record, links the Person to the Identity,
|
||||
/// and updates the Person's name to match the Identity.
|
||||
async fn register_from_person(
|
||||
State(_state): State<crate::api::server::AppState>,
|
||||
Json(req): Json<RegisterFromPersonRequest>,
|
||||
) -> Result<Json<RegisterFromPersonResponse>, (StatusCode, String)> {
|
||||
let db = match PostgresDb::init().await {
|
||||
Ok(db) => db,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("DB error: {}", e),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let mut tx = match db.pool().begin().await {
|
||||
Ok(tx) => tx,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Tx error: {}", e),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
// 1. Check if Person exists
|
||||
let person_query =
|
||||
"SELECT id, name FROM person_identities WHERE person_id = $1 AND video_uuid = $2";
|
||||
let person: Option<(i32, Option<String>)> = match sqlx::query_as(person_query)
|
||||
.bind(&req.person_id)
|
||||
.bind(&req.video_uuid)
|
||||
.fetch_optional(&mut *tx)
|
||||
.await
|
||||
{
|
||||
Ok(p) => p,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Query error: {}", e),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let (person_db_id, _old_name) = match person {
|
||||
Some(p) => p,
|
||||
None => {
|
||||
return Err((
|
||||
StatusCode::NOT_FOUND,
|
||||
format!(
|
||||
"Person '{}' not found in video '{}'",
|
||||
req.person_id, req.video_uuid
|
||||
),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
// 2. Check if Identity exists
|
||||
let identity_query = "SELECT id FROM identities WHERE name = $1";
|
||||
let identity_id: Option<i32> = match sqlx::query_scalar(identity_query)
|
||||
.bind(&req.identity_name)
|
||||
.fetch_optional(&mut *tx)
|
||||
.await
|
||||
{
|
||||
Ok(id) => id,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Query error: {}", e),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let final_identity_id = if let Some(id) = identity_id {
|
||||
id
|
||||
} else {
|
||||
// Create new Identity
|
||||
let meta_json = req.metadata.clone().unwrap_or(serde_json::json!({}));
|
||||
|
||||
let new_id: i32 = match sqlx::query_scalar(
|
||||
r#"
|
||||
INSERT INTO identities (name, embedding, metadata)
|
||||
VALUES ($1, NULLIF($2, '')::public.vector, $3)
|
||||
RETURNING id
|
||||
"#,
|
||||
)
|
||||
.bind(&req.identity_name)
|
||||
.bind("".to_string()) // No embedding for now via this API
|
||||
.bind(&meta_json)
|
||||
.fetch_one(&mut *tx)
|
||||
.await
|
||||
{
|
||||
Ok(id) => id,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Insert identity error: {}", e),
|
||||
))
|
||||
}
|
||||
};
|
||||
new_id
|
||||
};
|
||||
|
||||
// 3. Create Binding
|
||||
// Columns: id, identity_id, identity_type, identity_value, confidence, metadata, created_at
|
||||
let binding_query = r#"
|
||||
INSERT INTO identity_bindings (identity_id, identity_type, identity_value, confidence, metadata)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
ON CONFLICT DO NOTHING
|
||||
"#;
|
||||
|
||||
match sqlx::query(binding_query)
|
||||
.bind(final_identity_id)
|
||||
.bind("person_id") // identity_type
|
||||
.bind(&req.person_id) // identity_value
|
||||
.bind(1.0) // confidence
|
||||
.bind(&serde_json::json!({"auto_updated": true}))
|
||||
.execute(&mut *tx)
|
||||
.await
|
||||
{
|
||||
Ok(_) => {}
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Binding error: {}", e),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
// 4. Update Person Name
|
||||
let update_person = "UPDATE person_identities SET name = $1 WHERE id = $2";
|
||||
match sqlx::query(update_person)
|
||||
.bind(&req.identity_name)
|
||||
.bind(person_db_id)
|
||||
.execute(&mut *tx)
|
||||
.await
|
||||
{
|
||||
Ok(_) => {}
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Update person error: {}", e),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
match tx.commit().await {
|
||||
Ok(_) => {}
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Commit error: {}", e),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Json(RegisterFromPersonResponse {
|
||||
success: true,
|
||||
message: format!(
|
||||
"Successfully registered identity '{}' and linked to person '{}'",
|
||||
req.identity_name, req.person_id
|
||||
),
|
||||
identity_id: final_identity_id,
|
||||
identity_name: req.identity_name,
|
||||
person_id: req.person_id,
|
||||
}))
|
||||
}
|
||||
|
||||
/// List all global identities
|
||||
async fn list_identities(
|
||||
State(_state): State<crate::api::server::AppState>,
|
||||
Query(query): Query<ListIdentitiesQuery>,
|
||||
) -> Result<Json<IdentityListResponse>, (StatusCode, String)> {
|
||||
let db = match PostgresDb::init().await {
|
||||
Ok(db) => db,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("DB error: {}", e),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let page = query.page.unwrap_or(1);
|
||||
let page_size = query.page_size.unwrap_or(20);
|
||||
let offset = ((page - 1) as i64) * (page_size as i64);
|
||||
|
||||
// 獲取總數
|
||||
let count_sql = "SELECT COUNT(*) FROM identities";
|
||||
let total: i64 = match sqlx::query_scalar(count_sql).fetch_one(db.pool()).await {
|
||||
Ok(count) => count,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Count error: {}", e),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let sql = "SELECT id, name, metadata FROM identities ORDER BY id DESC LIMIT $1 OFFSET $2";
|
||||
|
||||
let rows: Vec<(i32, String, Option<serde_json::Value>)> = match sqlx::query_as(sql)
|
||||
.bind(page_size as i64)
|
||||
.bind(offset)
|
||||
.fetch_all(db.pool())
|
||||
.await
|
||||
{
|
||||
Ok(rows) => rows,
|
||||
Err(e) => {
|
||||
return Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Query error: {}", e),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let identities: Vec<IdentityResponse> = rows
|
||||
.into_iter()
|
||||
.map(|r| IdentityResponse {
|
||||
id: r.0,
|
||||
name: r.1,
|
||||
metadata: r.2,
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(Json(IdentityListResponse {
|
||||
identities,
|
||||
count: total,
|
||||
page,
|
||||
page_size,
|
||||
}))
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ListIdentitiesQuery {
|
||||
pub page: Option<usize>,
|
||||
pub page_size: Option<usize>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct IdentityResponse {
|
||||
pub id: i32,
|
||||
pub name: String,
|
||||
pub metadata: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct IdentityListResponse {
|
||||
pub identities: Vec<IdentityResponse>,
|
||||
pub count: i64,
|
||||
pub page: usize,
|
||||
pub page_size: usize,
|
||||
}
|
||||
412
src/api/identity_binding.rs
Normal file
412
src/api/identity_binding.rs
Normal file
@@ -0,0 +1,412 @@
|
||||
use axum::{
|
||||
extract::{Path, Query},
|
||||
http::StatusCode,
|
||||
response::Json,
|
||||
routing::{get, post},
|
||||
Router,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::core::db::{Database, PostgresDb};
|
||||
use crate::core::person_identity::{BindIdentityRequest, Identity, UnbindIdentityRequest};
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct ApiResponse<T: Serialize> {
|
||||
pub success: bool,
|
||||
pub message: String,
|
||||
pub data: Option<T>,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// API Handlers
|
||||
// ============================================================================
|
||||
|
||||
async fn get_db() -> Result<PostgresDb, (StatusCode, Json<serde_json::Value>)> {
|
||||
PostgresDb::init().await.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": format!("DB init failed: {}", e) })),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
/// 獲取 Identity (人物) 列表
|
||||
pub async fn list_identities(
|
||||
Query(params): Query<ListIdentitiesParams>,
|
||||
) -> Result<Json<ApiResponse<Vec<Identity>>>, (StatusCode, Json<serde_json::Value>)> {
|
||||
let db = get_db().await?;
|
||||
let limit = params.limit.unwrap_or(100);
|
||||
let offset = params.offset.unwrap_or(0);
|
||||
let search = params.search.unwrap_or_default();
|
||||
|
||||
let identities = db
|
||||
.list_identities(&search, limit, offset)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": e.to_string() })),
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(Json(ApiResponse {
|
||||
success: true,
|
||||
message: format!("Found {} identities", identities.len()),
|
||||
data: Some(identities),
|
||||
}))
|
||||
}
|
||||
|
||||
/// 綁定身份 (Face/Speaker -> Identity)
|
||||
pub async fn bind_identity(
|
||||
Json(req): Json<BindIdentityRequest>,
|
||||
) -> Result<Json<ApiResponse<()>>, (StatusCode, Json<serde_json::Value>)> {
|
||||
let db = get_db().await?;
|
||||
|
||||
let identity = if let Some(id_id) = req.identity_id {
|
||||
db.get_identity_by_id(id_id).await.ok().flatten()
|
||||
} else if let Some(name) = &req.name {
|
||||
db.get_or_create_identity(name).await.ok()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let identity = match identity {
|
||||
Some(t) => t,
|
||||
None => {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({ "error": "Identity not found or name required" })),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let source = req.source.unwrap_or("manual".to_string());
|
||||
db.bind_identity(
|
||||
identity.id as i64,
|
||||
&req.binding_type,
|
||||
&req.binding_value,
|
||||
&source,
|
||||
1.0,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": e.to_string() })),
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(Json(ApiResponse {
|
||||
success: true,
|
||||
message: format!(
|
||||
"Bound {} '{}' to Identity '{}'",
|
||||
req.binding_type, req.binding_value, identity.name
|
||||
),
|
||||
data: None,
|
||||
}))
|
||||
}
|
||||
|
||||
/// 解綁身份
|
||||
pub async fn unbind_identity(
|
||||
Json(req): Json<UnbindIdentityRequest>,
|
||||
) -> Result<Json<ApiResponse<()>>, (StatusCode, Json<serde_json::Value>)> {
|
||||
let db = get_db().await?;
|
||||
db.unbind_identity(&req.binding_type, &req.binding_value)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": e.to_string() })),
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(Json(ApiResponse {
|
||||
success: true,
|
||||
message: format!("Unbound {} '{}'", req.binding_type, req.binding_value),
|
||||
data: None,
|
||||
}))
|
||||
}
|
||||
|
||||
/// 查詢機器 ID 對應的 Identity (人物)
|
||||
pub async fn get_identity_info(
|
||||
Path((binding_type, binding_value)): Path<(String, String)>,
|
||||
) -> Result<Json<ApiResponse<Identity>>, (StatusCode, Json<serde_json::Value>)> {
|
||||
let db = get_db().await?;
|
||||
let identity = db
|
||||
.get_identity_by_binding(&binding_type, &binding_value)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": e.to_string() })),
|
||||
)
|
||||
})?;
|
||||
|
||||
let identity = identity.ok_or_else(|| {
|
||||
(
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(serde_json::json!({ "error": "Identity not found" })),
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(Json(ApiResponse {
|
||||
success: true,
|
||||
message: "Identity info retrieved".to_string(),
|
||||
data: Some(identity),
|
||||
}))
|
||||
}
|
||||
|
||||
/// 列出未綁定的信號 (待標註列表)
|
||||
pub async fn list_unbound_signals(
|
||||
Query(params): Query<ListSignalsParams>,
|
||||
) -> Result<Json<ApiResponse<Vec<String>>>, (StatusCode, Json<serde_json::Value>)> {
|
||||
let db = get_db().await?;
|
||||
let signals = db
|
||||
.list_unbound_signals(¶ms.uuid, ¶ms.binding_type)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": e.to_string() })),
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(Json(ApiResponse {
|
||||
success: true,
|
||||
message: format!(
|
||||
"Found {} unbound {} signals",
|
||||
signals.len(),
|
||||
params.binding_type
|
||||
),
|
||||
data: Some(signals),
|
||||
}))
|
||||
}
|
||||
|
||||
/// 獲取特定信號 (Face ID 或 Speaker ID) 出現的所有 Chunk (時間軸)
|
||||
pub async fn get_signal_timeline(
|
||||
Path((uuid, binding_type, binding_value)): Path<(String, String, String)>,
|
||||
) -> Result<Json<ApiResponse<Vec<serde_json::Value>>>, (StatusCode, Json<serde_json::Value>)> {
|
||||
let db = get_db().await?;
|
||||
let chunks = db
|
||||
.get_chunks_by_signal(&uuid, &binding_type, &binding_value)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": e.to_string() })),
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(Json(ApiResponse {
|
||||
success: true,
|
||||
message: format!(
|
||||
"Found {} chunks for {} '{}'",
|
||||
chunks.len(),
|
||||
binding_type,
|
||||
binding_value
|
||||
),
|
||||
data: Some(chunks),
|
||||
}))
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct AVSuggestRequest {
|
||||
pub video_uuid: String,
|
||||
pub overlap_threshold: Option<f64>, // default 0.6
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct AVSuggestion {
|
||||
pub face_id: String,
|
||||
pub speaker_id: String,
|
||||
pub overlap_score: f64,
|
||||
pub face_talent_id: Option<i32>,
|
||||
pub speaker_talent_id: Option<i32>,
|
||||
}
|
||||
|
||||
/// Suggests Face-Speaker bindings based on temporal overlap
|
||||
pub async fn suggest_audio_visual_bindings(
|
||||
Json(req): Json<AVSuggestRequest>,
|
||||
) -> Result<Json<ApiResponse<Vec<AVSuggestion>>>, (StatusCode, Json<serde_json::Value>)> {
|
||||
let db = get_db().await?;
|
||||
let threshold = req.overlap_threshold.unwrap_or(0.6);
|
||||
|
||||
// 1. Get Face signals and their time ranges
|
||||
let face_signals = db
|
||||
.list_unbound_signals(&req.video_uuid, "face")
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": format!("Face signals: {}", e) })),
|
||||
)
|
||||
})?;
|
||||
|
||||
let speaker_signals = db
|
||||
.list_unbound_signals(&req.video_uuid, "speaker")
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": format!("Speaker signals: {}", e) })),
|
||||
)
|
||||
})?;
|
||||
|
||||
let mut suggestions = Vec::new();
|
||||
|
||||
for face_id in &face_signals {
|
||||
for speaker_id in &speaker_signals {
|
||||
// Calculate overlap
|
||||
// In a real implementation, we would query the exact timestamps from DB.
|
||||
// For now, we'll use a placeholder or simple heuristic if timestamps are available in signal timeline.
|
||||
// Let's assume we fetch timelines.
|
||||
|
||||
// Placeholder: Calculate overlap by fetching timelines
|
||||
let face_timeline = db
|
||||
.get_chunks_by_signal(&req.video_uuid, "face", face_id)
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
let speaker_timeline = db
|
||||
.get_chunks_by_signal(&req.video_uuid, "speaker", speaker_id)
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
|
||||
// Simplified overlap calculation based on chunk count/ids (assuming chunk_id contains time info or we have ranges)
|
||||
// Since chunk IDs are generic, we rely on content JSON having 'start_time'
|
||||
let overlap = calculate_overlap(&face_timeline, &speaker_timeline);
|
||||
|
||||
if overlap >= threshold {
|
||||
// Check if they are already bound to the same identity
|
||||
let face_identity = db
|
||||
.get_identity_by_binding("face", face_id)
|
||||
.await
|
||||
.ok()
|
||||
.flatten();
|
||||
let speaker_identity = db
|
||||
.get_identity_by_binding("speaker", speaker_id)
|
||||
.await
|
||||
.ok()
|
||||
.flatten();
|
||||
|
||||
// If both bound to different identities, don't suggest (conflict)
|
||||
if let (Some(fi), Some(si)) = (&face_identity, &speaker_identity) {
|
||||
if fi.id != si.id {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
suggestions.push(AVSuggestion {
|
||||
face_id: face_id.clone(),
|
||||
speaker_id: speaker_id.clone(),
|
||||
overlap_score: overlap,
|
||||
face_talent_id: face_identity.as_ref().map(|i| i.id as i32),
|
||||
speaker_talent_id: speaker_identity.as_ref().map(|i| i.id as i32),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
suggestions.sort_by(|a, b| b.overlap_score.partial_cmp(&a.overlap_score).unwrap());
|
||||
|
||||
Ok(Json(ApiResponse {
|
||||
success: true,
|
||||
message: format!("Found {} AV suggestions", suggestions.len()),
|
||||
data: Some(suggestions),
|
||||
}))
|
||||
}
|
||||
|
||||
fn calculate_overlap(
|
||||
face_chunks: &[serde_json::Value],
|
||||
speaker_chunks: &[serde_json::Value],
|
||||
) -> f64 {
|
||||
// Simplified: Extract start/end times and calculate intersection over union
|
||||
// Assuming chunks have start_frame or start_time in content JSON
|
||||
// If content is raw string, we might need to parse it.
|
||||
// In our schema, content is JSONB.
|
||||
|
||||
let mut face_ranges: Vec<(f64, f64)> = Vec::new();
|
||||
for c in face_chunks {
|
||||
if let Some(content) = c.get("content") {
|
||||
if let (Some(start), Some(end)) = (
|
||||
content.get("start_time").and_then(|v| v.as_f64()),
|
||||
content.get("end_time").and_then(|v| v.as_f64()),
|
||||
) {
|
||||
face_ranges.push((start, end));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut speaker_ranges: Vec<(f64, f64)> = Vec::new();
|
||||
for c in speaker_chunks {
|
||||
if let Some(content) = c.get("content") {
|
||||
if let (Some(start), Some(end)) = (
|
||||
content.get("start_time").and_then(|v| v.as_f64()),
|
||||
content.get("end_time").and_then(|v| v.as_f64()),
|
||||
) {
|
||||
speaker_ranges.push((start, end));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut overlap_duration = 0.0;
|
||||
for (fs, fe) in &face_ranges {
|
||||
for (ss, se) in &speaker_ranges {
|
||||
let start = fs.max(*ss);
|
||||
let end = fe.min(*se);
|
||||
if start < end {
|
||||
overlap_duration += end - start;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Return normalized overlap (0.0 to 1.0+), simple version: overlap / min_duration
|
||||
let min_duration = face_ranges
|
||||
.iter()
|
||||
.map(|(_, e)| e)
|
||||
.sum::<f64>()
|
||||
.min(speaker_ranges.iter().map(|(_, e)| e).sum::<f64>());
|
||||
|
||||
if min_duration > 0.0 {
|
||||
(overlap_duration / min_duration).min(1.0)
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Router Setup
|
||||
// ============================================================================
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ListIdentitiesParams {
|
||||
pub search: Option<String>,
|
||||
pub limit: Option<i32>,
|
||||
pub offset: Option<i32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ListSignalsParams {
|
||||
pub uuid: String,
|
||||
pub binding_type: String, // "face" or "speaker"
|
||||
}
|
||||
|
||||
pub fn identity_binding_routes() -> Router<crate::api::server::AppState> {
|
||||
Router::new()
|
||||
.route("/api/v1/identities/bind", post(bind_identity))
|
||||
.route("/api/v1/identities/unbind", post(unbind_identity))
|
||||
.route(
|
||||
"/api/v1/identity/:binding_type/:binding_value",
|
||||
get(get_identity_info),
|
||||
)
|
||||
// 信號發現 (Discovery)
|
||||
.route("/api/v1/signals/unbound", get(list_unbound_signals))
|
||||
// 信號時間軸 (Timeline)
|
||||
.route(
|
||||
"/api/v1/signals/:uuid/:binding_type/:binding_value/timeline",
|
||||
get(get_signal_timeline),
|
||||
)
|
||||
.route(
|
||||
"/api/v1/identities/suggest-av",
|
||||
post(suggest_audio_visual_bindings),
|
||||
)
|
||||
}
|
||||
264
src/api/n8n_search.rs
Normal file
264
src/api/n8n_search.rs
Normal file
@@ -0,0 +1,264 @@
|
||||
use crate::core::db::{Bm25Result, PostgresDb};
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashSet;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct SmartSearchRequest {
|
||||
pub query: String,
|
||||
pub uuid: Option<String>,
|
||||
pub limit: Option<usize>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct SmartSearchResponse {
|
||||
pub query: String,
|
||||
pub parsed_dimensions: serde_json::Value,
|
||||
pub hits: Vec<serde_json::Value>,
|
||||
pub total: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
struct LlmDimensionResponse {
|
||||
pub who: Option<String>,
|
||||
pub what: Option<String>,
|
||||
pub when: Option<String>,
|
||||
pub r#where: Option<String>,
|
||||
pub why: Option<String>,
|
||||
#[serde(default)]
|
||||
pub keywords: Vec<String>,
|
||||
}
|
||||
|
||||
/// POST /api/v1/n8n/search/smart
|
||||
pub async fn n8n_search_smart(
|
||||
db: &PostgresDb,
|
||||
req: SmartSearchRequest,
|
||||
) -> Result<SmartSearchResponse, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let limit = req.limit.unwrap_or(10);
|
||||
let video_uuid = req.uuid.clone();
|
||||
|
||||
// 1. Call LLM to extract 5W1H (Fallback to keywords if LLM fails)
|
||||
let dimensions = match parse_query_with_llm(&req.query).await {
|
||||
Some(dims) => dims,
|
||||
None => LlmDimensionResponse {
|
||||
who: None,
|
||||
what: None,
|
||||
when: None,
|
||||
r#where: None,
|
||||
why: None,
|
||||
keywords: extract_keywords(&req.query),
|
||||
},
|
||||
};
|
||||
|
||||
// Prepare search terms based on dimensions
|
||||
let keywords = dimensions.keywords.join(" ");
|
||||
|
||||
let semantic_query = [
|
||||
dimensions.who.clone(),
|
||||
dimensions.what.clone(),
|
||||
dimensions.r#where.clone(),
|
||||
dimensions.why.clone(),
|
||||
dimensions.when.clone(),
|
||||
]
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.collect::<Vec<_>>()
|
||||
.join(" ");
|
||||
|
||||
// 2. Multi-dimensional Search
|
||||
let mut hits: Vec<serde_json::Value> = Vec::new();
|
||||
let mut seen_chunk_ids: HashSet<String> = HashSet::new();
|
||||
|
||||
// Helper function
|
||||
fn add_hit(
|
||||
hits: &mut Vec<serde_json::Value>,
|
||||
seen_chunk_ids: &mut HashSet<String>,
|
||||
sr: Bm25Result,
|
||||
boost: f32,
|
||||
) {
|
||||
if seen_chunk_ids.insert(sr.chunk_id.clone()) {
|
||||
let score = sr.bm25_score * boost;
|
||||
let val = serde_json::json!({
|
||||
"id": sr.chunk_id,
|
||||
"vid": sr.uuid,
|
||||
"start": sr.start_time,
|
||||
"end": sr.end_time,
|
||||
"text": sr.text,
|
||||
"score": score,
|
||||
"chunk_type": sr.chunk_type
|
||||
});
|
||||
hits.push(val);
|
||||
}
|
||||
}
|
||||
|
||||
// A. Keyword Search (BM25)
|
||||
if !keywords.is_empty() {
|
||||
if let Ok(results) = db
|
||||
.search_bm25(&keywords, video_uuid.as_deref(), limit)
|
||||
.await
|
||||
{
|
||||
for sr in results {
|
||||
add_hit(&mut hits, &mut seen_chunk_ids, sr, 1.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// B. Who Search (Person Matching)
|
||||
if let Some(who_query) = &dimensions.who {
|
||||
// 1. Search Person
|
||||
if let Ok(persons) = db.search_person_candidates(who_query, &video_uuid, 5).await {
|
||||
if !persons.is_empty() {
|
||||
let person_id = persons[0]
|
||||
.get("candidate_id")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(String::from);
|
||||
|
||||
if let Some(pid) = person_id {
|
||||
// Heuristic: Search BM25 for the person's ID or Name
|
||||
let person_name = persons[0]
|
||||
.get("display_name")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or(who_query);
|
||||
|
||||
// Re-run BM25 with person name to find specific chunks and boost them
|
||||
if let Ok(results) = db
|
||||
.search_bm25(person_name, video_uuid.as_deref(), limit)
|
||||
.await
|
||||
{
|
||||
for sr in results {
|
||||
let id = sr.chunk_id.clone();
|
||||
if seen_chunk_ids.insert(id) {
|
||||
let score = sr.bm25_score * 1.5;
|
||||
let val = serde_json::json!({
|
||||
"id": sr.chunk_id,
|
||||
"vid": sr.uuid,
|
||||
"start": sr.start_time,
|
||||
"end": sr.end_time,
|
||||
"text": sr.text,
|
||||
"score": score,
|
||||
"matched_person": pid
|
||||
});
|
||||
hits.push(val);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by score
|
||||
hits.sort_by(|a, b| {
|
||||
let score_a = a.get("score").and_then(|v| v.as_f64()).unwrap_or(0.0);
|
||||
let score_b = b.get("score").and_then(|v| v.as_f64()).unwrap_or(0.0);
|
||||
score_b.partial_cmp(&score_a).unwrap()
|
||||
});
|
||||
|
||||
// Limit
|
||||
hits.truncate(limit);
|
||||
let total = hits.len();
|
||||
|
||||
Ok(SmartSearchResponse {
|
||||
query: req.query,
|
||||
parsed_dimensions: serde_json::json!(dimensions),
|
||||
hits,
|
||||
total,
|
||||
})
|
||||
}
|
||||
|
||||
fn extract_keywords(query: &str) -> Vec<String> {
|
||||
// Simple keyword extraction: remove common stop words and punctuation
|
||||
let stop_words = [
|
||||
"who", "what", "where", "when", "why", "how", "is", "the", "a", "an", "and", "or", "of",
|
||||
"in", "to", "for", "with", "by", "on", "at", "from", "up", "about", "into", "over",
|
||||
"after",
|
||||
];
|
||||
|
||||
query
|
||||
.to_lowercase()
|
||||
.chars()
|
||||
.map(|c| if c.is_alphanumeric() { c } else { ' ' })
|
||||
.collect::<String>()
|
||||
.split_whitespace()
|
||||
.filter(|w| !stop_words.contains(w))
|
||||
.map(String::from)
|
||||
.collect()
|
||||
}
|
||||
|
||||
async fn parse_query_with_llm(query: &str) -> Option<LlmDimensionResponse> {
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
// Test connectivity first
|
||||
if let Ok(resp) = client.get("http://127.0.0.1:8081/health").send().await {
|
||||
tracing::info!("LLM Health Check: {}", resp.status());
|
||||
} else {
|
||||
tracing::error!("LLM Server is unreachable at 127.0.0.1:8081");
|
||||
}
|
||||
|
||||
// We use the OpenAI-compatible endpoint provided by llama.cpp server (default port 8081)
|
||||
let prompt = format!(
|
||||
r#"Analyze the user query and extract the following dimensions into a JSON object.
|
||||
If a dimension is not present, use null.
|
||||
Dimensions: "who" (person/subject), "what" (action/event), "where" (location), "when" (time), "why" (reason/intent), "keywords" (array of specific keywords).
|
||||
|
||||
User Query: "{}"
|
||||
|
||||
Output ONLY the JSON object.
|
||||
"#,
|
||||
query
|
||||
);
|
||||
|
||||
let payload = serde_json::json!({
|
||||
"model": "gemma4",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt
|
||||
}
|
||||
],
|
||||
"temperature": 0.1,
|
||||
"stream": false
|
||||
});
|
||||
|
||||
if let Ok(response) = client
|
||||
.post("http://127.0.0.1:8081/v1/chat/completions")
|
||||
.json(&payload)
|
||||
.timeout(std::time::Duration::from_secs(60))
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
tracing::info!("LLM Response Status: {}", response.status());
|
||||
if let Ok(json_resp) = response.json::<serde_json::Value>().await {
|
||||
tracing::info!("LLM Response Body: {}", json_resp);
|
||||
if let Some(choices) = json_resp.get("choices").and_then(|v| v.as_array()) {
|
||||
if let Some(choice) = choices.get(0) {
|
||||
if let Some(content) = choice
|
||||
.get("message")
|
||||
.and_then(|m| m.get("content"))
|
||||
.and_then(|c| c.as_str())
|
||||
{
|
||||
if let Some(start) = content.find("{") {
|
||||
if let Some(end) = content.rfind("}") {
|
||||
let json_str = &content[start..=end];
|
||||
if let Ok(dims) =
|
||||
serde_json::from_str::<LlmDimensionResponse>(json_str)
|
||||
{
|
||||
tracing::info!("Parsed LLM Dimensions: {:?}", dims);
|
||||
return Some(dims);
|
||||
} else {
|
||||
tracing::warn!("Failed to parse LLM JSON: {}", json_str);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
tracing::warn!("Failed to parse LLM response JSON");
|
||||
}
|
||||
} else {
|
||||
tracing::warn!("LLM request failed or timed out");
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
3281
src/api/person_identity.rs
Normal file
3281
src/api/person_identity.rs
Normal file
File diff suppressed because it is too large
Load Diff
2772
src/api/person_identity.rs.bak
Normal file
2772
src/api/person_identity.rs.bak
Normal file
File diff suppressed because it is too large
Load Diff
2774
src/api/person_identity.rs.bak2
Normal file
2774
src/api/person_identity.rs.bak2
Normal file
File diff suppressed because it is too large
Load Diff
195
src/api/search.rs
Normal file
195
src/api/search.rs
Normal file
@@ -0,0 +1,195 @@
|
||||
//! Smart Search API
|
||||
//! Implements the 5W1H search capability using semantic vectors.
|
||||
|
||||
use axum::{extract::State, http::StatusCode, response::Json, routing::post, Router};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json;
|
||||
use tracing;
|
||||
|
||||
use crate::core::db::PostgresDb;
|
||||
|
||||
// --- Request / Response Structures ---
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct SmartSearchRequest {
|
||||
pub uuid: String,
|
||||
pub query: String,
|
||||
pub limit: Option<usize>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct SearchResult {
|
||||
pub id: i32,
|
||||
pub parent_id: i32,
|
||||
pub scene_order: Option<i32>,
|
||||
|
||||
// Primary: frame-accurate position (authoritative unit)
|
||||
pub start_frame: i64,
|
||||
pub end_frame: i64,
|
||||
pub fps: f64,
|
||||
|
||||
// Reference: time derived from frames (subject to FPS variation, not precise)
|
||||
pub start_time: f64,
|
||||
pub end_time: f64,
|
||||
|
||||
pub raw_text: Option<String>, // Text content of the child chunk
|
||||
pub summary: Option<String>, // Summary from parent context
|
||||
pub metadata: Option<serde_json::Value>,
|
||||
pub similarity: Option<f64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct SmartSearchResponse {
|
||||
pub query: String,
|
||||
pub results: Vec<SearchResult>,
|
||||
pub strategy: String,
|
||||
}
|
||||
|
||||
// --- API Handler ---
|
||||
|
||||
pub async fn smart_search(
|
||||
State(state): State<crate::api::server::AppState>,
|
||||
Json(req): Json<SmartSearchRequest>,
|
||||
) -> Result<Json<SmartSearchResponse>, (StatusCode, Json<serde_json::Value>)> {
|
||||
let db = &state.db;
|
||||
let limit = req.limit.unwrap_or(5);
|
||||
|
||||
// 1. Generate Embedding using Ollama
|
||||
let embedding = get_ollama_embedding(&req.query).await.map_err(
|
||||
|e| -> (StatusCode, Json<serde_json::Value>) {
|
||||
tracing::error!("Embedding failed: {}", e);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": e.to_string() })),
|
||||
)
|
||||
},
|
||||
)?;
|
||||
|
||||
// 2. Search Database (Drill-Down: Find Parents First)
|
||||
let db_parents: Vec<crate::core::db::postgres_db::SemanticSearchResult> = db
|
||||
.search_parent_chunks_semantic(&req.uuid, &embedding, limit)
|
||||
.await
|
||||
.map_err(
|
||||
|e: anyhow::Error| -> (StatusCode, Json<serde_json::Value>) {
|
||||
tracing::error!("DB search failed: {}", e);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": e.to_string() })),
|
||||
)
|
||||
},
|
||||
)?;
|
||||
|
||||
if db_parents.is_empty() {
|
||||
return Ok(Json(SmartSearchResponse {
|
||||
query: req.query,
|
||||
results: vec![],
|
||||
strategy: "semantic_vector_search".to_string(),
|
||||
}));
|
||||
}
|
||||
|
||||
// Collect Parent IDs
|
||||
let parent_ids: Vec<i32> = db_parents.iter().map(|p| p.id).collect();
|
||||
|
||||
// 3. Fetch Children for these Parents (Drill Down)
|
||||
// We fetch all children for these parents (limit can be adjusted)
|
||||
let children: Vec<crate::core::db::postgres_db::ChildChunkResult> = db
|
||||
.get_children_for_parents(&parent_ids, 10) // Fetch top 10 children per parent
|
||||
.await
|
||||
.map_err(
|
||||
|e: anyhow::Error| -> (StatusCode, Json<serde_json::Value>) {
|
||||
tracing::error!("Fetching children failed: {}", e);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": e.to_string() })),
|
||||
)
|
||||
},
|
||||
)?;
|
||||
|
||||
// 4. Map Parents to a lookup table
|
||||
let parent_map: std::collections::HashMap<
|
||||
i32,
|
||||
&crate::core::db::postgres_db::SemanticSearchResult,
|
||||
> = db_parents.iter().map(|p| (p.id, p)).collect();
|
||||
|
||||
// Map Children to API response struct
|
||||
let results: Vec<SearchResult> = children
|
||||
.into_iter()
|
||||
.map(|c| {
|
||||
let parent = parent_map.get(&c.parent_id);
|
||||
SearchResult {
|
||||
id: c.id,
|
||||
parent_id: c.parent_id,
|
||||
scene_order: parent.map(|p| p.scene_order),
|
||||
|
||||
start_frame: c.start_frame,
|
||||
end_frame: c.end_frame,
|
||||
fps: c.fps,
|
||||
|
||||
start_time: c.start_time,
|
||||
end_time: c.end_time,
|
||||
raw_text: Some(c.raw_text),
|
||||
summary: parent.map(|p| p.summary.clone()),
|
||||
metadata: parent.map(|p| p.metadata.clone()),
|
||||
similarity: parent.and_then(|p| p.similarity),
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
// 6. Sort results by similarity (descending)
|
||||
// Since all children of a parent have the same parent similarity, this groups relevant chunks together
|
||||
let mut results = results;
|
||||
results.sort_by(|a, b| {
|
||||
b.similarity
|
||||
.partial_cmp(&a.similarity)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
|
||||
// 7. Limit the final results (optional, but good for API consistency)
|
||||
let limit = req.limit.unwrap_or(5) * 5; // Allow more children per parent context
|
||||
results.truncate(limit);
|
||||
|
||||
// 8. Format Response
|
||||
let response = SmartSearchResponse {
|
||||
query: req.query,
|
||||
results,
|
||||
strategy: "drill_down_semantic_search".to_string(),
|
||||
};
|
||||
|
||||
Ok(Json(response))
|
||||
}
|
||||
|
||||
// --- Helper: Ollama Embedding ---
|
||||
|
||||
async fn get_ollama_embedding(
|
||||
text: &str,
|
||||
) -> Result<Vec<f32>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let client = reqwest::Client::new();
|
||||
let payload = serde_json::json!({
|
||||
"model": "nomic-embed-text",
|
||||
"prompt": text
|
||||
});
|
||||
|
||||
let res = client
|
||||
.post("http://localhost:11434/api/embeddings")
|
||||
.json(&payload)
|
||||
.send()
|
||||
.await?
|
||||
.json::<serde_json::Value>()
|
||||
.await?;
|
||||
|
||||
// Parse embedding array from response
|
||||
let embedding = res["embedding"]
|
||||
.as_array()
|
||||
.ok_or("No embedding found in Ollama response")?
|
||||
.iter()
|
||||
.map(|v| v.as_f64().unwrap_or(0.0) as f32)
|
||||
.collect();
|
||||
|
||||
Ok(embedding)
|
||||
}
|
||||
|
||||
// --- Router Setup ---
|
||||
|
||||
pub fn search_routes() -> Router<crate::api::server::AppState> {
|
||||
Router::new().route("/smart", post(smart_search))
|
||||
}
|
||||
195
src/api/search.rs.bak
Normal file
195
src/api/search.rs.bak
Normal file
@@ -0,0 +1,195 @@
|
||||
//! Smart Search API
|
||||
//! Implements the 5W1H search capability using semantic vectors.
|
||||
|
||||
use axum::{extract::State, http::StatusCode, response::Json, routing::post, Router};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json;
|
||||
use tracing;
|
||||
|
||||
use crate::core::db::PostgresDb;
|
||||
|
||||
// --- Request / Response Structures ---
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct SmartSearchRequest {
|
||||
pub uuid: String,
|
||||
pub query: String,
|
||||
pub limit: Option<usize>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct SearchResult {
|
||||
pub id: i32,
|
||||
pub parent_id: i32,
|
||||
pub scene_order: Option<i32>,
|
||||
|
||||
// Primary: frame-accurate position (authoritative unit)
|
||||
pub start_frame: i64,
|
||||
pub end_frame: i64,
|
||||
pub fps: f64,
|
||||
|
||||
// Reference: time derived from frames (subject to FPS variation, not precise)
|
||||
pub start_time: f64,
|
||||
pub end_time: f64,
|
||||
|
||||
pub raw_text: Option<String>, // Text content of the child chunk
|
||||
pub summary: Option<String>, // Summary from parent context
|
||||
pub metadata: Option<serde_json::Value>,
|
||||
pub similarity: Option<f64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct SmartSearchResponse {
|
||||
pub query: String,
|
||||
pub results: Vec<SearchResult>,
|
||||
pub strategy: String,
|
||||
}
|
||||
|
||||
// --- API Handler ---
|
||||
|
||||
pub async fn smart_search(
|
||||
State(state): State<crate::api::server::AppState>,
|
||||
Json(req): Json<SmartSearchRequest>,
|
||||
) -> Result<Json<SmartSearchResponse>, (StatusCode, Json<serde_json::Value>)> {
|
||||
let db = &state.db;
|
||||
let limit = req.limit.unwrap_or(5);
|
||||
|
||||
// 1. Generate Embedding using Ollama
|
||||
let embedding = get_ollama_embedding(&req.query).await.map_err(
|
||||
|e| -> (StatusCode, Json<serde_json::Value>) {
|
||||
tracing::error!("Embedding failed: {}", e);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": e.to_string() })),
|
||||
)
|
||||
},
|
||||
)?;
|
||||
|
||||
// 2. Search Database (Drill-Down: Find Parents First)
|
||||
let db_parents: Vec<crate::core::db::postgres_db::SemanticSearchResult> = db
|
||||
.search_parent_chunks_semantic(&req.uuid, &embedding, limit)
|
||||
.await
|
||||
.map_err(
|
||||
|e: anyhow::Error| -> (StatusCode, Json<serde_json::Value>) {
|
||||
tracing::error!("DB search failed: {}", e);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": e.to_string() })),
|
||||
)
|
||||
},
|
||||
)?;
|
||||
|
||||
if db_parents.is_empty() {
|
||||
return Ok(Json(SmartSearchResponse {
|
||||
query: req.query,
|
||||
results: vec![],
|
||||
strategy: "semantic_vector_search".to_string(),
|
||||
}));
|
||||
}
|
||||
|
||||
// Collect Parent IDs
|
||||
let parent_ids: Vec<i32> = db_parents.iter().map(|p| p.id).collect();
|
||||
|
||||
// 3. Fetch Children for these Parents (Drill Down)
|
||||
// We fetch all children for these parents (limit can be adjusted)
|
||||
let children: Vec<crate::core::db::postgres_db::ChildChunkResult> = db
|
||||
.get_children_for_parents(&parent_ids, 10) // Fetch top 10 children per parent
|
||||
.await
|
||||
.map_err(
|
||||
|e: anyhow::Error| -> (StatusCode, Json<serde_json::Value>) {
|
||||
tracing::error!("Fetching children failed: {}", e);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": e.to_string() })),
|
||||
)
|
||||
},
|
||||
)?;
|
||||
|
||||
// 4. Map Parents to a lookup table
|
||||
let parent_map: std::collections::HashMap<
|
||||
i32,
|
||||
&crate::core::db::postgres_db::SemanticSearchResult,
|
||||
> = db_parents.iter().map(|p| (p.id, p)).collect();
|
||||
|
||||
// Map Children to API response struct
|
||||
let results: Vec<SearchResult> = children
|
||||
.into_iter()
|
||||
.map(|c| {
|
||||
let parent = parent_map.get(&c.parent_id);
|
||||
SearchResult {
|
||||
id: c.id,
|
||||
parent_id: c.parent_id,
|
||||
scene_order: parent.map(|p| p.scene_order),
|
||||
|
||||
start_frame: c.start_frame,
|
||||
end_frame: c.end_frame,
|
||||
fps: c.fps,
|
||||
|
||||
start_time: c.start_time,
|
||||
end_time: c.end_time,
|
||||
raw_text: Some(c.raw_text),
|
||||
summary: parent.map(|p| p.summary.clone()),
|
||||
metadata: parent.map(|p| p.metadata.clone()),
|
||||
similarity: parent.and_then(|p| p.similarity),
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
// 6. Sort results by similarity (descending)
|
||||
// Since all children of a parent have the same parent similarity, this groups relevant chunks together
|
||||
let mut results = results;
|
||||
results.sort_by(|a, b| {
|
||||
b.similarity
|
||||
.partial_cmp(&a.similarity)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
|
||||
// 7. Limit the final results (optional, but good for API consistency)
|
||||
let limit = req.limit.unwrap_or(5) * 5; // Allow more children per parent context
|
||||
results.truncate(limit);
|
||||
|
||||
// 8. Format Response
|
||||
let response = SmartSearchResponse {
|
||||
query: req.query,
|
||||
results,
|
||||
strategy: "drill_down_semantic_search".to_string(),
|
||||
};
|
||||
|
||||
Ok(Json(response))
|
||||
}
|
||||
|
||||
// --- Helper: Ollama Embedding ---
|
||||
|
||||
async fn get_ollama_embedding(
|
||||
text: &str,
|
||||
) -> Result<Vec<f32>, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let client = reqwest::Client::new();
|
||||
let payload = serde_json::json!({
|
||||
"model": "nomic-embed-text",
|
||||
"prompt": text
|
||||
});
|
||||
|
||||
let res = client
|
||||
.post("http://localhost:11434/api/embeddings")
|
||||
.json(&payload)
|
||||
.send()
|
||||
.await?
|
||||
.json::<serde_json::Value>()
|
||||
.await?;
|
||||
|
||||
// Parse embedding array from response
|
||||
let embedding = res["embedding"]
|
||||
.as_array()
|
||||
.ok_or("No embedding found in Ollama response")?
|
||||
.iter()
|
||||
.map(|v| v.as_f64().unwrap_or(0.0) as f32)
|
||||
.collect();
|
||||
|
||||
Ok(embedding)
|
||||
}
|
||||
|
||||
// --- Router Setup ---
|
||||
|
||||
pub fn search_routes() -> Router<crate::api::server::AppState> {
|
||||
Router::new().route("/smart", post(smart_search))
|
||||
}
|
||||
@@ -30,6 +30,33 @@ use super::universal_search;
|
||||
use super::visual_chunk_search;
|
||||
use crate::core::chunk::types::Chunk;
|
||||
|
||||
static DEMO_USER_API_KEY: &str = "muser_68600856036340bcafc01930eb4bd839_1774418104_97221b69";
|
||||
|
||||
fn hash_password(password: &str) -> String {
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(password.as_bytes());
|
||||
format!("{:x}", hasher.finalize())
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct LoginRequest {
|
||||
username: String,
|
||||
password: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct LoginResponse {
|
||||
success: bool,
|
||||
message: Option<String>,
|
||||
api_key: Option<String>,
|
||||
user: Option<UserInfo>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct UserInfo {
|
||||
username: String,
|
||||
}
|
||||
|
||||
// Global State
|
||||
static SERVER_START: OnceCell<Instant> = OnceCell::new();
|
||||
|
||||
@@ -334,6 +361,12 @@ struct VideoInfoResponse {
|
||||
duration: f64,
|
||||
width: u32,
|
||||
height: u32,
|
||||
status: String,
|
||||
processing_status: Option<String>,
|
||||
created_at: Option<String>,
|
||||
registration_time: Option<String>,
|
||||
file_size: Option<i64>,
|
||||
probe_json: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
@@ -348,6 +381,9 @@ struct VideosResponse {
|
||||
struct VideosQuery {
|
||||
page: Option<usize>,
|
||||
page_size: Option<usize>,
|
||||
q: Option<String>,
|
||||
status: Option<String>,
|
||||
uuid: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -408,7 +444,10 @@ async fn health_detailed(State(state): State<AppState>) -> Json<DetailedHealthRe
|
||||
let qdrant = check_qdrant().await;
|
||||
let mongodb = check_mongodb(&state.mongo_cache).await;
|
||||
|
||||
let overall_status = if postgres.status == "ok" && redis.status == "ok" && qdrant.status == "ok"
|
||||
let overall_status = if postgres.status == "ok"
|
||||
&& redis.status == "ok"
|
||||
&& qdrant.status == "ok"
|
||||
&& mongodb.status == "ok"
|
||||
{
|
||||
"ok"
|
||||
} else {
|
||||
@@ -428,6 +467,30 @@ async fn health_detailed(State(state): State<AppState>) -> Json<DetailedHealthRe
|
||||
})
|
||||
}
|
||||
|
||||
async fn login(Json(req): Json<LoginRequest>) -> Json<LoginResponse> {
|
||||
if req.username == "demo" && req.password == "demo" {
|
||||
Json(LoginResponse {
|
||||
success: true,
|
||||
message: Some("Login successful".to_string()),
|
||||
api_key: Some(DEMO_USER_API_KEY.to_string()),
|
||||
user: Some(UserInfo {
|
||||
username: "demo".to_string(),
|
||||
}),
|
||||
})
|
||||
} else {
|
||||
Json(LoginResponse {
|
||||
success: false,
|
||||
message: Some("Invalid username or password".to_string()),
|
||||
api_key: None,
|
||||
user: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
async fn logout() -> Json<serde_json::Value> {
|
||||
Json(serde_json::json!({ "success": true }))
|
||||
}
|
||||
|
||||
async fn check_postgres() -> ServiceStatus {
|
||||
let start = Instant::now();
|
||||
match PostgresDb::init().await {
|
||||
@@ -709,6 +772,7 @@ async fn register(
|
||||
user_id: None,
|
||||
job_id: None,
|
||||
created_at: String::new(),
|
||||
registration_time: None,
|
||||
};
|
||||
|
||||
let video_id = db
|
||||
@@ -1874,9 +1938,51 @@ async fn list_videos(
|
||||
let page = params.page.unwrap_or(1);
|
||||
let page_size = params.page_size.unwrap_or(20);
|
||||
let offset = ((page - 1) as i64) * (page_size as i64);
|
||||
let status_filter = params.status.clone();
|
||||
let query_filter = params.q.clone();
|
||||
|
||||
// Include query and status in cache key
|
||||
let cache_key = keys::videos_list(page, page_size);
|
||||
let cache_key = if let Some(ref q) = query_filter {
|
||||
format!("{}:q:{}", cache_key, q)
|
||||
} else {
|
||||
cache_key
|
||||
};
|
||||
let ttl = state.mongo_cache.ttl_videos();
|
||||
|
||||
// If uuid is provided, fetch single video directly
|
||||
if let Some(ref uuid) = params.uuid {
|
||||
let db = PostgresDb::init()
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
let video = db.get_video_by_uuid(uuid).await.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
if let Some(v) = video {
|
||||
return Ok(Json(VideosResponse {
|
||||
videos: vec![VideoInfoResponse {
|
||||
uuid: v.uuid,
|
||||
file_path: v.file_path,
|
||||
file_name: v.file_name,
|
||||
duration: v.duration,
|
||||
width: v.width,
|
||||
height: v.height,
|
||||
status: v.status.as_str().to_string(),
|
||||
processing_status: None,
|
||||
created_at: Some(v.created_at),
|
||||
registration_time: v.registration_time,
|
||||
file_size: None,
|
||||
probe_json: v.probe_json,
|
||||
}],
|
||||
count: 1,
|
||||
page,
|
||||
page_size,
|
||||
}));
|
||||
} else {
|
||||
return Err(StatusCode::NOT_FOUND);
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
"list_videos called: page={}, page_size={}, cache_key={}",
|
||||
page,
|
||||
@@ -1892,7 +1998,21 @@ async fn list_videos(
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("PG init failed: {}", e))?;
|
||||
|
||||
let (videos, count) = db.list_videos(page_size as i32, offset).await?;
|
||||
// Map status parameter to is_processed filter
|
||||
let is_processed = match status_filter.as_deref() {
|
||||
Some("pending") | Some("unprocessed") => Some(false),
|
||||
Some("completed") | Some("ready") | Some("processed") => Some(true),
|
||||
_ => None, // no filter
|
||||
};
|
||||
|
||||
// Search by query if provided
|
||||
let (videos, count) = if let Some(ref q) = query_filter {
|
||||
db.search_videos(Some(q.as_str()), is_processed, page_size as i32, offset).await?
|
||||
} else if let Some(processed) = is_processed {
|
||||
db.search_videos(None, Some(processed), page_size as i32, offset).await?
|
||||
} else {
|
||||
db.list_videos(page_size as i32, offset).await?
|
||||
};
|
||||
tracing::info!("Got {} videos from DB", videos.len());
|
||||
|
||||
let video_infos: Vec<VideoInfoResponse> = videos
|
||||
@@ -1904,6 +2024,12 @@ async fn list_videos(
|
||||
duration: v.duration,
|
||||
width: v.width,
|
||||
height: v.height,
|
||||
status: v.status.as_str().to_string(),
|
||||
processing_status: None,
|
||||
created_at: Some(v.created_at),
|
||||
registration_time: v.registration_time,
|
||||
file_size: None,
|
||||
probe_json: v.probe_json,
|
||||
})
|
||||
.collect();
|
||||
|
||||
@@ -2328,19 +2454,15 @@ pub async fn start_server(host: &str, port: u16) -> anyhow::Result<()> {
|
||||
.with_state(state.clone());
|
||||
|
||||
let cors = CorsLayer::new()
|
||||
.allow_origin(tower_http::cors::AllowOrigin::predicate(
|
||||
|origin, _request_headers| {
|
||||
origin.as_bytes().ends_with(b"localhost")
|
||||
|| origin.as_bytes().ends_with(b"momentry.ddns.net")
|
||||
|| origin.as_bytes().ends_with(b"127.0.0.1")
|
||||
},
|
||||
))
|
||||
.allow_origin(tower_http::cors::AllowOrigin::any())
|
||||
.allow_methods(Any)
|
||||
.allow_headers(Any);
|
||||
|
||||
let app = Router::new()
|
||||
.route("/health", get(health))
|
||||
.route("/health/detailed", get(health_detailed))
|
||||
.route("/api/v1/auth/login", post(login))
|
||||
.route("/api/v1/auth/logout", post(logout))
|
||||
.route("/api/v1/stats/ingest", get(get_ingest_stats))
|
||||
.route("/api/v1/stats/sftpgo", get(get_sftpgo_status))
|
||||
.route("/api/v1/stats/inference", get(get_inference_health))
|
||||
|
||||
814
src/api/universal_search.rs
Normal file
814
src/api/universal_search.rs
Normal file
@@ -0,0 +1,814 @@
|
||||
//! Universal Search API
|
||||
//! Unified search across chunks, frames, and persons.
|
||||
|
||||
use axum::{
|
||||
extract::{Query, State},
|
||||
http::StatusCode,
|
||||
response::Json,
|
||||
routing::{get, post},
|
||||
Router,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::core::db::{Database, PostgresDb};
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct UniversalSearchRequest {
|
||||
pub query: String,
|
||||
pub uuid: Option<String>,
|
||||
#[serde(default)]
|
||||
pub types: Vec<String>, // chunk, frame, person
|
||||
pub time_range: Option<[f64; 2]>,
|
||||
pub filters: Option<SearchFilters>,
|
||||
pub limit: Option<usize>,
|
||||
pub offset: Option<usize>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
pub struct SearchFilters {
|
||||
pub person_id: Option<String>,
|
||||
pub object_class: Option<Vec<String>>,
|
||||
pub ocr_text: Option<String>,
|
||||
pub has_face: Option<bool>,
|
||||
pub speaker_id: Option<String>,
|
||||
// Visual chunk filters
|
||||
pub min_confidence: Option<f32>,
|
||||
pub min_unique_classes: Option<u32>,
|
||||
pub min_spatial_density: Option<f32>,
|
||||
pub max_spatial_density: Option<f32>,
|
||||
pub required_object_classes: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct UniversalSearchResponse {
|
||||
pub query: String,
|
||||
pub results: Vec<SearchResult>,
|
||||
pub total: usize,
|
||||
pub took_ms: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Clone)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum SearchResult {
|
||||
#[serde(rename = "chunk")]
|
||||
Chunk {
|
||||
chunk_id: String,
|
||||
chunk_type: String,
|
||||
// Primary: frame-accurate position
|
||||
start_frame: i64,
|
||||
end_frame: i64,
|
||||
// Reference: time derived from frames (subject to FPS variation)
|
||||
start_time: f64,
|
||||
end_time: f64,
|
||||
score: f64,
|
||||
text: Option<String>,
|
||||
speaker_id: Option<String>,
|
||||
metadata: Option<serde_json::Value>,
|
||||
},
|
||||
#[serde(rename = "frame")]
|
||||
Frame {
|
||||
// Primary: exact frame number
|
||||
frame_number: i64,
|
||||
// Reference: time derived from frame (subject to FPS variation)
|
||||
timestamp: f64,
|
||||
score: f64,
|
||||
objects: Option<Vec<serde_json::Value>>,
|
||||
ocr_texts: Option<Vec<String>>,
|
||||
faces: Option<Vec<serde_json::Value>>,
|
||||
pose_persons: Option<Vec<serde_json::Value>>,
|
||||
},
|
||||
#[serde(rename = "person")]
|
||||
Person {
|
||||
person_id: String,
|
||||
name: Option<String>,
|
||||
speaker_id: Option<String>,
|
||||
appearance_count: i32,
|
||||
score: f64,
|
||||
first_appearance_time: Option<f64>,
|
||||
last_appearance_time: Option<f64>,
|
||||
},
|
||||
}
|
||||
|
||||
pub fn universal_search_routes() -> Router<crate::api::server::AppState> {
|
||||
Router::new()
|
||||
.route("/api/v1/search/universal", post(universal_search))
|
||||
.route("/api/v1/search/frames", post(search_frames))
|
||||
.route("/api/v1/search/persons", get(search_persons))
|
||||
}
|
||||
|
||||
/// Unified search across all data types
|
||||
pub async fn universal_search(
|
||||
State(_state): State<crate::api::server::AppState>,
|
||||
Json(req): Json<UniversalSearchRequest>,
|
||||
) -> Result<Json<UniversalSearchResponse>, (StatusCode, Json<serde_json::Value>)> {
|
||||
let start_time = std::time::Instant::now();
|
||||
let db = PostgresDb::init().await.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": format!("DB error: {}", e) })),
|
||||
)
|
||||
})?;
|
||||
|
||||
let limit = req.limit.unwrap_or(20);
|
||||
let offset = req.offset.unwrap_or(0);
|
||||
let types = if req.types.is_empty() {
|
||||
vec![
|
||||
"chunk".to_string(),
|
||||
"frame".to_string(),
|
||||
"person".to_string(),
|
||||
]
|
||||
} else {
|
||||
req.types.clone()
|
||||
};
|
||||
|
||||
let mut results = Vec::new();
|
||||
|
||||
// Search chunks
|
||||
if types.contains(&"chunk".to_string()) {
|
||||
let chunk_results = search_chunks(&db, &req).await.map_err(|e| {
|
||||
(
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({ "error": e.to_string() })),
|
||||
)
|
||||
})?;
|
||||
results.extend(chunk_results);
|
||||
}
|
||||
|
||||
// Search frames
|
||||
if types.contains(&"frame".to_string()) {
|
||||
let frame_results = search_frames_internal(&db, &req).await.unwrap_or_default();
|
||||
results.extend(frame_results);
|
||||
}
|
||||
|
||||
// Search persons
|
||||
if types.contains(&"person".to_string()) {
|
||||
let person_results = search_persons_internal(&db, &req).await.unwrap_or_default();
|
||||
results.extend(person_results);
|
||||
}
|
||||
|
||||
// Sort by score descending
|
||||
results.sort_by(|a, b| {
|
||||
let score_a = match a {
|
||||
SearchResult::Chunk { score, .. } => *score,
|
||||
SearchResult::Frame { score, .. } => *score,
|
||||
SearchResult::Person { score, .. } => *score,
|
||||
};
|
||||
let score_b = match b {
|
||||
SearchResult::Chunk { score, .. } => *score,
|
||||
SearchResult::Frame { score, .. } => *score,
|
||||
SearchResult::Person { score, .. } => *score,
|
||||
};
|
||||
score_b
|
||||
.partial_cmp(&score_a)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
|
||||
let total = results.len();
|
||||
let end = std::cmp::min(offset + limit, results.len());
|
||||
let paginated = if offset < results.len() {
|
||||
results[offset..end].to_vec()
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
|
||||
let took = start_time.elapsed().as_millis() as u64;
|
||||
|
||||
Ok(Json(UniversalSearchResponse {
|
||||
query: req.query,
|
||||
results: paginated,
|
||||
total,
|
||||
took_ms: took,
|
||||
}))
|
||||
}
|
||||
|
||||
/// Search frames by YOLO objects, OCR text, or face IDs
|
||||
pub async fn search_frames(
|
||||
State(_state): State<crate::api::server::AppState>,
|
||||
Json(req): Json<FrameSearchRequest>,
|
||||
) -> Result<Json<FrameSearchResponse>, (StatusCode, Json<serde_json::Value>)> {
|
||||
let db = PostgresDb::init().await.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": format!("DB error: {}", e) })),
|
||||
)
|
||||
})?;
|
||||
|
||||
let frames = search_frames_internal_v2(&db, &req).await.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": format!("Search error: {}", e) })),
|
||||
)
|
||||
})?;
|
||||
|
||||
let frames_count = frames.len();
|
||||
Ok(Json(FrameSearchResponse {
|
||||
frames,
|
||||
total: frames_count,
|
||||
}))
|
||||
}
|
||||
|
||||
/// Search persons by name or speaker_id
|
||||
pub async fn search_persons(
|
||||
State(_state): State<crate::api::server::AppState>,
|
||||
Query(query): Query<PersonSearchQuery>,
|
||||
) -> Result<Json<PersonSearchResponse>, (StatusCode, Json<serde_json::Value>)> {
|
||||
let db = PostgresDb::init().await.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": format!("DB error: {}", e) })),
|
||||
)
|
||||
})?;
|
||||
|
||||
let limit = query.limit.unwrap_or(20);
|
||||
let persons = search_persons_by_query(
|
||||
&db,
|
||||
&query.query,
|
||||
query.min_appearances,
|
||||
query.max_age,
|
||||
limit,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": format!("Search error: {}", e) })),
|
||||
)
|
||||
})?;
|
||||
|
||||
let persons_count = persons.len();
|
||||
Ok(Json(PersonSearchResponse {
|
||||
persons,
|
||||
total: persons_count,
|
||||
}))
|
||||
}
|
||||
|
||||
// --- Internal search functions ---
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct FrameSearchRequest {
|
||||
pub uuid: Option<String>,
|
||||
pub object_class: Option<String>,
|
||||
pub ocr_text: Option<String>,
|
||||
pub face_id: Option<String>,
|
||||
pub time_range: Option<[f64; 2]>,
|
||||
pub limit: Option<usize>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct FrameSearchResponse {
|
||||
pub frames: Vec<FrameResult>,
|
||||
pub total: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct FrameResult {
|
||||
pub frame_number: i64,
|
||||
pub timestamp: f64,
|
||||
pub uuid: String,
|
||||
pub objects: Option<Vec<serde_json::Value>>,
|
||||
pub ocr_texts: Option<Vec<String>>,
|
||||
pub faces: Option<Vec<serde_json::Value>>,
|
||||
pub pose_persons: Option<Vec<serde_json::Value>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct PersonSearchQuery {
|
||||
pub video_uuid: String,
|
||||
pub query: Option<String>,
|
||||
pub min_appearances: Option<i32>,
|
||||
pub max_age: Option<i32>, // New filter for "children"
|
||||
pub limit: Option<usize>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct PersonSearchResponse {
|
||||
pub persons: Vec<PersonResult>,
|
||||
pub total: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct PersonResult {
|
||||
pub person_id: String,
|
||||
pub name: Option<String>,
|
||||
pub character_name: Option<String>,
|
||||
pub aliases: Option<Vec<String>>,
|
||||
pub age: Option<i32>,
|
||||
pub gender: Option<String>,
|
||||
pub speaker_id: Option<String>,
|
||||
pub appearance_count: i32,
|
||||
pub first_appearance_time: Option<f64>,
|
||||
pub last_appearance_time: Option<f64>,
|
||||
}
|
||||
|
||||
async fn search_chunks(
|
||||
db: &PostgresDb,
|
||||
req: &UniversalSearchRequest,
|
||||
) -> Result<Vec<SearchResult>, anyhow::Error> {
|
||||
// uuid is required for chunk search - chunk_id is only unique within a video
|
||||
let uuid = match &req.uuid {
|
||||
Some(u) => u.replace('\'', "''"),
|
||||
None => return Err(anyhow::anyhow!("uuid is required for chunk search")),
|
||||
};
|
||||
|
||||
let mut sql = format!(
|
||||
"SELECT chunk_id, chunk_type, start_time, end_time, start_frame, end_frame, text_content, content FROM chunks WHERE uuid = '{}'",
|
||||
uuid
|
||||
);
|
||||
if let Some(tr) = &req.time_range {
|
||||
sql.push_str(&format!(
|
||||
" AND start_time >= {} AND end_time <= {}",
|
||||
tr[0], tr[1]
|
||||
));
|
||||
}
|
||||
if !req.query.is_empty() {
|
||||
let q = req.query.replace('\'', "''");
|
||||
sql.push_str(&format!(
|
||||
" AND (text_content ILIKE '%{}%' OR content::text ILIKE '%{}%')",
|
||||
q, q
|
||||
));
|
||||
}
|
||||
if let Some(ref filters) = req.filters {
|
||||
if let Some(ref speaker_id) = filters.speaker_id {
|
||||
sql.push_str(&format!(
|
||||
" AND content->>'speaker_id' = '{}'",
|
||||
speaker_id.replace('\'', "''")
|
||||
));
|
||||
}
|
||||
if let Some(ref person_id) = filters.person_id {
|
||||
sql.push_str(&format!(
|
||||
" AND content::text LIKE '%{}%'",
|
||||
person_id.replace('\'', "''")
|
||||
));
|
||||
}
|
||||
// Visual chunk filters
|
||||
if let Some(min_confidence) = filters.min_confidence {
|
||||
sql.push_str(&format!(
|
||||
" AND (content->'metadata'->>'avg_confidence')::float >= {}",
|
||||
min_confidence
|
||||
));
|
||||
}
|
||||
if let Some(min_unique_classes) = filters.min_unique_classes {
|
||||
sql.push_str(&format!(
|
||||
" AND jsonb_array_length(content->'metadata'->'unique_classes') >= {}",
|
||||
min_unique_classes
|
||||
));
|
||||
}
|
||||
if let Some(min_density) = filters.min_spatial_density {
|
||||
sql.push_str(&format!(
|
||||
" AND (content->'metadata'->>'spatial_density')::float >= {}",
|
||||
min_density
|
||||
));
|
||||
}
|
||||
if let Some(max_density) = filters.max_spatial_density {
|
||||
sql.push_str(&format!(
|
||||
" AND (content->'metadata'->>'spatial_density')::float <= {}",
|
||||
max_density
|
||||
));
|
||||
}
|
||||
if let Some(ref required_classes) = filters.required_object_classes {
|
||||
if !required_classes.is_empty() {
|
||||
let class_conditions: Vec<String> = required_classes
|
||||
.iter()
|
||||
.map(|class| {
|
||||
format!(
|
||||
"content->'keyframe_objects' @> '[{{ \"class_name\": \"{}\"}}]'",
|
||||
class.replace('\'', "''")
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
sql.push_str(&format!(" AND ({})", class_conditions.join(" OR ")));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sql.push_str(" ORDER BY start_time ASC");
|
||||
sql.push_str(&format!(" LIMIT {}", req.limit.unwrap_or(20)));
|
||||
|
||||
let rows: Vec<(
|
||||
String,
|
||||
String,
|
||||
f64,
|
||||
f64,
|
||||
i64,
|
||||
i64,
|
||||
Option<String>,
|
||||
Option<serde_json::Value>,
|
||||
)> = sqlx::query_as(&sql).fetch_all(db.pool()).await?;
|
||||
|
||||
let results: Vec<SearchResult> = rows
|
||||
.into_iter()
|
||||
.map(
|
||||
|(
|
||||
chunk_id,
|
||||
chunk_type,
|
||||
start_time,
|
||||
end_time,
|
||||
start_frame,
|
||||
end_frame,
|
||||
text_content,
|
||||
content,
|
||||
)| {
|
||||
let text = text_content.or_else(|| {
|
||||
content
|
||||
.as_ref()
|
||||
.and_then(|c| c.get("text").and_then(|v| v.as_str()).map(String::from))
|
||||
});
|
||||
let speaker_id = content.as_ref().and_then(|c| {
|
||||
c.get("speaker_id")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(String::from)
|
||||
});
|
||||
// Simple scoring: if query matches, score 0.8
|
||||
let score = if !req.query.is_empty()
|
||||
&& text.as_ref().map_or(false, |t| {
|
||||
t.to_lowercase().contains(&req.query.to_lowercase())
|
||||
}) {
|
||||
0.9
|
||||
} else {
|
||||
0.5
|
||||
};
|
||||
|
||||
SearchResult::Chunk {
|
||||
chunk_id,
|
||||
chunk_type,
|
||||
start_time,
|
||||
end_time,
|
||||
start_frame,
|
||||
end_frame,
|
||||
score,
|
||||
text,
|
||||
speaker_id,
|
||||
metadata: content,
|
||||
}
|
||||
},
|
||||
)
|
||||
.collect();
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
async fn search_frames_internal(
|
||||
db: &PostgresDb,
|
||||
req: &UniversalSearchRequest,
|
||||
) -> Result<Vec<SearchResult>, anyhow::Error> {
|
||||
let table = "frames";
|
||||
let video_table = "videos";
|
||||
|
||||
let mut sql = format!(
|
||||
"SELECT f.frame_number, f.timestamp, f.yolo_objects, f.ocr_results, f.face_results, f.pose_results, v.uuid
|
||||
FROM {} f JOIN {} v ON f.file_id = v.id WHERE 1=1",
|
||||
table, video_table
|
||||
);
|
||||
|
||||
if let Some(uuid) = &req.uuid {
|
||||
sql.push_str(&format!(" AND v.uuid = '{}'", uuid));
|
||||
}
|
||||
if let Some(tr) = &req.time_range {
|
||||
sql.push_str(&format!(
|
||||
" AND f.timestamp >= {} AND f.timestamp <= {}",
|
||||
tr[0], tr[1]
|
||||
));
|
||||
}
|
||||
if let Some(ref filters) = req.filters {
|
||||
if let Some(ref classes) = filters.object_class {
|
||||
for class in classes {
|
||||
sql.push_str(&format!(" AND f.yolo_objects::text ILIKE '%{}%'", class));
|
||||
}
|
||||
}
|
||||
if let Some(ref ocr) = filters.ocr_text {
|
||||
sql.push_str(&format!(" AND f.ocr_results::text ILIKE '%{}%'", ocr));
|
||||
}
|
||||
if let Some(true) = filters.has_face {
|
||||
sql.push_str(
|
||||
" AND f.face_results IS NOT NULL AND jsonb_array_length(f.face_results) > 0",
|
||||
);
|
||||
}
|
||||
if let Some(ref person_id) = filters.person_id {
|
||||
sql.push_str(&format!(" AND f.face_results::text LIKE '%{}%'", person_id));
|
||||
}
|
||||
}
|
||||
if !req.query.is_empty() {
|
||||
// Search across all frame data
|
||||
sql.push_str(&format!(
|
||||
" AND (f.yolo_objects::text ILIKE '%{}%' OR f.ocr_results::text ILIKE '%{}%' OR f.face_results::text ILIKE '%{}%')",
|
||||
req.query, req.query, req.query
|
||||
));
|
||||
}
|
||||
|
||||
sql.push_str(" ORDER BY f.timestamp ASC");
|
||||
sql.push_str(&format!(" LIMIT {}", req.limit.unwrap_or(20)));
|
||||
|
||||
let rows: Vec<(
|
||||
i64,
|
||||
f64,
|
||||
Option<serde_json::Value>,
|
||||
Option<serde_json::Value>,
|
||||
Option<serde_json::Value>,
|
||||
Option<serde_json::Value>,
|
||||
String,
|
||||
)> = sqlx::query_as(&sql).fetch_all(db.pool()).await?;
|
||||
|
||||
let results: Vec<SearchResult> = rows
|
||||
.into_iter()
|
||||
.map(|(frame_number, timestamp, yolo, ocr, face, pose, _uuid)| {
|
||||
let objects = yolo.as_ref().and_then(|v| {
|
||||
v.get("objects")
|
||||
.map(|o| o.as_array().cloned().unwrap_or_default())
|
||||
});
|
||||
let ocr_texts = ocr.as_ref().and_then(|v| {
|
||||
v.get("texts").and_then(|t| {
|
||||
t.as_array().map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|item| {
|
||||
item.get("text").and_then(|x| x.as_str()).map(String::from)
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
})
|
||||
});
|
||||
let faces = face.as_ref().and_then(|v| {
|
||||
v.get("faces")
|
||||
.map(|f| f.as_array().cloned().unwrap_or_default())
|
||||
});
|
||||
let pose_persons = pose.as_ref().and_then(|v| {
|
||||
v.get("persons")
|
||||
.map(|p| p.as_array().cloned().unwrap_or_default())
|
||||
});
|
||||
|
||||
SearchResult::Frame {
|
||||
frame_number,
|
||||
timestamp,
|
||||
score: 0.7,
|
||||
objects: objects.map(|arr| arr.iter().map(|v| v.clone()).collect()),
|
||||
ocr_texts,
|
||||
faces,
|
||||
pose_persons,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
async fn search_persons_internal(
|
||||
db: &PostgresDb,
|
||||
req: &UniversalSearchRequest,
|
||||
) -> Result<Vec<SearchResult>, anyhow::Error> {
|
||||
let table = "person_identities";
|
||||
let mut sql = format!(
|
||||
"SELECT person_id, name, speaker_id, appearance_count, first_appearance_time, last_appearance_time FROM {} WHERE 1=1",
|
||||
table
|
||||
);
|
||||
|
||||
if !req.query.is_empty() {
|
||||
sql.push_str(&format!(
|
||||
" AND (name ILIKE '%{}%' OR person_id ILIKE '%{}%' OR speaker_id ILIKE '%{}%')",
|
||||
req.query, req.query, req.query
|
||||
));
|
||||
}
|
||||
if let Some(ref filters) = req.filters {
|
||||
if let Some(ref speaker_id) = filters.speaker_id {
|
||||
sql.push_str(&format!(" AND speaker_id = '{}'", speaker_id));
|
||||
}
|
||||
if let Some(ref person_id) = filters.person_id {
|
||||
sql.push_str(&format!(" AND person_id = '{}'", person_id));
|
||||
}
|
||||
}
|
||||
|
||||
sql.push_str(" ORDER BY appearance_count DESC");
|
||||
sql.push_str(&format!(" LIMIT {}", req.limit.unwrap_or(20)));
|
||||
|
||||
let rows: Vec<(
|
||||
String,
|
||||
Option<String>,
|
||||
Option<String>,
|
||||
i32,
|
||||
Option<f64>,
|
||||
Option<f64>,
|
||||
)> = sqlx::query_as(&sql).fetch_all(db.pool()).await?;
|
||||
|
||||
let results: Vec<SearchResult> = rows
|
||||
.into_iter()
|
||||
.map(
|
||||
|(person_id, name, speaker_id, appearance_count, first_time, last_time)| {
|
||||
let score = if !req.query.is_empty()
|
||||
&& name.as_ref().map_or(false, |n| {
|
||||
n.to_lowercase().contains(&req.query.to_lowercase())
|
||||
}) {
|
||||
0.95
|
||||
} else {
|
||||
0.5
|
||||
};
|
||||
|
||||
SearchResult::Person {
|
||||
person_id,
|
||||
name,
|
||||
speaker_id,
|
||||
appearance_count,
|
||||
score,
|
||||
first_appearance_time: first_time,
|
||||
last_appearance_time: last_time,
|
||||
}
|
||||
},
|
||||
)
|
||||
.collect();
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
async fn search_frames_internal_v2(
|
||||
db: &PostgresDb,
|
||||
req: &FrameSearchRequest,
|
||||
) -> Result<Vec<FrameResult>, anyhow::Error> {
|
||||
let table = "frames";
|
||||
let video_table = "videos";
|
||||
|
||||
let mut sql = format!(
|
||||
"SELECT f.frame_number, f.timestamp, f.yolo_objects, f.ocr_results, f.face_results, f.pose_results, v.uuid
|
||||
FROM {} f JOIN {} v ON f.file_id = v.id WHERE 1=1",
|
||||
table, video_table
|
||||
);
|
||||
|
||||
if let Some(uuid) = &req.uuid {
|
||||
sql.push_str(&format!(" AND v.uuid = '{}'", uuid));
|
||||
}
|
||||
if let Some(tr) = &req.time_range {
|
||||
sql.push_str(&format!(
|
||||
" AND f.timestamp >= {} AND f.timestamp <= {}",
|
||||
tr[0], tr[1]
|
||||
));
|
||||
}
|
||||
if let Some(ref class) = req.object_class {
|
||||
sql.push_str(&format!(" AND f.yolo_objects::text ILIKE '%{}%'", class));
|
||||
}
|
||||
if let Some(ref ocr) = req.ocr_text {
|
||||
sql.push_str(&format!(" AND f.ocr_results::text ILIKE '%{}%'", ocr));
|
||||
}
|
||||
if let Some(ref face_id) = req.face_id {
|
||||
sql.push_str(&format!(" AND f.face_results::text LIKE '%{}%'", face_id));
|
||||
}
|
||||
|
||||
sql.push_str(" ORDER BY f.timestamp ASC");
|
||||
sql.push_str(&format!(" LIMIT {}", req.limit.unwrap_or(50)));
|
||||
|
||||
let rows: Vec<(
|
||||
i64,
|
||||
f64,
|
||||
Option<serde_json::Value>,
|
||||
Option<serde_json::Value>,
|
||||
Option<serde_json::Value>,
|
||||
Option<serde_json::Value>,
|
||||
String,
|
||||
)> = sqlx::query_as(&sql).fetch_all(db.pool()).await?;
|
||||
|
||||
let results: Vec<FrameResult> = rows
|
||||
.into_iter()
|
||||
.map(|(frame_number, timestamp, yolo, ocr, face, pose, uuid)| {
|
||||
let objects = yolo.as_ref().and_then(|v| {
|
||||
v.get("objects")
|
||||
.map(|o| o.as_array().cloned().unwrap_or_default())
|
||||
});
|
||||
let ocr_texts = ocr.as_ref().and_then(|v| {
|
||||
v.get("texts").and_then(|t| {
|
||||
t.as_array().map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|item| {
|
||||
item.get("text").and_then(|x| x.as_str()).map(String::from)
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
})
|
||||
});
|
||||
let faces = face.as_ref().and_then(|v| {
|
||||
v.get("faces")
|
||||
.map(|f| f.as_array().cloned().unwrap_or_default())
|
||||
});
|
||||
let pose_persons = pose.as_ref().and_then(|v| {
|
||||
v.get("persons")
|
||||
.map(|p| p.as_array().cloned().unwrap_or_default())
|
||||
});
|
||||
|
||||
FrameResult {
|
||||
frame_number,
|
||||
timestamp,
|
||||
uuid,
|
||||
objects: objects.map(|arr| arr.iter().map(|v| v.clone()).collect()),
|
||||
ocr_texts,
|
||||
faces,
|
||||
pose_persons,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
async fn search_persons_by_query(
|
||||
db: &PostgresDb,
|
||||
query: &Option<String>,
|
||||
min_appearances: Option<i32>,
|
||||
max_age: Option<i32>,
|
||||
limit: usize,
|
||||
) -> Result<Vec<PersonResult>, anyhow::Error> {
|
||||
let table = "person_identities";
|
||||
let mut sql = format!(
|
||||
"SELECT person_id, name, character_name, aliases, age, gender, speaker_id, appearance_count, first_appearance_time, last_appearance_time FROM {} WHERE 1=1",
|
||||
table
|
||||
);
|
||||
|
||||
if let Some(ref q) = query {
|
||||
// Search name, character_name, aliases (cast to text), person_id, speaker_id
|
||||
sql.push_str(&format!(
|
||||
" AND (name ILIKE '%{}%' OR character_name ILIKE '%{}%' OR aliases::text ILIKE '%{}%' OR person_id ILIKE '%{}%' OR speaker_id ILIKE '%{}%')",
|
||||
q, q, q, q, q
|
||||
));
|
||||
}
|
||||
|
||||
if let Some(min) = min_appearances {
|
||||
sql.push_str(&format!(" AND appearance_count >= {}", min));
|
||||
}
|
||||
if let Some(max_a) = max_age {
|
||||
// Strictly filter for age <= max_age.
|
||||
// Note: This excludes entries with NULL age.
|
||||
sql.push_str(&format!(" AND age <= {}", max_a));
|
||||
}
|
||||
|
||||
sql.push_str(" ORDER BY appearance_count DESC");
|
||||
sql.push_str(&format!(" LIMIT {}", limit));
|
||||
|
||||
let rows: Vec<(
|
||||
String,
|
||||
Option<String>,
|
||||
Option<String>,
|
||||
Option<serde_json::Value>,
|
||||
Option<i32>,
|
||||
Option<String>,
|
||||
Option<String>,
|
||||
i32,
|
||||
Option<f64>,
|
||||
Option<f64>,
|
||||
)> = sqlx::query_as(&sql).fetch_all(db.pool()).await?;
|
||||
|
||||
let results: Vec<PersonResult> = rows
|
||||
.into_iter()
|
||||
.map(
|
||||
|(
|
||||
person_id,
|
||||
name,
|
||||
character_name,
|
||||
aliases_json,
|
||||
age,
|
||||
gender,
|
||||
speaker_id,
|
||||
appearance_count,
|
||||
first_time,
|
||||
last_time,
|
||||
)| {
|
||||
let aliases = aliases_json.and_then(|v| {
|
||||
v.as_array().map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|val| val.as_str().map(String::from))
|
||||
.collect()
|
||||
})
|
||||
});
|
||||
|
||||
PersonResult {
|
||||
person_id,
|
||||
name,
|
||||
character_name,
|
||||
aliases,
|
||||
age,
|
||||
gender,
|
||||
speaker_id,
|
||||
appearance_count,
|
||||
first_appearance_time: first_time,
|
||||
last_appearance_time: last_time,
|
||||
}
|
||||
},
|
||||
)
|
||||
.collect();
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_search_filters_with_visual() {
|
||||
let filters = SearchFilters {
|
||||
person_id: None,
|
||||
object_class: None,
|
||||
ocr_text: None,
|
||||
has_face: None,
|
||||
speaker_id: None,
|
||||
min_confidence: Some(0.8),
|
||||
min_unique_classes: Some(3),
|
||||
min_spatial_density: Some(0.5),
|
||||
max_spatial_density: Some(0.9),
|
||||
required_object_classes: Some(vec!["person".to_string()]),
|
||||
};
|
||||
assert_eq!(filters.min_confidence, Some(0.8));
|
||||
}
|
||||
}
|
||||
504
src/api/visual_chunk_search.rs
Normal file
504
src/api/visual_chunk_search.rs
Normal file
@@ -0,0 +1,504 @@
|
||||
//! Visual chunk search functionality.
|
||||
//!
|
||||
//! This module provides search capabilities for visual chunks based on:
|
||||
//! - Object classes (e.g., "person", "car", "envelope")
|
||||
//! - Confidence thresholds
|
||||
//! - Object counts
|
||||
//! - Spatial density
|
||||
//! - Object relationships
|
||||
|
||||
use crate::core::chunk::types::{Chunk, ChunkRule, ChunkType};
|
||||
use crate::core::db::PostgresDb;
|
||||
use anyhow::Result;
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Criteria for searching visual chunks
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct VisualChunkSearchCriteria {
|
||||
/// Minimum average confidence across frames
|
||||
pub min_avg_confidence: Option<f32>,
|
||||
/// Minimum number of frames with objects
|
||||
pub min_frames_with_objects: Option<u32>,
|
||||
/// Minimum number of unique object classes
|
||||
pub min_unique_classes: Option<u32>,
|
||||
/// Specific object classes to include (empty means all)
|
||||
pub required_classes: Vec<String>,
|
||||
/// Object class counts to filter by
|
||||
pub class_counts: HashMap<String, (u32, u32)>,
|
||||
/// Time range (optional)
|
||||
pub time_range: Option<(f64, f64)>,
|
||||
}
|
||||
|
||||
impl Default for VisualChunkSearchCriteria {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
min_avg_confidence: None,
|
||||
min_frames_with_objects: None,
|
||||
min_unique_classes: None,
|
||||
required_classes: Vec::new(),
|
||||
class_counts: HashMap::new(),
|
||||
time_range: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Search visual chunks based on criteria
|
||||
pub async fn search_visual_chunks(
|
||||
db: &PostgresDb,
|
||||
uuid: &str,
|
||||
criteria: &VisualChunkSearchCriteria,
|
||||
) -> Result<Vec<Chunk>> {
|
||||
// First, get all visual chunks for this video
|
||||
let all_chunks = get_visual_chunks_by_uuid(db, uuid).await?;
|
||||
|
||||
// Apply filters
|
||||
let filtered_chunks: Vec<Chunk> = all_chunks
|
||||
.into_iter()
|
||||
.filter(|chunk| {
|
||||
// Check min avg confidence
|
||||
if let Some(min_avg_confidence) = criteria.min_avg_confidence {
|
||||
if let Some(content) = &chunk.content.as_object() {
|
||||
if let Some(metadata) = content.get("metadata") {
|
||||
if let Some(avg_confidence) = metadata.get("avg_confidence") {
|
||||
if let Some(conf) = avg_confidence.as_f64() {
|
||||
if conf < min_avg_confidence as f64 {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check min frames with objects
|
||||
if let Some(min_frames) = criteria.min_frames_with_objects {
|
||||
if let Some(stats) = &chunk.visual_stats {
|
||||
if let Some(frames_with_objects) = stats.get("frames_with_objects") {
|
||||
if let Some(count) = frames_with_objects.as_u64() {
|
||||
if count < min_frames as u64 {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check min unique classes
|
||||
if let Some(min_unique_classes) = criteria.min_unique_classes {
|
||||
if let Some(content) = &chunk.content.as_object() {
|
||||
if let Some(metadata) = content.get("metadata") {
|
||||
if let Some(unique_classes) = metadata.get("unique_classes") {
|
||||
if let Some(classes) = unique_classes.as_array() {
|
||||
if (classes.len() as u32) < min_unique_classes {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check required classes
|
||||
if !criteria.required_classes.is_empty() {
|
||||
if let Some(content) = &chunk.content.as_object() {
|
||||
if let Some(keyframe_objects) = content.get("keyframe_objects") {
|
||||
if let Some(objects) = keyframe_objects.as_array() {
|
||||
let mut found_all = true;
|
||||
for required_class in &criteria.required_classes {
|
||||
let mut found = false;
|
||||
for obj in objects {
|
||||
if let Some(class_name) = obj.get("class_name") {
|
||||
if let Some(class_str) = class_name.as_str() {
|
||||
if class_str == required_class {
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
found_all = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if !found_all {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check class counts
|
||||
if !criteria.class_counts.is_empty() {
|
||||
if let Some(content) = &chunk.content.as_object() {
|
||||
if let Some(metadata) = content.get("metadata") {
|
||||
if let Some(object_counts) = metadata.get("object_counts") {
|
||||
for (class, (min, max)) in &criteria.class_counts {
|
||||
if let Some(count_value) = object_counts.get(class) {
|
||||
if let Some(count) = count_value.as_u64() {
|
||||
if *min > 0 && count < *min as u64 {
|
||||
return false;
|
||||
}
|
||||
if *max < u32::MAX && count > *max as u64 {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} else if *min > 0 {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} else if criteria.class_counts.values().any(|(min, _)| *min > 0) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check time range
|
||||
if let Some((start_time, end_time)) = criteria.time_range {
|
||||
// Calculate chunk time from frames
|
||||
let chunk_start_time = chunk.start_frame as f64 / chunk.fps;
|
||||
let chunk_end_time = chunk.end_frame as f64 / chunk.fps;
|
||||
|
||||
if chunk_start_time < start_time || chunk_end_time > end_time {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(filtered_chunks)
|
||||
}
|
||||
|
||||
/// Get all visual chunks for a video UUID
|
||||
async fn get_visual_chunks_by_uuid(db: &PostgresDb, uuid: &str) -> Result<Vec<Chunk>> {
|
||||
let sql = format!(
|
||||
"SELECT file_id, uuid, chunk_id, chunk_index, chunk_type, fps, start_frame, end_frame, text_content, content, metadata, vector_id, visual_stats FROM chunks WHERE uuid = '{}' AND chunk_type = 'visual' ORDER BY start_frame ASC",
|
||||
uuid.replace('\'', "''")
|
||||
);
|
||||
|
||||
let rows: Vec<(
|
||||
i32, // file_id
|
||||
String, // uuid
|
||||
String, // chunk_id
|
||||
i32, // chunk_index
|
||||
String, // chunk_type
|
||||
f64, // fps
|
||||
i64, // start_frame
|
||||
i64, // end_frame
|
||||
Option<String>, // text_content
|
||||
Value, // content
|
||||
Option<Value>, // metadata
|
||||
Option<String>, // vector_id
|
||||
Option<Value>, // visual_stats
|
||||
)> = sqlx::query_as(&sql).fetch_all(db.pool()).await?;
|
||||
|
||||
let mut chunks = Vec::new();
|
||||
for row in rows {
|
||||
let chunk_type = match row.4.as_str() {
|
||||
"visual" => ChunkType::Visual,
|
||||
"sentence" => ChunkType::Sentence,
|
||||
"time_based" => ChunkType::TimeBased,
|
||||
"cut" => ChunkType::Cut,
|
||||
"trace" => ChunkType::Trace,
|
||||
"story" => ChunkType::Story,
|
||||
_ => ChunkType::TimeBased,
|
||||
};
|
||||
|
||||
// Calculate frame_count
|
||||
let frame_count = (row.7 - row.6) as i32;
|
||||
|
||||
chunks.push(Chunk {
|
||||
file_id: row.0,
|
||||
uuid: row.1,
|
||||
chunk_id: row.2,
|
||||
chunk_index: row.3 as u32,
|
||||
chunk_type,
|
||||
rule: ChunkRule::Rule2, // Visual chunks use Rule2
|
||||
fps: row.5,
|
||||
start_frame: row.6,
|
||||
end_frame: row.7,
|
||||
text_content: row.8,
|
||||
content: row.9,
|
||||
metadata: row.10,
|
||||
vector_id: row.11,
|
||||
frame_count,
|
||||
pre_chunk_ids: Vec::new(),
|
||||
parent_chunk_id: None,
|
||||
child_chunk_ids: Vec::new(),
|
||||
visual_stats: row.12,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(chunks)
|
||||
}
|
||||
|
||||
/// Search visual chunks by object class
|
||||
pub async fn search_visual_chunks_by_class(
|
||||
db: &PostgresDb,
|
||||
uuid: &str,
|
||||
object_class: &str,
|
||||
min_count: Option<u32>,
|
||||
max_count: Option<u32>,
|
||||
) -> Result<Vec<Chunk>> {
|
||||
let all_chunks = get_visual_chunks_by_uuid(db, uuid).await?;
|
||||
|
||||
let filtered_chunks: Vec<Chunk> = all_chunks
|
||||
.into_iter()
|
||||
.filter(|chunk| {
|
||||
// Check if chunk contains the object class
|
||||
let mut contains_class = false;
|
||||
if let Some(content) = &chunk.content.as_object() {
|
||||
if let Some(keyframe_objects) = content.get("keyframe_objects") {
|
||||
if let Some(objects) = keyframe_objects.as_array() {
|
||||
for obj in objects {
|
||||
if let Some(class_name) = obj.get("class_name") {
|
||||
if let Some(class_str) = class_name.as_str() {
|
||||
if class_str == object_class {
|
||||
contains_class = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !contains_class {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check count in visual_stats
|
||||
if let Some(stats) = &chunk.visual_stats {
|
||||
if let Some(count) = stats.get(object_class) {
|
||||
if let Some(c) = count.as_u64() {
|
||||
if let Some(min) = min_count {
|
||||
if c < min as u64 {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if let Some(max) = max_count {
|
||||
if c > max as u64 {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(filtered_chunks)
|
||||
}
|
||||
|
||||
/// Search visual chunks by spatial density
|
||||
pub async fn search_visual_chunks_by_density(
|
||||
db: &PostgresDb,
|
||||
uuid: &str,
|
||||
min_density: f32,
|
||||
max_density: Option<f32>,
|
||||
) -> Result<Vec<Chunk>> {
|
||||
let all_chunks = get_visual_chunks_by_uuid(db, uuid).await?;
|
||||
|
||||
let filtered_chunks: Vec<Chunk> = all_chunks
|
||||
.into_iter()
|
||||
.filter(|chunk| {
|
||||
if let Some(content) = &chunk.content.as_object() {
|
||||
if let Some(metadata) = content.get("metadata") {
|
||||
if let Some(density_value) = metadata.get("spatial_density") {
|
||||
if let Some(density) = density_value.as_f64() {
|
||||
if density < min_density as f64 {
|
||||
return false;
|
||||
}
|
||||
if let Some(max_dens) = max_density {
|
||||
if density > max_dens as f64 {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
false
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(filtered_chunks)
|
||||
}
|
||||
|
||||
/// Find chunks containing specific object combinations
|
||||
pub async fn search_visual_chunks_by_combination(
|
||||
db: &PostgresDb,
|
||||
uuid: &str,
|
||||
combination: &[(&str, u32)],
|
||||
) -> Result<Vec<Chunk>> {
|
||||
let all_chunks = get_visual_chunks_by_uuid(db, uuid).await?;
|
||||
|
||||
let filtered_chunks: Vec<Chunk> = all_chunks
|
||||
.into_iter()
|
||||
.filter(|chunk| {
|
||||
// Check if all required combinations are present
|
||||
for (object_class, min_count) in combination {
|
||||
let mut found = false;
|
||||
if let Some(stats) = &chunk.visual_stats {
|
||||
if let Some(object_counts) = stats.get("object_counts") {
|
||||
if let Some(count_value) = object_counts.get(*object_class) {
|
||||
if let Some(count) = count_value.as_u64() {
|
||||
if count >= *min_count as u64 {
|
||||
found = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(filtered_chunks)
|
||||
}
|
||||
|
||||
/// Get visual chunk statistics
|
||||
pub async fn get_visual_chunk_statistics(
|
||||
db: &PostgresDb,
|
||||
uuid: &str,
|
||||
) -> Result<HashMap<String, Value>> {
|
||||
let sql = format!(
|
||||
"SELECT
|
||||
COUNT(*) as total_chunks,
|
||||
AVG((content->'metadata'->>'avg_confidence')::float) as avg_confidence,
|
||||
MIN((content->'metadata'->>'avg_confidence')::float) as min_confidence,
|
||||
MAX((content->'metadata'->>'avg_confidence')::float) as max_confidence,
|
||||
SUM((content->'metadata'->>'object_count')::int) as total_objects,
|
||||
AVG((content->'metadata'->>'spatial_density')::float) as avg_density
|
||||
FROM chunks
|
||||
WHERE uuid = '{}'
|
||||
AND chunk_type = 'visual'",
|
||||
uuid.replace('\'', "''")
|
||||
);
|
||||
|
||||
let row: (i64, Option<f64>, Option<f64>, Option<f64>, i64, Option<f64>) =
|
||||
sqlx::query_as(&sql).fetch_one(db.pool()).await?;
|
||||
|
||||
let mut stats = HashMap::new();
|
||||
stats.insert("total_chunks".to_string(), Value::from(row.0));
|
||||
stats.insert(
|
||||
"avg_confidence".to_string(),
|
||||
Value::from(row.1.unwrap_or(0.0)),
|
||||
);
|
||||
stats.insert(
|
||||
"min_confidence".to_string(),
|
||||
Value::from(row.2.unwrap_or(0.0)),
|
||||
);
|
||||
stats.insert(
|
||||
"max_confidence".to_string(),
|
||||
Value::from(row.3.unwrap_or(0.0)),
|
||||
);
|
||||
stats.insert("total_objects".to_string(), Value::from(row.4));
|
||||
stats.insert("avg_density".to_string(), Value::from(row.5.unwrap_or(0.0)));
|
||||
|
||||
Ok(stats)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_visual_chunk_search_criteria_default() {
|
||||
let criteria = VisualChunkSearchCriteria::default();
|
||||
|
||||
assert_eq!(criteria.min_avg_confidence, None);
|
||||
assert_eq!(criteria.min_frames_with_objects, None);
|
||||
assert_eq!(criteria.min_unique_classes, None);
|
||||
assert!(criteria.required_classes.is_empty());
|
||||
assert!(criteria.class_counts.is_empty());
|
||||
assert_eq!(criteria.time_range, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_visual_chunk_search_criteria_with_values() {
|
||||
let mut criteria = VisualChunkSearchCriteria::default();
|
||||
criteria.min_avg_confidence = Some(0.8);
|
||||
criteria.min_frames_with_objects = Some(10);
|
||||
criteria.min_unique_classes = Some(3);
|
||||
criteria.required_classes = vec!["person".to_string(), "car".to_string()];
|
||||
criteria.time_range = Some((0.0, 60.0));
|
||||
|
||||
assert_eq!(criteria.min_avg_confidence, Some(0.8));
|
||||
assert_eq!(criteria.min_frames_with_objects, Some(10));
|
||||
assert_eq!(criteria.min_unique_classes, Some(3));
|
||||
assert_eq!(criteria.required_classes.len(), 2);
|
||||
assert_eq!(criteria.time_range, Some((0.0, 60.0)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_visual_chunk_search_criteria_serialization() {
|
||||
let criteria = VisualChunkSearchCriteria {
|
||||
min_avg_confidence: Some(0.85),
|
||||
min_frames_with_objects: Some(5),
|
||||
min_unique_classes: Some(2),
|
||||
required_classes: vec!["person".to_string()],
|
||||
class_counts: HashMap::new(),
|
||||
time_range: Some((10.0, 30.0)),
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&criteria).unwrap();
|
||||
assert!(json.contains("min_avg_confidence"));
|
||||
assert!(json.contains("required_classes"));
|
||||
|
||||
let deserialized: VisualChunkSearchCriteria = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(deserialized.min_avg_confidence, Some(0.85));
|
||||
assert_eq!(deserialized.required_classes.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_visual_chunk_search_criteria_with_class_counts() {
|
||||
let mut criteria = VisualChunkSearchCriteria::default();
|
||||
criteria.class_counts.insert("person".to_string(), (5, 20));
|
||||
criteria.class_counts.insert("car".to_string(), (1, 10));
|
||||
|
||||
assert_eq!(criteria.class_counts.len(), 2);
|
||||
assert_eq!(criteria.class_counts.get("person"), Some(&(5, 20)));
|
||||
assert_eq!(criteria.class_counts.get("car"), Some(&(1, 10)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chunk_type_conversion() {
|
||||
// Test chunk type string to enum conversion logic
|
||||
let test_cases = vec![
|
||||
("visual", ChunkType::Visual),
|
||||
("sentence", ChunkType::Sentence),
|
||||
("time_based", ChunkType::TimeBased),
|
||||
("cut", ChunkType::Cut),
|
||||
("trace", ChunkType::Trace),
|
||||
("story", ChunkType::Story),
|
||||
("unknown", ChunkType::TimeBased), // Default fallback
|
||||
];
|
||||
|
||||
for (input, expected) in test_cases {
|
||||
let chunk_type = match input {
|
||||
"visual" => ChunkType::Visual,
|
||||
"sentence" => ChunkType::Sentence,
|
||||
"time_based" => ChunkType::TimeBased,
|
||||
"cut" => ChunkType::Cut,
|
||||
"trace" => ChunkType::Trace,
|
||||
"story" => ChunkType::Story,
|
||||
_ => ChunkType::TimeBased,
|
||||
};
|
||||
assert_eq!(chunk_type, expected);
|
||||
}
|
||||
}
|
||||
}
|
||||
147
src/api/who.rs
Normal file
147
src/api/who.rs
Normal file
@@ -0,0 +1,147 @@
|
||||
//! Who API - 身份識別與 ID 映射接口 (Video-Scoped)
|
||||
|
||||
use axum::{
|
||||
extract::State,
|
||||
http::StatusCode,
|
||||
response::Json,
|
||||
routing::{get, post},
|
||||
Router,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::core::db::Database;
|
||||
|
||||
// --- Request / Response Structures ---
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct WhoQuery {
|
||||
pub face_id: Option<String>,
|
||||
pub speaker_id: Option<String>,
|
||||
pub uuid: Option<String>,
|
||||
pub chunk_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct WhoCandidatesRequest {
|
||||
pub query: String,
|
||||
pub video_uuid: Option<String>,
|
||||
pub limit: Option<i32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct DefinePersonRequest {
|
||||
pub uuid: String,
|
||||
pub identity_id: Option<i32>,
|
||||
pub name: String,
|
||||
pub face_ids: Option<Vec<String>>,
|
||||
pub speaker_ids: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct WhoIdentity {
|
||||
pub identity_id: i32,
|
||||
pub uuid: String,
|
||||
pub name: String,
|
||||
pub tags: Option<Vec<String>>,
|
||||
pub face_ids: Vec<String>,
|
||||
pub speaker_ids: Vec<String>,
|
||||
}
|
||||
|
||||
// --- API Handlers ---
|
||||
|
||||
/// GET /api/v1/who
|
||||
pub async fn get_who_identity(
|
||||
State(state): State<crate::api::server::AppState>,
|
||||
axum::extract::Query(query): axum::extract::Query<WhoQuery>,
|
||||
) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
|
||||
let db = &state.db;
|
||||
|
||||
// Priority 1: Query by Chunk (UUID + Chunk ID)
|
||||
if let (Some(uuid), Some(chunk_id)) = (&query.uuid, &query.chunk_id) {
|
||||
let info = db
|
||||
.get_who_info_by_chunk(uuid, chunk_id)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": e.to_string() })),
|
||||
)
|
||||
})?;
|
||||
return Ok(Json(info));
|
||||
}
|
||||
|
||||
// Priority 2: List all for a specific UUID
|
||||
if let Some(uuid) = &query.uuid {
|
||||
// TODO: Implement list_all_persons(uuid)
|
||||
return Ok(Json(serde_json::json!({ "message": "List all pending" })));
|
||||
}
|
||||
|
||||
Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({ "error": "Missing uuid" })),
|
||||
))
|
||||
}
|
||||
|
||||
/// POST /api/v1/who/candidates
|
||||
/// Search person_identities table for n8n workflow
|
||||
pub async fn get_who_candidates(
|
||||
State(state): State<crate::api::server::AppState>,
|
||||
Json(req): Json<WhoCandidatesRequest>,
|
||||
) -> Result<Json<serde_json::Value>, (StatusCode, Json<serde_json::Value>)> {
|
||||
let db = &state.db;
|
||||
let limit = req.limit.unwrap_or(20);
|
||||
let query_str = format!("%{}%", req.query);
|
||||
|
||||
let results = db
|
||||
.search_person_candidates(&query_str, &req.video_uuid, limit)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": e.to_string() })),
|
||||
)
|
||||
})?;
|
||||
|
||||
// Format for n8n
|
||||
let response = serde_json::json!({
|
||||
"query": req.query,
|
||||
"items": results,
|
||||
"total": results.len()
|
||||
});
|
||||
|
||||
Ok(Json(response))
|
||||
}
|
||||
|
||||
/// POST /api/v1/who
|
||||
pub async fn define_person(
|
||||
State(state): State<crate::api::server::AppState>,
|
||||
Json(req): Json<DefinePersonRequest>,
|
||||
) -> Result<Json<WhoIdentity>, (StatusCode, Json<serde_json::Value>)> {
|
||||
let db = &state.db;
|
||||
|
||||
let identity = db
|
||||
.create_or_update_person(
|
||||
&req.uuid,
|
||||
req.identity_id,
|
||||
req.name.clone(),
|
||||
req.face_ids.unwrap_or_default(),
|
||||
req.speaker_ids.unwrap_or_default(),
|
||||
)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({ "error": e.to_string() })),
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(Json(identity))
|
||||
}
|
||||
|
||||
// --- Router Setup ---
|
||||
|
||||
pub fn who_routes() -> Router<crate::api::server::AppState> {
|
||||
Router::new()
|
||||
.route("/api/v1/who", get(get_who_identity).post(define_person))
|
||||
.route("/api/v1/who/candidates", post(get_who_candidates))
|
||||
}
|
||||
38
src/bin/debug_tsquery.rs
Normal file
38
src/bin/debug_tsquery.rs
Normal file
@@ -0,0 +1,38 @@
|
||||
use momentry_core::core::text::global_synonym_expander;
|
||||
|
||||
fn main() {
|
||||
let expander = global_synonym_expander();
|
||||
let query = "電腦";
|
||||
|
||||
println!("原始查詢: '{}'", query);
|
||||
let expanded = expander.expand_chinese_query(query);
|
||||
println!("擴展結果: '{}'", expanded);
|
||||
|
||||
// 測試 split
|
||||
let groups: Vec<&str> = if expanded.contains('&') {
|
||||
expanded.split('&').map(|s| s.trim()).collect()
|
||||
} else {
|
||||
expanded.split_whitespace().collect()
|
||||
};
|
||||
|
||||
println!("分組: {:?}", groups);
|
||||
|
||||
for group in groups {
|
||||
println!(" 分組: '{}'", group);
|
||||
let terms = if group.starts_with('(') && group.ends_with(')') {
|
||||
let inner = &group[1..group.len() - 1];
|
||||
inner.split('|').map(|s| s.trim()).collect::<Vec<&str>>()
|
||||
} else {
|
||||
vec![group]
|
||||
};
|
||||
println!(" 詞語: {:?}", terms);
|
||||
|
||||
for term in &terms {
|
||||
let cleaned: String = term
|
||||
.chars()
|
||||
.filter(|c| c.is_alphanumeric() || c.is_alphabetic())
|
||||
.collect();
|
||||
println!(" 詞語 '{}' -> 清理後 '{}'", term, cleaned);
|
||||
}
|
||||
}
|
||||
}
|
||||
659
src/bin/integrated_player.rs
Normal file
659
src/bin/integrated_player.rs
Normal file
@@ -0,0 +1,659 @@
|
||||
use anyhow::{Context, Result};
|
||||
use clap::Parser;
|
||||
use crossterm::event::{self, Event, KeyCode};
|
||||
use crossterm::terminal as crossterm_terminal;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::io::{self, IsTerminal, Write};
|
||||
use std::path::PathBuf;
|
||||
use std::process::{Command, Stdio};
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::thread;
|
||||
use std::time::Duration;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(name = "integrated_player")]
|
||||
#[command(about = "Integrated player for ASR, Face, ASRX, and Pose")]
|
||||
struct Args {
|
||||
#[arg(short, long)]
|
||||
video: PathBuf,
|
||||
|
||||
#[arg(short = 'r', long)]
|
||||
asr: Option<PathBuf>,
|
||||
|
||||
#[arg(short = 'f', long)]
|
||||
face: Option<PathBuf>,
|
||||
|
||||
#[arg(short = 'x', long)]
|
||||
asrx: Option<PathBuf>,
|
||||
|
||||
#[arg(short = 'p', long)]
|
||||
pose: Option<PathBuf>,
|
||||
|
||||
#[arg(short = 's', long, default_value = "0.0")]
|
||||
start: f64,
|
||||
|
||||
#[arg(long)]
|
||||
speaker_name: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
auto_play_speaker: bool,
|
||||
|
||||
#[arg(long)]
|
||||
demo: bool,
|
||||
|
||||
#[arg(long, default_value = "3")]
|
||||
demo_segments_per_speaker: usize,
|
||||
|
||||
#[arg(long, default_value = "2.0")]
|
||||
demo_speed: f64,
|
||||
|
||||
#[arg(long)]
|
||||
show_video: bool,
|
||||
|
||||
#[arg(long, default_value = "800")]
|
||||
video_width: u32,
|
||||
|
||||
#[arg(long, default_value = "600")]
|
||||
video_height: u32,
|
||||
|
||||
#[arg(long)]
|
||||
continuous_demo: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct AsrSegment {
|
||||
start: f64,
|
||||
end: f64,
|
||||
text: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct AsrData {
|
||||
language: Option<String>,
|
||||
segments: Vec<AsrSegment>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct FaceInfo {
|
||||
face_id: Option<String>,
|
||||
x: i32,
|
||||
y: i32,
|
||||
width: i32,
|
||||
height: i32,
|
||||
confidence: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct FaceFrame {
|
||||
frame: u64,
|
||||
timestamp: f64,
|
||||
faces: Vec<FaceInfo>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct FaceData {
|
||||
fps: f64,
|
||||
frame_count: u64,
|
||||
frames: Vec<FaceFrame>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct AsrxSegment {
|
||||
index: usize,
|
||||
start: f64,
|
||||
end: f64,
|
||||
duration: f64,
|
||||
speaker: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct AsrxData {
|
||||
segments: Vec<AsrxSegment>,
|
||||
speaker_stats: HashMap<String, SpeakerStats>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct SpeakerStats {
|
||||
count: usize,
|
||||
duration: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct Keypoint {
|
||||
name: String,
|
||||
x: f32,
|
||||
y: f32,
|
||||
confidence: f32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct PersonPose {
|
||||
keypoints: Vec<Keypoint>,
|
||||
bbox: Bbox,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct Bbox {
|
||||
x: i32,
|
||||
y: i32,
|
||||
width: i32,
|
||||
height: i32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct PoseFrame {
|
||||
frame: u64,
|
||||
timestamp: f64,
|
||||
persons: Vec<PersonPose>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct PoseData {
|
||||
frames: Vec<PoseFrame>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct IntegratedSegment {
|
||||
start: f64,
|
||||
end: f64,
|
||||
text: Option<String>,
|
||||
speaker: Option<String>,
|
||||
face: Option<FaceInfo>,
|
||||
mouth_landmarks: Option<Vec<Keypoint>>,
|
||||
}
|
||||
|
||||
struct IntegratedPlayer {
|
||||
asr_data: Option<AsrData>,
|
||||
face_data: Option<FaceData>,
|
||||
asrx_data: Option<AsrxData>,
|
||||
pose_data: Option<PoseData>,
|
||||
current_time: f64,
|
||||
speaker_names: HashMap<String, (String, String)>,
|
||||
}
|
||||
|
||||
impl IntegratedPlayer {
|
||||
fn new() -> Self {
|
||||
let mut speaker_names = HashMap::new();
|
||||
speaker_names.insert(
|
||||
"SPEAKER_0".to_string(),
|
||||
("Cary Grant".to_string(), "Peter Joshua".to_string()),
|
||||
);
|
||||
speaker_names.insert(
|
||||
"SPEAKER_1".to_string(),
|
||||
("Audrey Hepburn".to_string(), "Regina Lampert".to_string()),
|
||||
);
|
||||
speaker_names.insert(
|
||||
"SPEAKER_2".to_string(),
|
||||
(
|
||||
"Walter Matthau".to_string(),
|
||||
"Hamilton Bartholomew".to_string(),
|
||||
),
|
||||
);
|
||||
speaker_names.insert(
|
||||
"SPEAKER_4".to_string(),
|
||||
("James Coburn".to_string(), "Tex Panthollow".to_string()),
|
||||
);
|
||||
|
||||
Self {
|
||||
asr_data: None,
|
||||
face_data: None,
|
||||
asrx_data: None,
|
||||
pose_data: None,
|
||||
current_time: 0.0,
|
||||
speaker_names,
|
||||
}
|
||||
}
|
||||
|
||||
fn load_asr(&mut self, path: &PathBuf) -> Result<()> {
|
||||
let content = std::fs::read_to_string(path)
|
||||
.with_context(|| format!("Failed to read ASR file: {:?}", path))?;
|
||||
self.asr_data = Some(serde_json::from_str(&content)?);
|
||||
println!(
|
||||
"✓ Loaded {} ASR segments",
|
||||
self.asr_data.as_ref().unwrap().segments.len()
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn load_face(&mut self, path: &PathBuf) -> Result<()> {
|
||||
let content = std::fs::read_to_string(path)
|
||||
.with_context(|| format!("Failed to read Face file: {:?}", path))?;
|
||||
self.face_data = Some(serde_json::from_str(&content)?);
|
||||
let total_faces = self
|
||||
.face_data
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.frames
|
||||
.iter()
|
||||
.map(|f| f.faces.len())
|
||||
.sum::<usize>();
|
||||
println!(
|
||||
"✓ Loaded {} face frames, {} total detections",
|
||||
self.face_data.as_ref().unwrap().frames.len(),
|
||||
total_faces
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn load_asrx(&mut self, path: &PathBuf) -> Result<()> {
|
||||
let content = std::fs::read_to_string(path)
|
||||
.with_context(|| format!("Failed to read ASRX file: {:?}", path))?;
|
||||
self.asrx_data = Some(serde_json::from_str(&content)?);
|
||||
println!(
|
||||
"✓ Loaded {} ASRX segments, {} speakers",
|
||||
self.asrx_data.as_ref().unwrap().segments.len(),
|
||||
self.asrx_data.as_ref().unwrap().speaker_stats.len()
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn load_pose(&mut self, path: &PathBuf) -> Result<()> {
|
||||
let content = std::fs::read_to_string(path)
|
||||
.with_context(|| format!("Failed to read Pose file: {:?}", path))?;
|
||||
self.pose_data = Some(serde_json::from_str(&content)?);
|
||||
println!(
|
||||
"✓ Loaded {} pose frames",
|
||||
self.pose_data.as_ref().unwrap().frames.len()
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_current_segment(&self, time: f64) -> Option<IntegratedSegment> {
|
||||
let mut segment = IntegratedSegment {
|
||||
start: 0.0,
|
||||
end: 0.0,
|
||||
text: None,
|
||||
speaker: None,
|
||||
face: None,
|
||||
mouth_landmarks: None,
|
||||
};
|
||||
|
||||
if let Some(asr) = &self.asr_data {
|
||||
for seg in &asr.segments {
|
||||
if time >= seg.start && time <= seg.end {
|
||||
segment.start = seg.start;
|
||||
segment.end = seg.end;
|
||||
segment.text = Some(seg.text.clone());
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(asrx) = &self.asrx_data {
|
||||
for seg in &asrx.segments {
|
||||
if time >= seg.start && time <= seg.end {
|
||||
segment.start = seg.start;
|
||||
segment.end = seg.end;
|
||||
segment.speaker = Some(seg.speaker.clone());
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(face) = &self.face_data {
|
||||
for frame in &face.frames {
|
||||
if (frame.timestamp - time).abs() < 1.0 {
|
||||
if let Some(face_info) = frame.faces.first() {
|
||||
segment.face = Some(face_info.clone());
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(pose) = &self.pose_data {
|
||||
for frame in &pose.frames {
|
||||
if (frame.timestamp - time).abs() < 0.5 {
|
||||
if let Some(person) = frame.persons.first() {
|
||||
let mouth_points: Vec<Keypoint> = person
|
||||
.keypoints
|
||||
.iter()
|
||||
.filter(|kp| {
|
||||
kp.name.contains("mouth")
|
||||
|| kp.name.contains("lip")
|
||||
|| kp.name == "nose"
|
||||
})
|
||||
.cloned()
|
||||
.collect();
|
||||
if !mouth_points.is_empty() {
|
||||
segment.mouth_landmarks = Some(mouth_points);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if segment.text.is_some()
|
||||
|| segment.speaker.is_some()
|
||||
|| segment.face.is_some()
|
||||
|| segment.mouth_landmarks.is_some()
|
||||
{
|
||||
Some(segment)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn get_speaker_info(&self, speaker_id: &str) -> (String, String) {
|
||||
self.speaker_names
|
||||
.get(speaker_id)
|
||||
.cloned()
|
||||
.unwrap_or_else(|| ("Unknown".to_string(), "Unknown".to_string()))
|
||||
}
|
||||
|
||||
fn list_speakers(&self) {
|
||||
if let Some(asrx) = &self.asrx_data {
|
||||
println!("\n📊 Speaker Statistics:");
|
||||
println!("{:-<80}", "");
|
||||
println!(
|
||||
"{:15} {:20} {:20} {:>10} {:>10}",
|
||||
"Speaker ID", "Actor", "Character", "Segments", "Duration"
|
||||
);
|
||||
println!("{:-<80}", "");
|
||||
|
||||
for (speaker_id, stats) in &asrx.speaker_stats {
|
||||
let (actor, character) = self.get_speaker_info(speaker_id);
|
||||
println!(
|
||||
"{:15} {:20} {:20} {:>10} {:>9.1}s",
|
||||
speaker_id, actor, character, stats.count, stats.duration
|
||||
);
|
||||
}
|
||||
println!("{:-<80}", "");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn run_continuous_demo(player: &IntegratedPlayer, args: &Args) -> Result<()> {
|
||||
println!("\n🎬 Continuous Demo Mode");
|
||||
println!("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
|
||||
|
||||
let is_interactive = io::stdin().is_terminal();
|
||||
if is_interactive {
|
||||
println!("Controls:");
|
||||
println!(" SPACE - Pause/Resume");
|
||||
println!(" Q - Quit");
|
||||
} else {
|
||||
println!("Running in non-interactive mode (no keyboard control)");
|
||||
println!("Use Ctrl+C to stop");
|
||||
}
|
||||
println!("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
|
||||
println!();
|
||||
|
||||
let paused = Arc::new(AtomicBool::new(false));
|
||||
let quit = Arc::new(AtomicBool::new(false));
|
||||
|
||||
let paused_clone = paused.clone();
|
||||
let quit_clone = quit.clone();
|
||||
|
||||
let raw_mode_enabled = if is_interactive {
|
||||
crossterm_terminal::enable_raw_mode().ok().is_some()
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
if is_interactive && raw_mode_enabled {
|
||||
thread::spawn(move || loop {
|
||||
if let Ok(Event::Key(key_event)) = event::read() {
|
||||
if key_event.code == KeyCode::Char(' ') {
|
||||
paused_clone.fetch_xor(true, Ordering::SeqCst);
|
||||
} else if key_event.code == KeyCode::Char('q')
|
||||
|| key_event.code == KeyCode::Char('Q')
|
||||
|| key_event.code == KeyCode::Esc
|
||||
{
|
||||
quit_clone.store(true, Ordering::SeqCst);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if quit_clone.load(Ordering::SeqCst) {
|
||||
break;
|
||||
}
|
||||
thread::sleep(Duration::from_millis(50));
|
||||
});
|
||||
}
|
||||
|
||||
if let Some(asr) = &player.asr_data {
|
||||
let total_segments = asr.segments.len();
|
||||
|
||||
for (i, seg) in asr.segments.iter().enumerate() {
|
||||
if quit.load(Ordering::SeqCst) {
|
||||
println!("\n⏹️ Stopped by user");
|
||||
break;
|
||||
}
|
||||
|
||||
while paused.load(Ordering::SeqCst) {
|
||||
println!("\r⏸️ Paused - Press SPACE to resume");
|
||||
io::stdout().flush()?;
|
||||
thread::sleep(Duration::from_millis(100));
|
||||
|
||||
if quit.load(Ordering::SeqCst) {
|
||||
println!("\n⏹️ Stopped by user");
|
||||
if raw_mode_enabled {
|
||||
crossterm_terminal::disable_raw_mode().ok();
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
println!("\n[{}/{}] Segment", i + 1, total_segments);
|
||||
println!("{:=<80}", "");
|
||||
println!("📝 ASR Text: {}", seg.text);
|
||||
println!("⏱ Time: {:.2}s - {:.2}s", seg.start, seg.end);
|
||||
|
||||
if let Some(asrx) = &player.asrx_data {
|
||||
for asrx_seg in &asrx.segments {
|
||||
if seg.start >= asrx_seg.start && seg.start <= asrx_seg.end {
|
||||
let (actor, character) = player.get_speaker_info(&asrx_seg.speaker);
|
||||
println!(
|
||||
"🎤 Speaker: {} → {} ({})",
|
||||
asrx_seg.speaker, actor, character
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(segment) = player.get_current_segment(seg.start + 0.01) {
|
||||
if let Some(face) = &segment.face {
|
||||
println!(
|
||||
"👤 Face: bbox=({},{}) {}x{}, conf={:.2}",
|
||||
face.x, face.y, face.width, face.height, face.confidence
|
||||
);
|
||||
}
|
||||
if let Some(landmarks) = &segment.mouth_landmarks {
|
||||
println!("👄 Mouth landmarks: {} points", landmarks.len());
|
||||
}
|
||||
}
|
||||
|
||||
let duration = seg.end - seg.start;
|
||||
println!(
|
||||
"▶️ Playing: {:.2}s - {:.2}s ({:.2}s)",
|
||||
seg.start, seg.end, duration
|
||||
);
|
||||
|
||||
let mut cmd = Command::new("ffplay");
|
||||
if args.show_video {
|
||||
cmd.args([
|
||||
"-ss",
|
||||
&format!("{:.2}", seg.start),
|
||||
"-t",
|
||||
&format!("{:.2}", duration),
|
||||
"-autoexit",
|
||||
"-x",
|
||||
&format!("{}", args.video_width),
|
||||
"-y",
|
||||
&format!("{}", args.video_height),
|
||||
args.video.to_str().unwrap(),
|
||||
]);
|
||||
} else {
|
||||
cmd.args([
|
||||
"-ss",
|
||||
&format!("{:.2}", seg.start),
|
||||
"-t",
|
||||
&format!("{:.2}", duration),
|
||||
"-autoexit",
|
||||
"-nodisp",
|
||||
args.video.to_str().unwrap(),
|
||||
]);
|
||||
}
|
||||
|
||||
let _child = cmd
|
||||
.stdout(Stdio::null())
|
||||
.stderr(Stdio::null())
|
||||
.spawn()
|
||||
.context("Failed to start ffplay")?;
|
||||
|
||||
thread::sleep(Duration::from_millis((duration * 1000.0) as u64 + 100));
|
||||
}
|
||||
|
||||
println!("\n{:=<80}", "");
|
||||
println!("✅ Demo completed! Played {} segments", total_segments);
|
||||
println!("{:=<80}", "");
|
||||
} else if let Some(asrx) = &player.asrx_data {
|
||||
let total_segments = asrx.segments.len();
|
||||
println!(
|
||||
"Playing {} ASRX segments (no ASR text available)",
|
||||
total_segments
|
||||
);
|
||||
|
||||
for (i, seg) in asrx.segments.iter().enumerate() {
|
||||
if quit.load(Ordering::SeqCst) {
|
||||
println!("\n⏹️ Stopped by user");
|
||||
break;
|
||||
}
|
||||
|
||||
while paused.load(Ordering::SeqCst) {
|
||||
println!("\r⏸️ Paused - Press SPACE to resume");
|
||||
io::stdout().flush()?;
|
||||
thread::sleep(Duration::from_millis(100));
|
||||
|
||||
if quit.load(Ordering::SeqCst) {
|
||||
println!("\n⏹️ Stopped by user");
|
||||
if raw_mode_enabled {
|
||||
crossterm_terminal::disable_raw_mode().ok();
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
let (actor, character) = player.get_speaker_info(&seg.speaker);
|
||||
|
||||
println!("\n[{}/{}] Segment", i + 1, total_segments);
|
||||
println!("{:=<80}", "");
|
||||
println!(
|
||||
"⏱ Time: {:.2}s - {:.2}s ({:.2}s)",
|
||||
seg.start, seg.end, seg.duration
|
||||
);
|
||||
println!("🎤 Speaker: {} → {} ({})", seg.speaker, actor, character);
|
||||
|
||||
if let Some(segment) = player.get_current_segment(seg.start + 0.01) {
|
||||
if let Some(face) = &segment.face {
|
||||
println!(
|
||||
"👤 Face: bbox=({},{}) {}x{}, conf={:.2}",
|
||||
face.x, face.y, face.width, face.height, face.confidence
|
||||
);
|
||||
}
|
||||
if let Some(landmarks) = &segment.mouth_landmarks {
|
||||
println!("👄 Mouth landmarks: {} points", landmarks.len());
|
||||
}
|
||||
}
|
||||
|
||||
println!("▶️ Playing audio segment");
|
||||
|
||||
let mut cmd = Command::new("ffplay");
|
||||
if args.show_video {
|
||||
cmd.args([
|
||||
"-ss",
|
||||
&format!("{:.2}", seg.start),
|
||||
"-t",
|
||||
&format!("{:.2}", seg.duration),
|
||||
"-autoexit",
|
||||
"-x",
|
||||
&format!("{}", args.video_width),
|
||||
"-y",
|
||||
&format!("{}", args.video_height),
|
||||
args.video.to_str().unwrap(),
|
||||
]);
|
||||
} else {
|
||||
cmd.args([
|
||||
"-ss",
|
||||
&format!("{:.2}", seg.start),
|
||||
"-t",
|
||||
&format!("{:.2}", seg.duration),
|
||||
"-autoexit",
|
||||
"-nodisp",
|
||||
args.video.to_str().unwrap(),
|
||||
]);
|
||||
}
|
||||
|
||||
let _child = cmd
|
||||
.stdout(Stdio::null())
|
||||
.stderr(Stdio::null())
|
||||
.spawn()
|
||||
.context("Failed to start ffplay")?;
|
||||
|
||||
thread::sleep(Duration::from_millis((seg.duration * 1000.0) as u64 + 100));
|
||||
}
|
||||
|
||||
println!("\n{:=<80}", "");
|
||||
println!("✅ Demo completed! Played {} segments", total_segments);
|
||||
println!("{:=<80}", "");
|
||||
} else {
|
||||
println!("⚠️ No ASR or ASRX data loaded");
|
||||
}
|
||||
|
||||
if raw_mode_enabled {
|
||||
crossterm_terminal::disable_raw_mode().ok();
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
if !args.video.exists() {
|
||||
anyhow::bail!("Video file not found: {:?}", args.video);
|
||||
}
|
||||
|
||||
println!("🎬 Integrated Player for ASR/Face/ASRX/Pose");
|
||||
println!("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
|
||||
println!("Video: {:?}", args.video);
|
||||
|
||||
let mut player = IntegratedPlayer::new();
|
||||
|
||||
if let Some(asr_path) = &args.asr {
|
||||
if asr_path.exists() {
|
||||
player.load_asr(asr_path)?;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(face_path) = &args.face {
|
||||
if face_path.exists() {
|
||||
player.load_face(face_path)?;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(asrx_path) = &args.asrx {
|
||||
if asrx_path.exists() {
|
||||
player.load_asrx(asrx_path)?;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(pose_path) = &args.pose {
|
||||
if pose_path.exists() {
|
||||
player.load_pose(pose_path)?;
|
||||
}
|
||||
}
|
||||
|
||||
player.list_speakers();
|
||||
|
||||
if args.continuous_demo {
|
||||
run_continuous_demo(&player, &args)?;
|
||||
} else {
|
||||
println!("\n⚠️ Please use --continuous-demo flag");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
711
src/bin/integrated_player.rs.bak
Normal file
711
src/bin/integrated_player.rs.bak
Normal file
@@ -0,0 +1,711 @@
|
||||
use anyhow::{Context, Result};
|
||||
use clap::Parser;
|
||||
use crossterm::event::{self, Event, KeyCode, KeyModifiers};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::io::{self, Write};
|
||||
use std::path::PathBuf;
|
||||
use std::process::{Child, Command, Stdio};
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::thread;
|
||||
use std::time::Duration;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(name = "integrated_player")]
|
||||
#[command(about = "Integrated player for ASR, Face, ASRX, and Pose")]
|
||||
struct Args {
|
||||
#[arg(short, long)]
|
||||
video: PathBuf,
|
||||
|
||||
#[arg(short = 'r', long)]
|
||||
asr: Option<PathBuf>,
|
||||
|
||||
#[arg(short = 'f', long)]
|
||||
face: Option<PathBuf>,
|
||||
|
||||
#[arg(short = 'x', long)]
|
||||
asrx: Option<PathBuf>,
|
||||
|
||||
#[arg(short = 'p', long)]
|
||||
pose: Option<PathBuf>,
|
||||
|
||||
#[arg(short = 's', long, default_value = "0.0")]
|
||||
start: f64,
|
||||
|
||||
#[arg(long)]
|
||||
speaker_name: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
auto_play_speaker: bool,
|
||||
|
||||
#[arg(long)]
|
||||
demo: bool,
|
||||
|
||||
#[arg(long, default_value = "3")]
|
||||
demo_segments_per_speaker: usize,
|
||||
|
||||
#[arg(long, default_value = "2.0")]
|
||||
demo_speed: f64,
|
||||
|
||||
#[arg(long)]
|
||||
show_video: bool,
|
||||
|
||||
#[arg(long, default_value = "800")]
|
||||
video_width: u32,
|
||||
|
||||
#[arg(long, default_value = "600")]
|
||||
video_height: u32,
|
||||
|
||||
#[arg(long)]
|
||||
continuous_demo: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct AsrSegment {
|
||||
start: f64,
|
||||
end: f64,
|
||||
text: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct AsrData {
|
||||
language: Option<String>,
|
||||
segments: Vec<AsrSegment>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct FaceDetection {
|
||||
frame: u64,
|
||||
timestamp: f64,
|
||||
x: i32,
|
||||
y: i32,
|
||||
width: i32,
|
||||
height: i32,
|
||||
confidence: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct FaceResult {
|
||||
results: FaceResults,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct FaceResults {
|
||||
detections: Vec<FaceDetection>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct AsrxSegment {
|
||||
index: usize,
|
||||
start: f64,
|
||||
end: f64,
|
||||
duration: f64,
|
||||
speaker: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct AsrxData {
|
||||
segments: Vec<AsrxSegment>,
|
||||
speaker_stats: HashMap<String, SpeakerStats>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct SpeakerStats {
|
||||
count: usize,
|
||||
duration: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct Keypoint {
|
||||
name: String,
|
||||
x: f32,
|
||||
y: f32,
|
||||
confidence: f32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct PersonPose {
|
||||
keypoints: Vec<Keypoint>,
|
||||
bbox: Bbox,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct Bbox {
|
||||
x: i32,
|
||||
y: i32,
|
||||
width: i32,
|
||||
height: i32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct PoseFrame {
|
||||
frame: u64,
|
||||
timestamp: f64,
|
||||
persons: Vec<PersonPose>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct PoseData {
|
||||
frames: Vec<PoseFrame>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct IntegratedSegment {
|
||||
start: f64,
|
||||
end: f64,
|
||||
text: Option<String>,
|
||||
speaker: Option<String>,
|
||||
face: Option<FaceDetection>,
|
||||
mouth_landmarks: Option<Vec<Keypoint>>,
|
||||
}
|
||||
|
||||
struct IntegratedPlayer {
|
||||
asr_data: Option<AsrData>,
|
||||
face_data: Option<FaceResult>,
|
||||
asrx_data: Option<AsrxData>,
|
||||
pose_data: Option<PoseData>,
|
||||
current_time: f64,
|
||||
is_playing: bool,
|
||||
speaker_names: HashMap<String, (String, String)>,
|
||||
}
|
||||
|
||||
impl IntegratedPlayer {
|
||||
fn new() -> Self {
|
||||
let mut speaker_names = HashMap::new();
|
||||
speaker_names.insert(
|
||||
"SPEAKER_0".to_string(),
|
||||
("Cary Grant".to_string(), "Peter Joshua".to_string()),
|
||||
);
|
||||
speaker_names.insert(
|
||||
"SPEAKER_1".to_string(),
|
||||
("Audrey Hepburn".to_string(), "Regina Lampert".to_string()),
|
||||
);
|
||||
speaker_names.insert(
|
||||
"SPEAKER_2".to_string(),
|
||||
(
|
||||
"Walter Matthau".to_string(),
|
||||
"Hamilton Bartholomew".to_string(),
|
||||
),
|
||||
);
|
||||
speaker_names.insert(
|
||||
"SPEAKER_4".to_string(),
|
||||
("James Coburn".to_string(), "Tex Panthollow".to_string()),
|
||||
);
|
||||
|
||||
Self {
|
||||
asr_data: None,
|
||||
face_data: None,
|
||||
asrx_data: None,
|
||||
pose_data: None,
|
||||
current_time: 0.0,
|
||||
is_playing: false,
|
||||
speaker_names,
|
||||
}
|
||||
}
|
||||
|
||||
fn load_asr(&mut self, path: &PathBuf) -> Result<()> {
|
||||
let content = std::fs::read_to_string(path)
|
||||
.with_context(|| format!("Failed to read ASR file: {:?}", path))?;
|
||||
self.asr_data = Some(serde_json::from_str(&content)?);
|
||||
println!(
|
||||
"✓ Loaded {} ASR segments",
|
||||
self.asr_data.as_ref().unwrap().segments.len()
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn load_face(&mut self, path: &PathBuf) -> Result<()> {
|
||||
let content = std::fs::read_to_string(path)
|
||||
.with_context(|| format!("Failed to read Face file: {:?}", path))?;
|
||||
self.face_data = Some(serde_json::from_str(&content)?);
|
||||
println!(
|
||||
"✓ Loaded {} face detections",
|
||||
self.face_data.as_ref().unwrap().results.detections.len()
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn load_asrx(&mut self, path: &PathBuf) -> Result<()> {
|
||||
let content = std::fs::read_to_string(path)
|
||||
.with_context(|| format!("Failed to read ASRX file: {:?}", path))?;
|
||||
self.asrx_data = Some(serde_json::from_str(&content)?);
|
||||
println!(
|
||||
"✓ Loaded {} ASRX segments, {} speakers",
|
||||
self.asrx_data.as_ref().unwrap().segments.len(),
|
||||
self.asrx_data.as_ref().unwrap().speaker_stats.len()
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn load_pose(&mut self, path: &PathBuf) -> Result<()> {
|
||||
let content = std::fs::read_to_string(path)
|
||||
.with_context(|| format!("Failed to read Pose file: {:?}", path))?;
|
||||
self.pose_data = Some(serde_json::from_str(&content)?);
|
||||
println!(
|
||||
"✓ Loaded {} pose frames",
|
||||
self.pose_data.as_ref().unwrap().frames.len()
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_current_segment(&self, time: f64) -> Option<IntegratedSegment> {
|
||||
let mut segment = IntegratedSegment {
|
||||
start: 0.0,
|
||||
end: 0.0,
|
||||
text: None,
|
||||
speaker: None,
|
||||
face: None,
|
||||
mouth_landmarks: None,
|
||||
};
|
||||
|
||||
if let Some(asr) = &self.asr_data {
|
||||
for seg in &asr.segments {
|
||||
if time >= seg.start && time <= seg.end {
|
||||
segment.start = seg.start;
|
||||
segment.end = seg.end;
|
||||
segment.text = Some(seg.text.clone());
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(asrx) = &self.asrx_data {
|
||||
for seg in &asrx.segments {
|
||||
if time >= seg.start && time <= seg.end {
|
||||
segment.start = seg.start;
|
||||
segment.end = seg.end;
|
||||
segment.speaker = Some(seg.speaker.clone());
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(face) = &self.face_data {
|
||||
for det in &face.results.detections {
|
||||
if (det.timestamp - time).abs() < 1.0 {
|
||||
segment.face = Some(det.clone());
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(pose) = &self.pose_data {
|
||||
for frame in &pose.frames {
|
||||
if (frame.timestamp - time).abs() < 0.5 {
|
||||
if let Some(person) = frame.persons.first() {
|
||||
let mouth_points: Vec<Keypoint> = person
|
||||
.keypoints
|
||||
.iter()
|
||||
.filter(|kp| {
|
||||
kp.name.contains("mouth")
|
||||
|| kp.name.contains("lip")
|
||||
|| kp.name == "nose"
|
||||
})
|
||||
.cloned()
|
||||
.collect();
|
||||
if !mouth_points.is_empty() {
|
||||
segment.mouth_landmarks = Some(mouth_points);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if segment.text.is_some()
|
||||
|| segment.speaker.is_some()
|
||||
|| segment.face.is_some()
|
||||
|| segment.mouth_landmarks.is_some()
|
||||
{
|
||||
Some(segment)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn get_speaker_info(&self, speaker_id: &str) -> (String, String) {
|
||||
self.speaker_names
|
||||
.get(speaker_id)
|
||||
.cloned()
|
||||
.unwrap_or_else(|| ("Unknown".to_string(), "Unknown".to_string()))
|
||||
}
|
||||
|
||||
fn print_segment(&self, segment: &IntegratedSegment) {
|
||||
println!("\n{:=<80}", "");
|
||||
println!("⏱ Time: {:.2}s - {:.2}s", segment.start, segment.end);
|
||||
|
||||
if let Some(text) = &segment.text {
|
||||
println!("📝 Text: {}", text);
|
||||
}
|
||||
|
||||
if let Some(speaker) = &segment.speaker {
|
||||
let (actor, character) = self.get_speaker_info(speaker);
|
||||
println!("🎤 Speaker: {} → {} ({})", speaker, actor, character);
|
||||
}
|
||||
|
||||
if let Some(face) = &segment.face {
|
||||
println!(
|
||||
"👤 Face: bbox=({},{}) {}x{}, confidence={:.2}",
|
||||
face.x, face.y, face.width, face.height, face.confidence
|
||||
);
|
||||
}
|
||||
|
||||
if let Some(landmarks) = &segment.mouth_landmarks {
|
||||
println!("👄 Mouth landmarks: {} points", landmarks.len());
|
||||
for kp in landmarks.iter().take(3) {
|
||||
println!(
|
||||
" • {}: ({:.1}, {:.1}) conf={:.2}",
|
||||
kp.name, kp.x, kp.y, kp.confidence
|
||||
);
|
||||
}
|
||||
}
|
||||
println!("{:=<80}", "");
|
||||
}
|
||||
|
||||
fn list_speakers(&self) {
|
||||
if let Some(asrx) = &self.asrx_data {
|
||||
println!("\n📊 Speaker Statistics:");
|
||||
println!("{:-<80}", "");
|
||||
println!(
|
||||
"{:15} {:20} {:20} {:>10} {:>10}",
|
||||
"Speaker ID", "Actor", "Character", "Segments", "Duration"
|
||||
);
|
||||
println!("{:-<80}", "");
|
||||
|
||||
for (speaker_id, stats) in &asrx.speaker_stats {
|
||||
let (actor, character) = self.get_speaker_info(speaker_id);
|
||||
println!(
|
||||
"{:15} {:20} {:20} {:>10} {:>9.1}s",
|
||||
speaker_id, actor, character, stats.count, stats.duration
|
||||
);
|
||||
}
|
||||
println!("{:-<80}", "");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn play_segment(video_path: &PathBuf, start: f64, duration: f64, show_video: bool) -> Result<()> {
|
||||
println!("▶️ Playing {:.2}s - {:.2}s", start, start + duration);
|
||||
|
||||
let mut cmd = Command::new("ffplay");
|
||||
|
||||
if show_video {
|
||||
cmd.args([
|
||||
"-ss",
|
||||
&format!("{:.2}", start),
|
||||
"-t",
|
||||
&format!("{:.2}", duration),
|
||||
"-autoexit",
|
||||
video_path.to_str().unwrap(),
|
||||
]);
|
||||
} else {
|
||||
cmd.args([
|
||||
"-ss",
|
||||
&format!("{:.2}", start),
|
||||
"-t",
|
||||
&format!("{:.2}", duration),
|
||||
"-autoexit",
|
||||
"-nodisp",
|
||||
video_path.to_str().unwrap(),
|
||||
]);
|
||||
}
|
||||
|
||||
let _child = cmd
|
||||
.stdout(Stdio::null())
|
||||
.stderr(Stdio::null())
|
||||
.spawn()
|
||||
.context("Failed to start ffplay")?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn play_speaker_segments(
|
||||
player: &IntegratedPlayer,
|
||||
video_path: &PathBuf,
|
||||
speaker_id: &str,
|
||||
limit: Option<usize>,
|
||||
) -> Result<()> {
|
||||
if let Some(asrx) = &player.asrx_data {
|
||||
let segments: Vec<&AsrxSegment> = asrx
|
||||
.segments
|
||||
.iter()
|
||||
.filter(|s| s.speaker == speaker_id)
|
||||
.collect();
|
||||
|
||||
let total = segments.len();
|
||||
let count = limit.unwrap_or(total).min(total);
|
||||
|
||||
println!("\n🎬 Playing {} segments for {}", count, speaker_id);
|
||||
|
||||
for (i, seg) in segments.iter().take(count).enumerate() {
|
||||
println!("\n[{}/{}] Segment {}", i + 1, count, seg.index);
|
||||
|
||||
if let Some(segment) = player.get_current_segment(seg.start + 0.1) {
|
||||
player.print_segment(&segment);
|
||||
}
|
||||
|
||||
play_segment(video_path, seg.start, seg.duration, false)?;
|
||||
|
||||
thread::sleep(Duration::from_millis(500));
|
||||
}
|
||||
|
||||
println!("\n✅ Finished playing {} segments", count);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_demo(player: &IntegratedPlayer, args: &Args) -> Result<()> {
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
if let Some(asr) = &player.asr_data {
|
||||
let total_segments = asr.segments.len();
|
||||
|
||||
for (i, seg) in asr.segments.iter().enumerate() {
|
||||
// 檢查是否退出
|
||||
if quit.load(Ordering::SeqCst) {
|
||||
println!("\n⏹️ Stopped by user");
|
||||
break;
|
||||
}
|
||||
|
||||
// 檢查是否暫停
|
||||
while paused.load(Ordering::SeqCst) {
|
||||
println!("\r⏸️ Paused - Press SPACE to resume",);
|
||||
std::io::stdout().flush()?;
|
||||
thread::sleep(Duration::from_millis(100));
|
||||
|
||||
if quit.load(Ordering::SeqCst) {
|
||||
println!("\n⏹️ Stopped by user");
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
println!("\n[{}/{}] Segment", i + 1, total_segments);
|
||||
println!("{:=<80}", "");
|
||||
|
||||
// 顯示所有信息
|
||||
if let Some(segment) = player.get_current_segment(seg.start + 0.01) {
|
||||
player.print_segment(&segment);
|
||||
}
|
||||
|
||||
// 播放音頻/視頻
|
||||
let duration = seg.end - seg.start;
|
||||
println!(
|
||||
"▶️ Playing: {:.2}s - {:.2}s ({:.2}s)",
|
||||
seg.start, seg.end, duration
|
||||
);
|
||||
|
||||
let mut cmd = Command::new("ffplay");
|
||||
if args.show_video {
|
||||
cmd.args([
|
||||
"-ss",
|
||||
&format!("{:.2}", seg.start),
|
||||
"-t",
|
||||
&format!("{:.2}", duration),
|
||||
"-autoexit",
|
||||
"-x",
|
||||
&format!("{}", args.video_width),
|
||||
"-y",
|
||||
&format!("{}", args.video_height),
|
||||
args.video.to_str().unwrap(),
|
||||
]);
|
||||
} else {
|
||||
cmd.args([
|
||||
"-ss",
|
||||
&format!("{:.2}", seg.start),
|
||||
"-t",
|
||||
&format!("{:.2}", duration),
|
||||
"-autoexit",
|
||||
"-nodisp",
|
||||
args.video.to_str().unwrap(),
|
||||
]);
|
||||
}
|
||||
|
||||
let _child = cmd
|
||||
.stdout(Stdio::null())
|
||||
.stderr(Stdio::null())
|
||||
.spawn()
|
||||
.context("Failed to start ffplay")?;
|
||||
|
||||
// 等待播放完成
|
||||
thread::sleep(Duration::from_millis((duration * 1000.0) as u64 + 100));
|
||||
}
|
||||
|
||||
println!("\n{:=<80}", "");
|
||||
println!("✅ Demo completed! Played {} segments", total_segments);
|
||||
println!("{:=<80}", "");
|
||||
} else {
|
||||
println!("⚠️ No ASR data loaded");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_demo(player: &IntegratedPlayer, args: &Args) -> Result<()> {
|
||||
println!("\n🎬 Auto Demo Mode");
|
||||
println!("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
|
||||
println!("Segments per speaker: {}", args.demo_segments_per_speaker);
|
||||
println!("Demo speed: {:.1}x", args.demo_speed);
|
||||
println!();
|
||||
|
||||
if let Some(asrx) = &player.asrx_data {
|
||||
let mut speaker_ids: Vec<String> = asrx.speaker_stats.keys().cloned().collect();
|
||||
speaker_ids.sort();
|
||||
|
||||
for speaker_id in &speaker_ids {
|
||||
let (actor, character) = player.get_speaker_info(speaker_id);
|
||||
|
||||
println!("\n{:=<80}", "");
|
||||
println!("🎭 Demo: {} → {} ({})", speaker_id, actor, character);
|
||||
println!("{:=<80}", "");
|
||||
|
||||
let segments: Vec<&AsrxSegment> = asrx
|
||||
.segments
|
||||
.iter()
|
||||
.filter(|s| s.speaker == *speaker_id)
|
||||
.collect();
|
||||
|
||||
let count = args.demo_segments_per_speaker.min(segments.len());
|
||||
|
||||
for (i, seg) in segments.iter().take(count).enumerate() {
|
||||
println!("\n[Segment {}/{}]", i + 1, count);
|
||||
|
||||
if let Some(segment) = player.get_current_segment(seg.start + 0.1) {
|
||||
player.print_segment(&segment);
|
||||
}
|
||||
|
||||
println!(
|
||||
"⏳ Playing audio ({:.1}s)...",
|
||||
seg.duration / args.demo_speed
|
||||
);
|
||||
|
||||
let _child = Command::new("ffplay")
|
||||
.args([
|
||||
"-ss",
|
||||
&format!("{:.2}", seg.start),
|
||||
"-t",
|
||||
&format!("{:.2}", seg.duration / args.demo_speed),
|
||||
"-autoexit",
|
||||
"-nodisp",
|
||||
args.video.to_str().unwrap(),
|
||||
])
|
||||
.stdout(Stdio::null())
|
||||
.stderr(Stdio::null())
|
||||
.spawn()
|
||||
.context("Failed to start ffplay")?;
|
||||
|
||||
thread::sleep(Duration::from_millis(
|
||||
((seg.duration / args.demo_speed) * 1000.0) as u64 + 500,
|
||||
));
|
||||
}
|
||||
|
||||
println!("\n⏸️ Pausing 2 seconds before next speaker...");
|
||||
thread::sleep(Duration::from_secs(2));
|
||||
}
|
||||
|
||||
println!("\n{:=<80}", "");
|
||||
println!("✅ Demo completed!");
|
||||
println!("{:=<80}", "");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
if !args.video.exists() {
|
||||
anyhow::bail!("Video file not found: {:?}", args.video);
|
||||
}
|
||||
|
||||
println!("🎬 Integrated Player for ASR/Face/ASRX/Pose");
|
||||
println!("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
|
||||
println!("Video: {:?}", args.video);
|
||||
|
||||
let mut player = IntegratedPlayer::new();
|
||||
|
||||
if let Some(asr_path) = &args.asr {
|
||||
if asr_path.exists() {
|
||||
player.load_asr(asr_path)?;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(face_path) = &args.face {
|
||||
if face_path.exists() {
|
||||
player.load_face(face_path)?;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(asrx_path) = &args.asrx {
|
||||
if asrx_path.exists() {
|
||||
player.load_asrx(asrx_path)?;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(pose_path) = &args.pose {
|
||||
if pose_path.exists() {
|
||||
player.load_pose(pose_path)?;
|
||||
}
|
||||
}
|
||||
|
||||
player.list_speakers();
|
||||
|
||||
if args.demo {
|
||||
run_demo(&player, &args)?;
|
||||
} else if args.continuous_demo {
|
||||
run_continuous_demo(&player, &args)?;
|
||||
} else if args.auto_play_speaker {
|
||||
if let Some(speaker_id) = &args.speaker_name {
|
||||
play_speaker_segments(&player, &args.video, speaker_id, Some(5))?;
|
||||
} else {
|
||||
println!("\n⚠️ --speaker-name required for --auto-play-speaker");
|
||||
}
|
||||
} else {
|
||||
println!("\n🎮 Interactive Mode");
|
||||
println!(" Commands:");
|
||||
println!(" • Enter time in seconds to seek");
|
||||
println!(" • 's' to show current segment");
|
||||
println!(" • 'l' to list speakers");
|
||||
println!(" • 'p <speaker>' to play speaker segments");
|
||||
println!(" • 'q' to quit");
|
||||
println!();
|
||||
|
||||
loop {
|
||||
print!("> ");
|
||||
std::io::Write::flush(&mut std::io::stdout())?;
|
||||
|
||||
let mut input = String::new();
|
||||
std::io::stdin().read_line(&mut input)?;
|
||||
let input = input.trim();
|
||||
|
||||
if input == "q" || input == "quit" || input == "exit" {
|
||||
break;
|
||||
} else if input == "s" || input == "show" {
|
||||
if let Some(segment) = player.get_current_segment(player.current_time) {
|
||||
player.print_segment(&segment);
|
||||
} else {
|
||||
println!("No segment at time {:.2}s", player.current_time);
|
||||
}
|
||||
} else if input == "l" || input == "list" {
|
||||
player.list_speakers();
|
||||
} else if input.starts_with("p ") {
|
||||
let speaker_id = input.strip_prefix("p ").unwrap();
|
||||
play_speaker_segments(&player, &args.video, speaker_id, Some(3))?;
|
||||
} else if let Ok(time) = input.parse::<f64>() {
|
||||
player.current_time = time;
|
||||
println!("Seeked to {:.2}s", time);
|
||||
|
||||
if let Some(segment) = player.get_current_segment(time) {
|
||||
player.print_segment(&segment);
|
||||
} else {
|
||||
println!("No segment at this time");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
92
src/bin/migrate_chinese_text.rs
Normal file
92
src/bin/migrate_chinese_text.rs
Normal file
@@ -0,0 +1,92 @@
|
||||
// Migration script to tokenize existing Chinese text in the database
|
||||
// Usage: cargo run --bin migrate_chinese_text
|
||||
|
||||
use dotenv;
|
||||
use momentry_core::core::text::tokenizer::tokenize_chinese_text;
|
||||
use sqlx::{postgres::PgPoolOptions, Row};
|
||||
use std::env;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
// Load environment variables from .env file
|
||||
dotenv::dotenv().ok();
|
||||
|
||||
// Get database URL from environment
|
||||
let database_url = env::var("DATABASE_URL")
|
||||
.unwrap_or_else(|_| "postgres://accusys@localhost:5432/momentry".to_string());
|
||||
|
||||
println!("Connecting to database...");
|
||||
|
||||
// Create connection pool
|
||||
let pool = PgPoolOptions::new()
|
||||
.max_connections(5)
|
||||
.connect(&database_url)
|
||||
.await?;
|
||||
|
||||
println!("Fetching Chinese chunks from database...");
|
||||
|
||||
// Get all chunks with Chinese text using raw query to avoid sqlx macro issues
|
||||
let query = r#"
|
||||
SELECT id, text_content, content->'data'->>'text' as chinese_text, content->>'text' as english_text
|
||||
FROM chunks
|
||||
WHERE text_content ~ '[\u4e00-\u9fff]'
|
||||
ORDER BY id
|
||||
"#;
|
||||
|
||||
let rows = sqlx::query(query).fetch_all(&pool).await?;
|
||||
|
||||
println!("Found {} Chinese chunks to process", rows.len());
|
||||
|
||||
let mut updated_count = 0;
|
||||
|
||||
for row in &rows {
|
||||
let id: i32 = row.get(0);
|
||||
let text_content: Option<String> = row.get(1);
|
||||
let chinese_text: Option<String> = row.get(2);
|
||||
let english_text: Option<String> = row.get(3);
|
||||
|
||||
// Clone text_content for later comparison
|
||||
let text_content_clone = text_content.clone();
|
||||
|
||||
// Determine the original text (prioritize chinese_text from content->'data'->>'text')
|
||||
let original_text = if let Some(ref chinese_text) = chinese_text {
|
||||
chinese_text.as_str()
|
||||
} else if let Some(ref english_text) = english_text {
|
||||
english_text.as_str()
|
||||
} else {
|
||||
text_content.as_deref().unwrap_or("")
|
||||
};
|
||||
|
||||
// Tokenize the text
|
||||
let tokenized_text = tokenize_chinese_text(original_text);
|
||||
|
||||
// Check if tokenization changed the text
|
||||
let current_text = text_content_clone.unwrap_or_default();
|
||||
if current_text == tokenized_text {
|
||||
println!("Skipping chunk {} - already tokenized", id);
|
||||
continue;
|
||||
}
|
||||
|
||||
println!("Updating chunk {}:", id);
|
||||
println!(" Original: {}", original_text);
|
||||
println!(" Tokenized: {}", tokenized_text);
|
||||
|
||||
// Update the chunk
|
||||
sqlx::query("UPDATE chunks SET text_content = $1 WHERE id = $2")
|
||||
.bind(&tokenized_text)
|
||||
.bind(id)
|
||||
.execute(&pool)
|
||||
.await?;
|
||||
|
||||
updated_count += 1;
|
||||
}
|
||||
|
||||
println!("\nMigration completed!");
|
||||
println!(
|
||||
"Updated {} out of {} Chinese chunks",
|
||||
updated_count,
|
||||
rows.len()
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
68
src/bin/test_bm25_simple.rs
Normal file
68
src/bin/test_bm25_simple.rs
Normal file
@@ -0,0 +1,68 @@
|
||||
use anyhow::{Context, Result};
|
||||
use momentry_core::core::db::{Database, PostgresDb};
|
||||
use std::env;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
env::set_var("RUST_LOG", "info");
|
||||
|
||||
println!("=== BM25 簡單測試 ===\n");
|
||||
|
||||
// 初始化 PostgreSQL
|
||||
let pg = PostgresDb::init()
|
||||
.await
|
||||
.context("Failed to initialize PostgreSQL database")?;
|
||||
|
||||
// 測試查詢
|
||||
let test_queries = vec![
|
||||
("telephone", Some("384b0ff44aaaa1f1")),
|
||||
("工作", Some("9760d0820f0cf9a7")),
|
||||
("团体", Some("9760d0820f0cf9a7")), // Simplified Chinese, should match Traditional "團體"
|
||||
("computer", None),
|
||||
];
|
||||
|
||||
for (query_str, uuid_opt) in test_queries {
|
||||
println!(
|
||||
"\n🔍 測試查詢: '{}' {}",
|
||||
query_str,
|
||||
uuid_opt
|
||||
.map(|u| format!("(uuid: {})", u))
|
||||
.unwrap_or_default()
|
||||
);
|
||||
|
||||
// 顯示轉換後的 tsquery (除錯用)
|
||||
match pg.prepare_tsquery(query_str) {
|
||||
Ok(tsquery) => println!(" TSQUERY: {}", tsquery),
|
||||
Err(e) => println!(" TSQUERY 錯誤: {}", e),
|
||||
}
|
||||
|
||||
let results = pg.search_bm25(query_str, uuid_opt, 5).await?;
|
||||
|
||||
println!("找到 {} 筆結果:", results.len());
|
||||
for (i, r) in results.iter().enumerate() {
|
||||
let text_preview: String = r.text.chars().take(60).collect();
|
||||
let text_preview = if r.text.chars().count() > 60 {
|
||||
format!("{}...", text_preview)
|
||||
} else {
|
||||
text_preview
|
||||
};
|
||||
println!(
|
||||
" {}. {} (uuid: {}, chunk_id: {})",
|
||||
i + 1,
|
||||
text_preview,
|
||||
r.uuid,
|
||||
r.chunk_id
|
||||
);
|
||||
println!(
|
||||
" 分數: {:.4}, 時間: {:.1}-{:.1}s, 類型: {}",
|
||||
r.bm25_score, r.start_time, r.end_time, r.chunk_type
|
||||
);
|
||||
}
|
||||
|
||||
if results.is_empty() {
|
||||
println!(" ⚠️ 沒有找到結果");
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
37
src/bin/test_simplified_chinese.rs
Normal file
37
src/bin/test_simplified_chinese.rs
Normal file
@@ -0,0 +1,37 @@
|
||||
use anyhow::{Context, Result};
|
||||
use momentry_core::core::db::{Database, PostgresDb};
|
||||
use std::env;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
env::set_var("RUST_LOG", "info");
|
||||
|
||||
println!("=== 簡體中文轉換測試 ===\n");
|
||||
|
||||
// 初始化 PostgreSQL
|
||||
let pg = PostgresDb::init()
|
||||
.await
|
||||
.context("Failed to initialize PostgreSQL database")?;
|
||||
|
||||
// 測試查詢:簡體中文
|
||||
let test_queries = vec!["团体", "视频", "文件"];
|
||||
|
||||
for query_str in test_queries {
|
||||
println!("\n🔍 測試查詢 (簡體): '{}'", query_str);
|
||||
|
||||
// 顯示轉換後的 tsquery
|
||||
match pg.prepare_tsquery(query_str) {
|
||||
Ok(tsquery) => println!(" TSQUERY: {}", tsquery),
|
||||
Err(e) => println!(" TSQUERY 錯誤: {}", e),
|
||||
}
|
||||
|
||||
// 執行搜索
|
||||
let results = pg.search_bm25(query_str, None, 5).await?;
|
||||
println!(" 找到 {} 筆結果", results.len());
|
||||
for (i, r) in results.iter().enumerate() {
|
||||
println!(" {}. [{}] {}", i + 1, r.uuid, r.text);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
23
src/bin/test_synonym_chinese.rs
Normal file
23
src/bin/test_synonym_chinese.rs
Normal file
@@ -0,0 +1,23 @@
|
||||
use momentry_core::core::text::global_synonym_expander;
|
||||
|
||||
fn main() {
|
||||
let expander = global_synonym_expander();
|
||||
|
||||
println!("=== 中文同義詞擴展測試 ===");
|
||||
|
||||
let test_queries = vec!["電腦", "電腦工作", "工作檔案", "視頻分析", "電腦工作檔案"];
|
||||
|
||||
for query in test_queries {
|
||||
println!("\n查詢: '{}'", query);
|
||||
let expanded = expander.expand_chinese_query(query);
|
||||
println!("擴展結果: {}", expanded);
|
||||
|
||||
// 測試單詞擴展
|
||||
println!("單詞擴展:");
|
||||
if let Some(syns) = expander.get_synonyms(query) {
|
||||
println!(" '{}' -> {:?}", query, syns);
|
||||
} else {
|
||||
println!(" '{}' 沒有同義詞", query);
|
||||
}
|
||||
}
|
||||
}
|
||||
56
src/bin/test_synonym_expansion.rs
Normal file
56
src/bin/test_synonym_expansion.rs
Normal file
@@ -0,0 +1,56 @@
|
||||
use anyhow::{Context, Result};
|
||||
use momentry_core::core::db::{Database, PostgresDb};
|
||||
use momentry_core::core::text::tokenizer::{contains_chinese, tokenize_chinese_text};
|
||||
use momentry_core::core::text::{global_synonym_expander, normalize_chinese_query};
|
||||
use std::env;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
env::set_var("RUST_LOG", "info");
|
||||
|
||||
println!("=== 同義詞擴展測試 ===\n");
|
||||
|
||||
// 初始化 PostgreSQL
|
||||
let pg = PostgresDb::init()
|
||||
.await
|
||||
.context("Failed to initialize PostgreSQL database")?;
|
||||
|
||||
let expander = global_synonym_expander();
|
||||
|
||||
// 測試查詢
|
||||
let test_queries = vec![
|
||||
"電腦",
|
||||
"視頻",
|
||||
"分析",
|
||||
"工作",
|
||||
"檔案",
|
||||
"電腦工作",
|
||||
"工作檔案",
|
||||
];
|
||||
|
||||
for query_str in test_queries {
|
||||
println!("\n🔍 測試查詢: '{}'", query_str);
|
||||
|
||||
// 顯示同義詞擴展
|
||||
if contains_chinese(query_str) {
|
||||
let normalized = normalize_chinese_query(query_str);
|
||||
let expanded = expander.expand_chinese_query(&normalized);
|
||||
println!(" 同義詞擴展: {}", expanded);
|
||||
}
|
||||
|
||||
// 顯示轉換後的 tsquery
|
||||
match pg.prepare_tsquery(query_str) {
|
||||
Ok(tsquery) => println!(" TSQUERY: {}", tsquery),
|
||||
Err(e) => println!(" TSQUERY 錯誤: {}", e),
|
||||
}
|
||||
|
||||
// 執行搜索(即使沒有結果)
|
||||
let results = pg.search_bm25(query_str, None, 2).await?;
|
||||
println!(" 找到 {} 筆結果", results.len());
|
||||
for (i, r) in results.iter().enumerate() {
|
||||
println!(" {}. [{}] {}", i + 1, r.uuid, r.text);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
56
src/bin/test_synonym_expansion.rs.bak
Normal file
56
src/bin/test_synonym_expansion.rs.bak
Normal file
@@ -0,0 +1,56 @@
|
||||
use anyhow::{Context, Result};
|
||||
use momentry_core::core::db::{Database, PostgresDb};
|
||||
use momentry_core::core::text::tokenizer::{contains_chinese, tokenize_chinese_text};
|
||||
use momentry_core::core::text::{global_synonym_expander, normalize_chinese_query};
|
||||
use std::env;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
env::set_var("RUST_LOG", "info");
|
||||
|
||||
println!("=== 同義詞擴展測試 ===\n");
|
||||
|
||||
// 初始化 PostgreSQL
|
||||
let pg = PostgresDb::init()
|
||||
.await
|
||||
.context("Failed to initialize PostgreSQL database")?;
|
||||
|
||||
let expander = global_synonym_expander();
|
||||
|
||||
// 測試查詢
|
||||
let test_queries = vec![
|
||||
"電腦",
|
||||
"視頻",
|
||||
"分析",
|
||||
"工作",
|
||||
"檔案",
|
||||
"電腦工作",
|
||||
"工作檔案",
|
||||
];
|
||||
|
||||
for query_str in test_queries {
|
||||
println!("\n🔍 測試查詢: '{}'", query_str);
|
||||
|
||||
// 顯示同義詞擴展
|
||||
if contains_chinese(query_str) {
|
||||
let normalized = normalize_chinese_query(query_str);
|
||||
let expanded = expander.expand_chinese_query(&normalized);
|
||||
println!(" 同義詞擴展: {}", expanded);
|
||||
}
|
||||
|
||||
// 顯示轉換後的 tsquery
|
||||
match pg.prepare_tsquery(query_str) {
|
||||
Ok(tsquery) => println!(" TSQUERY: {}", tsquery),
|
||||
Err(e) => println!(" TSQUERY 錯誤: {}", e),
|
||||
}
|
||||
|
||||
// 執行搜索(即使沒有結果)
|
||||
let results = pg.search_bm25(query_str, None, 2).await?;
|
||||
println!(" 找到 {} 筆結果", results.len());
|
||||
for (i, r) in results.iter().enumerate() {
|
||||
println!(" {}. [{}] {}", i + 1, r.uuid, r.text);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
27
src/bin/test_tokenizer_debug.rs
Normal file
27
src/bin/test_tokenizer_debug.rs
Normal file
@@ -0,0 +1,27 @@
|
||||
use momentry_core::core::text::tokenizer::{contains_chinese, tokenize_chinese_text};
|
||||
|
||||
fn main() {
|
||||
let texts = ["電腦", "工作", "視頻", "分析", "檔案", "這是一個測試"];
|
||||
for text in texts {
|
||||
let tokens = tokenize_chinese_text(text);
|
||||
println!("Text: '{}' -> Tokens: '{}'", text, tokens);
|
||||
let split: Vec<&str> = tokens.split_whitespace().collect();
|
||||
println!(" Split: {:?}", split);
|
||||
}
|
||||
|
||||
println!("\n=== Testing complex queries ===");
|
||||
let complex = [
|
||||
"(電腦 | 計算機 | 微机)",
|
||||
"(工作 | 任務 | 作業)",
|
||||
"電腦 & 工作",
|
||||
"(電腦:* | 計算機:* | 微机:*)",
|
||||
];
|
||||
|
||||
for query in complex {
|
||||
let tokens = tokenize_chinese_text(query);
|
||||
println!("Query: '{}' -> Tokens: '{}'", query, tokens);
|
||||
let split: Vec<&str> = tokens.split_whitespace().collect();
|
||||
println!(" Split: {:?}", split);
|
||||
println!("---");
|
||||
}
|
||||
}
|
||||
94
src/core/chunk/rule1_ingest.rs
Normal file
94
src/core/chunk/rule1_ingest.rs
Normal file
@@ -0,0 +1,94 @@
|
||||
use crate::core::config::OUTPUT_DIR;
|
||||
use anyhow::{Context, Result};
|
||||
use serde::Deserialize;
|
||||
use sqlx::PgPool;
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
|
||||
// --- 結構體定義 (對齊外部處理器產出格式) ---
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct AsrSegment {
|
||||
start: f64,
|
||||
end: f64,
|
||||
text: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct AsrxSegment {
|
||||
start: f64,
|
||||
end: f64,
|
||||
speaker: String,
|
||||
}
|
||||
|
||||
// --- 核心邏輯 ---
|
||||
|
||||
/// 執行 Rule 1 入庫
|
||||
/// 讀取 asr.json 與 asrx.json,合併 Speaker 資訊,寫入 chunks_rule1
|
||||
pub async fn ingest_rule1(pool: &PgPool, asset_uuid: &str, fps: f64) -> Result<usize> {
|
||||
// 1. 讀取檔案
|
||||
let asr_path = format!("{}/{}.asr.json", *OUTPUT_DIR, asset_uuid);
|
||||
let asrx_path = format!("{}/{}.asrx.json", *OUTPUT_DIR, asset_uuid);
|
||||
|
||||
let asr_content = fs::read_to_string(&asr_path)
|
||||
.with_context(|| format!("Failed to read ASR file: {}", asr_path))?;
|
||||
let asrx_content = fs::read_to_string(&asrx_path)
|
||||
.with_context(|| format!("Failed to read ASRX file: {}", asrx_path))?;
|
||||
|
||||
let asr_segments: Vec<AsrSegment> = serde_json::from_str(&asr_content)?;
|
||||
let asrx_segments: Vec<AsrxSegment> = serde_json::from_str(&asrx_content)?;
|
||||
|
||||
let mut count = 0;
|
||||
|
||||
// 2. 交易處理
|
||||
let mut tx = pool.begin().await?;
|
||||
|
||||
for seg in &asr_segments {
|
||||
// 時間轉幀
|
||||
let start_frame = (seg.start * fps).round() as i64;
|
||||
let end_frame = (seg.end * fps).round() as i64;
|
||||
|
||||
// 3. 尋找重疊最多的 Speaker
|
||||
let mut best_speaker: Option<String> = None;
|
||||
let mut max_overlap = 0.0f64;
|
||||
|
||||
for spk in &asrx_segments {
|
||||
let overlap = (seg.end.min(spk.end) - seg.start.max(spk.start)).max(0.0);
|
||||
if overlap > max_overlap {
|
||||
max_overlap = overlap;
|
||||
best_speaker = Some(spk.speaker.clone());
|
||||
}
|
||||
}
|
||||
|
||||
let speaker_id = best_speaker.unwrap_or("UNKNOWN".to_string());
|
||||
|
||||
// 4. 寫入 DB
|
||||
sqlx::query!(
|
||||
r#"
|
||||
INSERT INTO chunks_rule1 (
|
||||
id, asset_uuid, start_frame, end_frame, content, speaker_id
|
||||
) VALUES (
|
||||
gen_random_uuid(), $1, $2, $3, $4, $5
|
||||
)
|
||||
"#,
|
||||
asset_uuid,
|
||||
start_frame,
|
||||
end_frame,
|
||||
seg.text,
|
||||
speaker_id
|
||||
)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
|
||||
count += 1;
|
||||
|
||||
// 每 100 筆 Commit 一次 (可選優化)
|
||||
if count % 500 == 0 {
|
||||
tx.commit().await?;
|
||||
tx = pool.begin().await?;
|
||||
}
|
||||
}
|
||||
|
||||
tx.commit().await?;
|
||||
Ok(count)
|
||||
}
|
||||
182
src/core/chunk/rule3_ingest.rs
Normal file
182
src/core/chunk/rule3_ingest.rs
Normal file
@@ -0,0 +1,182 @@
|
||||
use crate::core::config::OUTPUT_DIR;
|
||||
use crate::core::llm::client::generate_5w1h_summary;
|
||||
use anyhow::{Context, Result};
|
||||
use serde::Deserialize;
|
||||
use sqlx::PgPool;
|
||||
use std::fs;
|
||||
use tracing::{info, warn};
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct CutScene {
|
||||
scene_number: u32,
|
||||
start_frame: u64,
|
||||
end_frame: u64,
|
||||
start_time: f64,
|
||||
end_time: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct CutResult {
|
||||
scenes: Vec<CutScene>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct AsrSegment {
|
||||
start: f64,
|
||||
end: f64,
|
||||
text: String,
|
||||
}
|
||||
|
||||
/// Executes Rule 3 Ingestion: Scene-based Chunking with LLM 5W1H+ Summary.
|
||||
/// 1. Reads CUT data to identify scenes.
|
||||
/// 2. Aggregates Rule 1 (Sentence) chunks falling within each scene.
|
||||
/// 3. Calls LLM to generate 5W1H+ summary.
|
||||
/// 4. Inserts parent chunks into `dev.chunks`.
|
||||
pub async fn ingest_rule3(pool: &PgPool, asset_uuid: &str) -> Result<usize> {
|
||||
let cut_path = format!("{}/{}.cut.json", *OUTPUT_DIR, asset_uuid);
|
||||
let asr_path = format!("{}/{}.asr.json", *OUTPUT_DIR, asset_uuid);
|
||||
|
||||
// 1. Load CUT and ASR data
|
||||
let cut_content = fs::read_to_string(&cut_path)
|
||||
.with_context(|| format!("Failed to read CUT file: {}", cut_path))?;
|
||||
let cut_result: CutResult = serde_json::from_str(&cut_content).context("Invalid CUT JSON")?;
|
||||
|
||||
let asr_segments: Vec<AsrSegment> = match fs::read_to_string(&asr_path) {
|
||||
Ok(content) => serde_json::from_str(&content).unwrap_or_default(),
|
||||
Err(_) => {
|
||||
warn!("ASR file not found, proceeding with empty transcript for scenes");
|
||||
vec![]
|
||||
}
|
||||
};
|
||||
|
||||
let mut count = 0;
|
||||
let mut tx = pool.begin().await?;
|
||||
|
||||
// 2. Process each scene
|
||||
for scene in &cut_result.scenes {
|
||||
let chunk_id = format!("scene_{}", scene.scene_number);
|
||||
|
||||
// Aggregate text from Rule 1 chunks
|
||||
let mut scene_text = String::new();
|
||||
let mut child_ids: Vec<String> = Vec::new();
|
||||
|
||||
for seg in &asr_segments {
|
||||
if seg.start >= scene.start_time && seg.end <= scene.end_time {
|
||||
scene_text.push_str(&seg.text);
|
||||
scene_text.push(' ');
|
||||
// We'll look up the chunk_id from Rule 1 later if needed,
|
||||
// but for now we just group by text overlap.
|
||||
// A better approach is to query Rule 1 table for this range.
|
||||
}
|
||||
}
|
||||
|
||||
// Query Rule 1 table for better linking
|
||||
let rule1_rows: Vec<(String,)> = sqlx::query_as(
|
||||
r#"
|
||||
SELECT id::text FROM chunks_rule1
|
||||
WHERE asset_uuid = $1
|
||||
AND start_frame >= $2
|
||||
AND end_frame <= $3
|
||||
"#,
|
||||
)
|
||||
.bind(asset_uuid)
|
||||
.bind(scene.start_frame as i64)
|
||||
.bind(scene.end_frame as i64)
|
||||
.fetch_all(&mut *tx)
|
||||
.await?;
|
||||
|
||||
for row in &rule1_rows {
|
||||
child_ids.push(row.0.clone());
|
||||
}
|
||||
|
||||
// Fallback to simple aggregation if query didn't get text (due to frame boundaries)
|
||||
if scene_text.is_empty() {
|
||||
// Try to grab text directly if rule1 table doesn't have it or boundaries differ
|
||||
// But rule1 table has start_frame/end_frame which should match.
|
||||
// Let's re-query text directly.
|
||||
}
|
||||
|
||||
let texts: Vec<String> = sqlx::query_scalar(
|
||||
r#"
|
||||
SELECT content FROM chunks_rule1
|
||||
WHERE asset_uuid = $1
|
||||
AND start_frame >= $2
|
||||
AND end_frame <= $3
|
||||
ORDER BY start_frame ASC
|
||||
"#,
|
||||
)
|
||||
.bind(asset_uuid)
|
||||
.bind(scene.start_frame as i64)
|
||||
.bind(scene.end_frame as i64)
|
||||
.fetch_all(&mut *tx)
|
||||
.await?;
|
||||
|
||||
let aggregated_text = texts.join(" ");
|
||||
|
||||
// 3. Call LLM for Summary
|
||||
let summary = if !aggregated_text.is_empty() {
|
||||
match generate_5w1h_summary(&aggregated_text).await {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
warn!("LLM Summary failed for scene {}: {}", scene.scene_number, e);
|
||||
"LLM Error".to_string()
|
||||
}
|
||||
}
|
||||
} else {
|
||||
"No Audio".to_string()
|
||||
};
|
||||
|
||||
info!(
|
||||
"Scene {}: {} -> {} ({} sentences)",
|
||||
scene.scene_number,
|
||||
scene.start_time,
|
||||
scene.end_time,
|
||||
texts.len()
|
||||
);
|
||||
|
||||
// 4. Insert into dev.chunks
|
||||
let fps_query: Option<f64> = sqlx::query_scalar("SELECT fps FROM videos WHERE uuid = $1")
|
||||
.bind(asset_uuid)
|
||||
.fetch_optional(&mut *tx)
|
||||
.await?;
|
||||
let fps = fps_query.unwrap_or(29.97);
|
||||
|
||||
// Prepare metadata JSON
|
||||
let metadata = serde_json::json!({
|
||||
"type": "scene",
|
||||
"scene_number": scene.scene_number
|
||||
});
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO chunks (
|
||||
uuid, chunk_id, chunk_index, chunk_type,
|
||||
start_time, end_time, fps, start_frame, end_frame,
|
||||
content, text_content, summary_text, metadata, child_chunk_ids
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
|
||||
ON CONFLICT (uuid, chunk_id) DO NOTHING
|
||||
"#,
|
||||
)
|
||||
.bind(asset_uuid)
|
||||
.bind(&chunk_id)
|
||||
.bind(scene.scene_number as i32)
|
||||
.bind("cut") // Chunk type
|
||||
.bind(scene.start_time)
|
||||
.bind(scene.end_time)
|
||||
.bind(fps)
|
||||
.bind(scene.start_frame as i64)
|
||||
.bind(scene.end_frame as i64)
|
||||
.bind(&metadata) // Content JSON
|
||||
.bind(&aggregated_text) // Text content
|
||||
.bind(&summary) // Summary
|
||||
.bind(&metadata) // Metadata
|
||||
.bind(&child_ids) // Child IDs
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
|
||||
count += 1;
|
||||
}
|
||||
|
||||
tx.commit().await?;
|
||||
Ok(count)
|
||||
}
|
||||
755
src/core/chunk/types.rs.bak
Normal file
755
src/core/chunk/types.rs.bak
Normal file
@@ -0,0 +1,755 @@
|
||||
use crate::core::time::FrameTime;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ChunkType {
|
||||
TimeBased,
|
||||
Sentence,
|
||||
Cut,
|
||||
Trace,
|
||||
Story, // Parent chunk from story analysis
|
||||
Visual, // Visual object-based chunk from YOLO detection
|
||||
}
|
||||
|
||||
impl ChunkType {
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
ChunkType::TimeBased => "time",
|
||||
ChunkType::Sentence => "sentence",
|
||||
ChunkType::Cut => "cut",
|
||||
ChunkType::Trace => "trace",
|
||||
ChunkType::Story => "story",
|
||||
ChunkType::Visual => "visual",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ChunkRule {
|
||||
Rule1, // 直接轉換
|
||||
Rule2, // 集合內容
|
||||
}
|
||||
|
||||
/// 關鍵幀的物件列表
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct KeyframeObjects {
|
||||
/// 關鍵幀時間 (秒)
|
||||
pub timestamp: f64,
|
||||
/// 關鍵幀幀號
|
||||
pub frame_number: u64,
|
||||
/// 檢測到的物件
|
||||
pub objects: Vec<DetectedObject>,
|
||||
}
|
||||
|
||||
/// 檢測到的物件
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DetectedObject {
|
||||
/// 物件類別名稱
|
||||
pub class_name: String,
|
||||
/// 物件類別 ID
|
||||
pub class_id: u32,
|
||||
/// 信心值 (0.0-1.0)
|
||||
pub confidence: f32,
|
||||
/// 邊界框 (x, y, width, height)
|
||||
pub bbox: Option<BoundingBox>,
|
||||
/// 出現次數 (在分片內)
|
||||
pub occurrence: u32,
|
||||
}
|
||||
|
||||
/// 邊界框
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct VisualChunkContent {
|
||||
pub start_time: f64,
|
||||
pub end_time: f64,
|
||||
pub keyframe_objects: Vec<KeyframeObjects>,
|
||||
pub dominant_objects: Vec<String>,
|
||||
pub object_relationships: Vec<(String, String, String)>, // (object1, relationship, object2)
|
||||
pub scene_description: Option<String>,
|
||||
pub metadata: VisualMetadata,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct VisualMetadata {
|
||||
pub object_count: u32,
|
||||
pub unique_classes: Vec<String>,
|
||||
pub max_confidence: f32,
|
||||
pub avg_confidence: f32,
|
||||
pub spatial_density: f32, // objects per frame
|
||||
}
|
||||
|
||||
impl ChunkRule {
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
ChunkRule::Rule1 => "rule_1",
|
||||
ChunkRule::Rule2 => "rule_2",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Chunk {
|
||||
pub file_id: i32,
|
||||
pub uuid: String,
|
||||
pub chunk_id: String,
|
||||
pub chunk_index: u32,
|
||||
pub chunk_type: ChunkType,
|
||||
pub rule: ChunkRule,
|
||||
/// Frames per second (can be fractional, e.g., 29.97, 23.976)
|
||||
pub fps: f64,
|
||||
/// Start frame (0-based)
|
||||
pub start_frame: i64,
|
||||
/// End frame (exclusive)
|
||||
pub end_frame: i64,
|
||||
pub text_content: Option<String>,
|
||||
pub content: serde_json::Value,
|
||||
pub metadata: Option<serde_json::Value>,
|
||||
pub vector_id: Option<String>,
|
||||
pub frame_count: i32,
|
||||
pub pre_chunk_ids: Vec<i32>,
|
||||
pub parent_chunk_id: Option<String>, // For parent-child chunk hierarchy
|
||||
pub child_chunk_ids: Vec<String>, // Child chunk IDs (for parent chunks)
|
||||
pub visual_stats: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
id: i64,
|
||||
video_id: i64,
|
||||
yolo_result: &crate::core::processor::yolo::YoloResult,
|
||||
min_frames_per_chunk: usize,
|
||||
similarity_threshold: f32,
|
||||
) -> Vec<Self> {
|
||||
if yolo_result.frames.is_empty() {
|
||||
return vec![];
|
||||
}
|
||||
|
||||
let mut chunks = Vec::new();
|
||||
let mut current_chunk_frames = Vec::new();
|
||||
let mut current_id = id;
|
||||
|
||||
for (i, frame) in yolo_result.frames.iter().enumerate() {
|
||||
if current_chunk_frames.is_empty() {
|
||||
current_chunk_frames.push(frame);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check similarity with last frame in current chunk
|
||||
let last_frame = current_chunk_frames.last().unwrap();
|
||||
let similarity = VisualChunkContent::frame_similarity(last_frame, frame);
|
||||
|
||||
if similarity >= similarity_threshold && current_chunk_frames.len() < 100 {
|
||||
// Similar enough, add to current chunk
|
||||
current_chunk_frames.push(frame);
|
||||
} else {
|
||||
// Not similar enough or chunk too large, create new chunk
|
||||
if current_chunk_frames.len() >= min_frames_per_chunk {
|
||||
if let Some(chunk) =
|
||||
Self::create_chunk_from_frames(current_id, video_id, ¤t_chunk_frames)
|
||||
{
|
||||
chunks.push(chunk);
|
||||
current_id += 1;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::core::processor::yolo::{YoloFrame, YoloObject, YoloResult};
|
||||
|
||||
#[test]
|
||||
fn test_chunk_type_visual_serialization() {
|
||||
let chunk_type = ChunkType::Visual;
|
||||
let json = serde_json::to_string(&chunk_type).unwrap();
|
||||
assert_eq!(json, "\"visual\"");
|
||||
|
||||
let deserialized: ChunkType = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(deserialized, ChunkType::Visual);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_visual_chunk_creation() {
|
||||
// Create a mock YOLO result
|
||||
let yolo_result = YoloResult {
|
||||
frame_count: 2,
|
||||
fps: 30.0,
|
||||
frames: vec![
|
||||
YoloFrame {
|
||||
frame: 0,
|
||||
timestamp: 0.0,
|
||||
objects: vec![
|
||||
YoloObject {
|
||||
class_name: "person".to_string(),
|
||||
class_id: 0,
|
||||
x: 100,
|
||||
y: 200,
|
||||
width: 50,
|
||||
height: 100,
|
||||
confidence: 0.95,
|
||||
},
|
||||
YoloObject {
|
||||
class_name: "car".to_string(),
|
||||
class_id: 2,
|
||||
x: 300,
|
||||
y: 150,
|
||||
width: 80,
|
||||
height: 60,
|
||||
confidence: 0.87,
|
||||
},
|
||||
],
|
||||
},
|
||||
YoloFrame {
|
||||
frame: 1,
|
||||
timestamp: 0.033, // 1/30 second
|
||||
objects: vec![YoloObject {
|
||||
class_name: "person".to_string(),
|
||||
class_id: 0,
|
||||
x: 110,
|
||||
y: 210,
|
||||
width: 52,
|
||||
height: 102,
|
||||
confidence: 0.92,
|
||||
}],
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
// Create visual chunk from YOLO result
|
||||
let chunk = Chunk::from_yolo_result(1, 100, &yolo_result, 0, 1).unwrap();
|
||||
|
||||
// Verify chunk properties
|
||||
assert_eq!(chunk.id, 1);
|
||||
assert_eq!(chunk.video_id, 100);
|
||||
assert_eq!(chunk.chunk_type, ChunkType::Visual);
|
||||
assert_eq!(chunk.start_time, 0.0);
|
||||
assert_eq!(chunk.end_time, 0.033);
|
||||
|
||||
// Verify visual content
|
||||
if let ChunkContent::Visual(content) = chunk.content {
|
||||
assert_eq!(content.metadata.object_count, 3);
|
||||
assert_eq!(content.metadata.unique_classes.len(), 2);
|
||||
assert!(content
|
||||
.metadata
|
||||
.unique_classes
|
||||
.contains(&"person".to_string()));
|
||||
assert!(content.metadata.unique_classes.contains(&"car".to_string()));
|
||||
assert_eq!(content.dominant_objects, vec!["person"]);
|
||||
assert_eq!(content.keyframe_objects.len(), 2);
|
||||
} else {
|
||||
panic!("Expected Visual content type");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_visual_chunk_content_methods() {
|
||||
let content = VisualChunkContent {
|
||||
start_time: 0.0,
|
||||
end_time: 5.0,
|
||||
keyframe_objects: vec![KeyframeObjects {
|
||||
frame: 0,
|
||||
timestamp: 0.0,
|
||||
objects: vec![
|
||||
DetectedObject {
|
||||
class_name: "person".to_string(),
|
||||
class_id: 0,
|
||||
bounding_box: BoundingBox {
|
||||
x: 100,
|
||||
y: 200,
|
||||
width: 50,
|
||||
height: 100,
|
||||
},
|
||||
confidence: 0.95,
|
||||
},
|
||||
DetectedObject {
|
||||
class_name: "car".to_string(),
|
||||
class_id: 2,
|
||||
bounding_box: BoundingBox {
|
||||
x: 300,
|
||||
y: 150,
|
||||
width: 80,
|
||||
height: 60,
|
||||
},
|
||||
confidence: 0.87,
|
||||
},
|
||||
],
|
||||
}],
|
||||
dominant_objects: vec!["person".to_string()],
|
||||
object_relationships: vec![],
|
||||
scene_description: Some("A person near a car".to_string()),
|
||||
metadata: VisualMetadata {
|
||||
object_count: 2,
|
||||
unique_classes: vec!["person".to_string(), "car".to_string()],
|
||||
max_confidence: 0.95,
|
||||
avg_confidence: 0.91,
|
||||
spatial_density: 2.0,
|
||||
},
|
||||
};
|
||||
|
||||
// Test summary method
|
||||
let summary = content.summary();
|
||||
assert!(summary.contains("Visual chunk from 0.0s to 5.0s"));
|
||||
assert!(summary.contains("person"));
|
||||
|
||||
// Test contains_object method
|
||||
assert!(content.contains_object("person"));
|
||||
assert!(content.contains_object("car"));
|
||||
assert!(!content.contains_object("dog"));
|
||||
|
||||
// Test high_confidence_objects method
|
||||
let high_conf_objects = content.high_confidence_objects(0.9);
|
||||
assert_eq!(high_conf_objects.len(), 1);
|
||||
assert_eq!(high_conf_objects[0].class_name, "person");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_frame_similarity() {
|
||||
let frame1 = YoloFrame {
|
||||
frame: 0,
|
||||
timestamp: 0.0,
|
||||
objects: vec![
|
||||
YoloObject {
|
||||
class_name: "person".to_string(),
|
||||
class_id: 0,
|
||||
x: 100,
|
||||
y: 200,
|
||||
width: 50,
|
||||
height: 100,
|
||||
confidence: 0.95,
|
||||
},
|
||||
YoloObject {
|
||||
class_name: "car".to_string(),
|
||||
class_id: 2,
|
||||
x: 300,
|
||||
y: 150,
|
||||
width: 80,
|
||||
height: 60,
|
||||
confidence: 0.87,
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
let frame2 = YoloFrame {
|
||||
frame: 1,
|
||||
timestamp: 0.033,
|
||||
objects: vec![
|
||||
YoloObject {
|
||||
class_name: "person".to_string(),
|
||||
class_id: 0,
|
||||
x: 110,
|
||||
y: 210,
|
||||
width: 52,
|
||||
height: 102,
|
||||
confidence: 0.92,
|
||||
},
|
||||
YoloObject {
|
||||
class_name: "car".to_string(),
|
||||
class_id: 2,
|
||||
x: 310,
|
||||
y: 155,
|
||||
width: 82,
|
||||
height: 62,
|
||||
confidence: 0.85,
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
let frame3 = YoloFrame {
|
||||
frame: 2,
|
||||
timestamp: 0.066,
|
||||
objects: vec![YoloObject {
|
||||
class_name: "dog".to_string(),
|
||||
class_id: 16,
|
||||
x: 150,
|
||||
y: 250,
|
||||
width: 40,
|
||||
height: 60,
|
||||
confidence: 0.78,
|
||||
}],
|
||||
};
|
||||
|
||||
// Test similar frames (same objects)
|
||||
let similarity_same =
|
||||
VisualChunkContent::frame_similarity(&frame1, &frame2);
|
||||
assert!((similarity_same - 1.0).abs() < 0.001);
|
||||
|
||||
// Test dissimilar frames (different objects)
|
||||
let similarity_diff =
|
||||
VisualChunkContent::frame_similarity(&frame1, &frame3);
|
||||
assert!((similarity_diff - 0.0).abs() < 0.001);
|
||||
|
||||
// Test empty frames
|
||||
let empty_frame = YoloFrame {
|
||||
frame: 3,
|
||||
timestamp: 0.1,
|
||||
objects: vec![],
|
||||
};
|
||||
let similarity_empty =
|
||||
VisualChunkContent::frame_similarity(&empty_frame, &empty_frame);
|
||||
assert!((similarity_empty - 1.0).abs() < 0.001);
|
||||
|
||||
let similarity_mixed =
|
||||
VisualChunkContent::frame_similarity(&empty_frame, &frame1);
|
||||
assert!((similarity_mixed - 0.0).abs() < 0.001);
|
||||
}
|
||||
}
|
||||
current_chunk_frames = vec![frame];
|
||||
}
|
||||
}
|
||||
|
||||
// Handle last chunk
|
||||
if current_chunk_frames.len() >= min_frames_per_chunk {
|
||||
if let Some(chunk) =
|
||||
Self::create_chunk_from_frames(current_id, video_id, ¤t_chunk_frames)
|
||||
{
|
||||
chunks.push(chunk);
|
||||
}
|
||||
}
|
||||
|
||||
chunks
|
||||
}
|
||||
|
||||
fn create_chunk_from_frames(
|
||||
id: i64,
|
||||
video_id: i64,
|
||||
frames: &[&crate::core::processor::yolo::YoloFrame],
|
||||
) -> Option<Self> {
|
||||
if frames.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Simple conversion - could use the from_yolo_result method
|
||||
let start_frame = frames.first().unwrap().frame;
|
||||
let end_frame = frames.last().unwrap().frame;
|
||||
let dummy_yolo_result = crate::core::processor::yolo::YoloResult {
|
||||
frame_count: frames.len() as u64,
|
||||
fps: 0.0, // Not used in this context
|
||||
frames: frames.iter().map(|f| (*f).clone()).collect(),
|
||||
};
|
||||
|
||||
Self::from_yolo_result(id, video_id, &dummy_yolo_result, start_frame, end_frame)
|
||||
}
|
||||
|
||||
/// Creates a new chunk from seconds (legacy conversion).
|
||||
///
|
||||
/// This is useful for migrating from older systems that store time as seconds.
|
||||
/// The frame counts are calculated by rounding `seconds * fps`.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn from_seconds(
|
||||
file_id: i32,
|
||||
uuid: String,
|
||||
chunk_index: u32,
|
||||
chunk_type: ChunkType,
|
||||
rule: ChunkRule,
|
||||
start_time: f64,
|
||||
end_time: f64,
|
||||
fps: f64,
|
||||
content: serde_json::Value,
|
||||
) -> Self {
|
||||
let start_frame = (start_time * fps).round() as i64;
|
||||
let end_frame = (end_time * fps).round() as i64;
|
||||
Self::new(
|
||||
file_id,
|
||||
uuid,
|
||||
chunk_index,
|
||||
chunk_type,
|
||||
rule,
|
||||
start_frame,
|
||||
end_frame,
|
||||
fps,
|
||||
content,
|
||||
)
|
||||
}
|
||||
|
||||
/// Returns the start time as a `FrameTime`.
|
||||
pub fn start_time(&self) -> FrameTime {
|
||||
FrameTime::from_frames(self.start_frame, self.fps)
|
||||
}
|
||||
|
||||
/// Returns the end time as a `FrameTime`.
|
||||
pub fn end_time(&self) -> FrameTime {
|
||||
FrameTime::from_frames(self.end_frame, self.fps)
|
||||
}
|
||||
|
||||
/// Returns the duration in frames.
|
||||
pub fn duration_frames(&self) -> i64 {
|
||||
self.end_frame - self.start_frame
|
||||
}
|
||||
|
||||
/// Returns the duration in seconds.
|
||||
pub fn duration_seconds(&self) -> f64 {
|
||||
self.duration_frames() as f64 / self.fps
|
||||
}
|
||||
|
||||
/// Formats the start time as "seconds.frame" (e.g., "123.04").
|
||||
pub fn format_start_sec_frame(&self) -> String {
|
||||
self.start_time().format_sec_frame()
|
||||
}
|
||||
|
||||
/// Formats the end time as "seconds.frame" (e.g., "456.15").
|
||||
pub fn format_end_sec_frame(&self) -> String {
|
||||
self.end_time().format_sec_frame()
|
||||
}
|
||||
|
||||
/// Formats the start time as "HH:MM:SS".
|
||||
pub fn format_start_hms(&self) -> String {
|
||||
self.start_time().format_hms()
|
||||
}
|
||||
|
||||
/// Formats the end time as "HH:MM:SS".
|
||||
pub fn format_end_hms(&self) -> String {
|
||||
self.end_time().format_hms()
|
||||
}
|
||||
|
||||
/// Formats the start time as "HH:MM:SS.FF".
|
||||
pub fn format_start_hms_frame(&self) -> String {
|
||||
self.start_time().format_hms_frame()
|
||||
}
|
||||
|
||||
/// Formats the end time as "HH:MM:SS.FF".
|
||||
pub fn format_end_hms_frame(&self) -> String {
|
||||
self.end_time().format_hms_frame()
|
||||
}
|
||||
|
||||
/// Returns a tuple of (start_seconds, end_seconds) for compatibility.
|
||||
///
|
||||
/// This is provided for backward compatibility during migration.
|
||||
/// Prefer using `start_time()` and `end_time()` methods.
|
||||
pub fn time_range_seconds(&self) -> (f64, f64) {
|
||||
(self.start_time().seconds(), self.end_time().seconds())
|
||||
}
|
||||
|
||||
pub fn with_metadata(mut self, metadata: serde_json::Value) -> Self {
|
||||
self.metadata = Some(metadata);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_vector_id(mut self, vector_id: String) -> Self {
|
||||
self.vector_id = Some(vector_id);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_text_content(mut self, text: String) -> Self {
|
||||
self.text_content = Some(text);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_frame_count(mut self, count: i32) -> Self {
|
||||
self.frame_count = count;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_pre_chunk_ids(mut self, ids: Vec<i32>) -> Self {
|
||||
self.pre_chunk_ids = ids;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_parent_chunk_id(mut self, parent_id: String) -> Self {
|
||||
self.parent_chunk_id = Some(parent_id);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_child_chunk_ids(mut self, child_ids: Vec<String>) -> Self {
|
||||
self.child_chunk_ids = child_ids;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn is_parent_chunk(&self) -> bool {
|
||||
!self.child_chunk_ids.is_empty()
|
||||
}
|
||||
|
||||
pub fn is_child_chunk(&self) -> bool {
|
||||
self.parent_chunk_id.is_some()
|
||||
}
|
||||
|
||||
/// 創建視覺分片
|
||||
pub fn new_visual(
|
||||
file_id: i32,
|
||||
uuid: String,
|
||||
chunk_index: u32,
|
||||
start_frame: i64,
|
||||
end_frame: i64,
|
||||
fps: f64,
|
||||
visual_content: VisualChunkContent,
|
||||
) -> Self {
|
||||
let content = serde_json::to_value(&visual_content)
|
||||
.unwrap_or_else(|_| serde_json::json!({"error": "Failed to serialize visual content"}));
|
||||
|
||||
Self::new(
|
||||
file_id,
|
||||
uuid,
|
||||
chunk_index,
|
||||
ChunkType::Visual,
|
||||
ChunkRule::Rule2,
|
||||
start_frame,
|
||||
end_frame,
|
||||
fps,
|
||||
content,
|
||||
)
|
||||
}
|
||||
|
||||
/// 從 YOLO 結果創建視覺分片
|
||||
pub fn from_yolo_result(
|
||||
file_id: i32,
|
||||
uuid: String,
|
||||
chunk_index: u32,
|
||||
start_frame: i64,
|
||||
end_frame: i64,
|
||||
fps: f64,
|
||||
yolo_frames: Vec<crate::core::processor::yolo::YoloFrame>,
|
||||
) -> Self {
|
||||
use crate::core::processor::yolo::YoloFrame;
|
||||
use std::collections::HashMap;
|
||||
|
||||
// 分析物件統計
|
||||
let mut object_counts = HashMap::new();
|
||||
let mut keyframe_objects = Vec::new();
|
||||
let mut all_objects = Vec::new();
|
||||
|
||||
for frame in &yolo_frames {
|
||||
let mut frame_objects = Vec::new();
|
||||
|
||||
for obj in &frame.objects {
|
||||
// 更新物件統計
|
||||
*object_counts.entry(obj.class_name.clone()).or_insert(0) += 1;
|
||||
|
||||
// 創建檢測到的物件
|
||||
let detected_obj = DetectedObject {
|
||||
class_name: obj.class_name.clone(),
|
||||
class_id: obj.class_id,
|
||||
confidence: obj.confidence,
|
||||
bbox: Some(BoundingBox {
|
||||
x: obj.x,
|
||||
y: obj.y,
|
||||
width: obj.width,
|
||||
height: obj.height,
|
||||
}),
|
||||
occurrence: 1,
|
||||
};
|
||||
|
||||
frame_objects.push(detected_obj.clone());
|
||||
all_objects.push(detected_obj);
|
||||
}
|
||||
|
||||
if !frame_objects.is_empty() {
|
||||
keyframe_objects.push(KeyframeObjects {
|
||||
timestamp: frame.timestamp,
|
||||
frame_number: frame.frame,
|
||||
objects: frame_objects,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// 創建主要物件標籤
|
||||
let primary_objects = object_counts
|
||||
.iter()
|
||||
.filter(|(_, &count)| count >= 3) // 出現至少3次的物件
|
||||
.map(|(name, _)| name.clone())
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ");
|
||||
|
||||
// 創建物件統計 JSON
|
||||
let object_stats =
|
||||
serde_json::to_value(&object_counts).unwrap_or_else(|_| serde_json::json!({}));
|
||||
|
||||
// 創建視覺內容
|
||||
let visual_content = VisualChunkContent {
|
||||
primary_objects: if primary_objects.is_empty() {
|
||||
"no objects detected".to_string()
|
||||
} else {
|
||||
primary_objects
|
||||
},
|
||||
object_stats,
|
||||
keyframe_objects,
|
||||
object_frequency: serde_json::to_value(&object_counts)
|
||||
.unwrap_or_else(|_| serde_json::json!({})),
|
||||
visual_summary: None, // 可選,後續可添加 LLM 生成的摘要
|
||||
};
|
||||
|
||||
Self::new_visual(
|
||||
file_id,
|
||||
uuid,
|
||||
chunk_index,
|
||||
start_frame,
|
||||
end_frame,
|
||||
fps,
|
||||
visual_content,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl VisualChunkContent {
|
||||
/// Calculate similarity between two YOLO frames based on object composition
|
||||
pub fn frame_similarity(
|
||||
frame1: &crate::core::processor::yolo::YoloFrame,
|
||||
frame2: &crate::core::processor::yolo::YoloFrame,
|
||||
) -> f32 {
|
||||
if frame1.objects.is_empty() && frame2.objects.is_empty() {
|
||||
return 1.0; // Both empty frames are perfectly similar
|
||||
}
|
||||
|
||||
if frame1.objects.is_empty() || frame2.objects.is_empty() {
|
||||
return 0.0; // One empty, one non-empty are dissimilar
|
||||
}
|
||||
|
||||
// Create sets of object class names
|
||||
let set1: std::collections::HashSet<String> = frame1
|
||||
.objects
|
||||
.iter()
|
||||
.map(|o| o.class_name.clone())
|
||||
.collect();
|
||||
let set2: std::collections::HashSet<String> = frame2
|
||||
.objects
|
||||
.iter()
|
||||
.map(|o| o.class_name.clone())
|
||||
.collect();
|
||||
|
||||
// Calculate Jaccard similarity
|
||||
let intersection: Vec<_> = set1.intersection(&set2).collect();
|
||||
let union: Vec<_> = set1.union(&set2).collect();
|
||||
|
||||
if union.is_empty() {
|
||||
0.0
|
||||
} else {
|
||||
intersection.len() as f32 / union.len() as f32
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a summary of the visual chunk
|
||||
pub fn summary(&self) -> String {
|
||||
let duration = self.end_time - self.start_time;
|
||||
let frame_count = self.keyframe_objects.len();
|
||||
|
||||
format!(
|
||||
"Visual chunk from {:.1}s to {:.1}s (duration: {:.1}s, {} frames). Objects: {} total, {} unique. Dominant objects: {}",
|
||||
self.start_time,
|
||||
self.end_time,
|
||||
duration,
|
||||
frame_count,
|
||||
self.metadata.object_count,
|
||||
self.metadata.unique_classes.len(),
|
||||
if self.dominant_objects.is_empty() {
|
||||
"none".to_string()
|
||||
} else {
|
||||
self.dominant_objects.join(", ")
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
/// Check if this chunk contains a specific object class
|
||||
pub fn contains_object(&self, class_name: &str) -> bool {
|
||||
self.keyframe_objects
|
||||
.iter()
|
||||
.any(|ko| ko.objects.iter().any(|obj| obj.class_name == class_name))
|
||||
}
|
||||
|
||||
/// Get all objects with confidence above threshold
|
||||
pub fn high_confidence_objects(&self, threshold: f32) -> Vec<&DetectedObject> {
|
||||
self.keyframe_objects
|
||||
.iter()
|
||||
.flat_map(|ko| ko.objects.iter())
|
||||
.filter(|obj| obj.confidence >= threshold)
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
320
src/core/chunk/types_fixed.rs
Normal file
320
src/core/chunk/types_fixed.rs
Normal file
@@ -0,0 +1,320 @@
|
||||
use crate::core::time::FrameTime;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ChunkType {
|
||||
TimeBased,
|
||||
Sentence,
|
||||
Cut,
|
||||
Trace,
|
||||
Story, // Parent chunk from story analysis
|
||||
Visual, // Visual object-based chunk from YOLO detection (Phase 2.1)
|
||||
}
|
||||
|
||||
impl ChunkType {
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
ChunkType::TimeBased => "time",
|
||||
ChunkType::Sentence => "sentence",
|
||||
ChunkType::Cut => "cut",
|
||||
ChunkType::Trace => "trace",
|
||||
ChunkType::Story => "story",
|
||||
ChunkType::Visual => "visual",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ChunkRule {
|
||||
Rule1, // 直接轉換
|
||||
Rule2, // 集合內容
|
||||
}
|
||||
|
||||
/// 關鍵幀的物件列表
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct KeyframeObjects {
|
||||
/// 關鍵幀時間 (秒)
|
||||
pub timestamp: f64,
|
||||
/// 關鍵幀幀號
|
||||
pub frame_number: u64,
|
||||
/// 檢測到的物件
|
||||
pub objects: Vec<DetectedObject>,
|
||||
}
|
||||
|
||||
/// 檢測到的物件
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DetectedObject {
|
||||
/// 物件類別名稱
|
||||
pub class_name: String,
|
||||
/// 物件類別 ID
|
||||
pub class_id: u32,
|
||||
/// 信心值 (0.0-1.0)
|
||||
pub confidence: f32,
|
||||
/// 邊界框 (x, y, width, height)
|
||||
pub bbox: Option<BoundingBox>,
|
||||
/// 出現次數 (在分片內)
|
||||
pub occurrence: u32,
|
||||
}
|
||||
|
||||
/// 邊界框
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct BoundingBox {
|
||||
pub x: i32,
|
||||
pub y: i32,
|
||||
pub width: i32,
|
||||
pub height: i32,
|
||||
}
|
||||
|
||||
/// 視覺分片內容 (Phase 2.1)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct VisualChunkContent {
|
||||
pub start_time: f64,
|
||||
pub end_time: f64,
|
||||
pub keyframe_objects: Vec<KeyframeObjects>,
|
||||
pub dominant_objects: Vec<String>,
|
||||
pub object_relationships: Vec<(String, String, String)>, // (object1, relationship, object2)
|
||||
pub scene_description: Option<String>,
|
||||
pub metadata: VisualMetadata,
|
||||
}
|
||||
|
||||
/// 視覺元數據 (Phase 2.1)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct VisualMetadata {
|
||||
pub object_count: u32,
|
||||
pub unique_classes: Vec<String>,
|
||||
pub max_confidence: f32,
|
||||
pub avg_confidence: f32,
|
||||
pub spatial_density: f32, // objects per frame
|
||||
}
|
||||
|
||||
impl ChunkRule {
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
ChunkRule::Rule1 => "rule_1",
|
||||
ChunkRule::Rule2 => "rule_2",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Chunk {
|
||||
pub file_id: i32,
|
||||
pub uuid: String,
|
||||
pub chunk_id: String,
|
||||
pub chunk_index: u32,
|
||||
pub chunk_type: ChunkType,
|
||||
pub rule: ChunkRule,
|
||||
/// Frames per second (can be fractional, e.g., 29.97, 23.976)
|
||||
pub fps: f64,
|
||||
/// Start frame (0-based)
|
||||
pub start_frame: i64,
|
||||
/// End frame (exclusive)
|
||||
pub end_frame: i64,
|
||||
pub text_content: Option<String>,
|
||||
pub content: serde_json::Value,
|
||||
pub metadata: Option<serde_json::Value>,
|
||||
pub vector_id: Option<String>,
|
||||
pub frame_count: i32,
|
||||
pub pre_chunk_ids: Vec<i32>,
|
||||
pub parent_chunk_id: Option<String>, // For parent-child chunk hierarchy
|
||||
pub child_chunk_ids: Vec<String>, // Child chunk IDs (for parent chunks)
|
||||
pub visual_stats: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
impl Chunk {
|
||||
/// 創建視覺分片 (Phase 2.1)
|
||||
pub fn new_visual(
|
||||
file_id: i32,
|
||||
uuid: String,
|
||||
chunk_index: u32,
|
||||
start_frame: i64,
|
||||
end_frame: i64,
|
||||
fps: f64,
|
||||
visual_content: VisualChunkContent,
|
||||
) -> Self {
|
||||
let content = serde_json::to_value(&visual_content)
|
||||
.unwrap_or_else(|_| serde_json::json!({"error": "Failed to serialize visual content"}));
|
||||
|
||||
Self::new(
|
||||
file_id,
|
||||
uuid,
|
||||
chunk_index,
|
||||
ChunkType::Visual,
|
||||
ChunkRule::Rule2,
|
||||
start_frame,
|
||||
end_frame,
|
||||
fps,
|
||||
content,
|
||||
)
|
||||
}
|
||||
|
||||
/// 從 YOLO 結果創建視覺分片 (Phase 2.1)
|
||||
pub fn from_yolo_result(
|
||||
file_id: i32,
|
||||
uuid: String,
|
||||
chunk_index: u32,
|
||||
start_frame: i64,
|
||||
end_frame: i64,
|
||||
fps: f64,
|
||||
yolo_frames: Vec<crate::core::processor::yolo::YoloFrame>,
|
||||
) -> Self {
|
||||
let keyframe_objects: Vec<KeyframeObjects> = yolo_frames
|
||||
.iter()
|
||||
.map(|frame| {
|
||||
let objects: Vec<DetectedObject> = frame
|
||||
.objects
|
||||
.iter()
|
||||
.map(|obj| DetectedObject {
|
||||
class_name: obj.class_name.clone(),
|
||||
class_id: obj.class_id,
|
||||
confidence: obj.confidence,
|
||||
bbox: Some(BoundingBox {
|
||||
x: obj.x,
|
||||
y: obj.y,
|
||||
width: obj.width,
|
||||
height: obj.height,
|
||||
}),
|
||||
occurrence: 1,
|
||||
})
|
||||
.collect();
|
||||
|
||||
KeyframeObjects {
|
||||
timestamp: frame.timestamp,
|
||||
frame_number: frame.frame,
|
||||
objects,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
// 計算物件統計
|
||||
let mut object_counts = std::collections::HashMap::new();
|
||||
for obj in yolo_frames.iter().flat_map(|f| &f.objects) {
|
||||
*object_counts.entry(obj.class_name.clone()).or_insert(0) += 1;
|
||||
}
|
||||
|
||||
let total_objects: u32 = yolo_frames.iter().map(|f| f.objects.len() as u32).sum();
|
||||
let all_classes: Vec<String> = yolo_frames
|
||||
.iter()
|
||||
.flat_map(|f| f.objects.iter().map(|o| o.class_name.clone()))
|
||||
.collect();
|
||||
let unique_classes: Vec<String> = all_classes
|
||||
.iter()
|
||||
.cloned()
|
||||
.collect::<std::collections::HashSet<_>>()
|
||||
.into_iter()
|
||||
.collect();
|
||||
|
||||
let confidences: Vec<f32> = yolo_frames
|
||||
.iter()
|
||||
.flat_map(|f| f.objects.iter().map(|o| o.confidence))
|
||||
.collect();
|
||||
let max_confidence = confidences.iter().copied().fold(0.0f32, f32::max);
|
||||
let avg_confidence = if !confidences.is_empty() {
|
||||
confidences.iter().sum::<f32>() / confidences.len() as f32
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
// 找出主要物件
|
||||
let primary_objects = object_counts
|
||||
.iter()
|
||||
.filter(|(_, &count)| count as f32 / yolo_frames.len() as f32 > 0.5)
|
||||
.map(|(name, _)| name.clone())
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ");
|
||||
|
||||
let object_stats =
|
||||
serde_json::to_value(&object_counts).unwrap_or_else(|_| serde_json::json!({}));
|
||||
|
||||
let visual_content = VisualChunkContent {
|
||||
start_time: if let Some(first) = yolo_frames.first() {
|
||||
first.timestamp
|
||||
} else {
|
||||
0.0
|
||||
},
|
||||
end_time: if let Some(last) = yolo_frames.last() {
|
||||
last.timestamp
|
||||
} else {
|
||||
0.0
|
||||
},
|
||||
keyframe_objects,
|
||||
dominant_objects: primary_objects
|
||||
.split(", ")
|
||||
.map(|s| s.to_string())
|
||||
.filter(|s| !s.is_empty())
|
||||
.collect(),
|
||||
object_relationships: vec![], // 可選:後續添加關係檢測
|
||||
scene_description: None, // 可選:後續添加 LLM 生成的場景描述
|
||||
metadata: VisualMetadata {
|
||||
object_count: total_objects,
|
||||
unique_classes,
|
||||
max_confidence,
|
||||
avg_confidence,
|
||||
spatial_density: if yolo_frames.len() > 0 {
|
||||
total_objects as f32 / yolo_frames.len() as f32
|
||||
} else {
|
||||
0.0
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
Self::new_visual(
|
||||
file_id,
|
||||
uuid,
|
||||
chunk_index,
|
||||
start_frame,
|
||||
end_frame,
|
||||
fps,
|
||||
visual_content,
|
||||
)
|
||||
}
|
||||
|
||||
/// 創建新分片
|
||||
pub fn new(
|
||||
file_id: i32,
|
||||
uuid: String,
|
||||
chunk_index: u32,
|
||||
chunk_type: ChunkType,
|
||||
rule: ChunkRule,
|
||||
start_frame: i64,
|
||||
end_frame: i64,
|
||||
fps: f64,
|
||||
content: serde_json::Value,
|
||||
) -> Self {
|
||||
let frame_count = (end_frame - start_frame) as i32;
|
||||
let chunk_id = format!("{}_{}", uuid, chunk_index);
|
||||
|
||||
Self {
|
||||
file_id,
|
||||
uuid,
|
||||
chunk_id,
|
||||
chunk_index,
|
||||
chunk_type,
|
||||
rule,
|
||||
fps,
|
||||
start_frame,
|
||||
end_frame,
|
||||
text_content: None,
|
||||
content,
|
||||
metadata: None,
|
||||
vector_id: None,
|
||||
frame_count,
|
||||
pre_chunk_ids: vec![],
|
||||
parent_chunk_id: None,
|
||||
child_chunk_ids: vec![],
|
||||
visual_stats: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// 將分片轉換為幀時間
|
||||
pub fn to_frame_time(&self) -> FrameTime {
|
||||
FrameTime::from_frames(self.start_frame as u64, self.end_frame as u64, self.fps)
|
||||
}
|
||||
|
||||
/// 檢查是否是父分片
|
||||
pub fn is_parent(&self) -> bool {
|
||||
self.parent_chunk_id.is_some()
|
||||
}
|
||||
}
|
||||
486
src/core/chunk/visual_test.rs
Normal file
486
src/core/chunk/visual_test.rs
Normal file
@@ -0,0 +1,486 @@
|
||||
//! 視覺分片測試
|
||||
//!
|
||||
//! 測試視覺分片數據結構和功能
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// 視覺分片類型
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ChunkType {
|
||||
TimeBased,
|
||||
Sentence,
|
||||
Cut,
|
||||
Trace,
|
||||
Story,
|
||||
Visual,
|
||||
}
|
||||
|
||||
/// 檢測到的物件
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DetectedObject {
|
||||
/// 物件類別名稱
|
||||
pub class_name: String,
|
||||
/// 物件類別 ID
|
||||
pub class_id: u32,
|
||||
/// 信心值 (0.0-1.0)
|
||||
pub confidence: f32,
|
||||
/// 邊界框 (x, y, width, height)
|
||||
pub bbox: Option<(i32, i32, i32, i32)>,
|
||||
}
|
||||
|
||||
/// 關鍵幀的物件列表
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct KeyframeObjects {
|
||||
/// 關鍵幀時間 (秒)
|
||||
pub timestamp: f64,
|
||||
/// 關鍵幀幀號
|
||||
pub frame_number: u64,
|
||||
/// 檢測到的物件
|
||||
pub objects: Vec<DetectedObject>,
|
||||
}
|
||||
|
||||
/// 視覺分片內容
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct VisualChunkContent {
|
||||
pub start_time: f64,
|
||||
pub end_time: f64,
|
||||
pub keyframe_objects: Vec<KeyframeObjects>,
|
||||
pub dominant_objects: Vec<String>,
|
||||
pub object_relationships: Vec<(String, String, String)>, // (object1, relationship, object2)
|
||||
pub scene_description: Option<String>,
|
||||
pub metadata: VisualMetadata,
|
||||
}
|
||||
|
||||
/// 視覺元數據
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct VisualMetadata {
|
||||
pub object_count: u32,
|
||||
pub unique_classes: Vec<String>,
|
||||
pub max_confidence: f32,
|
||||
pub avg_confidence: f32,
|
||||
pub spatial_density: f32, // objects per frame
|
||||
}
|
||||
|
||||
impl VisualChunkContent {
|
||||
/// 計算兩個幀之間的相似度(基於物件組成)
|
||||
pub fn frame_similarity(
|
||||
frame1_objects: &[DetectedObject],
|
||||
frame2_objects: &[DetectedObject],
|
||||
) -> f32 {
|
||||
if frame1_objects.is_empty() && frame2_objects.is_empty() {
|
||||
return 1.0; // 兩個空幀完全相似
|
||||
}
|
||||
|
||||
if frame1_objects.is_empty() || frame2_objects.is_empty() {
|
||||
return 0.0; // 一個空一個非空,不相似
|
||||
}
|
||||
|
||||
// 創建物件類別名稱集合
|
||||
let set1: std::collections::HashSet<String> = frame1_objects
|
||||
.iter()
|
||||
.map(|o| o.class_name.clone())
|
||||
.collect();
|
||||
let set2: std::collections::HashSet<String> = frame2_objects
|
||||
.iter()
|
||||
.map(|o| o.class_name.clone())
|
||||
.collect();
|
||||
|
||||
// 計算 Jaccard 相似度
|
||||
let intersection: Vec<_> = set1.intersection(&set2).collect();
|
||||
let union: Vec<_> = set1.union(&set2).collect();
|
||||
|
||||
if union.is_empty() {
|
||||
0.0
|
||||
} else {
|
||||
intersection.len() as f32 / union.len() as f32
|
||||
}
|
||||
}
|
||||
|
||||
/// 獲取視覺分片的摘要
|
||||
pub fn summary(&self) -> String {
|
||||
let duration = self.end_time - self.start_time;
|
||||
let frame_count = self.keyframe_objects.len();
|
||||
|
||||
format!(
|
||||
"視覺分片: {:.1}s 到 {:.1}s (持續時間: {:.1}s, {} 幀). 物件: {} 個總計, {} 個唯一. 主要物件: {}",
|
||||
self.start_time,
|
||||
self.end_time,
|
||||
duration,
|
||||
frame_count,
|
||||
self.metadata.object_count,
|
||||
self.metadata.unique_classes.len(),
|
||||
if self.dominant_objects.is_empty() {
|
||||
"無".to_string()
|
||||
} else {
|
||||
self.dominant_objects.join(", ")
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
/// 檢查是否包含特定物件類別
|
||||
pub fn contains_object(&self, class_name: &str) -> bool {
|
||||
self.keyframe_objects
|
||||
.iter()
|
||||
.any(|ko| ko.objects.iter().any(|obj| obj.class_name == class_name))
|
||||
}
|
||||
|
||||
/// 獲取信心值高於閾值的所有物件
|
||||
pub fn high_confidence_objects(&self, threshold: f32) -> Vec<&DetectedObject> {
|
||||
self.keyframe_objects
|
||||
.iter()
|
||||
.flat_map(|ko| ko.objects.iter())
|
||||
.filter(|obj| obj.confidence >= threshold)
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
/// 模擬 YOLO 結果
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MockYoloResult {
|
||||
pub frames: Vec<MockYoloFrame>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MockYoloFrame {
|
||||
pub frame: u64,
|
||||
pub timestamp: f64,
|
||||
pub objects: Vec<MockYoloObject>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MockYoloObject {
|
||||
pub class_name: String,
|
||||
pub class_id: u32,
|
||||
pub x: i32,
|
||||
pub y: i32,
|
||||
pub width: i32,
|
||||
pub height: i32,
|
||||
pub confidence: f32,
|
||||
}
|
||||
|
||||
impl MockYoloResult {
|
||||
/// 從模擬 YOLO 結果創建視覺分片
|
||||
pub fn to_visual_chunk(&self, start_frame: u64, end_frame: u64) -> Option<VisualChunkContent> {
|
||||
let frames: Vec<_> = self
|
||||
.frames
|
||||
.iter()
|
||||
.filter(|f| f.frame >= start_frame && f.frame <= end_frame)
|
||||
.collect();
|
||||
|
||||
if frames.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// 轉換幀為關鍵幀物件
|
||||
let keyframe_objects: Vec<KeyframeObjects> = frames
|
||||
.iter()
|
||||
.map(|frame| {
|
||||
let objects: Vec<DetectedObject> = frame
|
||||
.objects
|
||||
.iter()
|
||||
.map(|obj| DetectedObject {
|
||||
class_name: obj.class_name.clone(),
|
||||
class_id: obj.class_id,
|
||||
confidence: obj.confidence,
|
||||
bbox: Some((obj.x, obj.y, obj.width, obj.height)),
|
||||
})
|
||||
.collect();
|
||||
KeyframeObjects {
|
||||
timestamp: frame.timestamp,
|
||||
frame_number: frame.frame,
|
||||
objects,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
// 計算元數據
|
||||
let total_objects: u32 = frames.iter().map(|f| f.objects.len() as u32).sum();
|
||||
let all_classes: Vec<String> = frames
|
||||
.iter()
|
||||
.flat_map(|f| f.objects.iter().map(|o| o.class_name.clone()))
|
||||
.collect();
|
||||
let unique_classes: Vec<String> = all_classes
|
||||
.iter()
|
||||
.cloned()
|
||||
.collect::<std::collections::HashSet<_>>()
|
||||
.into_iter()
|
||||
.collect();
|
||||
let confidences: Vec<f32> = frames
|
||||
.iter()
|
||||
.flat_map(|f| f.objects.iter().map(|o| o.confidence))
|
||||
.collect();
|
||||
let max_confidence = confidences.iter().copied().fold(0.0f32, f32::max);
|
||||
let avg_confidence = if !confidences.is_empty() {
|
||||
confidences.iter().sum::<f32>() / confidences.len() as f32
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
let start_time = frames.first().map(|f| f.timestamp).unwrap_or(0.0);
|
||||
let end_time = frames.last().map(|f| f.timestamp).unwrap_or(0.0);
|
||||
|
||||
// 查找主要物件(出現在大多數幀中的物件)
|
||||
let mut object_counts = std::collections::HashMap::new();
|
||||
for frame in &frames {
|
||||
let frame_classes: std::collections::HashSet<_> =
|
||||
frame.objects.iter().map(|o| o.class_name.clone()).collect();
|
||||
for class in frame_classes {
|
||||
*object_counts.entry(class).or_insert(0) += 1;
|
||||
}
|
||||
}
|
||||
|
||||
let mut dominant_objects: Vec<String> = object_counts
|
||||
.into_iter()
|
||||
.filter(|(_, count)| *count as f32 / frames.len() as f32 > 0.5) // 出現在 >50% 的幀中
|
||||
.map(|(class, _)| class)
|
||||
.collect();
|
||||
dominant_objects.sort();
|
||||
|
||||
Some(VisualChunkContent {
|
||||
start_time,
|
||||
end_time,
|
||||
keyframe_objects,
|
||||
dominant_objects,
|
||||
object_relationships: vec![], // 需要關係檢測邏輯
|
||||
scene_description: None, // 可由 LLM 後期生成
|
||||
metadata: VisualMetadata {
|
||||
object_count: total_objects,
|
||||
unique_classes,
|
||||
max_confidence,
|
||||
avg_confidence,
|
||||
spatial_density: if frames.len() > 0 {
|
||||
total_objects as f32 / frames.len() as f32
|
||||
} else {
|
||||
0.0
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_chunk_type_visual() {
|
||||
let chunk_type = ChunkType::Visual;
|
||||
let json = serde_json::to_string(&chunk_type).unwrap();
|
||||
assert_eq!(json, "\"visual\"");
|
||||
|
||||
let deserialized: ChunkType = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(deserialized, ChunkType::Visual);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_visual_chunk_creation() {
|
||||
// 創建模擬 YOLO 結果
|
||||
let yolo_result = MockYoloResult {
|
||||
frames: vec![
|
||||
MockYoloFrame {
|
||||
frame: 0,
|
||||
timestamp: 0.0,
|
||||
objects: vec![
|
||||
MockYoloObject {
|
||||
class_name: "person".to_string(),
|
||||
class_id: 0,
|
||||
x: 100,
|
||||
y: 200,
|
||||
width: 50,
|
||||
height: 100,
|
||||
confidence: 0.95,
|
||||
},
|
||||
MockYoloObject {
|
||||
class_name: "car".to_string(),
|
||||
class_id: 2,
|
||||
x: 300,
|
||||
y: 150,
|
||||
width: 80,
|
||||
height: 60,
|
||||
confidence: 0.87,
|
||||
},
|
||||
],
|
||||
},
|
||||
MockYoloFrame {
|
||||
frame: 1,
|
||||
timestamp: 0.033, // 1/30 秒
|
||||
objects: vec![MockYoloObject {
|
||||
class_name: "person".to_string(),
|
||||
class_id: 0,
|
||||
x: 110,
|
||||
y: 210,
|
||||
width: 52,
|
||||
height: 102,
|
||||
confidence: 0.92,
|
||||
}],
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
// 從 YOLO 結果創建視覺分片
|
||||
let chunk = yolo_result.to_visual_chunk(0, 1).unwrap();
|
||||
|
||||
// 驗證分片屬性
|
||||
assert_eq!(chunk.start_time, 0.0);
|
||||
assert_eq!(chunk.end_time, 0.033);
|
||||
assert_eq!(chunk.metadata.object_count, 3);
|
||||
assert_eq!(chunk.metadata.unique_classes.len(), 2);
|
||||
assert!(chunk
|
||||
.metadata
|
||||
.unique_classes
|
||||
.contains(&"person".to_string()));
|
||||
assert!(chunk.metadata.unique_classes.contains(&"car".to_string()));
|
||||
assert_eq!(chunk.dominant_objects, vec!["person"]);
|
||||
assert_eq!(chunk.keyframe_objects.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_visual_chunk_content_methods() {
|
||||
let content = VisualChunkContent {
|
||||
start_time: 0.0,
|
||||
end_time: 5.0,
|
||||
keyframe_objects: vec![KeyframeObjects {
|
||||
timestamp: 0.0,
|
||||
frame_number: 0,
|
||||
objects: vec![
|
||||
DetectedObject {
|
||||
class_name: "person".to_string(),
|
||||
class_id: 0,
|
||||
confidence: 0.95,
|
||||
bbox: Some((100, 200, 50, 100)),
|
||||
},
|
||||
DetectedObject {
|
||||
class_name: "car".to_string(),
|
||||
class_id: 2,
|
||||
confidence: 0.87,
|
||||
bbox: Some((300, 150, 80, 60)),
|
||||
},
|
||||
],
|
||||
}],
|
||||
dominant_objects: vec!["person".to_string()],
|
||||
object_relationships: vec![],
|
||||
scene_description: Some("一個人站在車旁".to_string()),
|
||||
metadata: VisualMetadata {
|
||||
object_count: 2,
|
||||
unique_classes: vec!["person".to_string(), "car".to_string()],
|
||||
max_confidence: 0.95,
|
||||
avg_confidence: 0.91,
|
||||
spatial_density: 2.0,
|
||||
},
|
||||
};
|
||||
|
||||
// 測試摘要方法
|
||||
let summary = content.summary();
|
||||
assert!(summary.contains("視覺分片"));
|
||||
assert!(summary.contains("person"));
|
||||
|
||||
// 測試 contains_object 方法
|
||||
assert!(content.contains_object("person"));
|
||||
assert!(content.contains_object("car"));
|
||||
assert!(!content.contains_object("dog"));
|
||||
|
||||
// 測試 high_confidence_objects 方法
|
||||
let high_conf_objects = content.high_confidence_objects(0.9);
|
||||
assert_eq!(high_conf_objects.len(), 1);
|
||||
assert_eq!(high_conf_objects[0].class_name, "person");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_frame_similarity() {
|
||||
let frame1_objects = vec![
|
||||
DetectedObject {
|
||||
class_name: "person".to_string(),
|
||||
class_id: 0,
|
||||
confidence: 0.95,
|
||||
bbox: Some((100, 200, 50, 100)),
|
||||
},
|
||||
DetectedObject {
|
||||
class_name: "car".to_string(),
|
||||
class_id: 2,
|
||||
confidence: 0.87,
|
||||
bbox: Some((300, 150, 80, 60)),
|
||||
},
|
||||
];
|
||||
|
||||
let frame2_objects = vec![
|
||||
DetectedObject {
|
||||
class_name: "person".to_string(),
|
||||
class_id: 0,
|
||||
confidence: 0.92,
|
||||
bbox: Some((110, 210, 52, 102)),
|
||||
},
|
||||
DetectedObject {
|
||||
class_name: "car".to_string(),
|
||||
class_id: 2,
|
||||
confidence: 0.85,
|
||||
bbox: Some((310, 155, 82, 62)),
|
||||
},
|
||||
];
|
||||
|
||||
let frame3_objects = vec![DetectedObject {
|
||||
class_name: "dog".to_string(),
|
||||
class_id: 16,
|
||||
confidence: 0.78,
|
||||
bbox: Some((150, 250, 40, 60)),
|
||||
}];
|
||||
|
||||
// 測試相似幀(相同物件)
|
||||
let similarity_same =
|
||||
VisualChunkContent::frame_similarity(&frame1_objects, &frame2_objects);
|
||||
assert!((similarity_same - 1.0).abs() < 0.001);
|
||||
|
||||
// 測試不相似幀(不同物件)
|
||||
let similarity_diff =
|
||||
VisualChunkContent::frame_similarity(&frame1_objects, &frame3_objects);
|
||||
assert!((similarity_diff - 0.0).abs() < 0.001);
|
||||
|
||||
// 測試空幀
|
||||
let empty_objects: Vec<DetectedObject> = vec![];
|
||||
let similarity_empty = VisualChunkContent::frame_similarity(&empty_objects, &empty_objects);
|
||||
assert!((similarity_empty - 1.0).abs() < 0.001);
|
||||
|
||||
let similarity_mixed =
|
||||
VisualChunkContent::frame_similarity(&empty_objects, &frame1_objects);
|
||||
assert!((similarity_mixed - 0.0).abs() < 0.001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_serialization_deserialization() {
|
||||
let content = VisualChunkContent {
|
||||
start_time: 0.0,
|
||||
end_time: 5.0,
|
||||
keyframe_objects: vec![KeyframeObjects {
|
||||
timestamp: 0.0,
|
||||
frame_number: 0,
|
||||
objects: vec![DetectedObject {
|
||||
class_name: "person".to_string(),
|
||||
class_id: 0,
|
||||
confidence: 0.95,
|
||||
bbox: Some((100, 200, 50, 100)),
|
||||
}],
|
||||
}],
|
||||
dominant_objects: vec!["person".to_string()],
|
||||
object_relationships: vec![],
|
||||
scene_description: Some("場景描述".to_string()),
|
||||
metadata: VisualMetadata {
|
||||
object_count: 1,
|
||||
unique_classes: vec!["person".to_string()],
|
||||
max_confidence: 0.95,
|
||||
avg_confidence: 0.95,
|
||||
spatial_density: 1.0,
|
||||
},
|
||||
};
|
||||
|
||||
// 序列化
|
||||
let json = serde_json::to_string(&content).unwrap();
|
||||
assert!(json.contains("person"));
|
||||
assert!(json.contains("visual_chunk"));
|
||||
|
||||
// 反序列化
|
||||
let deserialized: VisualChunkContent = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(deserialized.start_time, 0.0);
|
||||
assert_eq!(deserialized.end_time, 5.0);
|
||||
assert_eq!(deserialized.dominant_objects, vec!["person"]);
|
||||
}
|
||||
}
|
||||
@@ -77,6 +77,8 @@ pub struct VideoRow {
|
||||
pub status: String,
|
||||
pub user_id: Option<i32>,
|
||||
pub job_id: Option<i32>,
|
||||
pub created_at: Option<String>,
|
||||
pub registration_time: Option<String>,
|
||||
}
|
||||
|
||||
impl From<VideoRow> for VideoRecord {
|
||||
@@ -103,7 +105,8 @@ impl From<VideoRow> for VideoRecord {
|
||||
status: VideoStatus::from_db_str(&row.status).unwrap_or(VideoStatus::Pending),
|
||||
user_id: row.user_id.map(|v| v as i64),
|
||||
job_id: row.job_id.map(|v| v as i64),
|
||||
created_at: String::new(),
|
||||
created_at: row.created_at.unwrap_or_default(),
|
||||
registration_time: row.registration_time,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -124,6 +127,7 @@ pub struct VideoRecord {
|
||||
pub user_id: Option<i64>,
|
||||
pub job_id: Option<i64>,
|
||||
pub created_at: String,
|
||||
pub registration_time: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
@@ -701,7 +705,7 @@ impl PostgresDb {
|
||||
let table = schema::table_name("videos");
|
||||
let result = sqlx::query_as::<_, VideoRow>(
|
||||
&format!(
|
||||
"SELECT id, uuid, file_path, file_name, duration, width, height, fps, probe_json, fs_video, fs_json, psql_chunk, pobject_chunk, mobject_chunk, pvector_chunk, qvector_chunk, status, user_id, job_id FROM {} WHERE uuid = $1",
|
||||
"SELECT id, uuid, file_path, file_name, duration, width, height, fps, probe_json, fs_video, fs_json, psql_chunk, pobject_chunk, mobject_chunk, pvector_chunk, qvector_chunk, status, user_id, job_id, created_at::text, registration_time::text FROM {} WHERE uuid = $1",
|
||||
table
|
||||
)
|
||||
)
|
||||
@@ -796,28 +800,90 @@ impl PostgresDb {
|
||||
}
|
||||
|
||||
pub async fn list_videos(&self, limit: i32, offset: i64) -> Result<(Vec<VideoRecord>, i64)> {
|
||||
// Default to unprocessed (status != 'ready')
|
||||
self.search_videos(None, Some(false), limit, offset).await
|
||||
}
|
||||
|
||||
pub async fn search_videos(
|
||||
&self,
|
||||
query: Option<&str>,
|
||||
is_processed: Option<bool>,
|
||||
limit: i32,
|
||||
offset: i64,
|
||||
) -> Result<(Vec<VideoRecord>, i64)> {
|
||||
let table = schema::table_name("videos");
|
||||
|
||||
// Build status condition
|
||||
// is_processed = Some(true) => status = 'ready'
|
||||
// is_processed = Some(false) => status != 'ready'
|
||||
// is_processed = None => no filter
|
||||
let status_cond = match is_processed {
|
||||
Some(true) => "AND status = 'ready'",
|
||||
Some(false) => "AND status != 'ready'",
|
||||
None => "",
|
||||
};
|
||||
|
||||
// Count total
|
||||
let count: Option<i64> = sqlx::query_scalar(&format!("SELECT COUNT(*) FROM {}", table))
|
||||
.fetch_one(&self.pool)
|
||||
.await?;
|
||||
let total = count.unwrap_or(0);
|
||||
// Build search condition safely
|
||||
// If query is Some, we filter by filename/path/probe_json
|
||||
let search_cond = if query.is_some() {
|
||||
"AND (LOWER(file_name) LIKE $1 OR LOWER(file_path) LIKE $1 OR LOWER(probe_json::text) LIKE $1)"
|
||||
} else {
|
||||
""
|
||||
};
|
||||
|
||||
// Select paged
|
||||
let rows = sqlx::query_as::<_, VideoRow>(
|
||||
&format!(
|
||||
"SELECT id, uuid, file_path, file_name, duration, width, height, fps, probe_json, fs_video, fs_json, psql_chunk, pobject_chunk, mobject_chunk, pvector_chunk, qvector_chunk, status, user_id, job_id FROM {} ORDER BY id DESC LIMIT $1 OFFSET $2",
|
||||
table
|
||||
)
|
||||
)
|
||||
.bind(limit)
|
||||
.bind(offset)
|
||||
.fetch_all(&self.pool)
|
||||
.await?;
|
||||
let where_clause = format!("WHERE 1=1 {} {}", status_cond, search_cond);
|
||||
|
||||
// 1. Count Query
|
||||
// If query is present, $1 is the pattern.
|
||||
// If query is None, no pattern param needed for count?
|
||||
// Actually, to keep code simple, let's just construct the query string.
|
||||
// SQLx query_as requires bind count to match placeholders.
|
||||
|
||||
let count_query = format!("SELECT COUNT(*) FROM {} {}", table, where_clause);
|
||||
|
||||
let total: i64 = if let Some(q) = query {
|
||||
let pattern = format!("%{}%", q.to_lowercase());
|
||||
sqlx::query_scalar(&count_query)
|
||||
.bind(&pattern)
|
||||
.fetch_one(&self.pool)
|
||||
.await?
|
||||
} else {
|
||||
sqlx::query_scalar(&count_query)
|
||||
.fetch_one(&self.pool)
|
||||
.await?
|
||||
};
|
||||
|
||||
// 2. Select Query
|
||||
// Cast created_at and registration_time to text
|
||||
let columns = "id, uuid, file_path, file_name, duration, width, height, fps, probe_json, fs_video, fs_json, psql_chunk, pobject_chunk, mobject_chunk, pvector_chunk, qvector_chunk, status, user_id, job_id, created_at::text, registration_time::text";
|
||||
|
||||
// Determine parameter order for LIMIT/OFFSET
|
||||
// If search is present, pattern is $1. Limit is $2. Offset is $3.
|
||||
// If search is not present, Limit is $1. Offset is $2.
|
||||
|
||||
let select_query = if query.is_some() {
|
||||
format!("SELECT {} FROM {} {} ORDER BY id DESC LIMIT $2 OFFSET $3", columns, table, where_clause)
|
||||
} else {
|
||||
format!("SELECT {} FROM {} {} ORDER BY id DESC LIMIT $1 OFFSET $2", columns, table, where_clause)
|
||||
};
|
||||
|
||||
let rows = if let Some(q) = query {
|
||||
let pattern = format!("%{}%", q.to_lowercase());
|
||||
sqlx::query_as::<_, VideoRow>(&select_query)
|
||||
.bind(&pattern)
|
||||
.bind(limit)
|
||||
.bind(offset)
|
||||
.fetch_all(&self.pool)
|
||||
.await?
|
||||
} else {
|
||||
sqlx::query_as::<_, VideoRow>(&select_query)
|
||||
.bind(limit)
|
||||
.bind(offset)
|
||||
.fetch_all(&self.pool)
|
||||
.await?
|
||||
};
|
||||
|
||||
let videos: Vec<VideoRecord> = rows.into_iter().map(|r| r.into()).collect();
|
||||
|
||||
Ok((videos, total))
|
||||
}
|
||||
|
||||
@@ -850,6 +916,19 @@ impl PostgresDb {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn set_registration_time(&self, uuid: &str) -> Result<()> {
|
||||
let table = schema::table_name("videos");
|
||||
sqlx::query(&format!(
|
||||
"UPDATE {} SET registration_time = CURRENT_TIMESTAMP, updated_at = CURRENT_TIMESTAMP WHERE uuid = $1 AND registration_time IS NULL",
|
||||
table
|
||||
))
|
||||
.bind(uuid)
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn delete_video(&self, uuid: &str) -> Result<()> {
|
||||
tracing::info!("[PostgresDb] Deleting video: {}", uuid);
|
||||
|
||||
|
||||
68
src/core/db/schema_ctx.rs
Normal file
68
src/core/db/schema_ctx.rs
Normal file
@@ -0,0 +1,68 @@
|
||||
use anyhow::Result;
|
||||
use sqlx::PgPool;
|
||||
use std::sync::atomic::{AtomicU32, Ordering};
|
||||
|
||||
/// Schema context for database operations
|
||||
/// Ensures all queries use the correct schema prefix
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SchemaContext {
|
||||
pub prefix: String,
|
||||
}
|
||||
|
||||
static SCHEMA_INSTANCE: std::sync::OnceLock<SchemaContext> = std::sync::OnceLock::new();
|
||||
static SCHEMA_VERSION: AtomicU32 = AtomicU32::new(0);
|
||||
|
||||
impl SchemaContext {
|
||||
/// Initialize schema context from environment
|
||||
pub fn init() -> Self {
|
||||
let schema = std::env::var("DATABASE_SCHEMA").unwrap_or_else(|_| "dev".to_string());
|
||||
let prefix = if schema == "public" {
|
||||
String::new()
|
||||
} else {
|
||||
format!("{}.", schema)
|
||||
};
|
||||
Self { prefix }
|
||||
}
|
||||
|
||||
/// Get the global schema context
|
||||
pub fn global() -> &'static Self {
|
||||
SCHEMA_INSTANCE.get_or_init(|| Self::init())
|
||||
}
|
||||
|
||||
/// Get table name with schema prefix
|
||||
pub fn table(&self, name: &str) -> String {
|
||||
format!("{}{}", self.prefix, name)
|
||||
}
|
||||
|
||||
/// Reload schema context (for testing)
|
||||
pub fn reload() {
|
||||
SCHEMA_VERSION.fetch_add(1, Ordering::SeqCst);
|
||||
// Note: OnceLock can't be reset, so we use a different approach
|
||||
// In production, schema doesn't change at runtime
|
||||
}
|
||||
}
|
||||
|
||||
/// Quick helper to get table name with current schema prefix
|
||||
pub fn t(name: &str) -> String {
|
||||
SchemaContext::global().table(name)
|
||||
}
|
||||
|
||||
/// Check if a table exists in the current schema
|
||||
pub async fn table_exists(pool: &PgPool, table_name: &str) -> Result<bool> {
|
||||
let schema = SchemaContext::global();
|
||||
let schema_name = if schema.prefix.is_empty() {
|
||||
"public".to_string()
|
||||
} else {
|
||||
schema.prefix.trim_end_matches('.').to_string()
|
||||
};
|
||||
|
||||
let query = format!(
|
||||
"SELECT EXISTS(SELECT 1 FROM information_schema.tables WHERE table_schema = $1 AND table_name = $2)"
|
||||
);
|
||||
let exists: bool = sqlx::query_scalar(&query)
|
||||
.bind(&schema_name)
|
||||
.bind(table_name)
|
||||
.fetch_one(pool)
|
||||
.await?;
|
||||
Ok(exists)
|
||||
}
|
||||
143
src/core/ingestion.rs
Normal file
143
src/core/ingestion.rs
Normal file
@@ -0,0 +1,143 @@
|
||||
use anyhow::{Context, Result};
|
||||
use std::path::Path;
|
||||
use tracing::{info, warn};
|
||||
|
||||
use crate::core::db::{Database, PostgresDb, VideoRecord, VideoStatus};
|
||||
use crate::core::probe;
|
||||
use crate::core::storage::FileManager;
|
||||
use crate::uuid as uuid_utils;
|
||||
|
||||
/// Handles the automatic ingestion of video files.
|
||||
/// This service is responsible for:
|
||||
/// 1. Running `ffprobe` (Pre-processing)
|
||||
/// 2. Saving probe JSON
|
||||
/// 3. Registering the video in the database (making it visible in the API)
|
||||
pub struct IngestionService {
|
||||
db: PostgresDb,
|
||||
}
|
||||
|
||||
impl IngestionService {
|
||||
pub fn new(db: PostgresDb) -> Self {
|
||||
Self { db }
|
||||
}
|
||||
|
||||
/// Registers a video file found in the watched directory.
|
||||
/// This function is idempotent: if the video (UUID) already exists, it skips.
|
||||
pub async fn ingest(&self, file_path: &str) -> Result<Option<String>> {
|
||||
let path = Path::new(file_path);
|
||||
|
||||
// 1. Validate extension
|
||||
if !is_video_extension(path) {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// 2. Compute UUID
|
||||
let uuid = uuid_utils::compute_uuid_from_path(file_path);
|
||||
|
||||
// 3. Check if already registered
|
||||
if let Ok(Some(_)) = self.db.get_video_by_uuid(&uuid).await {
|
||||
info!(
|
||||
"Video already registered: {} ({})",
|
||||
path.file_name().unwrap_or_default().to_string_lossy(),
|
||||
uuid
|
||||
);
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
info!("Starting ingestion for: {} ({})", path.display(), uuid);
|
||||
|
||||
// 4. Run ffprobe
|
||||
let probe_result = probe::probe_video(file_path)
|
||||
.with_context(|| format!("Failed to probe video: {}", file_path))?;
|
||||
|
||||
// 5. Extract metadata
|
||||
let duration = probe_result
|
||||
.format
|
||||
.duration
|
||||
.as_ref()
|
||||
.and_then(|s| s.parse::<f64>().ok())
|
||||
.unwrap_or(0.0);
|
||||
|
||||
let mut width = 0u32;
|
||||
let mut height = 0u32;
|
||||
let mut fps = 0.0;
|
||||
|
||||
for stream in &probe_result.streams {
|
||||
if stream.codec_type.as_deref() == Some("video") {
|
||||
width = stream.width.unwrap_or(0);
|
||||
height = stream.height.unwrap_or(0);
|
||||
if let Some(fps_str) = &stream.r_frame_rate {
|
||||
if let Some((num, den)) = fps_str.split_once('/') {
|
||||
if let (Ok(n), Ok(d)) = (num.parse::<f64>(), den.parse::<f64>()) {
|
||||
if d > 0.0 {
|
||||
fps = n / d;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 6. Save Probe JSON
|
||||
let file_manager = FileManager::new(std::path::PathBuf::from("."));
|
||||
let probe_json_str = serde_json::to_string_pretty(&probe_result)?;
|
||||
|
||||
if let Err(e) = file_manager.save_json(&uuid, "probe", &probe_json_str) {
|
||||
warn!("Failed to save probe JSON for {}: {}", uuid, e);
|
||||
} else {
|
||||
info!("Probe JSON saved for {}", uuid);
|
||||
}
|
||||
|
||||
// 7. Create Record
|
||||
// Use absolute path for safety
|
||||
let canonical_path = path.canonicalize().unwrap_or_else(|_| path.to_path_buf());
|
||||
|
||||
let record = VideoRecord {
|
||||
id: 0,
|
||||
uuid: uuid.clone(),
|
||||
file_path: canonical_path.to_string_lossy().to_string(),
|
||||
file_name: path
|
||||
.file_name()
|
||||
.unwrap_or_default()
|
||||
.to_string_lossy()
|
||||
.to_string(),
|
||||
duration,
|
||||
width,
|
||||
height,
|
||||
fps,
|
||||
probe_json: Some(probe_json_str),
|
||||
storage: Default::default(),
|
||||
status: VideoStatus::Pending, // Ready for processing
|
||||
user_id: None,
|
||||
job_id: None,
|
||||
created_at: String::new(),
|
||||
registration_time: None,
|
||||
};
|
||||
|
||||
// 8. Insert DB
|
||||
self.db
|
||||
.register_video(&record)
|
||||
.await
|
||||
.with_context(|| "Failed to register video in database")?;
|
||||
|
||||
self.db
|
||||
.set_registration_time(&uuid)
|
||||
.await
|
||||
.with_context(|| "Failed to set registration_time")?;
|
||||
|
||||
info!(
|
||||
"Successfully registered video: {} (UUID: {})",
|
||||
record.file_name, uuid
|
||||
);
|
||||
Ok(Some(uuid))
|
||||
}
|
||||
}
|
||||
|
||||
fn is_video_extension(path: &Path) -> bool {
|
||||
if let Some(ext) = path.extension().and_then(|e| e.to_str()) {
|
||||
let ext = ext.to_lowercase();
|
||||
matches!(ext.as_str(), "mp4" | "mov" | "mkv" | "avi" | "webm" | "m4v")
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
104
src/core/llm/client.rs
Normal file
104
src/core/llm/client.rs
Normal file
@@ -0,0 +1,104 @@
|
||||
use anyhow::Result;
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::time::Duration;
|
||||
use tracing::{debug, error, warn};
|
||||
|
||||
use crate::core::config;
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ChatRequest {
|
||||
model: String,
|
||||
messages: Vec<ChatMessage>,
|
||||
temperature: f32,
|
||||
max_tokens: u32,
|
||||
stream: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct ChatMessage {
|
||||
role: String,
|
||||
content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ChatResponse {
|
||||
choices: Vec<Choice>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct Choice {
|
||||
message: ChatMessage,
|
||||
}
|
||||
|
||||
/// Generates a 5W1H+ summary for a given scene context.
|
||||
/// Context should include the combined text of all sentences in the scene.
|
||||
pub async fn generate_5w1h_summary(scene_text: &str) -> Result<String> {
|
||||
if !*config::llm::SUMMARY_ENABLED {
|
||||
warn!("LLM Summary is disabled via config");
|
||||
return Ok("LLM Disabled".to_string());
|
||||
}
|
||||
|
||||
let client = Client::builder()
|
||||
.timeout(Duration::from_secs(*config::llm::SUMMARY_TIMEOUT_SECS))
|
||||
.build()?;
|
||||
|
||||
let prompt = format!(
|
||||
r#"Analyze the following video scene transcript and provide a concise 5W1H+ summary in JSON format.
|
||||
Focus on: Who, What, Where, When, Why, How, and Key Objects/Actions.
|
||||
|
||||
Transcript:
|
||||
"{}"
|
||||
|
||||
Output format:
|
||||
{{
|
||||
"who": "...",
|
||||
"what": "...",
|
||||
"where": "...",
|
||||
"when": "...",
|
||||
"why": "...",
|
||||
"how": "...",
|
||||
"summary": "..."
|
||||
}}"#,
|
||||
scene_text
|
||||
);
|
||||
|
||||
let req = ChatRequest {
|
||||
model: (*config::llm::SUMMARY_MODEL).clone(),
|
||||
messages: vec![
|
||||
ChatMessage {
|
||||
role: "system".to_string(),
|
||||
content: "You are an expert video analyst assistant.".to_string(),
|
||||
},
|
||||
ChatMessage {
|
||||
role: "user".to_string(),
|
||||
content: prompt,
|
||||
},
|
||||
],
|
||||
temperature: 0.1,
|
||||
max_tokens: 512,
|
||||
stream: false,
|
||||
};
|
||||
|
||||
debug!("Calling LLM for summary: {}", *config::llm::SUMMARY_URL);
|
||||
|
||||
let res = client
|
||||
.post(&*config::llm::SUMMARY_URL)
|
||||
.json(&req)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !res.status().is_success() {
|
||||
error!("LLM API error: {}", res.status());
|
||||
let text = res.text().await.unwrap_or_default();
|
||||
anyhow::bail!("LLM API error: {}", text);
|
||||
}
|
||||
|
||||
let chat_res: ChatResponse = res.json().await?;
|
||||
|
||||
if let Some(choice) = chat_res.choices.into_iter().next() {
|
||||
Ok(choice.message.content.trim().to_string())
|
||||
} else {
|
||||
anyhow::bail!("Empty response from LLM");
|
||||
}
|
||||
}
|
||||
1
src/core/llm/mod.rs
Normal file
1
src/core/llm/mod.rs
Normal file
@@ -0,0 +1 @@
|
||||
pub mod client;
|
||||
266
src/core/person_identity.rs
Normal file
266
src/core/person_identity.rs
Normal file
@@ -0,0 +1,266 @@
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::FromRow;
|
||||
|
||||
// ==========================================
|
||||
// 舊版結構體 (保留以向後兼容)
|
||||
// ==========================================
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
|
||||
pub struct PersonIdentity {
|
||||
pub id: i32,
|
||||
pub person_id: String,
|
||||
pub face_identity_id: Option<i32>,
|
||||
pub speaker_id: Option<String>,
|
||||
pub video_uuid: String,
|
||||
pub confidence: f64,
|
||||
pub name: Option<String>,
|
||||
pub metadata: serde_json::Value,
|
||||
pub first_appearance_time: Option<f64>,
|
||||
pub last_appearance_time: Option<f64>,
|
||||
pub total_appearance_duration: f64,
|
||||
pub appearance_count: i32,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
pub is_confirmed: bool,
|
||||
}
|
||||
|
||||
// ==========================================
|
||||
// 新版結構體 (V5 身份綁定系統)
|
||||
// ==========================================
|
||||
|
||||
/// 人物身份 (Identity) - 統一管理演員、公眾人物、家人朋友等
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
|
||||
pub struct Identity {
|
||||
pub id: i32,
|
||||
pub name: String,
|
||||
pub embedding: Option<String>, // Vector embedding stored as text/json
|
||||
pub metadata: Option<serde_json::Value>,
|
||||
pub created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
/// 身份綁定記錄 (Identity Binding)
|
||||
/// 將機器 ID (face_x, speaker_y) 綁定到 Identity
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
|
||||
pub struct IdentityBinding {
|
||||
pub id: i64,
|
||||
pub identity_id: i64,
|
||||
pub binding_type: String, // 'face', 'speaker'
|
||||
pub binding_value: String, // e.g. "face_1", "speaker_3"
|
||||
pub source: String, // 'auto', 'manual'
|
||||
pub confidence: f64,
|
||||
pub is_active: bool,
|
||||
pub created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
/// 綁定請求 (用於 API)
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct BindIdentityRequest {
|
||||
pub identity_id: Option<i64>,
|
||||
pub name: Option<String>, // 若未提供 identity_id,則建立新 Identity
|
||||
pub binding_type: String, // 'face' 或 'speaker'
|
||||
pub binding_value: String, // e.g. "face_1"
|
||||
pub source: Option<String>, // 預設 'manual'
|
||||
}
|
||||
|
||||
/// 解綁請求
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct UnbindIdentityRequest {
|
||||
pub binding_type: String,
|
||||
pub binding_value: String,
|
||||
}
|
||||
|
||||
/// 建議綁定請求 (由系統自動產生,人工確認)
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct SuggestedBinding {
|
||||
pub binding_type: String,
|
||||
pub binding_value: String,
|
||||
pub suggested_identity_id: i64,
|
||||
pub suggested_identity_name: String,
|
||||
pub confidence: f64,
|
||||
pub reason: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
|
||||
pub struct PersonAppearance {
|
||||
pub id: i32,
|
||||
pub person_id: String,
|
||||
pub video_uuid: String,
|
||||
pub start_time: f64,
|
||||
pub end_time: f64,
|
||||
pub duration: f64,
|
||||
pub face_detection_id: Option<i32>,
|
||||
pub asrx_segment_start: Option<f64>,
|
||||
pub asrx_segment_end: Option<f64>,
|
||||
pub confidence: f64,
|
||||
pub metadata: serde_json::Value,
|
||||
pub created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, FromRow)]
|
||||
pub struct PersonMatch {
|
||||
pub face_id: String,
|
||||
pub speaker_id: String,
|
||||
pub confidence: f64,
|
||||
pub match_count: i64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PersonTimelineEntry {
|
||||
pub start_time: f64,
|
||||
pub end_time: f64,
|
||||
pub duration: f64,
|
||||
pub confidence: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PersonStatistics {
|
||||
pub total_appearances: i32,
|
||||
pub total_duration: f64,
|
||||
pub first_appearance: Option<f64>,
|
||||
pub last_appearance: Option<f64>,
|
||||
pub average_confidence: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CreatePersonIdentityRequest {
|
||||
pub video_uuid: String,
|
||||
pub face_identity_id: Option<i32>,
|
||||
pub speaker_id: Option<String>,
|
||||
pub name: Option<String>,
|
||||
pub metadata: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct UpdatePersonIdentityRequest {
|
||||
pub name: Option<String>,
|
||||
pub metadata: Option<serde_json::Value>,
|
||||
pub is_confirmed: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PersonIdentityResponse {
|
||||
pub person_id: String,
|
||||
pub name: Option<String>,
|
||||
pub face_identity_id: Option<i32>,
|
||||
pub speaker_id: Option<String>,
|
||||
pub confidence: f64,
|
||||
pub appearance_count: i32,
|
||||
pub total_appearance_duration: f64,
|
||||
pub first_appearance_time: Option<f64>,
|
||||
pub last_appearance_time: Option<f64>,
|
||||
pub is_confirmed: bool,
|
||||
}
|
||||
|
||||
impl From<PersonIdentity> for PersonIdentityResponse {
|
||||
fn from(person: PersonIdentity) -> Self {
|
||||
Self {
|
||||
person_id: person.person_id,
|
||||
name: person.name,
|
||||
face_identity_id: person.face_identity_id,
|
||||
speaker_id: person.speaker_id,
|
||||
confidence: person.confidence,
|
||||
appearance_count: person.appearance_count,
|
||||
total_appearance_duration: person.total_appearance_duration,
|
||||
first_appearance_time: person.first_appearance_time,
|
||||
last_appearance_time: person.last_appearance_time,
|
||||
is_confirmed: person.is_confirmed,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PersonTimelineResponse {
|
||||
pub person_id: String,
|
||||
pub name: Option<String>,
|
||||
pub timeline: Vec<PersonTimelineEntry>,
|
||||
pub statistics: PersonStatistics,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChunkPersonInfo {
|
||||
pub person_id: String,
|
||||
pub name: Option<String>,
|
||||
pub confidence: f64,
|
||||
pub overlap_duration: f64,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_person_identity_serialization() {
|
||||
let person = PersonIdentity {
|
||||
id: 1,
|
||||
person_id: "person_001".to_string(),
|
||||
face_identity_id: Some(123),
|
||||
speaker_id: Some("SPEAKER_00".to_string()),
|
||||
video_uuid: "video_abc".to_string(),
|
||||
confidence: 0.85,
|
||||
name: Some("张三".to_string()),
|
||||
metadata: serde_json::json!({"role": "host"}),
|
||||
first_appearance_time: Some(10.5),
|
||||
last_appearance_time: Some(350.2),
|
||||
total_appearance_duration: 120.5,
|
||||
appearance_count: 15,
|
||||
created_at: Utc::now(),
|
||||
updated_at: Utc::now(),
|
||||
is_confirmed: true,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&person).unwrap();
|
||||
assert!(json.contains("person_001"));
|
||||
assert!(json.contains("SPEAKER_00"));
|
||||
assert!(json.contains("张三"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_person_appearance_serialization() {
|
||||
let appearance = PersonAppearance {
|
||||
id: 1,
|
||||
person_id: "person_001".to_string(),
|
||||
video_uuid: "video_abc".to_string(),
|
||||
start_time: 10.5,
|
||||
end_time: 25.3,
|
||||
duration: 14.8,
|
||||
face_detection_id: Some(456),
|
||||
asrx_segment_start: Some(10.0),
|
||||
asrx_segment_end: Some(26.0),
|
||||
confidence: 0.92,
|
||||
metadata: serde_json::json!({}),
|
||||
created_at: Utc::now(),
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&appearance).unwrap();
|
||||
assert!(json.contains("person_001"));
|
||||
assert!(json.contains("14.8"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_person_match() {
|
||||
let match_result = PersonMatch {
|
||||
face_id: "face_123".to_string(),
|
||||
speaker_id: "SPEAKER_00".to_string(),
|
||||
confidence: 0.85,
|
||||
match_count: 15,
|
||||
};
|
||||
|
||||
assert_eq!(match_result.face_id, "face_123");
|
||||
assert!(match_result.confidence >= 0.0 && match_result.confidence <= 1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_person_statistics() {
|
||||
let stats = PersonStatistics {
|
||||
total_appearances: 15,
|
||||
total_duration: 120.5,
|
||||
first_appearance: Some(10.5),
|
||||
last_appearance: Some(350.2),
|
||||
average_confidence: 0.88,
|
||||
};
|
||||
|
||||
assert_eq!(stats.total_appearances, 15);
|
||||
assert!(stats.total_duration > 0.0);
|
||||
}
|
||||
}
|
||||
124
src/core/processor/asr_legacy.rs
Normal file
124
src/core/processor/asr_legacy.rs
Normal file
@@ -0,0 +1,124 @@
|
||||
use anyhow::{Context, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::time::Duration;
|
||||
|
||||
use super::executor::PythonExecutor;
|
||||
use crate::core::config::processor;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct AsrResult {
|
||||
pub language: Option<String>,
|
||||
pub language_probability: Option<f64>,
|
||||
pub segments: Vec<AsrSegment>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct AsrSegment {
|
||||
pub start: f64,
|
||||
pub end: f64,
|
||||
pub text: String,
|
||||
}
|
||||
|
||||
pub async fn process_asr(
|
||||
video_path: &str,
|
||||
output_path: &str,
|
||||
uuid: Option<&str>,
|
||||
) -> Result<AsrResult> {
|
||||
let executor = PythonExecutor::new()?;
|
||||
let script_path = executor.script_path("asr_processor.py");
|
||||
|
||||
tracing::info!("[ASR] Starting ASR processing: {}", video_path);
|
||||
|
||||
executor
|
||||
.run(
|
||||
"asr_processor.py",
|
||||
&[video_path, output_path],
|
||||
uuid,
|
||||
"ASR",
|
||||
Some(Duration::from_secs(*processor::ASR_TIMEOUT_SECS)),
|
||||
)
|
||||
.await
|
||||
.with_context(|| format!("Failed to run {:?}", script_path))?;
|
||||
|
||||
let json_str = std::fs::read_to_string(output_path).context("Failed to read ASR output")?;
|
||||
|
||||
let result: AsrResult =
|
||||
serde_json::from_str(&json_str).context("Failed to parse ASR output")?;
|
||||
|
||||
tracing::info!(
|
||||
"[ASR] Result: {} segments, language: {:?}",
|
||||
result.segments.len(),
|
||||
result.language
|
||||
);
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_asr_result_serialization() {
|
||||
let result = AsrResult {
|
||||
language: Some("en".to_string()),
|
||||
language_probability: Some(0.95),
|
||||
segments: vec![
|
||||
AsrSegment {
|
||||
start: 0.0,
|
||||
end: 2.5,
|
||||
text: "Hello world".to_string(),
|
||||
},
|
||||
AsrSegment {
|
||||
start: 2.5,
|
||||
end: 5.0,
|
||||
text: "Test speech".to_string(),
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&result).unwrap();
|
||||
assert!(json.contains("Hello world"));
|
||||
assert!(json.contains("en"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_asr_result_deserialization() {
|
||||
let json = r#"{
|
||||
"language": "zh",
|
||||
"language_probability": 0.98,
|
||||
"segments": [
|
||||
{"start": 0.0, "end": 1.5, "text": "測試"}
|
||||
]
|
||||
}"#;
|
||||
|
||||
let result: AsrResult = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(result.language, Some("zh".to_string()));
|
||||
assert_eq!(result.language_probability, Some(0.98));
|
||||
assert_eq!(result.segments.len(), 1);
|
||||
assert_eq!(result.segments[0].text, "測試");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_asr_segment_default() {
|
||||
let segment = AsrSegment {
|
||||
start: 0.0,
|
||||
end: 1.0,
|
||||
text: String::new(),
|
||||
};
|
||||
assert_eq!(segment.start, 0.0);
|
||||
assert_eq!(segment.end, 1.0);
|
||||
assert!(segment.text.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_asr_result_empty_segments() {
|
||||
let result = AsrResult {
|
||||
language: None,
|
||||
language_probability: None,
|
||||
segments: vec![],
|
||||
};
|
||||
assert!(result.language.is_none());
|
||||
assert!(result.segments.is_empty());
|
||||
}
|
||||
}
|
||||
345
src/core/processor/face_recognition.rs
Normal file
345
src/core/processor/face_recognition.rs
Normal file
@@ -0,0 +1,345 @@
|
||||
use anyhow::{Context, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::time::Duration;
|
||||
|
||||
use super::executor::PythonExecutor;
|
||||
|
||||
const FACE_RECOGNITION_TIMEOUT: Duration = Duration::from_secs(10800); // 3 hours for recognition
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct FaceRecognitionResult {
|
||||
pub frame_count: u64,
|
||||
pub fps: f64,
|
||||
pub frames: Vec<FaceRecognitionFrame>,
|
||||
pub recognized_faces: Vec<RecognizedFace>,
|
||||
pub face_clusters: Vec<FaceCluster>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct FaceRecognitionFrame {
|
||||
pub frame: u64,
|
||||
pub timestamp: f64,
|
||||
pub faces: Vec<RecognizedFaceDetection>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct RecognizedFaceDetection {
|
||||
pub face_id: Option<String>,
|
||||
pub x: i32,
|
||||
pub y: i32,
|
||||
pub width: i32,
|
||||
pub height: i32,
|
||||
pub confidence: f32,
|
||||
pub embedding: Option<Vec<f32>>,
|
||||
pub attributes: Option<FaceAttributes>,
|
||||
pub identity: Option<FaceIdentity>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct FaceAttributes {
|
||||
pub age: Option<u8>,
|
||||
pub gender: Option<String>,
|
||||
pub emotion: Option<String>,
|
||||
pub glasses: Option<bool>,
|
||||
pub mask: Option<bool>,
|
||||
pub pose: Option<FacePose>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct FacePose {
|
||||
pub yaw: f32,
|
||||
pub pitch: f32,
|
||||
pub roll: f32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct FaceIdentity {
|
||||
pub name: Option<String>,
|
||||
pub confidence: f32,
|
||||
pub database_id: Option<String>,
|
||||
pub metadata: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct RecognizedFace {
|
||||
pub face_id: String,
|
||||
pub embedding: Vec<f32>,
|
||||
pub first_seen: f64,
|
||||
pub last_seen: f64,
|
||||
pub total_appearances: u32,
|
||||
pub attributes: Option<FaceAttributes>,
|
||||
pub identities: Vec<FaceIdentity>,
|
||||
pub cluster_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct FaceCluster {
|
||||
pub cluster_id: String,
|
||||
pub face_ids: Vec<String>,
|
||||
pub centroid: Vec<f32>,
|
||||
pub size: u32,
|
||||
pub representative_face_id: Option<String>,
|
||||
pub metadata: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
pub async fn process_face_recognition(
|
||||
video_path: &str,
|
||||
output_path: &str,
|
||||
uuid: Option<&str>,
|
||||
enable_recognition: bool,
|
||||
enable_tracking: bool,
|
||||
enable_clustering: bool,
|
||||
) -> Result<FaceRecognitionResult> {
|
||||
let executor = PythonExecutor::new()?;
|
||||
let script_path = executor.script_path("face_recognition_processor.py");
|
||||
|
||||
tracing::info!(
|
||||
"[FACE_RECOGNITION] Starting face recognition: {}",
|
||||
video_path
|
||||
);
|
||||
|
||||
if !script_path.exists() {
|
||||
tracing::warn!("[FACE_RECOGNITION] Script not found, returning empty result");
|
||||
return Ok(FaceRecognitionResult {
|
||||
frame_count: 0,
|
||||
fps: 0.0,
|
||||
frames: vec![],
|
||||
recognized_faces: vec![],
|
||||
face_clusters: vec![],
|
||||
});
|
||||
}
|
||||
|
||||
let args = vec![
|
||||
video_path,
|
||||
output_path,
|
||||
if enable_recognition { "1" } else { "0" },
|
||||
if enable_tracking { "1" } else { "0" },
|
||||
if enable_clustering { "1" } else { "0" },
|
||||
];
|
||||
|
||||
executor
|
||||
.run(
|
||||
"face_recognition_processor.py",
|
||||
&args,
|
||||
uuid,
|
||||
"FACE_RECOGNITION",
|
||||
Some(FACE_RECOGNITION_TIMEOUT),
|
||||
)
|
||||
.await
|
||||
.with_context(|| format!("Failed to run {:?}", script_path))?;
|
||||
|
||||
let json_str =
|
||||
std::fs::read_to_string(output_path).context("Failed to read FACE_RECOGNITION output")?;
|
||||
|
||||
let result: FaceRecognitionResult =
|
||||
serde_json::from_str(&json_str).context("Failed to parse FACE_RECOGNITION output")?;
|
||||
|
||||
tracing::info!(
|
||||
"[FACE_RECOGNITION] Result: {} frames, {} recognized faces, {} clusters",
|
||||
result.frames.len(),
|
||||
result.recognized_faces.len(),
|
||||
result.face_clusters.len()
|
||||
);
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
pub async fn register_face(
|
||||
image_path: &str,
|
||||
name: &str,
|
||||
metadata: Option<serde_json::Value>,
|
||||
) -> Result<FaceRegistrationResult> {
|
||||
let executor = PythonExecutor::new()?;
|
||||
let script_path = executor.script_path("face_registration.py");
|
||||
|
||||
tracing::info!("[FACE_REGISTRATION] Registering face: {}", name);
|
||||
|
||||
if !script_path.exists() {
|
||||
anyhow::bail!("Face registration script not found");
|
||||
}
|
||||
|
||||
let output_path = format!("/tmp/face_registration_{}.json", uuid::Uuid::new_v4());
|
||||
|
||||
// Handle metadata separately to avoid lifetime issues
|
||||
let meta_temp_file = metadata.as_ref().map(|meta| {
|
||||
let meta_path = format!("/tmp/face_metadata_{}.json", uuid::Uuid::new_v4());
|
||||
std::fs::write(&meta_path, serde_json::to_string(meta).unwrap()).unwrap();
|
||||
meta_path
|
||||
});
|
||||
|
||||
// Build arguments - use output_path as database path so Python writes there
|
||||
let mut args = vec![
|
||||
image_path.to_string(),
|
||||
output_path.clone(),
|
||||
name.to_string(),
|
||||
];
|
||||
|
||||
// Add database parameter (point to same output for now)
|
||||
let database_path = output_path.clone();
|
||||
args.push("--database".to_string());
|
||||
args.push(database_path.clone());
|
||||
|
||||
if let Some(ref meta_path) = meta_temp_file {
|
||||
args.push("--metadata".to_string());
|
||||
args.push(meta_path.clone());
|
||||
}
|
||||
|
||||
let args_refs: Vec<&str> = args.iter().map(|s| s.as_str()).collect();
|
||||
executor
|
||||
.run(
|
||||
"face_registration.py",
|
||||
&args_refs,
|
||||
None,
|
||||
"FACE_REGISTRATION",
|
||||
Some(Duration::from_secs(300)),
|
||||
)
|
||||
.await
|
||||
.with_context(|| format!("Failed to run {:?}", script_path))?;
|
||||
|
||||
let json_str =
|
||||
std::fs::read_to_string(&output_path).context("Failed to read registration output")?;
|
||||
|
||||
let result: FaceRegistrationResult =
|
||||
serde_json::from_str(&json_str).context("Failed to parse registration output")?;
|
||||
|
||||
// Clean up temp files
|
||||
let _ = std::fs::remove_file(&output_path);
|
||||
if let Some(meta_path) = meta_temp_file {
|
||||
let _ = std::fs::remove_file(&meta_path);
|
||||
}
|
||||
|
||||
tracing::info!("[FACE_REGISTRATION] Registered face: {}", result.face_id);
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct FaceRegistrationResult {
|
||||
pub face_id: String,
|
||||
pub embedding: Vec<f32>,
|
||||
pub attributes: Option<FaceAttributes>,
|
||||
pub success: bool,
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_face_recognition_result_serialization() {
|
||||
let result = FaceRecognitionResult {
|
||||
frame_count: 100,
|
||||
fps: 30.0,
|
||||
frames: vec![FaceRecognitionFrame {
|
||||
frame: 0,
|
||||
timestamp: 0.0,
|
||||
faces: vec![RecognizedFaceDetection {
|
||||
face_id: Some("face_1".to_string()),
|
||||
x: 100,
|
||||
y: 100,
|
||||
width: 50,
|
||||
height: 60,
|
||||
confidence: 0.95,
|
||||
embedding: Some(vec![0.1, 0.2, 0.3]),
|
||||
attributes: Some(FaceAttributes {
|
||||
age: Some(30),
|
||||
gender: Some("male".to_string()),
|
||||
emotion: Some("neutral".to_string()),
|
||||
glasses: Some(false),
|
||||
mask: Some(false),
|
||||
pose: Some(FacePose {
|
||||
yaw: 0.1,
|
||||
pitch: 0.2,
|
||||
roll: 0.3,
|
||||
}),
|
||||
}),
|
||||
identity: Some(FaceIdentity {
|
||||
name: Some("John Doe".to_string()),
|
||||
confidence: 0.85,
|
||||
database_id: Some("user_123".to_string()),
|
||||
metadata: Some(serde_json::json!({"role": "employee"})),
|
||||
}),
|
||||
}],
|
||||
}],
|
||||
recognized_faces: vec![RecognizedFace {
|
||||
face_id: "face_1".to_string(),
|
||||
embedding: vec![0.1, 0.2, 0.3],
|
||||
first_seen: 0.0,
|
||||
last_seen: 10.0,
|
||||
total_appearances: 5,
|
||||
attributes: Some(FaceAttributes {
|
||||
age: Some(30),
|
||||
gender: Some("male".to_string()),
|
||||
emotion: Some("neutral".to_string()),
|
||||
glasses: Some(false),
|
||||
mask: Some(false),
|
||||
pose: Some(FacePose {
|
||||
yaw: 0.1,
|
||||
pitch: 0.2,
|
||||
roll: 0.3,
|
||||
}),
|
||||
}),
|
||||
identities: vec![FaceIdentity {
|
||||
name: Some("John Doe".to_string()),
|
||||
confidence: 0.85,
|
||||
database_id: Some("user_123".to_string()),
|
||||
metadata: Some(serde_json::json!({"role": "employee"})),
|
||||
}],
|
||||
cluster_id: Some("cluster_1".to_string()),
|
||||
}],
|
||||
face_clusters: vec![FaceCluster {
|
||||
cluster_id: "cluster_1".to_string(),
|
||||
face_ids: vec!["face_1".to_string()],
|
||||
centroid: vec![0.1, 0.2, 0.3],
|
||||
size: 1,
|
||||
representative_face_id: Some("face_1".to_string()),
|
||||
metadata: Some(serde_json::json!({"description": "main person"})),
|
||||
}],
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&result).unwrap();
|
||||
assert!(json.contains("face_1"));
|
||||
assert!(json.contains("John Doe"));
|
||||
assert!(json.contains("cluster_1"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_face_attributes_serialization() {
|
||||
let attributes = FaceAttributes {
|
||||
age: Some(25),
|
||||
gender: Some("female".to_string()),
|
||||
emotion: Some("happy".to_string()),
|
||||
glasses: Some(true),
|
||||
mask: Some(false),
|
||||
pose: Some(FacePose {
|
||||
yaw: -0.1,
|
||||
pitch: 0.05,
|
||||
roll: 0.02,
|
||||
}),
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&attributes).unwrap();
|
||||
assert!(json.contains("\"age\":25"));
|
||||
assert!(json.contains("\"gender\":\"female\""));
|
||||
assert!(json.contains("\"emotion\":\"happy\""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_face_identity_serialization() {
|
||||
let identity = FaceIdentity {
|
||||
name: Some("Alice Smith".to_string()),
|
||||
confidence: 0.92,
|
||||
database_id: Some("employee_456".to_string()),
|
||||
metadata: Some(serde_json::json!({
|
||||
"department": "engineering",
|
||||
"position": "senior developer"
|
||||
})),
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&identity).unwrap();
|
||||
assert!(json.contains("Alice Smith"));
|
||||
assert!(json.contains("\"confidence\":0.92"));
|
||||
assert!(json.contains("engineering"));
|
||||
}
|
||||
}
|
||||
562
src/core/processor/visual_chunk.rs
Normal file
562
src/core/processor/visual_chunk.rs
Normal file
@@ -0,0 +1,562 @@
|
||||
//! 視覺分片處理器 (Phase 2.2)
|
||||
//!
|
||||
//! 從 YOLO 結果生成視覺分片
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::time::Duration;
|
||||
|
||||
use super::executor::PythonExecutor;
|
||||
use super::yolo::{YoloFrame, YoloResult};
|
||||
|
||||
const VISUAL_CHUNK_TIMEOUT: Duration = Duration::from_secs(3600);
|
||||
|
||||
/// 視覺分片處理結果
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct VisualChunkResult {
|
||||
/// 生成的視覺分片數量
|
||||
pub chunk_count: u32,
|
||||
/// 處理的總幀數
|
||||
pub total_frames: u32,
|
||||
/// 檢測到的總物件數
|
||||
pub total_objects: u32,
|
||||
/// 唯一物件類別數
|
||||
pub unique_classes: u32,
|
||||
/// 生成的視覺分片
|
||||
pub chunks: Vec<crate::core::chunk::Chunk>,
|
||||
}
|
||||
|
||||
/// 從 YOLO 結果生成視覺分片
|
||||
pub async fn process_visual_chunk(
|
||||
file_id: i32,
|
||||
uuid: String,
|
||||
video_path: &str,
|
||||
yolo_result: &YoloResult,
|
||||
chunk_index_offset: u32,
|
||||
fps: f64,
|
||||
) -> Result<VisualChunkResult> {
|
||||
tracing::info!(
|
||||
"[VisualChunk] Starting visual chunk generation for video: {}, {} frames",
|
||||
video_path,
|
||||
yolo_result.frames.len()
|
||||
);
|
||||
|
||||
if yolo_result.frames.is_empty() {
|
||||
tracing::warn!("[VisualChunk] No YOLO frames to process");
|
||||
return Ok(VisualChunkResult {
|
||||
chunk_count: 0,
|
||||
total_frames: 0,
|
||||
total_objects: 0,
|
||||
unique_classes: 0,
|
||||
chunks: vec![],
|
||||
});
|
||||
}
|
||||
|
||||
// 策略 1: 固定幀數分片(每 N 幀一個分片)
|
||||
let chunks = create_fixed_frame_chunks(file_id, &uuid, yolo_result, chunk_index_offset, fps);
|
||||
|
||||
// 統計信息
|
||||
let total_objects: u32 = yolo_result
|
||||
.frames
|
||||
.iter()
|
||||
.map(|f| f.objects.len() as u32)
|
||||
.sum();
|
||||
let all_classes: Vec<String> = yolo_result
|
||||
.frames
|
||||
.iter()
|
||||
.flat_map(|f| f.objects.iter().map(|o| o.class_name.clone()))
|
||||
.collect();
|
||||
let unique_classes: u32 = all_classes
|
||||
.iter()
|
||||
.cloned()
|
||||
.collect::<std::collections::HashSet<_>>()
|
||||
.len() as u32;
|
||||
|
||||
tracing::info!(
|
||||
"[VisualChunk] Generated {} visual chunks from {} frames, {} total objects, {} unique classes",
|
||||
chunks.len(),
|
||||
yolo_result.frames.len(),
|
||||
total_objects,
|
||||
unique_classes
|
||||
);
|
||||
|
||||
Ok(VisualChunkResult {
|
||||
chunk_count: chunks.len() as u32,
|
||||
total_frames: yolo_result.frames.len() as u32,
|
||||
total_objects,
|
||||
unique_classes,
|
||||
chunks,
|
||||
})
|
||||
}
|
||||
|
||||
/// 創建固定幀數分片(每 N 幀一個分片)
|
||||
fn create_fixed_frame_chunks(
|
||||
file_id: i32,
|
||||
uuid: &str,
|
||||
yolo_result: &YoloResult,
|
||||
chunk_index_offset: u32,
|
||||
fps: f64,
|
||||
) -> Vec<crate::core::chunk::Chunk> {
|
||||
let mut chunks = Vec::new();
|
||||
|
||||
// 配置:每 30 幀創建一個分片(約 1 秒,如果 fps=30)
|
||||
let frames_per_chunk = 30;
|
||||
let total_frames = yolo_result.frames.len();
|
||||
|
||||
if total_frames == 0 {
|
||||
return chunks;
|
||||
}
|
||||
|
||||
let mut chunk_index = chunk_index_offset;
|
||||
let mut start_idx = 0;
|
||||
|
||||
while start_idx < total_frames {
|
||||
let end_idx = std::cmp::min(start_idx + frames_per_chunk, total_frames);
|
||||
|
||||
// 獲取這個分片的幀
|
||||
let chunk_frames: Vec<YoloFrame> = yolo_result.frames[start_idx..end_idx]
|
||||
.iter()
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
if chunk_frames.is_empty() {
|
||||
break;
|
||||
}
|
||||
|
||||
// 計算幀範圍
|
||||
let start_frame = chunk_frames.first().unwrap().frame as i64;
|
||||
let end_frame = chunk_frames.last().unwrap().frame as i64 + 1; // exclusive
|
||||
|
||||
// 創建視覺分片
|
||||
let chunk = crate::core::chunk::Chunk::from_yolo_frames(
|
||||
file_id,
|
||||
uuid.to_string(),
|
||||
chunk_index,
|
||||
start_frame,
|
||||
end_frame,
|
||||
fps,
|
||||
chunk_frames,
|
||||
);
|
||||
|
||||
chunks.push(chunk);
|
||||
|
||||
// 更新索引
|
||||
start_idx = end_idx;
|
||||
chunk_index += 1;
|
||||
}
|
||||
|
||||
chunks
|
||||
}
|
||||
|
||||
/// 基於物件相似度創建分片
|
||||
fn create_similarity_based_chunks(
|
||||
file_id: i32,
|
||||
uuid: &str,
|
||||
yolo_result: &YoloResult,
|
||||
chunk_index_offset: u32,
|
||||
fps: f64,
|
||||
similarity_threshold: f32,
|
||||
min_frames_per_chunk: usize,
|
||||
) -> Vec<crate::core::chunk::Chunk> {
|
||||
let mut chunks = Vec::new();
|
||||
|
||||
if yolo_result.frames.is_empty() {
|
||||
return chunks;
|
||||
}
|
||||
|
||||
let mut current_chunk_frames: Vec<YoloFrame> = Vec::new();
|
||||
let mut chunk_index = chunk_index_offset;
|
||||
let mut current_start_frame = 0;
|
||||
|
||||
for (i, frame) in yolo_result.frames.iter().enumerate() {
|
||||
if current_chunk_frames.is_empty() {
|
||||
current_chunk_frames.push(frame.clone());
|
||||
current_start_frame = frame.frame as i64;
|
||||
continue;
|
||||
}
|
||||
|
||||
// 檢查相似度(簡化版本:檢查物件類別是否相同)
|
||||
let last_frame = current_chunk_frames.last().unwrap();
|
||||
let similarity = calculate_frame_similarity(last_frame, frame);
|
||||
|
||||
if similarity >= similarity_threshold {
|
||||
// 相似度高,加入當前分片
|
||||
current_chunk_frames.push(frame.clone());
|
||||
} else {
|
||||
// 相似度低,創建新分片
|
||||
if current_chunk_frames.len() >= min_frames_per_chunk {
|
||||
let end_frame = current_chunk_frames.last().unwrap().frame as i64 + 1;
|
||||
|
||||
let chunk = crate::core::chunk::Chunk::from_yolo_frames(
|
||||
file_id,
|
||||
uuid.to_string(),
|
||||
chunk_index,
|
||||
current_start_frame,
|
||||
end_frame,
|
||||
fps,
|
||||
current_chunk_frames.clone(),
|
||||
);
|
||||
|
||||
chunks.push(chunk);
|
||||
chunk_index += 1;
|
||||
}
|
||||
|
||||
// 開始新的分片
|
||||
current_chunk_frames = vec![frame.clone()];
|
||||
current_start_frame = frame.frame as i64;
|
||||
}
|
||||
}
|
||||
|
||||
// 處理最後一個分片
|
||||
if current_chunk_frames.len() >= min_frames_per_chunk {
|
||||
let end_frame = current_chunk_frames.last().unwrap().frame as i64 + 1;
|
||||
|
||||
let chunk = crate::core::chunk::Chunk::from_yolo_frames(
|
||||
file_id,
|
||||
uuid.to_string(),
|
||||
chunk_index,
|
||||
current_start_frame,
|
||||
end_frame,
|
||||
fps,
|
||||
current_chunk_frames,
|
||||
);
|
||||
|
||||
chunks.push(chunk);
|
||||
}
|
||||
|
||||
chunks
|
||||
}
|
||||
|
||||
/// 計算兩個幀之間的相似度(基於物件類別)
|
||||
fn calculate_frame_similarity(frame1: &YoloFrame, frame2: &YoloFrame) -> f32 {
|
||||
if frame1.objects.is_empty() && frame2.objects.is_empty() {
|
||||
return 1.0;
|
||||
}
|
||||
|
||||
if frame1.objects.is_empty() || frame2.objects.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let set1: std::collections::HashSet<String> = frame1
|
||||
.objects
|
||||
.iter()
|
||||
.map(|o| o.class_name.clone())
|
||||
.collect();
|
||||
let set2: std::collections::HashSet<String> = frame2
|
||||
.objects
|
||||
.iter()
|
||||
.map(|o| o.class_name.clone())
|
||||
.collect();
|
||||
|
||||
let intersection: Vec<_> = set1.intersection(&set2).collect();
|
||||
let union: Vec<_> = set1.union(&set2).collect();
|
||||
|
||||
if union.is_empty() {
|
||||
0.0
|
||||
} else {
|
||||
intersection.len() as f32 / union.len() as f32
|
||||
}
|
||||
}
|
||||
|
||||
/// 使用 Python 腳本生成視覺分片(進階版本)
|
||||
pub async fn process_visual_chunk_advanced(
|
||||
video_path: &str,
|
||||
output_path: &str,
|
||||
uuid: Option<&str>,
|
||||
) -> Result<VisualChunkResult> {
|
||||
let executor = PythonExecutor::new()?;
|
||||
let script_path = executor.script_path("visual_chunk_processor.py");
|
||||
|
||||
tracing::info!(
|
||||
"[VisualChunk] Starting advanced visual chunk generation: {}",
|
||||
video_path
|
||||
);
|
||||
|
||||
if !script_path.exists() {
|
||||
tracing::warn!("[VisualChunk] Script not found, using basic generation");
|
||||
// 這裡可以回退到基本生成方法
|
||||
return Ok(VisualChunkResult {
|
||||
chunk_count: 0,
|
||||
total_frames: 0,
|
||||
total_objects: 0,
|
||||
unique_classes: 0,
|
||||
chunks: vec![],
|
||||
});
|
||||
}
|
||||
|
||||
executor
|
||||
.run(
|
||||
"visual_chunk_processor.py",
|
||||
&[video_path, output_path],
|
||||
uuid,
|
||||
"VisualChunk",
|
||||
Some(VISUAL_CHUNK_TIMEOUT),
|
||||
)
|
||||
.await
|
||||
.with_context(|| format!("Failed to run {:?}", script_path))?;
|
||||
|
||||
let json_str =
|
||||
std::fs::read_to_string(output_path).context("Failed to read visual chunk output")?;
|
||||
|
||||
let result: VisualChunkResult =
|
||||
serde_json::from_str(&json_str).context("Failed to parse visual chunk output")?;
|
||||
|
||||
tracing::info!(
|
||||
"[VisualChunk] Advanced generation result: {} chunks, {} frames",
|
||||
result.chunk_count,
|
||||
result.total_frames
|
||||
);
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_calculate_frame_similarity() {
|
||||
use crate::core::processor::yolo::{YoloFrame, YoloObject};
|
||||
|
||||
let frame1 = YoloFrame {
|
||||
frame: 0,
|
||||
timestamp: 0.0,
|
||||
objects: vec![
|
||||
YoloObject {
|
||||
class_name: "person".to_string(),
|
||||
class_id: 0,
|
||||
x: 100,
|
||||
y: 200,
|
||||
width: 50,
|
||||
height: 100,
|
||||
confidence: 0.95,
|
||||
},
|
||||
YoloObject {
|
||||
class_name: "car".to_string(),
|
||||
class_id: 2,
|
||||
x: 300,
|
||||
y: 150,
|
||||
width: 80,
|
||||
height: 60,
|
||||
confidence: 0.87,
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
let frame2 = YoloFrame {
|
||||
frame: 1,
|
||||
timestamp: 0.033,
|
||||
objects: vec![
|
||||
YoloObject {
|
||||
class_name: "person".to_string(),
|
||||
class_id: 0,
|
||||
x: 110,
|
||||
y: 210,
|
||||
width: 52,
|
||||
height: 102,
|
||||
confidence: 0.92,
|
||||
},
|
||||
YoloObject {
|
||||
class_name: "car".to_string(),
|
||||
class_id: 2,
|
||||
x: 310,
|
||||
y: 155,
|
||||
width: 82,
|
||||
height: 62,
|
||||
confidence: 0.85,
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
let frame3 = YoloFrame {
|
||||
frame: 2,
|
||||
timestamp: 0.066,
|
||||
objects: vec![YoloObject {
|
||||
class_name: "dog".to_string(),
|
||||
class_id: 16,
|
||||
x: 150,
|
||||
y: 250,
|
||||
width: 40,
|
||||
height: 60,
|
||||
confidence: 0.78,
|
||||
}],
|
||||
};
|
||||
|
||||
// 相同物件的幀應該高度相似
|
||||
let similarity_same = calculate_frame_similarity(&frame1, &frame2);
|
||||
assert!((similarity_same - 1.0).abs() < 0.001);
|
||||
|
||||
// 不同物件的幀應該不相似
|
||||
let similarity_diff = calculate_frame_similarity(&frame1, &frame3);
|
||||
assert!((similarity_diff - 0.0).abs() < 0.001);
|
||||
|
||||
// 空幀應該完全相似
|
||||
let empty_frame = YoloFrame {
|
||||
frame: 3,
|
||||
timestamp: 0.1,
|
||||
objects: vec![],
|
||||
};
|
||||
let similarity_empty = calculate_frame_similarity(&empty_frame, &empty_frame);
|
||||
assert!((similarity_empty - 1.0).abs() < 0.001);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_fixed_frame_chunks() {
|
||||
use crate::core::processor::yolo::{YoloFrame, YoloObject, YoloResult};
|
||||
|
||||
// 創建測試 YOLO 結果(60 幀,每幀都有物件)
|
||||
let mut frames = Vec::new();
|
||||
for i in 0..60 {
|
||||
frames.push(YoloFrame {
|
||||
frame: i as u64,
|
||||
timestamp: i as f64 / 30.0, // 假設 fps=30
|
||||
objects: vec![YoloObject {
|
||||
class_name: "person".to_string(),
|
||||
class_id: 0,
|
||||
x: 100,
|
||||
y: 200,
|
||||
width: 50,
|
||||
height: 100,
|
||||
confidence: 0.9,
|
||||
}],
|
||||
});
|
||||
}
|
||||
|
||||
let yolo_result = YoloResult {
|
||||
frame_count: 60,
|
||||
fps: 30.0,
|
||||
frames,
|
||||
};
|
||||
|
||||
let chunks = create_fixed_frame_chunks(1, "test-uuid", &yolo_result, 0, 30.0);
|
||||
|
||||
// 60 幀,每 30 幀一個分片,應該有 2 個分片
|
||||
assert_eq!(chunks.len(), 2);
|
||||
|
||||
// 檢查第一個分片
|
||||
let first_chunk = &chunks[0];
|
||||
assert_eq!(
|
||||
first_chunk.chunk_type,
|
||||
crate::core::chunk::ChunkType::Visual
|
||||
);
|
||||
assert_eq!(first_chunk.start_frame, 0);
|
||||
assert_eq!(first_chunk.end_frame, 30); // exclusive
|
||||
assert_eq!(first_chunk.frame_count, 30);
|
||||
|
||||
// 檢查第二個分片
|
||||
let second_chunk = &chunks[1];
|
||||
assert_eq!(
|
||||
second_chunk.chunk_type,
|
||||
crate::core::chunk::ChunkType::Visual
|
||||
);
|
||||
assert_eq!(second_chunk.start_frame, 30);
|
||||
assert_eq!(second_chunk.end_frame, 60); // exclusive
|
||||
assert_eq!(second_chunk.frame_count, 30);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_similarity_based_chunks() {
|
||||
use crate::core::processor::yolo::{YoloFrame, YoloObject, YoloResult};
|
||||
|
||||
// 創建測試 YOLO 結果
|
||||
let frames = vec![
|
||||
YoloFrame {
|
||||
// 幀 0-4: 都有 person 和 car
|
||||
frame: 0,
|
||||
timestamp: 0.0,
|
||||
objects: vec![
|
||||
YoloObject {
|
||||
class_name: "person".to_string(),
|
||||
class_id: 0,
|
||||
x: 100,
|
||||
y: 200,
|
||||
width: 50,
|
||||
height: 100,
|
||||
confidence: 0.9,
|
||||
},
|
||||
YoloObject {
|
||||
class_name: "car".to_string(),
|
||||
class_id: 2,
|
||||
x: 300,
|
||||
y: 150,
|
||||
width: 80,
|
||||
height: 60,
|
||||
confidence: 0.8,
|
||||
},
|
||||
],
|
||||
},
|
||||
YoloFrame {
|
||||
// 幀 1
|
||||
frame: 1,
|
||||
timestamp: 0.033,
|
||||
objects: vec![
|
||||
YoloObject {
|
||||
class_name: "person".to_string(),
|
||||
class_id: 0,
|
||||
x: 110,
|
||||
y: 210,
|
||||
width: 52,
|
||||
height: 102,
|
||||
confidence: 0.88,
|
||||
},
|
||||
YoloObject {
|
||||
class_name: "car".to_string(),
|
||||
class_id: 2,
|
||||
x: 310,
|
||||
y: 155,
|
||||
width: 82,
|
||||
height: 62,
|
||||
confidence: 0.78,
|
||||
},
|
||||
],
|
||||
},
|
||||
YoloFrame {
|
||||
// 幀 5-9: 只有 dog
|
||||
frame: 5,
|
||||
timestamp: 0.166,
|
||||
objects: vec![YoloObject {
|
||||
class_name: "dog".to_string(),
|
||||
class_id: 16,
|
||||
x: 150,
|
||||
y: 250,
|
||||
width: 40,
|
||||
height: 60,
|
||||
confidence: 0.7,
|
||||
}],
|
||||
},
|
||||
YoloFrame {
|
||||
// 幀 6
|
||||
frame: 6,
|
||||
timestamp: 0.2,
|
||||
objects: vec![YoloObject {
|
||||
class_name: "dog".to_string(),
|
||||
class_id: 16,
|
||||
x: 155,
|
||||
y: 255,
|
||||
width: 42,
|
||||
height: 62,
|
||||
confidence: 0.68,
|
||||
}],
|
||||
},
|
||||
];
|
||||
|
||||
let yolo_result = YoloResult {
|
||||
frame_count: 7,
|
||||
fps: 30.0,
|
||||
frames,
|
||||
};
|
||||
|
||||
let chunks = create_similarity_based_chunks(
|
||||
1,
|
||||
"test-uuid",
|
||||
&yolo_result,
|
||||
0,
|
||||
30.0,
|
||||
0.5, // similarity threshold
|
||||
2, // min frames per chunk
|
||||
);
|
||||
|
||||
// 應該有 2 個分片:一個是 person+car,一個是 dog
|
||||
assert_eq!(chunks.len(), 2);
|
||||
}
|
||||
}
|
||||
9
src/core/text/mod.rs
Normal file
9
src/core/text/mod.rs
Normal file
@@ -0,0 +1,9 @@
|
||||
pub mod online_synonym_expander;
|
||||
pub mod synonym;
|
||||
pub mod synonym_expander;
|
||||
pub mod tokenizer;
|
||||
|
||||
pub use online_synonym_expander::{global_online_expander, OnlineSynonymExpander};
|
||||
pub use synonym::{normalize_chinese_query, simplified_to_traditional, traditional_to_simplified};
|
||||
pub use synonym_expander::{global_synonym_expander, SynonymExpander};
|
||||
pub use tokenizer::{contains_chinese, extract_and_tokenize_text, tokenize_chinese_text};
|
||||
242
src/core/text/online_synonym_expander.rs
Normal file
242
src/core/text/online_synonym_expander.rs
Normal file
@@ -0,0 +1,242 @@
|
||||
use anyhow::{Context, Result};
|
||||
use once_cell::sync::Lazy;
|
||||
use serde::Deserialize;
|
||||
use std::collections::HashMap;
|
||||
use std::env;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
/// Online Synonym Expander
|
||||
/// Fetches synonyms from LLM (llama.cpp server) on-demand and caches them.
|
||||
///
|
||||
/// Environment variables:
|
||||
/// - `MOMENTRY_ONLINE_SYNONYM` - Enable online synonym expansion (default: false)
|
||||
/// - `MOMENTRY_LLM_SYNONYM_URL` - LLM server URL (default: http://127.0.0.1:8081)
|
||||
/// - `MOMENTRY_LLM_SYNONYM_MODEL` - Model name (default: gemma4)
|
||||
/// - `MOMENTRY_LLM_SYNONYM_TIMEOUT` - Request timeout in seconds (default: 60)
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct LlmResponse {
|
||||
choices: Vec<LlmChoice>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct LlmChoice {
|
||||
message: LlmMessage,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct LlmMessage {
|
||||
content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct OnlineSynonymExpander {
|
||||
/// Local synonym cache (loaded from file)
|
||||
local_map: HashMap<String, Vec<String>>,
|
||||
/// Runtime cache for LLM-fetched synonyms
|
||||
runtime_cache: Arc<Mutex<HashMap<String, Vec<String>>>>,
|
||||
/// LLM server URL
|
||||
api_url: String,
|
||||
/// Model name
|
||||
model: String,
|
||||
/// Request timeout
|
||||
timeout_secs: u64,
|
||||
}
|
||||
|
||||
static SYSTEM_PROMPT: &str = r#"You are a synonym generation assistant. For each given word, provide 8-12 synonyms in the same language.
|
||||
Rules:
|
||||
1. Return ONLY a JSON array of strings, nothing else
|
||||
2. Synonyms should be contextually relevant for video content search
|
||||
3. Include common words, informal terms, and related concepts
|
||||
4. Do NOT include the input word in the output
|
||||
5. All synonyms must be in the SAME language as the input word
|
||||
6. No explanations, no markdown, just the JSON array
|
||||
|
||||
Example input: "money"
|
||||
Example output: ["cash", "dollar", "currency", "funds", "bucks", "greenbacks", "coins", "wealth", "payment"]"#;
|
||||
|
||||
impl OnlineSynonymExpander {
|
||||
pub fn new(local_file_path: Option<&str>) -> Self {
|
||||
let local_map = if let Some(path) = local_file_path {
|
||||
match Self::load_local_file(path) {
|
||||
Ok(map) => map,
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to load local synonym file {}: {}", path, e);
|
||||
HashMap::new()
|
||||
}
|
||||
}
|
||||
} else {
|
||||
HashMap::new()
|
||||
};
|
||||
|
||||
let api_url = env::var("MOMENTRY_LLM_SYNONYM_URL")
|
||||
.unwrap_or_else(|_| "http://127.0.0.1:8081".to_string());
|
||||
let model = env::var("MOMENTRY_LLM_SYNONYM_MODEL").unwrap_or_else(|_| "gemma4".to_string());
|
||||
let timeout_secs = env::var("MOMENTRY_LLM_SYNONYM_TIMEOUT")
|
||||
.ok()
|
||||
.and_then(|v| v.parse().ok())
|
||||
.unwrap_or(60);
|
||||
|
||||
Self {
|
||||
local_map,
|
||||
runtime_cache: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_url,
|
||||
model,
|
||||
timeout_secs,
|
||||
}
|
||||
}
|
||||
|
||||
fn load_local_file(path: &str) -> Result<HashMap<String, Vec<String>>> {
|
||||
let content = std::fs::read_to_string(path).context("Failed to read local synonym file")?;
|
||||
let map: HashMap<String, Vec<String>> =
|
||||
serde_json::from_str(&content).context("Failed to parse local synonym JSON")?;
|
||||
Ok(map)
|
||||
}
|
||||
|
||||
/// Get synonyms for a word. Checks local map first, then runtime cache, then fetches from LLM.
|
||||
pub async fn expand_word(&self, word: &str) -> String {
|
||||
// 1. Check local map
|
||||
if let Some(syns) = self.local_map.get(word) {
|
||||
if !syns.is_empty() {
|
||||
let mut parts = vec![word.to_string()];
|
||||
parts.extend_from_slice(syns);
|
||||
return format!("({})", parts.join(" | "));
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Check runtime cache
|
||||
let mut cache = self.runtime_cache.lock().await;
|
||||
if let Some(syns) = cache.get(word) {
|
||||
if !syns.is_empty() {
|
||||
let mut parts = vec![word.to_string()];
|
||||
parts.extend_from_slice(syns);
|
||||
return format!("({})", parts.join(" | "));
|
||||
}
|
||||
}
|
||||
drop(cache);
|
||||
|
||||
// 3. Fetch from LLM
|
||||
if let Ok(synonyms) = self.fetch_from_llm(word).await {
|
||||
if !synonyms.is_empty() {
|
||||
// Add to runtime cache
|
||||
let mut cache = self.runtime_cache.lock().await;
|
||||
cache.insert(word.to_string(), synonyms.clone());
|
||||
drop(cache);
|
||||
|
||||
let mut parts = vec![word.to_string()];
|
||||
parts.extend_from_slice(&synonyms);
|
||||
return format!("({})", parts.join(" | "));
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Fallback: return original word
|
||||
word.to_string()
|
||||
}
|
||||
|
||||
async fn fetch_from_llm(&self, word: &str) -> Result<Vec<String>> {
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
let prompt = format!(
|
||||
r#"Give synonyms for: "{}"
|
||||
Return ONLY a JSON array of strings, nothing else. Do NOT include the input word."#,
|
||||
word
|
||||
);
|
||||
|
||||
let payload = serde_json::json!({
|
||||
"model": self.model,
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": SYSTEM_PROMPT
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt
|
||||
}
|
||||
],
|
||||
"temperature": 0.3,
|
||||
"stream": false,
|
||||
"max_tokens": 256,
|
||||
});
|
||||
|
||||
let response = client
|
||||
.post(format!("{}/v1/chat/completions", self.api_url))
|
||||
.json(&payload)
|
||||
.timeout(std::time::Duration::from_secs(self.timeout_secs))
|
||||
.send()
|
||||
.await
|
||||
.context("LLM request failed")?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
anyhow::bail!("LLM request failed with status: {}", response.status());
|
||||
}
|
||||
|
||||
let llm_resp: LlmResponse = response
|
||||
.json()
|
||||
.await
|
||||
.context("Failed to parse LLM response")?;
|
||||
|
||||
let content = &llm_resp
|
||||
.choices
|
||||
.get(0)
|
||||
.context("No choices in LLM response")?
|
||||
.message
|
||||
.content;
|
||||
|
||||
// Extract JSON from response (handle markdown code blocks)
|
||||
let json_str = if let Some(start) = content.find('[') {
|
||||
if let Some(end) = content.rfind(']') {
|
||||
&content[start..=end]
|
||||
} else {
|
||||
anyhow::bail!("No JSON array found in LLM response");
|
||||
}
|
||||
} else {
|
||||
anyhow::bail!("No JSON array found in LLM response");
|
||||
};
|
||||
|
||||
let synonyms: Vec<String> =
|
||||
serde_json::from_str(json_str).context("Failed to parse LLM synonyms JSON")?;
|
||||
|
||||
// Filter and normalize
|
||||
let cleaned: Vec<String> = synonyms
|
||||
.into_iter()
|
||||
.map(|s| s.trim().to_lowercase())
|
||||
.filter(|s| !s.is_empty() && !s.contains(' ')) // Filter out multi-word synonyms for to_tsquery compatibility
|
||||
.collect();
|
||||
|
||||
if cleaned.is_empty() {
|
||||
anyhow::bail!("No valid synonyms returned");
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
"LLM fetched {} synonyms for '{}': {:?}",
|
||||
cleaned.len(),
|
||||
word,
|
||||
cleaned.iter().take(5).collect::<Vec<_>>()
|
||||
);
|
||||
|
||||
Ok(cleaned)
|
||||
}
|
||||
|
||||
/// Get the number of cached synonyms
|
||||
pub async fn cache_size(&self) -> usize {
|
||||
self.runtime_cache.lock().await.len()
|
||||
}
|
||||
}
|
||||
|
||||
/// Global online synonym expander (lazy-loaded)
|
||||
static ONLINE_EXPANDER: Lazy<Option<OnlineSynonymExpander>> = Lazy::new(|| {
|
||||
if env::var("MOMENTRY_ONLINE_SYNONYM").is_ok() {
|
||||
let local_file = env::var("MOMENTRY_SYNONYM_FILE").ok();
|
||||
tracing::info!("Initializing online synonym expander");
|
||||
Some(OnlineSynonymExpander::new(local_file.as_deref()))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
});
|
||||
|
||||
/// Get the global online synonym expander (if enabled)
|
||||
pub fn global_online_expander() -> Option<&'static OnlineSynonymExpander> {
|
||||
ONLINE_EXPANDER.as_ref()
|
||||
}
|
||||
71
src/core/text/synonym.rs
Normal file
71
src/core/text/synonym.rs
Normal file
@@ -0,0 +1,71 @@
|
||||
use ferrous_opencc::{config::BuiltinConfig, OpenCC};
|
||||
use once_cell::sync::Lazy;
|
||||
|
||||
static OPENCC_S2T: Lazy<OpenCC> = Lazy::new(|| {
|
||||
OpenCC::from_config(BuiltinConfig::S2t)
|
||||
.expect("Failed to initialize OpenCC Simplified to Traditional converter")
|
||||
});
|
||||
|
||||
static OPENCC_T2S: Lazy<OpenCC> = Lazy::new(|| {
|
||||
OpenCC::from_config(BuiltinConfig::T2s)
|
||||
.expect("Failed to initialize OpenCC Traditional to Simplified converter")
|
||||
});
|
||||
|
||||
/// Convert Simplified Chinese text to Traditional Chinese
|
||||
pub fn simplified_to_traditional(text: &str) -> String {
|
||||
OPENCC_S2T.convert(text)
|
||||
}
|
||||
|
||||
/// Convert Traditional Chinese text to Simplified Chinese
|
||||
pub fn traditional_to_simplified(text: &str) -> String {
|
||||
OPENCC_T2S.convert(text)
|
||||
}
|
||||
|
||||
/// Normalize Chinese query for search:
|
||||
/// 1. Convert Simplified Chinese to Traditional Chinese (assuming database stores Traditional)
|
||||
/// 2. Return converted text
|
||||
pub fn normalize_chinese_query(text: &str) -> String {
|
||||
simplified_to_traditional(text)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_simplified_to_traditional() {
|
||||
// Example: Simplified "计算机" -> Traditional "計算機"
|
||||
let simplified = "计算机";
|
||||
let traditional = simplified_to_traditional(simplified);
|
||||
// The conversion might produce "計算機" (depending on dictionary)
|
||||
// We'll just verify it's not empty and different from input
|
||||
assert!(!traditional.is_empty());
|
||||
assert_ne!(traditional, simplified);
|
||||
|
||||
// Traditional input should remain unchanged (or nearly unchanged)
|
||||
let traditional_input = "計算機";
|
||||
let converted = simplified_to_traditional(traditional_input);
|
||||
assert_eq!(converted, traditional_input);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_traditional_to_simplified() {
|
||||
let traditional = "計算機";
|
||||
let simplified = traditional_to_simplified(traditional);
|
||||
assert!(!simplified.is_empty());
|
||||
assert_ne!(simplified, traditional);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normalize_chinese_query() {
|
||||
let simplified = "计算机";
|
||||
let normalized = normalize_chinese_query(simplified);
|
||||
// Should be Traditional
|
||||
assert_ne!(normalized, simplified);
|
||||
|
||||
let traditional = "計算機";
|
||||
let normalized2 = normalize_chinese_query(traditional);
|
||||
// Should remain Traditional
|
||||
assert_eq!(normalized2, traditional);
|
||||
}
|
||||
}
|
||||
247
src/core/text/synonym_expander.rs
Normal file
247
src/core/text/synonym_expander.rs
Normal file
@@ -0,0 +1,247 @@
|
||||
use anyhow::{Context, Result};
|
||||
use once_cell::sync::Lazy;
|
||||
use std::collections::HashMap;
|
||||
use std::env;
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
|
||||
/// 同義詞擴展器
|
||||
/// 從 JSON 檔案加載自定義同義詞映射
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct SynonymExpander {
|
||||
/// 詞語 -> 同義詞列表的映射
|
||||
map: HashMap<String, Vec<String>>,
|
||||
}
|
||||
|
||||
impl SynonymExpander {
|
||||
/// 從 JSON 檔案創建同義詞擴展器
|
||||
pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
|
||||
let content = fs::read_to_string(path).context("Failed to read synonym file")?;
|
||||
let map: HashMap<String, Vec<String>> =
|
||||
serde_json::from_str(&content).context("Failed to parse synonym JSON")?;
|
||||
Ok(Self { map })
|
||||
}
|
||||
|
||||
/// 從多個 JSON 檔案創建同義詞擴展器(後面的檔案會覆蓋前面的)
|
||||
pub fn from_files<P: AsRef<Path>>(paths: &[P]) -> Result<Self> {
|
||||
let mut combined_map = HashMap::new();
|
||||
|
||||
for path in paths {
|
||||
let content = fs::read_to_string(path)
|
||||
.with_context(|| format!("Failed to read synonym file: {:?}", path.as_ref()))?;
|
||||
let map: HashMap<String, Vec<String>> =
|
||||
serde_json::from_str(&content).with_context(|| {
|
||||
format!("Failed to parse synonym JSON from {:?}", path.as_ref())
|
||||
})?;
|
||||
|
||||
// 合併映射,後面的檔案覆蓋前面的
|
||||
for (key, synonyms) in map {
|
||||
combined_map.insert(key, synonyms);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Self { map: combined_map })
|
||||
}
|
||||
|
||||
/// 從內建預設資料創建(返回空映射,用戶可通過配置文件添加自定義同義詞)
|
||||
pub fn from_default() -> Self {
|
||||
Self::empty()
|
||||
}
|
||||
|
||||
/// 獲取詞語的同義詞列表(如果存在)
|
||||
pub fn get_synonyms(&self, word: &str) -> Option<&[String]> {
|
||||
self.map.get(word).map(|v| v.as_slice())
|
||||
}
|
||||
|
||||
/// 擴展查詢詞語:將詞語替換為 (詞語 OR 同義詞1 OR 同義詞2 ...)
|
||||
/// 如果沒有同義詞,返回原詞語
|
||||
pub fn expand_word(&self, word: &str) -> String {
|
||||
match self.get_synonyms(word) {
|
||||
Some(syns) if !syns.is_empty() => {
|
||||
let mut parts = vec![word.to_string()];
|
||||
parts.extend_from_slice(syns);
|
||||
format!("({})", parts.join(" | "))
|
||||
}
|
||||
_ => word.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// 擴展整個查詢字符串(空格分隔的詞語)
|
||||
pub fn expand_query(&self, query: &str) -> String {
|
||||
query
|
||||
.split_whitespace()
|
||||
.map(|word| self.expand_word(word))
|
||||
.collect::<Vec<_>>()
|
||||
.join(" & ")
|
||||
}
|
||||
|
||||
/// 對中文查詢進行智能擴展:先匹配已知同義詞,再對剩餘部分進行分詞
|
||||
pub fn expand_chinese_query(&self, query: &str) -> String {
|
||||
// 如果查詢很短,直接嘗試匹配整個查詢
|
||||
if query.chars().count() <= 4 {
|
||||
if let Some(syns) = self.get_synonyms(query) {
|
||||
let mut parts = vec![query.to_string()];
|
||||
parts.extend_from_slice(syns);
|
||||
return format!("({})", parts.join(" | "));
|
||||
}
|
||||
}
|
||||
|
||||
// 嘗試在查詢中尋找已知的同義詞
|
||||
let mut expanded_parts = Vec::new();
|
||||
let mut remaining_query = query;
|
||||
let mut found_synonym = false;
|
||||
|
||||
// 對同義詞鍵按長度降序排序(最長匹配優先)
|
||||
let mut keys: Vec<&String> = self.map.keys().collect();
|
||||
keys.sort_by_key(|b| std::cmp::Reverse(b.chars().count()));
|
||||
|
||||
// 貪婪匹配:尋找最長的同義詞匹配
|
||||
while !remaining_query.is_empty() {
|
||||
let mut matched = false;
|
||||
|
||||
for key in &keys {
|
||||
if remaining_query.starts_with(*key) {
|
||||
// 找到匹配的同義詞
|
||||
expanded_parts.push(self.expand_word(key));
|
||||
remaining_query = &remaining_query[key.len()..];
|
||||
found_synonym = true;
|
||||
matched = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if !matched {
|
||||
// 沒有找到同義詞,跳過第一個字符,繼續嘗試
|
||||
let first_char_len = remaining_query.chars().next().map_or(0, |c| c.len_utf8());
|
||||
if first_char_len > 0 {
|
||||
let next_part = &remaining_query[..first_char_len];
|
||||
expanded_parts.push(next_part.to_string());
|
||||
remaining_query = &remaining_query[first_char_len..];
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if found_synonym {
|
||||
// 如果有找到同義詞,使用擴展後的查詢
|
||||
expanded_parts.join(" & ")
|
||||
} else {
|
||||
// 沒有找到同義詞,返回原查詢(稍後會進行分詞)
|
||||
query.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
/// 創建空的同義詞擴展器(無同義詞映射)
|
||||
pub fn empty() -> Self {
|
||||
Self {
|
||||
map: HashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 全局同義詞擴展器(懶加載)
|
||||
static SYNONYM_EXPANDER: Lazy<SynonymExpander> = Lazy::new(|| {
|
||||
// 優先嘗試 MOMENTRY_SYNONYM_FILES(逗號分隔的多個檔案)
|
||||
if let Ok(files_var) = env::var("MOMENTRY_SYNONYM_FILES") {
|
||||
let file_paths: Vec<&str> = files_var
|
||||
.split(',')
|
||||
.map(|s| s.trim())
|
||||
.filter(|s| !s.is_empty())
|
||||
.collect();
|
||||
|
||||
if !file_paths.is_empty() {
|
||||
match SynonymExpander::from_files(&file_paths) {
|
||||
Ok(expander) => {
|
||||
tracing::info!(
|
||||
"Loaded synonym expander from {} files: {:?}",
|
||||
file_paths.len(),
|
||||
file_paths
|
||||
);
|
||||
return expander;
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
"Failed to load synonym expander from files {:?}: {}",
|
||||
file_paths,
|
||||
e
|
||||
);
|
||||
// 繼續嘗試單一檔案或使用預設
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 回退到單一檔案 MOMENTRY_SYNONYM_FILE(向下兼容)
|
||||
if let Ok(file_path) = env::var("MOMENTRY_SYNONYM_FILE") {
|
||||
match SynonymExpander::from_file(&file_path) {
|
||||
Ok(expander) => {
|
||||
tracing::info!("Loaded synonym expander from {}", file_path);
|
||||
expander
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to load synonym expander from {}: {}", file_path, e);
|
||||
SynonymExpander::empty()
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 使用預設同義詞(示例)
|
||||
SynonymExpander::from_default()
|
||||
}
|
||||
});
|
||||
|
||||
/// 獲取全局同義詞擴展器實例
|
||||
pub fn global_synonym_expander() -> &'static SynonymExpander {
|
||||
&SYNONYM_EXPANDER
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_expand_word() {
|
||||
let mut map = HashMap::new();
|
||||
map.insert(
|
||||
"電腦".to_string(),
|
||||
vec!["計算機".to_string(), "微机".to_string()],
|
||||
);
|
||||
map.insert(
|
||||
"工作".to_string(),
|
||||
vec!["任務".to_string(), "作業".to_string()],
|
||||
);
|
||||
let expander = SynonymExpander { map };
|
||||
|
||||
assert_eq!(expander.expand_word("電腦"), "(電腦 | 計算機 | 微机)");
|
||||
assert_eq!(expander.expand_word("工作"), "(工作 | 任務 | 作業)");
|
||||
assert_eq!(expander.expand_word("未知"), "未知");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expand_query() {
|
||||
let mut map = HashMap::new();
|
||||
map.insert(
|
||||
"電腦".to_string(),
|
||||
vec!["計算機".to_string(), "微机".to_string()],
|
||||
);
|
||||
map.insert(
|
||||
"工作".to_string(),
|
||||
vec!["任務".to_string(), "作業".to_string()],
|
||||
);
|
||||
let expander = SynonymExpander { map };
|
||||
|
||||
assert_eq!(
|
||||
expander.expand_query("電腦 工作"),
|
||||
"(電腦 | 計算機 | 微机) & (工作 | 任務 | 作業)"
|
||||
);
|
||||
assert_eq!(expander.expand_query("單個詞"), "單個詞");
|
||||
assert_eq!(expander.expand_query(""), "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_from_files_empty() {
|
||||
let paths: Vec<&str> = vec![];
|
||||
let expander = SynonymExpander::from_files(&paths).unwrap();
|
||||
assert!(expander.map.is_empty());
|
||||
}
|
||||
}
|
||||
121
src/core/text/tokenizer.rs
Normal file
121
src/core/text/tokenizer.rs
Normal file
@@ -0,0 +1,121 @@
|
||||
use jieba_rs::Jieba;
|
||||
use once_cell::sync::Lazy;
|
||||
|
||||
static JIEBA: Lazy<Jieba> = Lazy::new(Jieba::new);
|
||||
|
||||
/// 檢查文本是否包含中文字符
|
||||
/// 包括 CJK Unified Ideographs (U+4E00-U+9FFF) 和 Extension A (U+3400-U+4DBF)
|
||||
pub fn contains_chinese(text: &str) -> bool {
|
||||
text.chars()
|
||||
.any(|c| ('\u{4e00}'..='\u{9fff}').contains(&c) || ('\u{3400}'..='\u{4dbf}').contains(&c))
|
||||
}
|
||||
|
||||
/// 對中文文本進行分詞,並用空格連接分詞結果
|
||||
/// 非中文文本保持不變
|
||||
///
|
||||
/// # 示例
|
||||
/// ```
|
||||
/// use momentry_core::core::text::tokenizer::tokenize_chinese_text;
|
||||
///
|
||||
/// assert_eq!(tokenize_chinese_text("這是一個測試"), "這 是 一 個 測 試");
|
||||
/// assert_eq!(tokenize_chinese_text("Hello world"), "Hello world");
|
||||
/// assert_eq!(tokenize_chinese_text("中文English混合"), "中文 English 混合");
|
||||
/// ```
|
||||
pub fn tokenize_chinese_text(text: &str) -> String {
|
||||
if contains_chinese(text) {
|
||||
// 使用精確模式分詞(cut=false)
|
||||
let tokens = JIEBA.cut(text, false);
|
||||
tokens.join(" ")
|
||||
} else {
|
||||
text.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
/// 從 JSON 內容中提取文本並進行分詞
|
||||
/// 支持兩種格式:
|
||||
/// 1. content->'data'->>'text' (中文視頻格式)
|
||||
/// 2. content->'text' (英文視頻格式)
|
||||
pub fn extract_and_tokenize_text(content: &serde_json::Value) -> String {
|
||||
let raw_text = content
|
||||
.get("data")
|
||||
.and_then(|data| data.get("text"))
|
||||
.and_then(|v| v.as_str())
|
||||
.or_else(|| content.get("text").and_then(|v| v.as_str()))
|
||||
.unwrap_or("");
|
||||
|
||||
tokenize_chinese_text(raw_text)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_contains_chinese() {
|
||||
assert!(contains_chinese("中文"));
|
||||
assert!(contains_chinese("這是一個測試"));
|
||||
assert!(contains_chinese("混合文本 English 中文"));
|
||||
assert!(!contains_chinese("English only"));
|
||||
assert!(!contains_chinese("123"));
|
||||
assert!(!contains_chinese(""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tokenize_chinese_text() {
|
||||
// 純中文
|
||||
assert_eq!(tokenize_chinese_text("這是一個測試"), "這 是 一 個 測 試");
|
||||
|
||||
// 純英文
|
||||
assert_eq!(tokenize_chinese_text("Hello world"), "Hello world");
|
||||
|
||||
// 中英混合
|
||||
assert_eq!(
|
||||
tokenize_chinese_text("中文English混合"),
|
||||
"中文 English 混合"
|
||||
);
|
||||
|
||||
// 空字符串
|
||||
assert_eq!(tokenize_chinese_text(""), "");
|
||||
|
||||
// 數字和標點
|
||||
assert_eq!(tokenize_chinese_text("測試123。"), "測 試 123 。");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_and_tokenize_text() {
|
||||
// 中文格式:content->'data'->>'text'
|
||||
let content1 = serde_json::json!({
|
||||
"data": {
|
||||
"text": "這是一個測試"
|
||||
}
|
||||
});
|
||||
assert_eq!(extract_and_tokenize_text(&content1), "這 是 一 個 測 試");
|
||||
|
||||
// 英文格式:content->'text'
|
||||
let content2 = serde_json::json!({
|
||||
"text": "Hello world"
|
||||
});
|
||||
assert_eq!(extract_and_tokenize_text(&content2), "Hello world");
|
||||
|
||||
// 混合格式:優先使用 data->text
|
||||
let content3 = serde_json::json!({
|
||||
"data": {
|
||||
"text": "中文測試"
|
||||
},
|
||||
"text": "English text"
|
||||
});
|
||||
assert_eq!(extract_and_tokenize_text(&content3), "中文 測 試");
|
||||
|
||||
// 無文本
|
||||
let content4 = serde_json::json!({});
|
||||
assert_eq!(extract_and_tokenize_text(&content4), "");
|
||||
|
||||
// 非字符串文本
|
||||
let content5 = serde_json::json!({
|
||||
"data": {
|
||||
"text": 123
|
||||
}
|
||||
});
|
||||
assert_eq!(extract_and_tokenize_text(&content5), "");
|
||||
}
|
||||
}
|
||||
40
src/core/tmdb/ingest.rs
Normal file
40
src/core/tmdb/ingest.rs
Normal file
@@ -0,0 +1,40 @@
|
||||
use anyhow::{Context, Result};
|
||||
use serde::Deserialize;
|
||||
use std::path::Path;
|
||||
use tracing::{info, warn};
|
||||
|
||||
use crate::core::db::PostgresDb;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct CastEntry {
|
||||
pub name: String,
|
||||
pub role: String,
|
||||
pub image: Option<String>,
|
||||
}
|
||||
|
||||
/// Ingests TMDB cast data from the JSON file generated by `tmdb_cast_fetcher.py`
|
||||
pub async fn ingest_cast(db: &PostgresDb, json_path: &str) -> Result<usize> {
|
||||
let path = Path::new(json_path);
|
||||
if !path.exists() {
|
||||
return Err(anyhow::anyhow!("Cast JSON file not found: {}", json_path));
|
||||
}
|
||||
|
||||
let content = std::fs::read_to_string(path)
|
||||
.with_context(|| format!("Failed to read cast JSON: {}", json_path))?;
|
||||
|
||||
let cast_list: Vec<CastEntry> =
|
||||
serde_json::from_str(&content).with_context(|| "Invalid cast JSON format")?;
|
||||
|
||||
let mut count = 0;
|
||||
for entry in &cast_list {
|
||||
match db.get_or_create_identity(&entry.name).await {
|
||||
Ok(_talent) => {
|
||||
info!("Ingested TMDB cast: {} as {}", entry.name, entry.role);
|
||||
count += 1;
|
||||
}
|
||||
Err(e) => warn!("Failed to create talent '{}': {}", entry.name, e),
|
||||
}
|
||||
}
|
||||
|
||||
Ok(count)
|
||||
}
|
||||
1
src/core/tmdb/mod.rs
Normal file
1
src/core/tmdb/mod.rs
Normal file
@@ -0,0 +1 @@
|
||||
pub mod ingest;
|
||||
144
src/core/worker/job_runner.rs
Normal file
144
src/core/worker/job_runner.rs
Normal file
@@ -0,0 +1,144 @@
|
||||
use sqlx::PgPool;
|
||||
use tokio::time::{sleep, Duration};
|
||||
use tracing;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::core::chunk;
|
||||
|
||||
pub struct JobWorker {
|
||||
pool: PgPool,
|
||||
poll_interval: Duration,
|
||||
}
|
||||
|
||||
impl JobWorker {
|
||||
pub fn new(pool: PgPool, poll_interval_secs: u64) -> Self {
|
||||
Self {
|
||||
pool,
|
||||
poll_interval: Duration::from_secs(poll_interval_secs),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run(&self) {
|
||||
tracing::info!(
|
||||
"🤖 Job Worker started (Polling every {}s)",
|
||||
self.poll_interval.as_secs()
|
||||
);
|
||||
|
||||
loop {
|
||||
match self.process_next_job().await {
|
||||
Ok(has_work) => {
|
||||
if !has_work {
|
||||
// No work found, wait before polling again
|
||||
sleep(self.poll_interval).await;
|
||||
}
|
||||
// If we processed a job, loop immediately to check for more
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("❌ Job Worker error: {}", e);
|
||||
sleep(Duration::from_secs(5)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn process_next_job(&self) -> anyhow::Result<bool> {
|
||||
// 1. Fetch a QUEUED job
|
||||
// We use a transaction to ensure no two workers pick the same job (atomic update)
|
||||
let job_row: Option<(String, String, String, String, String, i64)> = sqlx::query_as(
|
||||
r#"
|
||||
UPDATE dev.jobs
|
||||
SET status = 'RUNNING', updated_at = NOW()
|
||||
WHERE id = (
|
||||
SELECT id FROM dev.jobs
|
||||
WHERE status = 'QUEUED'
|
||||
ORDER BY created_at ASC
|
||||
LIMIT 1
|
||||
FOR UPDATE SKIP LOCKED
|
||||
)
|
||||
RETURNING id::text, asset_uuid, rule, status, processor_list, total_frames
|
||||
"#,
|
||||
)
|
||||
.fetch_optional(&self.pool)
|
||||
.await?;
|
||||
|
||||
if let Some((job_id, asset_uuid, rule, _status, _processors, total_frames)) = job_row {
|
||||
let job_uuid =
|
||||
Uuid::parse_str(&job_id).map_err(|e| anyhow::anyhow!("Invalid job UUID: {}", e))?;
|
||||
|
||||
tracing::info!(
|
||||
"🚀 Processing Job {} for Asset {} (Rule: {})",
|
||||
job_id,
|
||||
asset_uuid,
|
||||
rule
|
||||
);
|
||||
|
||||
// 2. Execute Logic based on Rule
|
||||
let result = match rule.as_str() {
|
||||
"rule1" => {
|
||||
let fps = self.get_asset_fps(&asset_uuid).await?;
|
||||
chunk::rule1_ingest::ingest_rule1(&self.pool, &asset_uuid, fps).await
|
||||
}
|
||||
_ => {
|
||||
tracing::warn!("Unknown rule type: {}", rule);
|
||||
Ok(0)
|
||||
}
|
||||
};
|
||||
|
||||
// 3. Update Job Status
|
||||
match result {
|
||||
Ok(chunk_count) => {
|
||||
tracing::info!(
|
||||
"✅ Job {} completed. Processed {} items.",
|
||||
job_id,
|
||||
chunk_count
|
||||
);
|
||||
|
||||
sqlx::query!(
|
||||
"UPDATE dev.jobs SET status = 'COMPLETED', processed_frames = total_frames, updated_at = NOW() WHERE id = $1",
|
||||
job_uuid
|
||||
)
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
|
||||
sqlx::query!(
|
||||
"UPDATE dev.videos SET processing_status = 'COMPLETED' WHERE uuid = $1",
|
||||
asset_uuid
|
||||
)
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("❌ Job {} failed: {}", job_id, e);
|
||||
let err_msg = e.to_string();
|
||||
let safe_msg = if err_msg.len() > 500 {
|
||||
&err_msg[..500]
|
||||
} else {
|
||||
&err_msg
|
||||
};
|
||||
|
||||
sqlx::query!(
|
||||
"UPDATE dev.jobs SET status = 'FAILED', error_message = $2, updated_at = NOW() WHERE id = $1",
|
||||
job_uuid,
|
||||
safe_msg
|
||||
)
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
return Ok(true); // Processed a job
|
||||
}
|
||||
|
||||
Ok(false) // No job found
|
||||
}
|
||||
|
||||
async fn get_asset_fps(&self, uuid: &str) -> anyhow::Result<f64> {
|
||||
let fps: Option<f64> =
|
||||
sqlx::query_scalar("SELECT (metadata->>'fps')::float FROM dev.videos WHERE uuid = $1")
|
||||
.bind(uuid)
|
||||
.fetch_optional(&self.pool)
|
||||
.await?;
|
||||
|
||||
// Fallback to 29.97 if not found
|
||||
Ok(fps.unwrap_or(29.97))
|
||||
}
|
||||
}
|
||||
2
src/core/worker/mod.rs
Normal file
2
src/core/worker/mod.rs
Normal file
@@ -0,0 +1,2 @@
|
||||
pub mod job_runner;
|
||||
pub use job_runner::JobWorker;
|
||||
@@ -915,6 +915,7 @@ async fn main() -> Result<()> {
|
||||
user_id: None,
|
||||
job_id: None,
|
||||
created_at: String::new(),
|
||||
registration_time: None,
|
||||
};
|
||||
|
||||
let video_id = db.register_video(&record).await?;
|
||||
|
||||
@@ -924,6 +924,7 @@ async fn main() -> Result<()> {
|
||||
user_id: None,
|
||||
job_id: None,
|
||||
created_at: String::new(),
|
||||
registration_time: None,
|
||||
};
|
||||
|
||||
let video_id = db.register_video(&record).await?;
|
||||
@@ -2373,20 +2374,25 @@ async fn main() -> Result<()> {
|
||||
target
|
||||
);
|
||||
|
||||
for chunk in sentence_chunks {
|
||||
println!("Starting to process {} chunks...", sentence_chunks.len());
|
||||
for (i, chunk) in sentence_chunks.iter().enumerate() {
|
||||
if i < 3 {
|
||||
println!("Processing chunk {}/{}: {}", i+1, sentence_chunks.len(), chunk.chunk_id);
|
||||
}
|
||||
let text = chunk
|
||||
.content
|
||||
.get("data")
|
||||
.and_then(|data| data.get("text"))
|
||||
.get("text")
|
||||
.and_then(|v| v.as_str())
|
||||
.or_else(|| chunk.content.get("data").and_then(|data| data.get("text")).and_then(|v| v.as_str()))
|
||||
.or(chunk.text_content.as_deref())
|
||||
.unwrap_or("");
|
||||
|
||||
eprintln!("Embedding chunk {}/{}: {} (text len: {})...", i+1, sentence_chunks.len(), chunk.chunk_id, text.len());
|
||||
|
||||
if text.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
print!("Embedding chunk {}... ", chunk.chunk_id);
|
||||
|
||||
match embedder.embed_document(text).await {
|
||||
Ok(vector) => {
|
||||
let vector_id = format!("{}_{}", chunk.uuid, chunk.chunk_id);
|
||||
@@ -2420,10 +2426,12 @@ async fn main() -> Result<()> {
|
||||
}
|
||||
|
||||
stored_count += 1;
|
||||
println!("done ({} dims)", vector.len());
|
||||
if stored_count % 100 == 0 || stored_count <= 3 {
|
||||
println!("Stored {}/1867 vectors", stored_count);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
println!("failed: {}", e);
|
||||
eprintln!("embed_document error for {}: {}", chunk.chunk_id, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
10
src/test_embed.rs
Normal file
10
src/test_embed.rs
Normal file
@@ -0,0 +1,10 @@
|
||||
use momentry_core::core::embedding::comic_embed::Embedder;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
let embedder = Embedder::new("nomic-embed-text-v2-moe:latest".to_string());
|
||||
match embedder.embed_document("test embedding").await {
|
||||
Ok(vector) => println!("Success! Vector length: {}", vector.len()),
|
||||
Err(e) => println!("Error: {}", e),
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user