Files
momentry_core/src/api/identity_agent_api.rs
2026-05-08 00:48:15 +08:00

921 lines
31 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
use axum::{
extract::State,
http::StatusCode,
response::Json,
routing::{get, post},
Router,
};
use serde::{Deserialize, Serialize};
use sqlx::Row;
use std::path::PathBuf;
use crate::api::server::AppState;
use crate::core::db::PostgresDb;
pub fn identity_agent_routes() -> Router<AppState> {
Router::new()
.route("/api/v1/agents/identity/analyze", post(analyze_identity))
.route("/api/v1/agents/identity/suggest", post(suggest_merges))
.route("/api/v1/agents/identity/status", get(get_identity_status))
.route(
"/api/v1/agents/suggest/clustering",
post(suggest_clustering),
)
.route("/api/v1/agents/suggest/merge", post(suggest_merge))
}
#[derive(Debug, Deserialize)]
pub struct AnalyzeIdentityRequest {
pub file_uuid: String,
pub auto_merge_threshold: Option<f64>,
pub llm_threshold: Option<f64>,
pub use_llm: Option<bool>,
pub model: Option<String>,
}
#[derive(Debug, Serialize)]
pub struct AnalyzeIdentityResponse {
pub success: bool,
pub file_uuid: String,
pub identities: Vec<IdentityResult>,
pub processing_status: IdentityProcessingStatus,
}
#[derive(Debug, Serialize)]
pub struct IdentityResult {
pub identity_id: String,
pub person_ids: Vec<String>,
pub speaker_ids: Vec<String>,
pub confidence: f64,
pub evidence: IdentityEvidence,
pub reasoning: String,
}
#[derive(Debug, Serialize)]
pub struct IdentityEvidence {
pub face_similarity: Option<f64>,
pub speaker_overlap: f64,
pub time_overlap: f64,
pub frame_ratio: f64,
}
#[derive(Debug, Serialize)]
pub struct IdentityProcessingStatus {
pub status: String,
pub persons_analyzed: i32,
pub identities_created: i32,
pub merges_suggested: i32,
}
#[derive(Debug, Deserialize)]
pub struct SuggestMergesRequest {
pub file_uuid: String,
}
#[derive(Debug, Serialize)]
pub struct SuggestMergesResponse {
pub success: bool,
pub file_uuid: String,
pub merge_suggestions: Vec<MergeSuggestion>,
pub naming_suggestions: Vec<NamingSuggestion>,
}
#[derive(Debug, Serialize)]
pub struct MergeSuggestion {
pub target_person_id: String,
pub source_person_ids: Vec<String>,
pub confidence: f64,
pub reasons: Vec<String>,
pub action: String,
}
#[derive(Debug, Serialize)]
pub struct NamingSuggestion {
pub person_id: String,
pub suggested_name: String,
pub confidence: f64,
pub reasoning: String,
}
#[derive(Debug, Serialize)]
pub struct IdentityStatusResponse {
pub success: bool,
pub agent_name: String,
pub version: String,
pub supported_models: Vec<String>,
pub default_thresholds: DefaultThresholds,
}
#[derive(Debug, Serialize)]
pub struct DefaultThresholds {
pub auto_merge_threshold: f64,
pub llm_threshold: f64,
pub face_similarity_threshold: f64,
}
async fn analyze_identity(
State(state): State<AppState>,
Json(req): Json<AnalyzeIdentityRequest>,
) -> Result<Json<AnalyzeIdentityResponse>, (StatusCode, String)> {
let output_dir = std::env::var("MOMENTRY_OUTPUT_DIR")
.unwrap_or_else(|_| "/Users/accusys/momentry/output".to_string());
let video_dir = PathBuf::from(&output_dir).join(&req.file_uuid);
let face_clustered_path = video_dir.join(format!("{}.face_clustered.json", req.file_uuid));
let asrx_path = video_dir.join(format!("{}.asrx.json", req.file_uuid));
// 如果子目錄找不到,試根目錄
let face_clustered_path = if face_clustered_path.exists() {
face_clustered_path
} else {
PathBuf::from(&output_dir).join(format!("{}.face_clustered.json", req.file_uuid))
};
if !face_clustered_path.exists() {
return Err((
StatusCode::NOT_FOUND,
format!("Face clustered data not found for video: {}", req.file_uuid),
));
}
let face_data: serde_json::Value = std::fs::read_to_string(&face_clustered_path)
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to read face data: {}", e)))?
.parse()
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to parse face data: {}", e)))?;
let asrx_data: Option<serde_json::Value> = if asrx_path.exists() {
Some(std::fs::read_to_string(&asrx_path)
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to read asrx data: {}", e)))?
.parse()
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to parse asrx data: {}", e)))?)
} else {
None
};
let persons = extract_persons_from_face_data(&face_data);
let speakers = extract_speakers_from_asrx_data(&asrx_data);
let identities = analyze_person_speaker_overlap(&persons, &speakers);
// 將 identity 結果寫入 DB
let pool = state.db.pool();
for id_result in &identities {
let identity_name = format!("person_{}", id_result.person_ids.first().map(|s| &**s).unwrap_or("unknown"));
let metadata = serde_json::json!({
"source": "identity_agent",
"trace_ids": id_result.person_ids,
"speaker_ids": id_result.speaker_ids,
"confidence": id_result.confidence,
"evidence": {
"speaker_overlap": id_result.evidence.speaker_overlap,
"frame_ratio": id_result.evidence.frame_ratio,
},
"reasoning": id_result.reasoning,
});
let _ = sqlx::query(
"INSERT INTO dev.identities (name, identity_type, source, metadata, status) VALUES ($1, 'people', 'auto', $2::jsonb, 'pending') ON CONFLICT DO NOTHING"
)
.bind(&identity_name)
.bind(&metadata)
.execute(pool)
.await;
}
// 迭代多角度 face embedding 比對TMDb seed → 傳播)
let _ = match_faces_iterative(pool, &req.file_uuid).await.unwrap_or(0);
// 將 ASRX speaker 綁定到已匹配 identity 的 trace
let _ = bind_speakers(pool, &req.file_uuid).await.unwrap_or(0);
let processing_status = IdentityProcessingStatus {
status: "completed".to_string(),
persons_analyzed: persons.len() as i32,
identities_created: identities.len() as i32,
merges_suggested: 0,
};
Ok(Json(AnalyzeIdentityResponse {
success: true,
file_uuid: req.file_uuid.clone(),
identities,
processing_status,
}))
}
async fn suggest_merges(
State(state): State<AppState>,
Json(req): Json<SuggestMergesRequest>,
) -> Result<Json<SuggestMergesResponse>, (StatusCode, String)> {
let analyze_req = AnalyzeIdentityRequest {
file_uuid: req.file_uuid.clone(),
auto_merge_threshold: Some(0.8),
llm_threshold: Some(0.5),
use_llm: Some(true),
model: Some("gemma4".to_string()),
};
let analyze_result = analyze_identity(State(state), Json(analyze_req)).await?;
let merge_suggestions: Vec<MergeSuggestion> = analyze_result
.identities
.iter()
.filter(|id| id.person_ids.len() > 1)
.map(|id| {
let reasons = vec![
format!(
"Shared speaker overlap: {:.0}%",
id.evidence.speaker_overlap * 100.0
),
format!(
"Face similarity: {:.2}",
id.evidence.face_similarity.unwrap_or(0.0)
),
format!("Confidence: {:.2}", id.confidence),
];
MergeSuggestion {
target_person_id: id.person_ids[0].clone(),
source_person_ids: id.person_ids[1..].to_vec(),
confidence: id.confidence,
reasons,
action: if id.confidence > 0.8 {
"auto_apply"
} else {
"review_needed"
}
.to_string(),
}
})
.collect();
Ok(Json(SuggestMergesResponse {
success: true,
file_uuid: req.file_uuid,
merge_suggestions,
naming_suggestions: vec![],
}))
}
async fn get_identity_status() -> Result<Json<IdentityStatusResponse>, (StatusCode, String)> {
Ok(Json(IdentityStatusResponse {
success: true,
agent_name: "Identity Agent".to_string(),
version: "1.0.0".to_string(),
supported_models: vec!["gemma4".to_string(), "qwen3".to_string()],
default_thresholds: DefaultThresholds {
auto_merge_threshold: 0.8,
llm_threshold: 0.5,
face_similarity_threshold: 0.3,
},
}))
}
fn extract_persons_from_face_data(face_data: &serde_json::Value) -> Vec<PersonData> {
let mut persons = Vec::new();
if let Some(frames) = face_data.get("frames").and_then(|f| f.as_array()) {
let mut person_frames_map: std::collections::HashMap<String, Vec<i32>> =
std::collections::HashMap::new();
for frame in frames {
if let Some(frame_num) = frame.get("frame").and_then(|f| f.as_i64()) {
if let Some(person_id) = frame.get("person_id").and_then(|p| p.as_str()) {
person_frames_map
.entry(person_id.to_string())
.or_insert_with(Vec::new)
.push(frame_num as i32);
}
}
}
for (person_id, frames) in person_frames_map {
persons.push(PersonData {
person_id,
frames,
avg_embedding: None,
});
}
}
persons
}
fn extract_speakers_from_asrx_data(asrx_data: &Option<serde_json::Value>) -> Vec<SpeakerData> {
let mut speakers = Vec::new();
if let Some(data) = asrx_data {
if let Some(segments) = data.get("segments").and_then(|s| s.as_array()) {
let mut speaker_segments_map: std::collections::HashMap<String, Vec<(f64, f64)>> =
std::collections::HashMap::new();
for segment in segments {
let speaker_id = segment.get("speaker_id").and_then(|s| s.as_str())
.or_else(|| segment.get("speaker").and_then(|s| s.as_str()));
if let Some(speaker_id) = speaker_id {
let start = segment.get("start").or_else(|| segment.get("start_time")).and_then(|s| s.as_f64()).unwrap_or(0.0);
let end = segment.get("end").or_else(|| segment.get("end_time")).and_then(|e| e.as_f64()).unwrap_or(0.0);
speaker_segments_map
.entry(speaker_id.to_string())
.or_insert_with(Vec::new)
.push((start, end));
}
}
for (speaker_id, segments) in speaker_segments_map {
speakers.push(SpeakerData { speaker_id, segments });
}
}
}
speakers
}
fn analyze_person_speaker_overlap(
persons: &[PersonData],
speakers: &[SpeakerData],
) -> Vec<IdentityResult> {
let mut identities = Vec::new();
for (i, person) in persons.iter().enumerate() {
let identity_id = format!("identity_{}", i + 1);
let mut speaker_ids = Vec::new();
let mut max_overlap: f64 = 0.0;
for speaker in speakers {
let overlap_frames = calculate_overlap(person, speaker);
let overlap_ratio = overlap_frames as f64 / person.frames.len() as f64;
if overlap_ratio > 0.5 {
speaker_ids.push(speaker.speaker_id.clone());
max_overlap = max_overlap.max(overlap_ratio);
}
}
let confidence = if speaker_ids.len() > 0 {
0.7 + max_overlap * 0.2
} else {
0.5
};
let reasoning = if speaker_ids.len() > 0 {
format!(
"Person has high overlap with speakers: {}",
speaker_ids.join(", ")
)
} else {
"Person has no speaker overlap".to_string()
};
identities.push(IdentityResult {
identity_id,
person_ids: vec![person.person_id.clone()],
speaker_ids,
confidence,
evidence: IdentityEvidence {
face_similarity: None,
speaker_overlap: max_overlap,
time_overlap: max_overlap,
frame_ratio: person.frames.len() as f64 / 1000.0,
},
reasoning,
});
}
identities
}
fn calculate_overlap(person: &PersonData, speaker: &SpeakerData) -> i32 {
let mut overlap_count = 0;
for frame_num in &person.frames {
let frame_time = *frame_num as f64 / 25.0; // default fps=25
for (start, end) in &speaker.segments {
if frame_time >= *start && frame_time <= *end {
overlap_count += 1;
break;
}
}
}
overlap_count
}
#[derive(Debug, Deserialize)]
pub struct SuggestClusteringRequest {
pub file_uuid: Option<String>,
pub min_cluster_size: Option<usize>,
pub similarity_threshold: Option<f64>,
}
#[derive(Debug, Serialize)]
pub struct SuggestClusteringResponse {
pub success: bool,
pub suggestions: Vec<ClusteringSuggestion>,
pub total_unclustered: usize,
}
#[derive(Debug, Serialize)]
pub struct ClusteringSuggestion {
pub cluster_id: String,
pub face_count: usize,
pub avg_confidence: f64,
pub suggested_name: Option<String>,
pub representative_face: Option<String>,
}
async fn suggest_clustering(
State(state): State<AppState>,
Json(req): Json<SuggestClusteringRequest>,
) -> Result<Json<SuggestClusteringResponse>, (StatusCode, String)> {
let file_filter = match &req.file_uuid {
Some(uuid) => format!("AND fd.file_uuid = '{}'", uuid),
None => String::new(),
};
let query = format!(
r#"
SELECT trace_id, file_uuid, COUNT(*) as face_count
FROM dev.face_detections fd
WHERE fd.trace_id IS NOT NULL
AND NOT EXISTS (
SELECT 1 FROM dev.identities i
WHERE i.metadata->>'trace_id' = fd.trace_id::text
)
{}
GROUP BY trace_id, file_uuid
HAVING COUNT(*) >= $1
ORDER BY face_count DESC
"#,
file_filter
);
let pool = state.db.pool();
let rows = sqlx::query(&query)
.bind(req.min_cluster_size.unwrap_or(3) as i64)
.fetch_all(pool)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
let suggestions: Vec<ClusteringSuggestion> = rows
.into_iter()
.map(|row| {
let trace_id: Option<i32> = row.try_get("trace_id").ok();
let face_count: i64 = row.get("face_count");
ClusteringSuggestion {
cluster_id: format!("trace_{}", trace_id.unwrap_or(0)),
face_count: face_count as usize,
avg_confidence: 0.0,
suggested_name: None,
representative_face: None,
}
})
.collect();
let total_unclustered: i64 = sqlx::query_scalar(
r#"
SELECT COUNT(*) FROM face_detections fd
WHERE fd.identity_id IS NULL
"#,
)
.fetch_one(pool)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
Ok(Json(SuggestClusteringResponse {
success: true,
suggestions,
total_unclustered: total_unclustered as usize,
}))
}
#[derive(Debug, Deserialize)]
pub struct SuggestMergeRequest {
pub identity_id: Option<String>,
pub similarity_threshold: Option<f64>,
}
#[derive(Debug, Serialize)]
pub struct SuggestMergeResponse {
pub success: bool,
pub suggestions: Vec<IdentityMergeSuggestion>,
}
#[derive(Debug, Serialize)]
pub struct IdentityMergeSuggestion {
pub source_identity_id: String,
pub target_identity_id: String,
pub source_name: String,
pub target_name: String,
pub similarity_score: f64,
pub shared_files: usize,
pub reason: String,
}
async fn suggest_merge(
State(state): State<AppState>,
Json(req): Json<SuggestMergeRequest>,
) -> Result<Json<SuggestMergeResponse>, (StatusCode, String)> {
let similarity_threshold = req.similarity_threshold.unwrap_or(0.8);
let identity_filter = match &req.identity_id {
Some(id) => format!("AND i1.uuid = '{}' OR i2.uuid = '{}'", id, id),
None => String::new(),
};
let query = format!(
r#"
SELECT
i1.uuid as source_uuid,
i2.uuid as target_uuid,
i1.name as source_name,
i2.name as target_name,
COUNT(DISTINCT fd1.file_uuid) as shared_files
FROM identities i1
JOIN identities i2 ON i1.id < i2.id
LEFT JOIN face_detections fd1 ON fd1.identity_id = i1.id
LEFT JOIN face_detections fd2 ON fd2.identity_id = i2.id AND fd1.file_uuid = fd2.file_uuid
WHERE i1.identity_type = 'people'
AND i2.identity_type = 'people'
AND i1.id != i2.id
{}
GROUP BY i1.uuid, i2.uuid, i1.name, i2.name
HAVING COUNT(DISTINCT fd1.file_uuid) > 0
ORDER BY shared_files DESC
LIMIT 50
"#,
identity_filter
);
let pool = state.db.pool();
let rows = sqlx::query(&query)
.fetch_all(pool)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
let suggestions: Vec<IdentityMergeSuggestion> = rows
.into_iter()
.filter_map(|row| {
let shared_files: i64 = row.get("shared_files");
if shared_files > 0 {
let similarity = (shared_files as f64 / 10.0).min(1.0);
if similarity >= similarity_threshold {
Some(IdentityMergeSuggestion {
source_identity_id: row.get("source_uuid"),
target_identity_id: row.get("target_uuid"),
source_name: row.get("source_name"),
target_name: row.get("target_name"),
similarity_score: similarity,
shared_files: shared_files as usize,
reason: format!(
"Share {} file(s) - similarity: {:.1}%",
shared_files,
similarity * 100.0
),
})
} else {
None
}
} else {
None
}
})
.collect();
Ok(Json(SuggestMergeResponse {
success: true,
suggestions,
}))
}
#[derive(Debug)]
struct PersonData {
person_id: String,
frames: Vec<i32>,
avg_embedding: Option<Vec<f64>>,
}
#[derive(Debug)]
struct SpeakerData {
speaker_id: String,
segments: Vec<(f64, f64)>,
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() { return 0.0; }
let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
let na: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let nb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if na == 0.0 || nb == 0.0 { 0.0 } else { dot / (na * nb) }
}
/// 迭代多角度 face embedding 比對 + 傳播
/// Round 1: 用 TMDb seed face_embedding 比對 face_detections (threshold 0.50)
/// Round 2+: 用已匹配 trace 的所有 face 作為 seed傳播到未匹配 trace
async fn match_faces_iterative(pool: &sqlx::PgPool, file_uuid: &str) -> anyhow::Result<usize> {
// Step 1: 載入 TMDb identities (source='tmdb' 且有 face_embedding)
let tmdb_rows = sqlx::query_as::<_, (i32, String, Vec<f32>)>(
"SELECT id, name, face_embedding::real[] FROM dev.identities WHERE source='tmdb' AND face_embedding IS NOT NULL"
)
.fetch_all(pool).await?;
if tmdb_rows.is_empty() {
tracing::warn!("[FaceMatch] No TMDb identities with face embeddings");
return Ok(0);
}
tracing::info!("[FaceMatch] Loaded {} TMDb seed identities", tmdb_rows.len());
// Step 2: 載入所有 face_detections按 trace_id 分組
let fd_rows = sqlx::query_as::<_, (i32, Vec<f32>)>(
"SELECT trace_id, embedding FROM dev.face_detections \
WHERE file_uuid=$1 AND trace_id IS NOT NULL AND embedding IS NOT NULL \
ORDER BY trace_id"
)
.bind(file_uuid)
.fetch_all(pool).await?;
if fd_rows.is_empty() {
tracing::warn!("[FaceMatch] No face detections with embeddings");
return Ok(0);
}
// 分組trace_id → Vec<embedding>
use std::collections::HashMap;
let mut trace_faces: HashMap<i32, Vec<Vec<f32>>> = HashMap::new();
for (tid, emb) in &fd_rows {
trace_faces.entry(*tid).or_insert_with(Vec::new).push(emb.clone());
}
// 去重:同一個 trace 內embedding 太接近的只留一個
for faces in trace_faces.values_mut() {
faces.sort_by(|a, b| b[0].partial_cmp(&a[0]).unwrap_or(std::cmp::Ordering::Equal));
faces.dedup_by(|a, b| cosine_similarity(a, b) > 0.99);
}
let total_traces = trace_faces.len();
tracing::info!("[FaceMatch] Loaded {} traces with {} faces", total_traces, fd_rows.len());
// Step 3: 建立 TMDb 查找表
let tmdb_seeds: Vec<(i32, String, Vec<f32>)> = tmdb_rows;
// Step 4: 迭代匹配
const TH: f32 = 0.50;
let mut matched: HashMap<i32, String> = HashMap::new(); // trace_id → identity_name
// Round 1: 直接比對 TMDb
for (&tid, faces) in &trace_faces {
let mut best_name = String::new();
let mut best_sim = 0.0f32;
for (_, ref name, ref tmdb_emb) in &tmdb_seeds {
for face_emb in faces {
let s = cosine_similarity(face_emb, tmdb_emb);
if s > best_sim { best_sim = s; best_name = name.clone(); }
}
}
if best_sim >= TH {
matched.insert(tid, best_name);
}
}
tracing::info!("[FaceMatch] Round 1: {} matched ({}%)", matched.len(), matched.len() * 100 / total_traces);
// Round 2+: 用已匹配的 face 作為 seed 傳播
for round_n in 2..=10 {
let prev = matched.len();
// 建立 seed pool: name → Vec<embedding>
let mut seed_pool: HashMap<String, Vec<&Vec<f32>>> = HashMap::new();
for (&tid, name) in &matched {
if let Some(faces) = trace_faces.get(&tid) {
seed_pool.entry(name.clone()).or_default().extend(faces.iter());
}
}
let mut new_matches: Vec<(i32, String)> = Vec::new();
for (&tid, faces) in &trace_faces {
if matched.contains_key(&tid) { continue; }
let mut best_name = String::new();
let mut best_sim = 0.0f32;
if faces.is_empty() { continue; }
let ref_face = &faces[0];
for (name, seed_faces) in &seed_pool {
for seed in seed_faces {
let s = cosine_similarity(ref_face, seed);
if s > best_sim { best_sim = s; best_name = name.clone(); }
}
}
if best_sim >= TH {
new_matches.push((tid, best_name));
}
}
for (tid, name) in new_matches {
matched.insert(tid, name);
}
let new = matched.len() - prev;
tracing::info!("[FaceMatch] Round {}: +{} matched (total {}, {}%)", round_n, new, matched.len(), matched.len() * 100 / total_traces);
if new < 5 { break; }
}
// Step 5: 寫入 DB
let mut updated = 0usize;
for (tid, name) in &matched {
let id_opt = sqlx::query_scalar::<_, Option<i32>>(
"SELECT id FROM dev.identities WHERE name=$1 AND source='tmdb'"
)
.bind(name)
.fetch_optional(pool).await?;
if let Some(identity_id) = id_opt {
let _ = sqlx::query(
"UPDATE dev.face_detections SET identity_id=$1 WHERE file_uuid=$2 AND trace_id=$3"
)
.bind(identity_id)
.bind(file_uuid)
.bind(tid)
.execute(pool).await;
updated += 1;
}
}
tracing::info!("[FaceMatch] Done: {}/{} traces matched ({}%)", matched.len(), total_traces, matched.len() * 100 / total_traces);
Ok(updated)
}
/// Bind ASRX speakers to face traces based on temporal overlap.
/// Reads face_detections (trace_id, identity_id, frame_number) and ASRX
/// segments (speaker_id, start_time, end_time), computes overlap,
/// and stores bindings in identity_bindings table.
pub async fn bind_speakers(pool: &sqlx::PgPool, file_uuid: &str) -> anyhow::Result<usize> {
// Load face traces with identity_id and frame numbers
let traces = sqlx::query_as::<_, (i32, Vec<i32>)>(
"SELECT trace_id, array_agg(frame_number ORDER BY frame_number) \
FROM dev.face_detections WHERE file_uuid=$1 AND trace_id IS NOT NULL AND identity_id IS NOT NULL \
GROUP BY trace_id"
)
.bind(file_uuid)
.fetch_all(pool).await?;
if traces.is_empty() {
tracing::info!("[SpeakerBind] No face traces with identities");
return Ok(0);
}
// Load ASRX speakers from the output JSON
let output_dir = std::env::var("MOMENTRY_OUTPUT_DIR")
.unwrap_or_else(|_| "/Users/accusys/momentry/output".to_string());
let asrx_path = std::path::Path::new(&output_dir).join(format!("{}.asrx.json", file_uuid));
let asrx_data: serde_json::Value = match std::fs::read_to_string(&asrx_path) {
Ok(s) => serde_json::from_str(&s).unwrap_or_default(),
Err(_) => {
tracing::info!("[SpeakerBind] No ASRX file found");
return Ok(0);
}
};
// Extract speaker segments: speaker_id → [(start_time, end_time)]
use std::collections::HashMap;
let mut speakers: HashMap<String, Vec<(f64, f64)>> = HashMap::new();
if let Some(segments) = asrx_data.get("segments").and_then(|s| s.as_array()) {
for seg in segments {
let sid = seg.get("speaker_id").and_then(|s| s.as_str())
.or_else(|| seg.get("speaker").and_then(|s| s.as_str()));
if let Some(sid) = sid {
let start = seg.get("start_time").or_else(|| seg.get("start")).and_then(|v| v.as_f64()).unwrap_or(0.0);
let end = seg.get("end_time").or_else(|| seg.get("end")).and_then(|v| v.as_f64()).unwrap_or(0.0);
speakers.entry(sid.to_string()).or_default().push((start, end));
}
}
}
if speakers.is_empty() {
tracing::info!("[SpeakerBind] No speakers found in ASRX data");
return Ok(0);
}
// Get fps for frame-to-time conversion
let fps: f64 = 25.0; // default, could also read from DB
// For each trace, compute overlap with each speaker
let mut bindings = 0usize;
for (trace_id, frames) in &traces {
if frames.is_empty() { continue; }
// Get identity_id for this trace
let identity_id: Option<i32> = sqlx::query_scalar(
"SELECT identity_id FROM dev.face_detections WHERE file_uuid=$1 AND trace_id=$2 AND identity_id IS NOT NULL LIMIT 1"
)
.bind(file_uuid).bind(trace_id)
.fetch_optional(pool).await?.flatten();
if identity_id.is_none() { continue; }
let identity_id = identity_id.unwrap();
// Compute overlap with each speaker
let mut best_speaker = String::new();
let mut best_overlap = 0usize;
for (speaker_id, segments) in &speakers {
let mut overlap = 0usize;
for &fn_num in frames {
let frame_time = fn_num as f64 / fps;
for (start, end) in segments {
if frame_time >= *start && frame_time <= *end {
overlap += 1;
break;
}
}
}
if overlap > best_overlap {
best_overlap = overlap;
best_speaker = speaker_id.clone();
}
}
// Only bind if meaningful overlap
let overlap_ratio = best_overlap as f64 / frames.len() as f64;
if overlap_ratio > 0.3 && !best_speaker.is_empty() {
let metadata = serde_json::json!({
"trace_id": trace_id,
"overlap_frames": best_overlap,
"total_frames": frames.len(),
"overlap_ratio": overlap_ratio,
});
let _ = sqlx::query(
"INSERT INTO dev.identity_bindings (identity_id, identity_type, identity_value, confidence, metadata) \
VALUES ($1, 'speaker', $2, $3, $4::jsonb) \
ON CONFLICT (identity_id, identity_type, identity_value) DO UPDATE SET confidence = EXCLUDED.confidence, metadata = EXCLUDED.metadata"
)
.bind(identity_id)
.bind(&best_speaker)
.bind(overlap_ratio)
.bind(&metadata)
.execute(pool).await;
bindings += 1;
}
}
tracing::info!("[SpeakerBind] Created {}/{} speaker bindings", bindings, traces.len());
Ok(bindings)
}
/// Pipeline-triggered entry point: runs the full identity agent for a file.
/// Reads face_clustered.json + asrx.json, extracts persons/speakers, creates identities,
/// runs iterative face matching, and binds speakers.
pub async fn run_identity_agent(db: &PostgresDb, file_uuid: &str) -> anyhow::Result<()> {
let output_dir = std::env::var("MOMENTRY_OUTPUT_DIR")
.unwrap_or_else(|_| "/Users/accusys/momentry/output".to_string());
let video_dir = PathBuf::from(&output_dir).join(file_uuid);
let face_clustered_path = video_dir.join(format!("{}.face_clustered.json", file_uuid));
let face_clustered_path = if face_clustered_path.exists() {
face_clustered_path
} else {
PathBuf::from(&output_dir).join(format!("{}.face_clustered.json", file_uuid))
};
if !face_clustered_path.exists() {
tracing::warn!("[IdentityAgent] face_clustered.json not found for {}", file_uuid);
return Ok(());
}
let face_data: serde_json::Value = std::fs::read_to_string(&face_clustered_path)?.parse()?;
let asrx_path = video_dir.join(format!("{}.asrx.json", file_uuid));
let asrx_data: Option<serde_json::Value> = if asrx_path.exists() {
Some(std::fs::read_to_string(&asrx_path)?.parse()?)
} else {
None
};
let persons = extract_persons_from_face_data(&face_data);
let speakers = extract_speakers_from_asrx_data(&asrx_data);
let identities = analyze_person_speaker_overlap(&persons, &speakers);
let pool = db.pool();
for id_result in &identities {
let identity_name = format!("person_{}", id_result.person_ids.first().map(|s| &**s).unwrap_or("unknown"));
let metadata = serde_json::json!({
"source": "identity_agent",
"trace_ids": id_result.person_ids,
"speaker_ids": id_result.speaker_ids,
"confidence": id_result.confidence,
"evidence": {
"speaker_overlap": id_result.evidence.speaker_overlap,
"frame_ratio": id_result.evidence.frame_ratio,
},
"reasoning": id_result.reasoning,
});
let _ = sqlx::query(
"INSERT INTO dev.identities (name, identity_type, source, metadata, status) VALUES ($1, 'people', 'auto', $2::jsonb, 'pending') ON CONFLICT DO NOTHING"
)
.bind(&identity_name)
.bind(&metadata)
.execute(pool)
.await;
}
let matched = match_faces_iterative(pool, file_uuid).await.unwrap_or(0);
let bound = bind_speakers(pool, file_uuid).await.unwrap_or(0);
tracing::info!(
"[IdentityAgent] Done for {}: {} identities, {} face matches, {} speaker bindings",
file_uuid, identities.len(), matched, bound
);
Ok(())
}