- Extract scan.rs, files.rs, types.rs, processing.rs, visual_chunk_search.rs - Move AppState and AppConfig to types.rs - Each module exposes pub fn xxx_routes() -> Router<AppState> - server.rs reduced from 5005 to 118 lines (orchestrator only) - All stubs filled with real implementations from git history - Verify: cargo check, clippy, tests all pass
1075 lines
36 KiB
Rust
1075 lines
36 KiB
Rust
use axum::{
|
||
extract::{Multipart, State},
|
||
http::StatusCode,
|
||
response::Json,
|
||
routing::{get, post},
|
||
Router,
|
||
};
|
||
use serde::{Deserialize, Serialize};
|
||
use sqlx::Row;
|
||
use std::path::PathBuf;
|
||
|
||
use crate::api::types::AppState;
|
||
use crate::core::db::schema;
|
||
use crate::core::db::PostgresDb;
|
||
|
||
pub fn identity_agent_routes() -> Router<AppState> {
|
||
Router::new()
|
||
.route(
|
||
"/api/v1/agents/identity/match-from-photo",
|
||
post(match_from_photo),
|
||
)
|
||
.route(
|
||
"/api/v1/agents/identity/match-from-trace",
|
||
post(match_from_trace),
|
||
)
|
||
}
|
||
|
||
#[derive(Debug, Serialize)]
|
||
pub struct IdentityResult {
|
||
pub identity_id: String,
|
||
pub person_ids: Vec<String>,
|
||
pub speaker_ids: Vec<String>,
|
||
pub confidence: f64,
|
||
pub evidence: IdentityEvidence,
|
||
pub reasoning: String,
|
||
}
|
||
|
||
#[derive(Debug, Serialize)]
|
||
pub struct IdentityEvidence {
|
||
pub face_similarity: Option<f64>,
|
||
pub speaker_overlap: f64,
|
||
pub time_overlap: f64,
|
||
pub frame_ratio: f64,
|
||
}
|
||
|
||
#[derive(Debug, Serialize)]
|
||
struct MatchFromPhotoResponse {
|
||
success: bool,
|
||
identity_uuid: String,
|
||
file_uuid: String,
|
||
matches: usize,
|
||
traces_matched: Vec<i32>,
|
||
message: String,
|
||
}
|
||
|
||
async fn match_from_photo(
|
||
State(state): State<AppState>,
|
||
mut multipart: Multipart,
|
||
) -> Result<Json<MatchFromPhotoResponse>, (StatusCode, Json<serde_json::Value>)> {
|
||
let mut identity_uuid = String::new();
|
||
let mut file_uuid = String::new();
|
||
let mut image_data: Option<Vec<u8>> = None;
|
||
|
||
while let Ok(Some(field)) = multipart.next_field().await {
|
||
let name = field.name().unwrap_or("").to_string();
|
||
match name.as_str() {
|
||
"identity_uuid" => {
|
||
identity_uuid = field.text().await.unwrap_or_default();
|
||
}
|
||
"file_uuid" => {
|
||
file_uuid = field.text().await.unwrap_or_default();
|
||
}
|
||
"image" => {
|
||
image_data = Some(field.bytes().await.unwrap_or_default().to_vec());
|
||
}
|
||
_ => {}
|
||
}
|
||
}
|
||
|
||
let uuid_clean = identity_uuid.replace('-', "");
|
||
if uuid_clean.is_empty() || file_uuid.is_empty() {
|
||
return Err((
|
||
StatusCode::BAD_REQUEST,
|
||
Json(serde_json::json!({
|
||
"success": false, "message": "identity_uuid and file_uuid are required"
|
||
})),
|
||
));
|
||
}
|
||
let data = image_data.ok_or_else(|| {
|
||
(
|
||
StatusCode::BAD_REQUEST,
|
||
Json(serde_json::json!({
|
||
"success": false, "message": "No image field found. Use field name 'image'."
|
||
})),
|
||
)
|
||
})?;
|
||
|
||
// 1. Save uploaded image to temp
|
||
let scripts_dir = std::env::var("MOMENTRY_SCRIPTS_DIR")
|
||
.unwrap_or_else(|_| "/Users/accusys/momentry_core_0.1/scripts".to_string());
|
||
let python_path = std::env::var("MOMENTRY_PYTHON_PATH")
|
||
.unwrap_or_else(|_| "/opt/homebrew/bin/python3.11".to_string());
|
||
let temp_dir = std::env::temp_dir().join("momentry_match_face");
|
||
std::fs::create_dir_all(&temp_dir).map_err(|e| {
|
||
(
|
||
StatusCode::INTERNAL_SERVER_ERROR,
|
||
Json(serde_json::json!({"message": format!("Failed to create temp dir: {}", e)})),
|
||
)
|
||
})?;
|
||
let temp_img = temp_dir.join(format!("{}.jpg", uuid_clean));
|
||
std::fs::write(&temp_img, &data).map_err(|e| {
|
||
(
|
||
StatusCode::INTERNAL_SERVER_ERROR,
|
||
Json(serde_json::json!({"message": format!("Failed to save temp image: {}", e)})),
|
||
)
|
||
})?;
|
||
|
||
// 2. Extract face embedding via Python script
|
||
let extract_script = std::path::Path::new(&scripts_dir).join("extract_face_embedding.py");
|
||
let output = tokio::process::Command::new(&*python_path)
|
||
.arg(&extract_script)
|
||
.arg(&temp_img)
|
||
.output()
|
||
.await
|
||
.map_err(|e| {
|
||
(
|
||
StatusCode::INTERNAL_SERVER_ERROR,
|
||
Json(serde_json::json!({"message": format!("Failed to run extractor: {}", e)})),
|
||
)
|
||
})?;
|
||
|
||
let _ = std::fs::remove_file(&temp_img);
|
||
|
||
if !output.status.success() {
|
||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||
return Err((
|
||
StatusCode::BAD_REQUEST,
|
||
Json(serde_json::json!({
|
||
"success": false, "message": format!("Face extraction failed: {}", stderr)
|
||
})),
|
||
));
|
||
}
|
||
|
||
let stdout = String::from_utf8_lossy(&output.stdout);
|
||
let extract_result: serde_json::Value = serde_json::from_str(&stdout).map_err(|_| {
|
||
(
|
||
StatusCode::INTERNAL_SERVER_ERROR,
|
||
Json(serde_json::json!({"message": "Failed to parse extractor output"})),
|
||
)
|
||
})?;
|
||
|
||
let embedding: Vec<f64> = serde_json::from_value(
|
||
extract_result
|
||
.get("embedding")
|
||
.ok_or_else(|| {
|
||
(
|
||
StatusCode::BAD_REQUEST,
|
||
Json(serde_json::json!({"message": "No embedding in extractor output"})),
|
||
)
|
||
})?
|
||
.clone(),
|
||
)
|
||
.map_err(|_| {
|
||
(
|
||
StatusCode::INTERNAL_SERVER_ERROR,
|
||
Json(serde_json::json!({"message": "Invalid embedding format"})),
|
||
)
|
||
})?;
|
||
|
||
let embedding_f32: Vec<f32> = embedding.into_iter().map(|v| v as f32).collect();
|
||
|
||
// 3. Look up identity internal ID
|
||
let id_table = schema::table_name("identities");
|
||
let identity_id_row: Option<(i32,)> = sqlx::query_as(&format!(
|
||
"SELECT id FROM {} WHERE REPLACE(uuid::text, '-', '') = $1",
|
||
id_table
|
||
))
|
||
.bind(&uuid_clean)
|
||
.fetch_optional(state.db.pool())
|
||
.await
|
||
.map_err(|e| {
|
||
(
|
||
StatusCode::INTERNAL_SERVER_ERROR,
|
||
Json(serde_json::json!({"message": format!("DB error: {}", e)})),
|
||
)
|
||
})?;
|
||
|
||
let identity_id = match identity_id_row {
|
||
Some((id,)) => id,
|
||
None => {
|
||
return Err((
|
||
StatusCode::NOT_FOUND,
|
||
Json(serde_json::json!({
|
||
"success": false, "message": "Identity not found"
|
||
})),
|
||
))
|
||
}
|
||
};
|
||
|
||
// 4. Find best matching trace (highest similarity, no threshold)
|
||
let fd_table = schema::table_name("face_detections");
|
||
let best_match: Option<(i32, i32, f64)> = sqlx::query_as(&format!(
|
||
r#"SELECT id, trace_id,
|
||
1 - (embedding::vector <=> $1::vector) as similarity
|
||
FROM {}
|
||
WHERE file_uuid = $2 AND embedding IS NOT NULL
|
||
ORDER BY embedding::vector <=> $1::vector
|
||
LIMIT 1"#,
|
||
fd_table
|
||
))
|
||
.bind(&embedding_f32)
|
||
.bind(&file_uuid)
|
||
.fetch_optional(state.db.pool())
|
||
.await
|
||
.map_err(|e| {
|
||
(
|
||
StatusCode::INTERNAL_SERVER_ERROR,
|
||
Json(serde_json::json!({"message": format!("Search failed: {}", e)})),
|
||
)
|
||
})?;
|
||
|
||
// 5. Update best match face_detection
|
||
let mut traces_matched: Vec<i32> = Vec::new();
|
||
if let Some((fb_id, fb_trace, fb_sim)) = best_match {
|
||
let _ = sqlx::query(&format!(
|
||
"UPDATE {} SET identity_id = $1 WHERE id = $2",
|
||
fd_table
|
||
))
|
||
.bind(identity_id)
|
||
.bind(fb_id)
|
||
.execute(state.db.pool())
|
||
.await;
|
||
traces_matched.push(fb_trace);
|
||
|
||
// 6. Save identity file
|
||
let _ = crate::core::identity::storage::save_identity_file(&*state.db, &uuid_clean).await;
|
||
|
||
Ok(Json(MatchFromPhotoResponse {
|
||
success: true,
|
||
identity_uuid: uuid_clean,
|
||
file_uuid,
|
||
matches: 1,
|
||
traces_matched,
|
||
message: format!(
|
||
"Best trace: trace_id={}, similarity={:.4}",
|
||
fb_trace, fb_sim
|
||
),
|
||
}))
|
||
} else {
|
||
Ok(Json(MatchFromPhotoResponse {
|
||
success: true,
|
||
identity_uuid: uuid_clean,
|
||
file_uuid,
|
||
matches: 0,
|
||
traces_matched,
|
||
message: "No matching face found in video".to_string(),
|
||
}))
|
||
}
|
||
}
|
||
|
||
#[derive(Debug, Deserialize)]
|
||
struct MatchFromTraceRequest {
|
||
file_uuid: String,
|
||
trace_id: i32,
|
||
identity_uuid: String,
|
||
}
|
||
|
||
async fn match_from_trace(
|
||
State(state): State<AppState>,
|
||
Json(req): Json<MatchFromTraceRequest>,
|
||
) -> Result<Json<MatchFromPhotoResponse>, (StatusCode, Json<serde_json::Value>)> {
|
||
let uuid_clean = req.identity_uuid.replace('-', "");
|
||
|
||
// 1. Get 3 best face embeddings from this trace at different angles
|
||
// Divide trace frame range into 3 segments, pick best face from each
|
||
let fd_table = schema::table_name("face_detections");
|
||
let all_faces: Vec<(Vec<f32>, i64)> = sqlx::query_as::<_, (Vec<f32>, i64)>(&format!(
|
||
"SELECT embedding, frame_number FROM {} \
|
||
WHERE file_uuid = $1 AND trace_id = $2 AND embedding IS NOT NULL \
|
||
ORDER BY frame_number ASC",
|
||
fd_table
|
||
))
|
||
.bind(&req.file_uuid)
|
||
.bind(req.trace_id)
|
||
.fetch_all(state.db.pool())
|
||
.await
|
||
.map_err(|e| {
|
||
(
|
||
StatusCode::INTERNAL_SERVER_ERROR,
|
||
Json(serde_json::json!({"message": format!("DB error: {}", e)})),
|
||
)
|
||
})?;
|
||
|
||
if all_faces.is_empty() {
|
||
return Err((
|
||
StatusCode::NOT_FOUND,
|
||
Json(serde_json::json!({
|
||
"success": false, "message": "No embedding found for this trace"
|
||
})),
|
||
));
|
||
}
|
||
|
||
// Pick 3 samples: divide frame range into 3 segments, use face with largest area per segment
|
||
let total = all_faces.len();
|
||
let segments = [
|
||
(0, total / 3),
|
||
(total / 3, total * 2 / 3),
|
||
(total * 2 / 3, total),
|
||
];
|
||
|
||
let mut query_embeddings: Vec<Vec<f32>> = Vec::new();
|
||
|
||
// Get width*height info if available (not all pipelines store it)
|
||
let face_sizes: Vec<(i64, i32)> = sqlx::query_as::<_, (i64, i32)>(&format!(
|
||
"SELECT frame_number, COALESCE(width, 0) * COALESCE(height, 0) AS area \
|
||
FROM {} WHERE file_uuid = $1 AND trace_id = $2 AND embedding IS NOT NULL \
|
||
ORDER BY frame_number ASC",
|
||
fd_table
|
||
))
|
||
.bind(&req.file_uuid)
|
||
.bind(req.trace_id)
|
||
.fetch_all(state.db.pool())
|
||
.await
|
||
.unwrap_or_default();
|
||
|
||
let face_sizes_map: std::collections::HashMap<i64, i32> = face_sizes.into_iter().collect();
|
||
|
||
for (start, end) in segments {
|
||
let seg_start = start.min(total - 1);
|
||
let seg_end = end.min(total);
|
||
if seg_start >= seg_end {
|
||
continue;
|
||
}
|
||
let seg_slice = &all_faces[seg_start..seg_end];
|
||
// Pick the face with largest area within this segment
|
||
let best_idx = seg_slice
|
||
.iter()
|
||
.enumerate()
|
||
.max_by_key(|(_, f)| face_sizes_map.get(&f.1).copied().unwrap_or(0))
|
||
.map(|(i, _)| i)
|
||
.unwrap_or(0);
|
||
query_embeddings.push(seg_slice[best_idx].0.clone());
|
||
}
|
||
|
||
if query_embeddings.is_empty() {
|
||
query_embeddings.push(all_faces[total / 2].0.clone());
|
||
}
|
||
|
||
// 2. Three angles each find their best match; union all results
|
||
let mut validated: Vec<(i32, i32, f64)> = Vec::new();
|
||
let mut seen_trace_ids = std::collections::HashSet::new();
|
||
|
||
for qemb in &query_embeddings {
|
||
let top = sqlx::query_as::<_, (i32, i32, f64)>(&format!(
|
||
r#"SELECT id, trace_id,
|
||
1 - (embedding::vector <=> $1::vector) as similarity
|
||
FROM {}
|
||
WHERE file_uuid = $2
|
||
AND trace_id != $3
|
||
AND embedding IS NOT NULL
|
||
ORDER BY embedding::vector <=> $1::vector
|
||
LIMIT 1"#,
|
||
fd_table
|
||
))
|
||
.bind(qemb)
|
||
.bind(&req.file_uuid)
|
||
.bind(req.trace_id)
|
||
.fetch_optional(state.db.pool())
|
||
.await
|
||
.map_err(|e| {
|
||
(
|
||
StatusCode::INTERNAL_SERVER_ERROR,
|
||
Json(serde_json::json!({"message": format!("Search failed: {}", e)})),
|
||
)
|
||
})?;
|
||
|
||
if let Some((cface_id, c_trace_id, c_sim)) = top {
|
||
if seen_trace_ids.insert(c_trace_id) {
|
||
validated.push((cface_id, c_trace_id, c_sim));
|
||
}
|
||
}
|
||
}
|
||
|
||
// 3. Look up identity internal ID
|
||
let id_table = schema::table_name("identities");
|
||
let identity_id_row: Option<(i32,)> = sqlx::query_as(&format!(
|
||
"SELECT id FROM {} WHERE REPLACE(uuid::text, '-', '') = $1",
|
||
id_table
|
||
))
|
||
.bind(&uuid_clean)
|
||
.fetch_optional(state.db.pool())
|
||
.await
|
||
.map_err(|e| {
|
||
(
|
||
StatusCode::INTERNAL_SERVER_ERROR,
|
||
Json(serde_json::json!({"message": format!("DB error: {}", e)})),
|
||
)
|
||
})?;
|
||
|
||
let identity_id = match identity_id_row {
|
||
Some((id,)) => id,
|
||
None => {
|
||
return Err((
|
||
StatusCode::NOT_FOUND,
|
||
Json(serde_json::json!({
|
||
"success": false, "message": "Identity not found"
|
||
})),
|
||
))
|
||
}
|
||
};
|
||
|
||
// 4. Update matched face_detections
|
||
let mut traces_matched: Vec<i32> = Vec::new();
|
||
for (id, trace_id, _similarity) in &validated {
|
||
if let Err(e) = sqlx::query(&format!(
|
||
"UPDATE {} SET identity_id = $1 WHERE id = $2",
|
||
fd_table
|
||
))
|
||
.bind(identity_id)
|
||
.bind(id)
|
||
.execute(state.db.pool())
|
||
.await
|
||
{
|
||
tracing::warn!(
|
||
"[match-from-trace] Failed to update face_detection {}: {}",
|
||
id,
|
||
e
|
||
);
|
||
} else {
|
||
if !traces_matched.contains(trace_id) {
|
||
traces_matched.push(*trace_id);
|
||
}
|
||
}
|
||
}
|
||
|
||
// 5. Also bind the source trace itself
|
||
let _ = sqlx::query(&format!(
|
||
"UPDATE {} SET identity_id = $1 WHERE file_uuid = $2 AND trace_id = $3",
|
||
fd_table
|
||
))
|
||
.bind(identity_id)
|
||
.bind(&req.file_uuid)
|
||
.bind(req.trace_id)
|
||
.execute(state.db.pool())
|
||
.await;
|
||
|
||
if !traces_matched.contains(&req.trace_id) {
|
||
traces_matched.push(req.trace_id);
|
||
}
|
||
|
||
// 6. Save identity file
|
||
let _ = crate::core::identity::storage::save_identity_file(&*state.db, &uuid_clean).await;
|
||
|
||
let match_count = validated.len() + 1;
|
||
let trace_count = traces_matched.len();
|
||
Ok(Json(MatchFromPhotoResponse {
|
||
success: true,
|
||
identity_uuid: uuid_clean,
|
||
file_uuid: req.file_uuid,
|
||
matches: match_count,
|
||
traces_matched,
|
||
message: format!(
|
||
"Matched {} faces ({} unique traces)",
|
||
match_count, trace_count
|
||
),
|
||
}))
|
||
}
|
||
|
||
fn extract_persons_from_face_data(face_data: &serde_json::Value) -> Vec<PersonData> {
|
||
let mut persons = Vec::new();
|
||
if let Some(frames) = face_data.get("frames").and_then(|f| f.as_array()) {
|
||
let mut person_frames_map: std::collections::HashMap<String, Vec<i32>> =
|
||
std::collections::HashMap::new();
|
||
for frame in frames {
|
||
if let Some(frame_num) = frame.get("frame").and_then(|f| f.as_i64()) {
|
||
if let Some(person_id) = frame.get("person_id").and_then(|p| p.as_str()) {
|
||
person_frames_map
|
||
.entry(person_id.to_string())
|
||
.or_insert_with(Vec::new)
|
||
.push(frame_num as i32);
|
||
}
|
||
}
|
||
}
|
||
for (person_id, frames) in person_frames_map {
|
||
persons.push(PersonData {
|
||
person_id,
|
||
frames,
|
||
avg_embedding: None,
|
||
});
|
||
}
|
||
}
|
||
persons
|
||
}
|
||
|
||
fn extract_speakers_from_asrx_data(asrx_data: &Option<serde_json::Value>) -> Vec<SpeakerData> {
|
||
let mut speakers = Vec::new();
|
||
if let Some(data) = asrx_data {
|
||
if let Some(segments) = data.get("segments").and_then(|s| s.as_array()) {
|
||
for seg in segments {
|
||
if let (Some(start), Some(end), Some(speaker_id)) = (
|
||
seg.get("start_time").and_then(|v| v.as_f64()),
|
||
seg.get("end_time").and_then(|v| v.as_f64()),
|
||
seg.get("speaker_id").and_then(|v| v.as_str()),
|
||
) {
|
||
speakers.push(SpeakerData {
|
||
speaker_id: speaker_id.to_string(),
|
||
segments: vec![(start, end)],
|
||
});
|
||
}
|
||
}
|
||
}
|
||
}
|
||
speakers
|
||
}
|
||
|
||
fn analyze_person_speaker_overlap(
|
||
persons: &[PersonData],
|
||
speakers: &[SpeakerData],
|
||
) -> Vec<IdentityResult> {
|
||
let mut identities: Vec<IdentityResult> = Vec::new();
|
||
let mut visited_persons: std::collections::HashSet<String> = std::collections::HashSet::new();
|
||
|
||
for person in persons {
|
||
if visited_persons.contains(&person.person_id) {
|
||
continue;
|
||
}
|
||
|
||
let mut matched_persons = vec![person.person_id.clone()];
|
||
let mut matched_speakers: Vec<String> = Vec::new();
|
||
visited_persons.insert(person.person_id.clone());
|
||
|
||
for other_person in persons {
|
||
if visited_persons.contains(&other_person.person_id) {
|
||
continue;
|
||
}
|
||
|
||
// Check if persons co-occur in time (frame proximity)
|
||
let overlap = person
|
||
.frames
|
||
.iter()
|
||
.any(|f| other_person.frames.contains(f));
|
||
if overlap {
|
||
matched_persons.push(other_person.person_id.clone());
|
||
visited_persons.insert(other_person.person_id.clone());
|
||
}
|
||
}
|
||
|
||
// Check speaker overlap
|
||
let person_time_range = (
|
||
person.frames.iter().min().copied().unwrap_or(0) as f64,
|
||
person.frames.iter().max().copied().unwrap_or(0) as f64,
|
||
);
|
||
for speaker in speakers {
|
||
let has_overlap = speaker
|
||
.segments
|
||
.iter()
|
||
.any(|(start, end)| *start <= person_time_range.1 && *end >= person_time_range.0);
|
||
if has_overlap {
|
||
if !matched_speakers.contains(&speaker.speaker_id) {
|
||
matched_speakers.push(speaker.speaker_id.clone());
|
||
}
|
||
}
|
||
}
|
||
|
||
let frame_count = person.frames.len() as f64;
|
||
let speaker_overlap = if matched_speakers.is_empty() {
|
||
0.0
|
||
} else {
|
||
matched_speakers.len() as f64 / speakers.len().max(1) as f64
|
||
};
|
||
|
||
identities.push(IdentityResult {
|
||
identity_id: person.person_id.clone(),
|
||
person_ids: matched_persons.clone(),
|
||
speaker_ids: matched_speakers.clone(),
|
||
confidence: 0.5 + (speaker_overlap * 0.3),
|
||
evidence: IdentityEvidence {
|
||
face_similarity: None,
|
||
speaker_overlap,
|
||
time_overlap: 1.0,
|
||
frame_ratio: frame_count / 100.0,
|
||
},
|
||
reasoning: format!(
|
||
"Matched {} persons with {} speakers, overlap={:.2}",
|
||
matched_persons.len(),
|
||
speaker_overlap,
|
||
speaker_overlap
|
||
),
|
||
});
|
||
}
|
||
|
||
identities
|
||
}
|
||
|
||
#[derive(Debug)]
|
||
struct PersonData {
|
||
person_id: String,
|
||
frames: Vec<i32>,
|
||
avg_embedding: Option<Vec<f64>>,
|
||
}
|
||
|
||
#[derive(Debug)]
|
||
struct SpeakerData {
|
||
speaker_id: String,
|
||
segments: Vec<(f64, f64)>,
|
||
}
|
||
|
||
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
|
||
if a.len() != b.len() || a.is_empty() {
|
||
return 0.0;
|
||
}
|
||
let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
|
||
let na: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||
let nb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||
if na == 0.0 || nb == 0.0 {
|
||
0.0
|
||
} else {
|
||
dot / (na * nb)
|
||
}
|
||
}
|
||
|
||
/// 迭代多角度 face embedding 比對 + 傳播
|
||
/// Round 1: 用 TMDb seed face_embedding 比對 face_detections (threshold 0.50)
|
||
/// Round 2+: 用已匹配 trace 的所有 face 作為 seed,傳播到未匹配 trace
|
||
async fn match_faces_iterative(pool: &sqlx::PgPool, file_uuid: &str) -> anyhow::Result<usize> {
|
||
// Step 1: 載入 TMDb identities (source='tmdb' 且有 face_embedding)
|
||
let identities_table = schema::table_name("identities");
|
||
let tmdb_rows = sqlx::query_as::<_, (i32, String, Vec<f32>)>(
|
||
&format!("SELECT id, name, face_embedding::real[] FROM {} WHERE source='tmdb' AND face_embedding IS NOT NULL", identities_table)
|
||
)
|
||
.fetch_all(pool).await?;
|
||
|
||
if tmdb_rows.is_empty() {
|
||
tracing::warn!("[FaceMatch] No TMDb identities with face embeddings");
|
||
return Ok(0);
|
||
}
|
||
tracing::info!(
|
||
"[FaceMatch] Loaded {} TMDb seed identities",
|
||
tmdb_rows.len()
|
||
);
|
||
|
||
// Step 2: 載入所有 face_detections(含 frame_number),按 trace_id 分組
|
||
let fd_table = schema::table_name("face_detections");
|
||
let fd_rows = sqlx::query_as::<_, (i32, i32, Vec<f32>)>(&format!(
|
||
"SELECT trace_id, frame_number, embedding FROM {} \
|
||
WHERE file_uuid=$1 AND trace_id IS NOT NULL AND embedding IS NOT NULL \
|
||
ORDER BY trace_id, frame_number",
|
||
fd_table
|
||
))
|
||
.bind(file_uuid)
|
||
.fetch_all(pool)
|
||
.await?;
|
||
|
||
if fd_rows.is_empty() {
|
||
tracing::warn!("[FaceMatch] No face detections with embeddings");
|
||
return Ok(0);
|
||
}
|
||
|
||
// 分組:trace_id → (frame_number, embedding)
|
||
use std::collections::HashMap;
|
||
let mut trace_faces_raw: HashMap<i32, Vec<(i32, Vec<f32>)>> = HashMap::new();
|
||
for (tid, frame, emb) in &fd_rows {
|
||
trace_faces_raw
|
||
.entry(*tid)
|
||
.or_insert_with(Vec::new)
|
||
.push((*frame, emb.clone()));
|
||
}
|
||
|
||
// 從每個 trace 選取不同角度的 3 個 face embedding
|
||
// 策略:按 frame_number 排序,取前中後各 1 個
|
||
let mut trace_samples: HashMap<i32, Vec<Vec<f32>>> = HashMap::new();
|
||
for (tid, mut faces) in trace_faces_raw {
|
||
faces.sort_by_key(|(frame, _)| *frame);
|
||
let n = faces.len();
|
||
let indices = if n <= 3 {
|
||
(0..n).collect()
|
||
} else {
|
||
let mid = n / 2;
|
||
vec![0, mid, n - 1]
|
||
};
|
||
let samples: Vec<Vec<f32>> = indices.iter().map(|&i| faces[i].1.clone()).collect();
|
||
trace_samples.insert(tid, samples);
|
||
}
|
||
|
||
let total_traces = trace_samples.len();
|
||
let sample_count: usize = trace_samples.values().map(|v| v.len()).sum();
|
||
tracing::info!(
|
||
"[FaceMatch] Loaded {} traces, sampled {} embeddings (3-angle)",
|
||
total_traces,
|
||
sample_count
|
||
);
|
||
|
||
// Step 3: 建立 TMDb 查找表
|
||
let tmdb_seeds: Vec<(i32, String, Vec<f32>)> = tmdb_rows;
|
||
|
||
// Step 4: 迭代匹配
|
||
const TH: f32 = 0.50;
|
||
let mut matched: HashMap<i32, String> = HashMap::new(); // trace_id → identity_name
|
||
|
||
// Round 1: 用 3-angle samples 比對 TMDb
|
||
// 每個 trace 選 3 個不同角度 face,取最高 similarity
|
||
for (&tid, samples) in &trace_samples {
|
||
let mut best_name = String::new();
|
||
let mut best_sim = 0.0f32;
|
||
for (_, ref name, ref tmdb_emb) in &tmdb_seeds {
|
||
for face_emb in samples {
|
||
let s = cosine_similarity(face_emb, tmdb_emb);
|
||
if s > best_sim {
|
||
best_sim = s;
|
||
best_name = name.clone();
|
||
}
|
||
}
|
||
}
|
||
if best_sim >= TH {
|
||
matched.insert(tid, best_name);
|
||
}
|
||
}
|
||
tracing::info!(
|
||
"[FaceMatch] Round 1: {} matched ({}%) — writing to DB",
|
||
matched.len(),
|
||
matched.len() * 100 / total_traces
|
||
);
|
||
|
||
// Step 5: 寫入 DB — Round 1 結果先存
|
||
let identities_table = schema::table_name("identities");
|
||
let fd_table = schema::table_name("face_detections");
|
||
let mut updated = 0usize;
|
||
for (tid, name) in &matched {
|
||
let id_opt = sqlx::query_scalar::<_, Option<i32>>(&format!(
|
||
"SELECT id FROM {} WHERE name=$1 AND source='tmdb'",
|
||
identities_table
|
||
))
|
||
.bind(name)
|
||
.fetch_optional(pool)
|
||
.await?;
|
||
if let Some(identity_id) = id_opt {
|
||
let _ = sqlx::query(&format!(
|
||
"UPDATE {} SET identity_id=$1 WHERE file_uuid=$2 AND trace_id=$3",
|
||
fd_table
|
||
))
|
||
.bind(identity_id)
|
||
.bind(file_uuid)
|
||
.bind(tid)
|
||
.execute(pool)
|
||
.await;
|
||
updated += 1;
|
||
}
|
||
}
|
||
tracing::info!("[FaceMatch] Round 1: updated {} face_detections", updated);
|
||
|
||
// Round 2+: 用已匹配的 face 作為 seed 傳播(剩餘未匹配的 trace)
|
||
let initial_matched = matched.len();
|
||
for round_n in 2..=5 {
|
||
let prev = matched.len();
|
||
// 建立 seed pool: name → Vec<embedding>
|
||
let mut seed_pool: HashMap<String, Vec<&Vec<f32>>> = HashMap::new();
|
||
for (&tid, name) in &matched {
|
||
if let Some(samples) = trace_samples.get(&tid) {
|
||
seed_pool
|
||
.entry(name.clone())
|
||
.or_default()
|
||
.extend(samples.iter());
|
||
}
|
||
}
|
||
|
||
let mut new_matches: Vec<(i32, String)> = Vec::new();
|
||
for (&tid, samples) in &trace_samples {
|
||
if matched.contains_key(&tid) {
|
||
continue;
|
||
}
|
||
let mut best_name = String::new();
|
||
let mut best_sim = 0.0f32;
|
||
if samples.is_empty() {
|
||
continue;
|
||
}
|
||
// 用 3-angle samples 分別比對 seed,取最高 similarity
|
||
for (name, seed_faces) in &seed_pool {
|
||
for face_emb in samples {
|
||
for seed in seed_faces {
|
||
let s = cosine_similarity(face_emb, seed);
|
||
if s > best_sim {
|
||
best_sim = s;
|
||
best_name = name.clone();
|
||
}
|
||
}
|
||
}
|
||
}
|
||
if best_sim >= TH {
|
||
new_matches.push((tid, best_name));
|
||
}
|
||
}
|
||
for (tid, name) in new_matches {
|
||
matched.insert(tid, name);
|
||
}
|
||
let new = matched.len() - prev;
|
||
tracing::info!(
|
||
"[FaceMatch] Round {}: +{} matched (total {}, {}%)",
|
||
round_n,
|
||
new,
|
||
matched.len(),
|
||
matched.len() * 100 / total_traces
|
||
);
|
||
if new < 5 {
|
||
break;
|
||
}
|
||
}
|
||
|
||
// Step 6: 未匹配的 trace 設 stranger_id = trace_id
|
||
// trace_id 在同一個 file 內是 sequential integer,直接複用為 stranger_id
|
||
let stranger_update = sqlx::query(&format!(
|
||
"UPDATE {} SET stranger_id = trace_id \
|
||
WHERE file_uuid = $1 AND trace_id IS NOT NULL AND identity_id IS NULL \
|
||
AND (stranger_id IS NULL OR stranger_id != trace_id)",
|
||
fd_table
|
||
))
|
||
.bind(file_uuid)
|
||
.execute(pool)
|
||
.await?;
|
||
let stranger_count = stranger_update.rows_affected();
|
||
|
||
// Step 7: Save identity files for all affected identities
|
||
let affected = sqlx::query_scalar::<_, uuid::Uuid>(&format!(
|
||
"SELECT DISTINCT i.uuid FROM {} i \
|
||
JOIN {} fd ON fd.identity_id = i.id \
|
||
WHERE fd.file_uuid=$1 AND fd.identity_id IS NOT NULL",
|
||
identities_table, fd_table
|
||
))
|
||
.bind(file_uuid)
|
||
.fetch_all(pool)
|
||
.await
|
||
.unwrap_or_default();
|
||
for uuid in &affected {
|
||
let us = uuid.to_string().replace('-', "");
|
||
if let Err(e) = crate::core::identity::storage::save_identity_file_by_pool(pool, &us).await
|
||
{
|
||
tracing::warn!("[FaceMatch] Failed to save identity file {}: {}", us, e);
|
||
}
|
||
}
|
||
tracing::info!(
|
||
"[FaceMatch] Done: {}/{} traces matched ({}%), {} strangers, {} identity files",
|
||
matched.len(),
|
||
total_traces,
|
||
matched.len() * 100 / total_traces,
|
||
stranger_count,
|
||
affected.len()
|
||
);
|
||
Ok(updated)
|
||
}
|
||
|
||
/// Bind ASRX speakers to face traces based on temporal overlap.
|
||
/// Reads face_detections (trace_id, identity_id, frame_number) and ASRX
|
||
/// segments (speaker_id, start_time, end_time), computes overlap,
|
||
/// and stores bindings in identity_bindings table.
|
||
pub async fn bind_speakers(pool: &sqlx::PgPool, file_uuid: &str) -> anyhow::Result<usize> {
|
||
// Load face traces with identity_id and frame numbers
|
||
let fd_table = schema::table_name("face_detections");
|
||
let traces = sqlx::query_as::<_, (i32, Vec<i32>)>(&format!(
|
||
"SELECT trace_id, array_agg(frame_number ORDER BY frame_number) \
|
||
FROM {} WHERE file_uuid=$1 AND trace_id IS NOT NULL AND identity_id IS NOT NULL \
|
||
GROUP BY trace_id",
|
||
fd_table
|
||
))
|
||
.bind(file_uuid)
|
||
.fetch_all(pool)
|
||
.await?;
|
||
|
||
if traces.is_empty() {
|
||
tracing::info!("[SpeakerBind] No face traces with identities");
|
||
return Ok(0);
|
||
}
|
||
|
||
// Load ASRX speakers from the output JSON
|
||
let output_dir = std::env::var("MOMENTRY_OUTPUT_DIR")
|
||
.unwrap_or_else(|_| "/Users/accusys/momentry/output".to_string());
|
||
let asrx_path = std::path::Path::new(&output_dir).join(format!("{}.asrx.json", file_uuid));
|
||
|
||
let asrx_data: serde_json::Value = match std::fs::read_to_string(&asrx_path) {
|
||
Ok(s) => serde_json::from_str(&s).unwrap_or_default(),
|
||
Err(_) => {
|
||
tracing::info!("[SpeakerBind] No ASRX file found");
|
||
return Ok(0);
|
||
}
|
||
};
|
||
|
||
// Extract speaker segments: speaker_id → [(start_time, end_time)]
|
||
use std::collections::HashMap;
|
||
let mut speakers: HashMap<String, Vec<(f64, f64)>> = HashMap::new();
|
||
if let Some(segments) = asrx_data.get("segments").and_then(|s| s.as_array()) {
|
||
for seg in segments {
|
||
let sid = seg
|
||
.get("speaker_id")
|
||
.and_then(|s| s.as_str())
|
||
.or_else(|| seg.get("speaker").and_then(|s| s.as_str()));
|
||
if let Some(sid) = sid {
|
||
let start = seg
|
||
.get("start_time")
|
||
.or_else(|| seg.get("start"))
|
||
.and_then(|v| v.as_f64())
|
||
.unwrap_or(0.0);
|
||
let end = seg
|
||
.get("end_time")
|
||
.or_else(|| seg.get("end"))
|
||
.and_then(|v| v.as_f64())
|
||
.unwrap_or(0.0);
|
||
speakers
|
||
.entry(sid.to_string())
|
||
.or_default()
|
||
.push((start, end));
|
||
}
|
||
}
|
||
}
|
||
|
||
if speakers.is_empty() {
|
||
tracing::info!("[SpeakerBind] No speakers found in ASRX data");
|
||
return Ok(0);
|
||
}
|
||
|
||
// Get fps for frame-to-time conversion
|
||
let fps: f64 = 25.0; // default, could also read from DB
|
||
|
||
// For each trace, compute overlap with each speaker
|
||
let mut bindings = 0usize;
|
||
for (trace_id, frames) in &traces {
|
||
if frames.is_empty() {
|
||
continue;
|
||
}
|
||
|
||
// Get identity_id for this trace
|
||
let fd_table = schema::table_name("face_detections");
|
||
let identity_id: Option<i32> = sqlx::query_scalar(
|
||
&format!("SELECT identity_id FROM {} WHERE file_uuid=$1 AND trace_id=$2 AND identity_id IS NOT NULL LIMIT 1", fd_table)
|
||
)
|
||
.bind(file_uuid).bind(trace_id)
|
||
.fetch_optional(pool).await?.flatten();
|
||
|
||
if identity_id.is_none() {
|
||
continue;
|
||
}
|
||
let identity_id = identity_id.unwrap();
|
||
|
||
// Compute overlap with each speaker
|
||
let mut best_speaker = String::new();
|
||
let mut best_overlap = 0usize;
|
||
|
||
for (speaker_id, segments) in &speakers {
|
||
let mut overlap = 0usize;
|
||
for &fn_num in frames {
|
||
let frame_time = fn_num as f64 / fps;
|
||
for (start, end) in segments {
|
||
if frame_time >= *start && frame_time <= *end {
|
||
overlap += 1;
|
||
break;
|
||
}
|
||
}
|
||
}
|
||
if overlap > best_overlap {
|
||
best_overlap = overlap;
|
||
best_speaker = speaker_id.clone();
|
||
}
|
||
}
|
||
|
||
// Only bind if meaningful overlap
|
||
let overlap_ratio = best_overlap as f64 / frames.len() as f64;
|
||
if overlap_ratio > 0.3 && !best_speaker.is_empty() {
|
||
let metadata = serde_json::json!({
|
||
"trace_id": trace_id,
|
||
"overlap_frames": best_overlap,
|
||
"total_frames": frames.len(),
|
||
"overlap_ratio": overlap_ratio,
|
||
});
|
||
|
||
let ib_table = schema::table_name("identity_bindings");
|
||
let _ = sqlx::query(
|
||
&format!("INSERT INTO {} (identity_id, identity_type, identity_value, confidence, metadata) \
|
||
VALUES ($1, 'speaker', $2, $3, $4::jsonb) \
|
||
ON CONFLICT (identity_id, identity_type, identity_value) DO UPDATE SET confidence = EXCLUDED.confidence, metadata = EXCLUDED.metadata", ib_table)
|
||
)
|
||
.bind(identity_id)
|
||
.bind(&best_speaker)
|
||
.bind(overlap_ratio)
|
||
.bind(&metadata)
|
||
.execute(pool).await;
|
||
|
||
bindings += 1;
|
||
}
|
||
}
|
||
|
||
tracing::info!(
|
||
"[SpeakerBind] Created {}/{} speaker bindings",
|
||
bindings,
|
||
traces.len()
|
||
);
|
||
Ok(bindings)
|
||
}
|
||
|
||
/// Pipeline-triggered entry point: runs the full identity agent for a file.
|
||
/// Reads face_clustered.json + asrx.json, extracts persons/speakers, creates identities,
|
||
/// runs iterative face matching, and binds speakers.
|
||
pub async fn run_identity_agent(db: &PostgresDb, file_uuid: &str) -> anyhow::Result<()> {
|
||
let output_dir = std::env::var("MOMENTRY_OUTPUT_DIR")
|
||
.unwrap_or_else(|_| "/Users/accusys/momentry/output".to_string());
|
||
|
||
let pool = db.pool();
|
||
|
||
// Step 1: 先跑 face matching(不需 face_clustered.json)
|
||
let matched = match_faces_iterative(pool, file_uuid).await.unwrap_or(0);
|
||
|
||
// Step 2: 試著載入 face_clustered.json 建立新 identities
|
||
let video_dir = PathBuf::from(&output_dir).join(file_uuid);
|
||
let face_clustered_path = video_dir.join(format!("{}.face_clustered.json", file_uuid));
|
||
let face_clustered_path = if face_clustered_path.exists() {
|
||
face_clustered_path
|
||
} else {
|
||
PathBuf::from(&output_dir).join(format!("{}.face_clustered.json", file_uuid))
|
||
};
|
||
|
||
if face_clustered_path.exists() {
|
||
let face_data: serde_json::Value =
|
||
std::fs::read_to_string(&face_clustered_path)?.parse()?;
|
||
let asrx_path = video_dir.join(format!("{}.asrx.json", file_uuid));
|
||
let asrx_data: Option<serde_json::Value> = if asrx_path.exists() {
|
||
Some(std::fs::read_to_string(&asrx_path)?.parse()?)
|
||
} else {
|
||
None
|
||
};
|
||
|
||
let persons = extract_persons_from_face_data(&face_data);
|
||
let speakers = extract_speakers_from_asrx_data(&asrx_data);
|
||
let identities = analyze_person_speaker_overlap(&persons, &speakers);
|
||
|
||
for (idx, id_result) in identities.iter().enumerate() {
|
||
let identity_name = format!("stranger_{}", idx);
|
||
let metadata = serde_json::json!({
|
||
"source": "identity_agent",
|
||
"trace_ids": id_result.person_ids,
|
||
"speaker_ids": id_result.speaker_ids,
|
||
"confidence": id_result.confidence,
|
||
"evidence": {
|
||
"speaker_overlap": id_result.evidence.speaker_overlap,
|
||
"frame_ratio": id_result.evidence.frame_ratio,
|
||
},
|
||
"reasoning": id_result.reasoning,
|
||
});
|
||
let _ = sqlx::query(
|
||
&format!("INSERT INTO {} (name, identity_type, source, metadata, status) VALUES ($1, 'people', 'auto', $2::jsonb, 'pending') ON CONFLICT DO NOTHING", schema::table_name("identities"))
|
||
)
|
||
.bind(&identity_name)
|
||
.bind(&metadata)
|
||
.execute(pool)
|
||
.await;
|
||
}
|
||
let _created = identities.len();
|
||
tracing::info!(
|
||
"[IdentityAgent] Created {} auto identities from face_clustered for {}",
|
||
_created,
|
||
file_uuid
|
||
);
|
||
} else {
|
||
tracing::warn!(
|
||
"[IdentityAgent] face_clustered.json not found for {}, skipping identity creation",
|
||
file_uuid
|
||
);
|
||
}
|
||
|
||
let bound = bind_speakers(pool, file_uuid).await.unwrap_or(0);
|
||
|
||
tracing::info!(
|
||
"[IdentityAgent] Done for {}: {} face matches, {} speaker bindings",
|
||
file_uuid,
|
||
matched,
|
||
bound
|
||
);
|
||
Ok(())
|
||
}
|