Files
momentry_core/src/api/identity_agent_api.rs
M5Max128 3a33d00449 refactor: modularize server.rs into separate route modules
- Extract scan.rs, files.rs, types.rs, processing.rs, visual_chunk_search.rs
- Move AppState and AppConfig to types.rs
- Each module exposes pub fn xxx_routes() -> Router<AppState>
- server.rs reduced from 5005 to 118 lines (orchestrator only)
- All stubs filled with real implementations from git history
- Verify: cargo check, clippy, tests all pass
2026-05-21 16:38:49 +08:00

1075 lines
36 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
use axum::{
extract::{Multipart, State},
http::StatusCode,
response::Json,
routing::{get, post},
Router,
};
use serde::{Deserialize, Serialize};
use sqlx::Row;
use std::path::PathBuf;
use crate::api::types::AppState;
use crate::core::db::schema;
use crate::core::db::PostgresDb;
pub fn identity_agent_routes() -> Router<AppState> {
Router::new()
.route(
"/api/v1/agents/identity/match-from-photo",
post(match_from_photo),
)
.route(
"/api/v1/agents/identity/match-from-trace",
post(match_from_trace),
)
}
#[derive(Debug, Serialize)]
pub struct IdentityResult {
pub identity_id: String,
pub person_ids: Vec<String>,
pub speaker_ids: Vec<String>,
pub confidence: f64,
pub evidence: IdentityEvidence,
pub reasoning: String,
}
#[derive(Debug, Serialize)]
pub struct IdentityEvidence {
pub face_similarity: Option<f64>,
pub speaker_overlap: f64,
pub time_overlap: f64,
pub frame_ratio: f64,
}
#[derive(Debug, Serialize)]
struct MatchFromPhotoResponse {
success: bool,
identity_uuid: String,
file_uuid: String,
matches: usize,
traces_matched: Vec<i32>,
message: String,
}
async fn match_from_photo(
State(state): State<AppState>,
mut multipart: Multipart,
) -> Result<Json<MatchFromPhotoResponse>, (StatusCode, Json<serde_json::Value>)> {
let mut identity_uuid = String::new();
let mut file_uuid = String::new();
let mut image_data: Option<Vec<u8>> = None;
while let Ok(Some(field)) = multipart.next_field().await {
let name = field.name().unwrap_or("").to_string();
match name.as_str() {
"identity_uuid" => {
identity_uuid = field.text().await.unwrap_or_default();
}
"file_uuid" => {
file_uuid = field.text().await.unwrap_or_default();
}
"image" => {
image_data = Some(field.bytes().await.unwrap_or_default().to_vec());
}
_ => {}
}
}
let uuid_clean = identity_uuid.replace('-', "");
if uuid_clean.is_empty() || file_uuid.is_empty() {
return Err((
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"success": false, "message": "identity_uuid and file_uuid are required"
})),
));
}
let data = image_data.ok_or_else(|| {
(
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"success": false, "message": "No image field found. Use field name 'image'."
})),
)
})?;
// 1. Save uploaded image to temp
let scripts_dir = std::env::var("MOMENTRY_SCRIPTS_DIR")
.unwrap_or_else(|_| "/Users/accusys/momentry_core_0.1/scripts".to_string());
let python_path = std::env::var("MOMENTRY_PYTHON_PATH")
.unwrap_or_else(|_| "/opt/homebrew/bin/python3.11".to_string());
let temp_dir = std::env::temp_dir().join("momentry_match_face");
std::fs::create_dir_all(&temp_dir).map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({"message": format!("Failed to create temp dir: {}", e)})),
)
})?;
let temp_img = temp_dir.join(format!("{}.jpg", uuid_clean));
std::fs::write(&temp_img, &data).map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({"message": format!("Failed to save temp image: {}", e)})),
)
})?;
// 2. Extract face embedding via Python script
let extract_script = std::path::Path::new(&scripts_dir).join("extract_face_embedding.py");
let output = tokio::process::Command::new(&*python_path)
.arg(&extract_script)
.arg(&temp_img)
.output()
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({"message": format!("Failed to run extractor: {}", e)})),
)
})?;
let _ = std::fs::remove_file(&temp_img);
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
return Err((
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"success": false, "message": format!("Face extraction failed: {}", stderr)
})),
));
}
let stdout = String::from_utf8_lossy(&output.stdout);
let extract_result: serde_json::Value = serde_json::from_str(&stdout).map_err(|_| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({"message": "Failed to parse extractor output"})),
)
})?;
let embedding: Vec<f64> = serde_json::from_value(
extract_result
.get("embedding")
.ok_or_else(|| {
(
StatusCode::BAD_REQUEST,
Json(serde_json::json!({"message": "No embedding in extractor output"})),
)
})?
.clone(),
)
.map_err(|_| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({"message": "Invalid embedding format"})),
)
})?;
let embedding_f32: Vec<f32> = embedding.into_iter().map(|v| v as f32).collect();
// 3. Look up identity internal ID
let id_table = schema::table_name("identities");
let identity_id_row: Option<(i32,)> = sqlx::query_as(&format!(
"SELECT id FROM {} WHERE REPLACE(uuid::text, '-', '') = $1",
id_table
))
.bind(&uuid_clean)
.fetch_optional(state.db.pool())
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({"message": format!("DB error: {}", e)})),
)
})?;
let identity_id = match identity_id_row {
Some((id,)) => id,
None => {
return Err((
StatusCode::NOT_FOUND,
Json(serde_json::json!({
"success": false, "message": "Identity not found"
})),
))
}
};
// 4. Find best matching trace (highest similarity, no threshold)
let fd_table = schema::table_name("face_detections");
let best_match: Option<(i32, i32, f64)> = sqlx::query_as(&format!(
r#"SELECT id, trace_id,
1 - (embedding::vector <=> $1::vector) as similarity
FROM {}
WHERE file_uuid = $2 AND embedding IS NOT NULL
ORDER BY embedding::vector <=> $1::vector
LIMIT 1"#,
fd_table
))
.bind(&embedding_f32)
.bind(&file_uuid)
.fetch_optional(state.db.pool())
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({"message": format!("Search failed: {}", e)})),
)
})?;
// 5. Update best match face_detection
let mut traces_matched: Vec<i32> = Vec::new();
if let Some((fb_id, fb_trace, fb_sim)) = best_match {
let _ = sqlx::query(&format!(
"UPDATE {} SET identity_id = $1 WHERE id = $2",
fd_table
))
.bind(identity_id)
.bind(fb_id)
.execute(state.db.pool())
.await;
traces_matched.push(fb_trace);
// 6. Save identity file
let _ = crate::core::identity::storage::save_identity_file(&*state.db, &uuid_clean).await;
Ok(Json(MatchFromPhotoResponse {
success: true,
identity_uuid: uuid_clean,
file_uuid,
matches: 1,
traces_matched,
message: format!(
"Best trace: trace_id={}, similarity={:.4}",
fb_trace, fb_sim
),
}))
} else {
Ok(Json(MatchFromPhotoResponse {
success: true,
identity_uuid: uuid_clean,
file_uuid,
matches: 0,
traces_matched,
message: "No matching face found in video".to_string(),
}))
}
}
#[derive(Debug, Deserialize)]
struct MatchFromTraceRequest {
file_uuid: String,
trace_id: i32,
identity_uuid: String,
}
async fn match_from_trace(
State(state): State<AppState>,
Json(req): Json<MatchFromTraceRequest>,
) -> Result<Json<MatchFromPhotoResponse>, (StatusCode, Json<serde_json::Value>)> {
let uuid_clean = req.identity_uuid.replace('-', "");
// 1. Get 3 best face embeddings from this trace at different angles
// Divide trace frame range into 3 segments, pick best face from each
let fd_table = schema::table_name("face_detections");
let all_faces: Vec<(Vec<f32>, i64)> = sqlx::query_as::<_, (Vec<f32>, i64)>(&format!(
"SELECT embedding, frame_number FROM {} \
WHERE file_uuid = $1 AND trace_id = $2 AND embedding IS NOT NULL \
ORDER BY frame_number ASC",
fd_table
))
.bind(&req.file_uuid)
.bind(req.trace_id)
.fetch_all(state.db.pool())
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({"message": format!("DB error: {}", e)})),
)
})?;
if all_faces.is_empty() {
return Err((
StatusCode::NOT_FOUND,
Json(serde_json::json!({
"success": false, "message": "No embedding found for this trace"
})),
));
}
// Pick 3 samples: divide frame range into 3 segments, use face with largest area per segment
let total = all_faces.len();
let segments = [
(0, total / 3),
(total / 3, total * 2 / 3),
(total * 2 / 3, total),
];
let mut query_embeddings: Vec<Vec<f32>> = Vec::new();
// Get width*height info if available (not all pipelines store it)
let face_sizes: Vec<(i64, i32)> = sqlx::query_as::<_, (i64, i32)>(&format!(
"SELECT frame_number, COALESCE(width, 0) * COALESCE(height, 0) AS area \
FROM {} WHERE file_uuid = $1 AND trace_id = $2 AND embedding IS NOT NULL \
ORDER BY frame_number ASC",
fd_table
))
.bind(&req.file_uuid)
.bind(req.trace_id)
.fetch_all(state.db.pool())
.await
.unwrap_or_default();
let face_sizes_map: std::collections::HashMap<i64, i32> = face_sizes.into_iter().collect();
for (start, end) in segments {
let seg_start = start.min(total - 1);
let seg_end = end.min(total);
if seg_start >= seg_end {
continue;
}
let seg_slice = &all_faces[seg_start..seg_end];
// Pick the face with largest area within this segment
let best_idx = seg_slice
.iter()
.enumerate()
.max_by_key(|(_, f)| face_sizes_map.get(&f.1).copied().unwrap_or(0))
.map(|(i, _)| i)
.unwrap_or(0);
query_embeddings.push(seg_slice[best_idx].0.clone());
}
if query_embeddings.is_empty() {
query_embeddings.push(all_faces[total / 2].0.clone());
}
// 2. Three angles each find their best match; union all results
let mut validated: Vec<(i32, i32, f64)> = Vec::new();
let mut seen_trace_ids = std::collections::HashSet::new();
for qemb in &query_embeddings {
let top = sqlx::query_as::<_, (i32, i32, f64)>(&format!(
r#"SELECT id, trace_id,
1 - (embedding::vector <=> $1::vector) as similarity
FROM {}
WHERE file_uuid = $2
AND trace_id != $3
AND embedding IS NOT NULL
ORDER BY embedding::vector <=> $1::vector
LIMIT 1"#,
fd_table
))
.bind(qemb)
.bind(&req.file_uuid)
.bind(req.trace_id)
.fetch_optional(state.db.pool())
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({"message": format!("Search failed: {}", e)})),
)
})?;
if let Some((cface_id, c_trace_id, c_sim)) = top {
if seen_trace_ids.insert(c_trace_id) {
validated.push((cface_id, c_trace_id, c_sim));
}
}
}
// 3. Look up identity internal ID
let id_table = schema::table_name("identities");
let identity_id_row: Option<(i32,)> = sqlx::query_as(&format!(
"SELECT id FROM {} WHERE REPLACE(uuid::text, '-', '') = $1",
id_table
))
.bind(&uuid_clean)
.fetch_optional(state.db.pool())
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({"message": format!("DB error: {}", e)})),
)
})?;
let identity_id = match identity_id_row {
Some((id,)) => id,
None => {
return Err((
StatusCode::NOT_FOUND,
Json(serde_json::json!({
"success": false, "message": "Identity not found"
})),
))
}
};
// 4. Update matched face_detections
let mut traces_matched: Vec<i32> = Vec::new();
for (id, trace_id, _similarity) in &validated {
if let Err(e) = sqlx::query(&format!(
"UPDATE {} SET identity_id = $1 WHERE id = $2",
fd_table
))
.bind(identity_id)
.bind(id)
.execute(state.db.pool())
.await
{
tracing::warn!(
"[match-from-trace] Failed to update face_detection {}: {}",
id,
e
);
} else {
if !traces_matched.contains(trace_id) {
traces_matched.push(*trace_id);
}
}
}
// 5. Also bind the source trace itself
let _ = sqlx::query(&format!(
"UPDATE {} SET identity_id = $1 WHERE file_uuid = $2 AND trace_id = $3",
fd_table
))
.bind(identity_id)
.bind(&req.file_uuid)
.bind(req.trace_id)
.execute(state.db.pool())
.await;
if !traces_matched.contains(&req.trace_id) {
traces_matched.push(req.trace_id);
}
// 6. Save identity file
let _ = crate::core::identity::storage::save_identity_file(&*state.db, &uuid_clean).await;
let match_count = validated.len() + 1;
let trace_count = traces_matched.len();
Ok(Json(MatchFromPhotoResponse {
success: true,
identity_uuid: uuid_clean,
file_uuid: req.file_uuid,
matches: match_count,
traces_matched,
message: format!(
"Matched {} faces ({} unique traces)",
match_count, trace_count
),
}))
}
fn extract_persons_from_face_data(face_data: &serde_json::Value) -> Vec<PersonData> {
let mut persons = Vec::new();
if let Some(frames) = face_data.get("frames").and_then(|f| f.as_array()) {
let mut person_frames_map: std::collections::HashMap<String, Vec<i32>> =
std::collections::HashMap::new();
for frame in frames {
if let Some(frame_num) = frame.get("frame").and_then(|f| f.as_i64()) {
if let Some(person_id) = frame.get("person_id").and_then(|p| p.as_str()) {
person_frames_map
.entry(person_id.to_string())
.or_insert_with(Vec::new)
.push(frame_num as i32);
}
}
}
for (person_id, frames) in person_frames_map {
persons.push(PersonData {
person_id,
frames,
avg_embedding: None,
});
}
}
persons
}
fn extract_speakers_from_asrx_data(asrx_data: &Option<serde_json::Value>) -> Vec<SpeakerData> {
let mut speakers = Vec::new();
if let Some(data) = asrx_data {
if let Some(segments) = data.get("segments").and_then(|s| s.as_array()) {
for seg in segments {
if let (Some(start), Some(end), Some(speaker_id)) = (
seg.get("start_time").and_then(|v| v.as_f64()),
seg.get("end_time").and_then(|v| v.as_f64()),
seg.get("speaker_id").and_then(|v| v.as_str()),
) {
speakers.push(SpeakerData {
speaker_id: speaker_id.to_string(),
segments: vec![(start, end)],
});
}
}
}
}
speakers
}
fn analyze_person_speaker_overlap(
persons: &[PersonData],
speakers: &[SpeakerData],
) -> Vec<IdentityResult> {
let mut identities: Vec<IdentityResult> = Vec::new();
let mut visited_persons: std::collections::HashSet<String> = std::collections::HashSet::new();
for person in persons {
if visited_persons.contains(&person.person_id) {
continue;
}
let mut matched_persons = vec![person.person_id.clone()];
let mut matched_speakers: Vec<String> = Vec::new();
visited_persons.insert(person.person_id.clone());
for other_person in persons {
if visited_persons.contains(&other_person.person_id) {
continue;
}
// Check if persons co-occur in time (frame proximity)
let overlap = person
.frames
.iter()
.any(|f| other_person.frames.contains(f));
if overlap {
matched_persons.push(other_person.person_id.clone());
visited_persons.insert(other_person.person_id.clone());
}
}
// Check speaker overlap
let person_time_range = (
person.frames.iter().min().copied().unwrap_or(0) as f64,
person.frames.iter().max().copied().unwrap_or(0) as f64,
);
for speaker in speakers {
let has_overlap = speaker
.segments
.iter()
.any(|(start, end)| *start <= person_time_range.1 && *end >= person_time_range.0);
if has_overlap {
if !matched_speakers.contains(&speaker.speaker_id) {
matched_speakers.push(speaker.speaker_id.clone());
}
}
}
let frame_count = person.frames.len() as f64;
let speaker_overlap = if matched_speakers.is_empty() {
0.0
} else {
matched_speakers.len() as f64 / speakers.len().max(1) as f64
};
identities.push(IdentityResult {
identity_id: person.person_id.clone(),
person_ids: matched_persons.clone(),
speaker_ids: matched_speakers.clone(),
confidence: 0.5 + (speaker_overlap * 0.3),
evidence: IdentityEvidence {
face_similarity: None,
speaker_overlap,
time_overlap: 1.0,
frame_ratio: frame_count / 100.0,
},
reasoning: format!(
"Matched {} persons with {} speakers, overlap={:.2}",
matched_persons.len(),
speaker_overlap,
speaker_overlap
),
});
}
identities
}
#[derive(Debug)]
struct PersonData {
person_id: String,
frames: Vec<i32>,
avg_embedding: Option<Vec<f64>>,
}
#[derive(Debug)]
struct SpeakerData {
speaker_id: String,
segments: Vec<(f64, f64)>,
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
let na: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let nb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if na == 0.0 || nb == 0.0 {
0.0
} else {
dot / (na * nb)
}
}
/// 迭代多角度 face embedding 比對 + 傳播
/// Round 1: 用 TMDb seed face_embedding 比對 face_detections (threshold 0.50)
/// Round 2+: 用已匹配 trace 的所有 face 作為 seed傳播到未匹配 trace
async fn match_faces_iterative(pool: &sqlx::PgPool, file_uuid: &str) -> anyhow::Result<usize> {
// Step 1: 載入 TMDb identities (source='tmdb' 且有 face_embedding)
let identities_table = schema::table_name("identities");
let tmdb_rows = sqlx::query_as::<_, (i32, String, Vec<f32>)>(
&format!("SELECT id, name, face_embedding::real[] FROM {} WHERE source='tmdb' AND face_embedding IS NOT NULL", identities_table)
)
.fetch_all(pool).await?;
if tmdb_rows.is_empty() {
tracing::warn!("[FaceMatch] No TMDb identities with face embeddings");
return Ok(0);
}
tracing::info!(
"[FaceMatch] Loaded {} TMDb seed identities",
tmdb_rows.len()
);
// Step 2: 載入所有 face_detections含 frame_number按 trace_id 分組
let fd_table = schema::table_name("face_detections");
let fd_rows = sqlx::query_as::<_, (i32, i32, Vec<f32>)>(&format!(
"SELECT trace_id, frame_number, embedding FROM {} \
WHERE file_uuid=$1 AND trace_id IS NOT NULL AND embedding IS NOT NULL \
ORDER BY trace_id, frame_number",
fd_table
))
.bind(file_uuid)
.fetch_all(pool)
.await?;
if fd_rows.is_empty() {
tracing::warn!("[FaceMatch] No face detections with embeddings");
return Ok(0);
}
// 分組trace_id → (frame_number, embedding)
use std::collections::HashMap;
let mut trace_faces_raw: HashMap<i32, Vec<(i32, Vec<f32>)>> = HashMap::new();
for (tid, frame, emb) in &fd_rows {
trace_faces_raw
.entry(*tid)
.or_insert_with(Vec::new)
.push((*frame, emb.clone()));
}
// 從每個 trace 選取不同角度的 3 個 face embedding
// 策略:按 frame_number 排序,取前中後各 1 個
let mut trace_samples: HashMap<i32, Vec<Vec<f32>>> = HashMap::new();
for (tid, mut faces) in trace_faces_raw {
faces.sort_by_key(|(frame, _)| *frame);
let n = faces.len();
let indices = if n <= 3 {
(0..n).collect()
} else {
let mid = n / 2;
vec![0, mid, n - 1]
};
let samples: Vec<Vec<f32>> = indices.iter().map(|&i| faces[i].1.clone()).collect();
trace_samples.insert(tid, samples);
}
let total_traces = trace_samples.len();
let sample_count: usize = trace_samples.values().map(|v| v.len()).sum();
tracing::info!(
"[FaceMatch] Loaded {} traces, sampled {} embeddings (3-angle)",
total_traces,
sample_count
);
// Step 3: 建立 TMDb 查找表
let tmdb_seeds: Vec<(i32, String, Vec<f32>)> = tmdb_rows;
// Step 4: 迭代匹配
const TH: f32 = 0.50;
let mut matched: HashMap<i32, String> = HashMap::new(); // trace_id → identity_name
// Round 1: 用 3-angle samples 比對 TMDb
// 每個 trace 選 3 個不同角度 face取最高 similarity
for (&tid, samples) in &trace_samples {
let mut best_name = String::new();
let mut best_sim = 0.0f32;
for (_, ref name, ref tmdb_emb) in &tmdb_seeds {
for face_emb in samples {
let s = cosine_similarity(face_emb, tmdb_emb);
if s > best_sim {
best_sim = s;
best_name = name.clone();
}
}
}
if best_sim >= TH {
matched.insert(tid, best_name);
}
}
tracing::info!(
"[FaceMatch] Round 1: {} matched ({}%) — writing to DB",
matched.len(),
matched.len() * 100 / total_traces
);
// Step 5: 寫入 DB — Round 1 結果先存
let identities_table = schema::table_name("identities");
let fd_table = schema::table_name("face_detections");
let mut updated = 0usize;
for (tid, name) in &matched {
let id_opt = sqlx::query_scalar::<_, Option<i32>>(&format!(
"SELECT id FROM {} WHERE name=$1 AND source='tmdb'",
identities_table
))
.bind(name)
.fetch_optional(pool)
.await?;
if let Some(identity_id) = id_opt {
let _ = sqlx::query(&format!(
"UPDATE {} SET identity_id=$1 WHERE file_uuid=$2 AND trace_id=$3",
fd_table
))
.bind(identity_id)
.bind(file_uuid)
.bind(tid)
.execute(pool)
.await;
updated += 1;
}
}
tracing::info!("[FaceMatch] Round 1: updated {} face_detections", updated);
// Round 2+: 用已匹配的 face 作為 seed 傳播(剩餘未匹配的 trace
let initial_matched = matched.len();
for round_n in 2..=5 {
let prev = matched.len();
// 建立 seed pool: name → Vec<embedding>
let mut seed_pool: HashMap<String, Vec<&Vec<f32>>> = HashMap::new();
for (&tid, name) in &matched {
if let Some(samples) = trace_samples.get(&tid) {
seed_pool
.entry(name.clone())
.or_default()
.extend(samples.iter());
}
}
let mut new_matches: Vec<(i32, String)> = Vec::new();
for (&tid, samples) in &trace_samples {
if matched.contains_key(&tid) {
continue;
}
let mut best_name = String::new();
let mut best_sim = 0.0f32;
if samples.is_empty() {
continue;
}
// 用 3-angle samples 分別比對 seed取最高 similarity
for (name, seed_faces) in &seed_pool {
for face_emb in samples {
for seed in seed_faces {
let s = cosine_similarity(face_emb, seed);
if s > best_sim {
best_sim = s;
best_name = name.clone();
}
}
}
}
if best_sim >= TH {
new_matches.push((tid, best_name));
}
}
for (tid, name) in new_matches {
matched.insert(tid, name);
}
let new = matched.len() - prev;
tracing::info!(
"[FaceMatch] Round {}: +{} matched (total {}, {}%)",
round_n,
new,
matched.len(),
matched.len() * 100 / total_traces
);
if new < 5 {
break;
}
}
// Step 6: 未匹配的 trace 設 stranger_id = trace_id
// trace_id 在同一個 file 內是 sequential integer直接複用為 stranger_id
let stranger_update = sqlx::query(&format!(
"UPDATE {} SET stranger_id = trace_id \
WHERE file_uuid = $1 AND trace_id IS NOT NULL AND identity_id IS NULL \
AND (stranger_id IS NULL OR stranger_id != trace_id)",
fd_table
))
.bind(file_uuid)
.execute(pool)
.await?;
let stranger_count = stranger_update.rows_affected();
// Step 7: Save identity files for all affected identities
let affected = sqlx::query_scalar::<_, uuid::Uuid>(&format!(
"SELECT DISTINCT i.uuid FROM {} i \
JOIN {} fd ON fd.identity_id = i.id \
WHERE fd.file_uuid=$1 AND fd.identity_id IS NOT NULL",
identities_table, fd_table
))
.bind(file_uuid)
.fetch_all(pool)
.await
.unwrap_or_default();
for uuid in &affected {
let us = uuid.to_string().replace('-', "");
if let Err(e) = crate::core::identity::storage::save_identity_file_by_pool(pool, &us).await
{
tracing::warn!("[FaceMatch] Failed to save identity file {}: {}", us, e);
}
}
tracing::info!(
"[FaceMatch] Done: {}/{} traces matched ({}%), {} strangers, {} identity files",
matched.len(),
total_traces,
matched.len() * 100 / total_traces,
stranger_count,
affected.len()
);
Ok(updated)
}
/// Bind ASRX speakers to face traces based on temporal overlap.
/// Reads face_detections (trace_id, identity_id, frame_number) and ASRX
/// segments (speaker_id, start_time, end_time), computes overlap,
/// and stores bindings in identity_bindings table.
pub async fn bind_speakers(pool: &sqlx::PgPool, file_uuid: &str) -> anyhow::Result<usize> {
// Load face traces with identity_id and frame numbers
let fd_table = schema::table_name("face_detections");
let traces = sqlx::query_as::<_, (i32, Vec<i32>)>(&format!(
"SELECT trace_id, array_agg(frame_number ORDER BY frame_number) \
FROM {} WHERE file_uuid=$1 AND trace_id IS NOT NULL AND identity_id IS NOT NULL \
GROUP BY trace_id",
fd_table
))
.bind(file_uuid)
.fetch_all(pool)
.await?;
if traces.is_empty() {
tracing::info!("[SpeakerBind] No face traces with identities");
return Ok(0);
}
// Load ASRX speakers from the output JSON
let output_dir = std::env::var("MOMENTRY_OUTPUT_DIR")
.unwrap_or_else(|_| "/Users/accusys/momentry/output".to_string());
let asrx_path = std::path::Path::new(&output_dir).join(format!("{}.asrx.json", file_uuid));
let asrx_data: serde_json::Value = match std::fs::read_to_string(&asrx_path) {
Ok(s) => serde_json::from_str(&s).unwrap_or_default(),
Err(_) => {
tracing::info!("[SpeakerBind] No ASRX file found");
return Ok(0);
}
};
// Extract speaker segments: speaker_id → [(start_time, end_time)]
use std::collections::HashMap;
let mut speakers: HashMap<String, Vec<(f64, f64)>> = HashMap::new();
if let Some(segments) = asrx_data.get("segments").and_then(|s| s.as_array()) {
for seg in segments {
let sid = seg
.get("speaker_id")
.and_then(|s| s.as_str())
.or_else(|| seg.get("speaker").and_then(|s| s.as_str()));
if let Some(sid) = sid {
let start = seg
.get("start_time")
.or_else(|| seg.get("start"))
.and_then(|v| v.as_f64())
.unwrap_or(0.0);
let end = seg
.get("end_time")
.or_else(|| seg.get("end"))
.and_then(|v| v.as_f64())
.unwrap_or(0.0);
speakers
.entry(sid.to_string())
.or_default()
.push((start, end));
}
}
}
if speakers.is_empty() {
tracing::info!("[SpeakerBind] No speakers found in ASRX data");
return Ok(0);
}
// Get fps for frame-to-time conversion
let fps: f64 = 25.0; // default, could also read from DB
// For each trace, compute overlap with each speaker
let mut bindings = 0usize;
for (trace_id, frames) in &traces {
if frames.is_empty() {
continue;
}
// Get identity_id for this trace
let fd_table = schema::table_name("face_detections");
let identity_id: Option<i32> = sqlx::query_scalar(
&format!("SELECT identity_id FROM {} WHERE file_uuid=$1 AND trace_id=$2 AND identity_id IS NOT NULL LIMIT 1", fd_table)
)
.bind(file_uuid).bind(trace_id)
.fetch_optional(pool).await?.flatten();
if identity_id.is_none() {
continue;
}
let identity_id = identity_id.unwrap();
// Compute overlap with each speaker
let mut best_speaker = String::new();
let mut best_overlap = 0usize;
for (speaker_id, segments) in &speakers {
let mut overlap = 0usize;
for &fn_num in frames {
let frame_time = fn_num as f64 / fps;
for (start, end) in segments {
if frame_time >= *start && frame_time <= *end {
overlap += 1;
break;
}
}
}
if overlap > best_overlap {
best_overlap = overlap;
best_speaker = speaker_id.clone();
}
}
// Only bind if meaningful overlap
let overlap_ratio = best_overlap as f64 / frames.len() as f64;
if overlap_ratio > 0.3 && !best_speaker.is_empty() {
let metadata = serde_json::json!({
"trace_id": trace_id,
"overlap_frames": best_overlap,
"total_frames": frames.len(),
"overlap_ratio": overlap_ratio,
});
let ib_table = schema::table_name("identity_bindings");
let _ = sqlx::query(
&format!("INSERT INTO {} (identity_id, identity_type, identity_value, confidence, metadata) \
VALUES ($1, 'speaker', $2, $3, $4::jsonb) \
ON CONFLICT (identity_id, identity_type, identity_value) DO UPDATE SET confidence = EXCLUDED.confidence, metadata = EXCLUDED.metadata", ib_table)
)
.bind(identity_id)
.bind(&best_speaker)
.bind(overlap_ratio)
.bind(&metadata)
.execute(pool).await;
bindings += 1;
}
}
tracing::info!(
"[SpeakerBind] Created {}/{} speaker bindings",
bindings,
traces.len()
);
Ok(bindings)
}
/// Pipeline-triggered entry point: runs the full identity agent for a file.
/// Reads face_clustered.json + asrx.json, extracts persons/speakers, creates identities,
/// runs iterative face matching, and binds speakers.
pub async fn run_identity_agent(db: &PostgresDb, file_uuid: &str) -> anyhow::Result<()> {
let output_dir = std::env::var("MOMENTRY_OUTPUT_DIR")
.unwrap_or_else(|_| "/Users/accusys/momentry/output".to_string());
let pool = db.pool();
// Step 1: 先跑 face matching不需 face_clustered.json
let matched = match_faces_iterative(pool, file_uuid).await.unwrap_or(0);
// Step 2: 試著載入 face_clustered.json 建立新 identities
let video_dir = PathBuf::from(&output_dir).join(file_uuid);
let face_clustered_path = video_dir.join(format!("{}.face_clustered.json", file_uuid));
let face_clustered_path = if face_clustered_path.exists() {
face_clustered_path
} else {
PathBuf::from(&output_dir).join(format!("{}.face_clustered.json", file_uuid))
};
if face_clustered_path.exists() {
let face_data: serde_json::Value =
std::fs::read_to_string(&face_clustered_path)?.parse()?;
let asrx_path = video_dir.join(format!("{}.asrx.json", file_uuid));
let asrx_data: Option<serde_json::Value> = if asrx_path.exists() {
Some(std::fs::read_to_string(&asrx_path)?.parse()?)
} else {
None
};
let persons = extract_persons_from_face_data(&face_data);
let speakers = extract_speakers_from_asrx_data(&asrx_data);
let identities = analyze_person_speaker_overlap(&persons, &speakers);
for (idx, id_result) in identities.iter().enumerate() {
let identity_name = format!("stranger_{}", idx);
let metadata = serde_json::json!({
"source": "identity_agent",
"trace_ids": id_result.person_ids,
"speaker_ids": id_result.speaker_ids,
"confidence": id_result.confidence,
"evidence": {
"speaker_overlap": id_result.evidence.speaker_overlap,
"frame_ratio": id_result.evidence.frame_ratio,
},
"reasoning": id_result.reasoning,
});
let _ = sqlx::query(
&format!("INSERT INTO {} (name, identity_type, source, metadata, status) VALUES ($1, 'people', 'auto', $2::jsonb, 'pending') ON CONFLICT DO NOTHING", schema::table_name("identities"))
)
.bind(&identity_name)
.bind(&metadata)
.execute(pool)
.await;
}
let _created = identities.len();
tracing::info!(
"[IdentityAgent] Created {} auto identities from face_clustered for {}",
_created,
file_uuid
);
} else {
tracing::warn!(
"[IdentityAgent] face_clustered.json not found for {}, skipping identity creation",
file_uuid
);
}
let bound = bind_speakers(pool, file_uuid).await.unwrap_or(0);
tracing::info!(
"[IdentityAgent] Done for {}: {} face matches, {} speaker bindings",
file_uuid,
matched,
bound
);
Ok(())
}