921 lines
31 KiB
Rust
921 lines
31 KiB
Rust
use axum::{
|
||
extract::State,
|
||
http::StatusCode,
|
||
response::Json,
|
||
routing::{get, post},
|
||
Router,
|
||
};
|
||
use serde::{Deserialize, Serialize};
|
||
use sqlx::Row;
|
||
use std::path::PathBuf;
|
||
|
||
use crate::api::server::AppState;
|
||
use crate::core::db::PostgresDb;
|
||
|
||
pub fn identity_agent_routes() -> Router<AppState> {
|
||
Router::new()
|
||
.route("/api/v1/agents/identity/analyze", post(analyze_identity))
|
||
.route("/api/v1/agents/identity/suggest", post(suggest_merges))
|
||
.route("/api/v1/agents/identity/status", get(get_identity_status))
|
||
.route(
|
||
"/api/v1/agents/suggest/clustering",
|
||
post(suggest_clustering),
|
||
)
|
||
.route("/api/v1/agents/suggest/merge", post(suggest_merge))
|
||
}
|
||
|
||
#[derive(Debug, Deserialize)]
|
||
pub struct AnalyzeIdentityRequest {
|
||
pub file_uuid: String,
|
||
pub auto_merge_threshold: Option<f64>,
|
||
pub llm_threshold: Option<f64>,
|
||
pub use_llm: Option<bool>,
|
||
pub model: Option<String>,
|
||
}
|
||
|
||
#[derive(Debug, Serialize)]
|
||
pub struct AnalyzeIdentityResponse {
|
||
pub success: bool,
|
||
pub file_uuid: String,
|
||
pub identities: Vec<IdentityResult>,
|
||
pub processing_status: IdentityProcessingStatus,
|
||
}
|
||
|
||
#[derive(Debug, Serialize)]
|
||
pub struct IdentityResult {
|
||
pub identity_id: String,
|
||
pub person_ids: Vec<String>,
|
||
pub speaker_ids: Vec<String>,
|
||
pub confidence: f64,
|
||
pub evidence: IdentityEvidence,
|
||
pub reasoning: String,
|
||
}
|
||
|
||
#[derive(Debug, Serialize)]
|
||
pub struct IdentityEvidence {
|
||
pub face_similarity: Option<f64>,
|
||
pub speaker_overlap: f64,
|
||
pub time_overlap: f64,
|
||
pub frame_ratio: f64,
|
||
}
|
||
|
||
#[derive(Debug, Serialize)]
|
||
pub struct IdentityProcessingStatus {
|
||
pub status: String,
|
||
pub persons_analyzed: i32,
|
||
pub identities_created: i32,
|
||
pub merges_suggested: i32,
|
||
}
|
||
|
||
#[derive(Debug, Deserialize)]
|
||
pub struct SuggestMergesRequest {
|
||
pub file_uuid: String,
|
||
}
|
||
|
||
#[derive(Debug, Serialize)]
|
||
pub struct SuggestMergesResponse {
|
||
pub success: bool,
|
||
pub file_uuid: String,
|
||
pub merge_suggestions: Vec<MergeSuggestion>,
|
||
pub naming_suggestions: Vec<NamingSuggestion>,
|
||
}
|
||
|
||
#[derive(Debug, Serialize)]
|
||
pub struct MergeSuggestion {
|
||
pub target_person_id: String,
|
||
pub source_person_ids: Vec<String>,
|
||
pub confidence: f64,
|
||
pub reasons: Vec<String>,
|
||
pub action: String,
|
||
}
|
||
|
||
#[derive(Debug, Serialize)]
|
||
pub struct NamingSuggestion {
|
||
pub person_id: String,
|
||
pub suggested_name: String,
|
||
pub confidence: f64,
|
||
pub reasoning: String,
|
||
}
|
||
|
||
#[derive(Debug, Serialize)]
|
||
pub struct IdentityStatusResponse {
|
||
pub success: bool,
|
||
pub agent_name: String,
|
||
pub version: String,
|
||
pub supported_models: Vec<String>,
|
||
pub default_thresholds: DefaultThresholds,
|
||
}
|
||
|
||
#[derive(Debug, Serialize)]
|
||
pub struct DefaultThresholds {
|
||
pub auto_merge_threshold: f64,
|
||
pub llm_threshold: f64,
|
||
pub face_similarity_threshold: f64,
|
||
}
|
||
|
||
async fn analyze_identity(
|
||
State(state): State<AppState>,
|
||
Json(req): Json<AnalyzeIdentityRequest>,
|
||
) -> Result<Json<AnalyzeIdentityResponse>, (StatusCode, String)> {
|
||
let output_dir = std::env::var("MOMENTRY_OUTPUT_DIR")
|
||
.unwrap_or_else(|_| "/Users/accusys/momentry/output".to_string());
|
||
|
||
let video_dir = PathBuf::from(&output_dir).join(&req.file_uuid);
|
||
|
||
let face_clustered_path = video_dir.join(format!("{}.face_clustered.json", req.file_uuid));
|
||
let asrx_path = video_dir.join(format!("{}.asrx.json", req.file_uuid));
|
||
|
||
// 如果子目錄找不到,試根目錄
|
||
let face_clustered_path = if face_clustered_path.exists() {
|
||
face_clustered_path
|
||
} else {
|
||
PathBuf::from(&output_dir).join(format!("{}.face_clustered.json", req.file_uuid))
|
||
};
|
||
|
||
if !face_clustered_path.exists() {
|
||
return Err((
|
||
StatusCode::NOT_FOUND,
|
||
format!("Face clustered data not found for video: {}", req.file_uuid),
|
||
));
|
||
}
|
||
|
||
let face_data: serde_json::Value = std::fs::read_to_string(&face_clustered_path)
|
||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to read face data: {}", e)))?
|
||
.parse()
|
||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to parse face data: {}", e)))?;
|
||
|
||
let asrx_data: Option<serde_json::Value> = if asrx_path.exists() {
|
||
Some(std::fs::read_to_string(&asrx_path)
|
||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to read asrx data: {}", e)))?
|
||
.parse()
|
||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to parse asrx data: {}", e)))?)
|
||
} else {
|
||
None
|
||
};
|
||
|
||
let persons = extract_persons_from_face_data(&face_data);
|
||
let speakers = extract_speakers_from_asrx_data(&asrx_data);
|
||
|
||
let identities = analyze_person_speaker_overlap(&persons, &speakers);
|
||
|
||
// 將 identity 結果寫入 DB
|
||
let pool = state.db.pool();
|
||
for id_result in &identities {
|
||
let identity_name = format!("person_{}", id_result.person_ids.first().map(|s| &**s).unwrap_or("unknown"));
|
||
let metadata = serde_json::json!({
|
||
"source": "identity_agent",
|
||
"trace_ids": id_result.person_ids,
|
||
"speaker_ids": id_result.speaker_ids,
|
||
"confidence": id_result.confidence,
|
||
"evidence": {
|
||
"speaker_overlap": id_result.evidence.speaker_overlap,
|
||
"frame_ratio": id_result.evidence.frame_ratio,
|
||
},
|
||
"reasoning": id_result.reasoning,
|
||
});
|
||
|
||
let _ = sqlx::query(
|
||
"INSERT INTO dev.identities (name, identity_type, source, metadata, status) VALUES ($1, 'people', 'auto', $2::jsonb, 'pending') ON CONFLICT DO NOTHING"
|
||
)
|
||
.bind(&identity_name)
|
||
.bind(&metadata)
|
||
.execute(pool)
|
||
.await;
|
||
}
|
||
|
||
// 迭代多角度 face embedding 比對(TMDb seed → 傳播)
|
||
let _ = match_faces_iterative(pool, &req.file_uuid).await.unwrap_or(0);
|
||
|
||
// 將 ASRX speaker 綁定到已匹配 identity 的 trace
|
||
let _ = bind_speakers(pool, &req.file_uuid).await.unwrap_or(0);
|
||
|
||
let processing_status = IdentityProcessingStatus {
|
||
status: "completed".to_string(),
|
||
persons_analyzed: persons.len() as i32,
|
||
identities_created: identities.len() as i32,
|
||
merges_suggested: 0,
|
||
};
|
||
|
||
Ok(Json(AnalyzeIdentityResponse {
|
||
success: true,
|
||
file_uuid: req.file_uuid.clone(),
|
||
identities,
|
||
processing_status,
|
||
}))
|
||
}
|
||
|
||
async fn suggest_merges(
|
||
State(state): State<AppState>,
|
||
Json(req): Json<SuggestMergesRequest>,
|
||
) -> Result<Json<SuggestMergesResponse>, (StatusCode, String)> {
|
||
let analyze_req = AnalyzeIdentityRequest {
|
||
file_uuid: req.file_uuid.clone(),
|
||
auto_merge_threshold: Some(0.8),
|
||
llm_threshold: Some(0.5),
|
||
use_llm: Some(true),
|
||
model: Some("gemma4".to_string()),
|
||
};
|
||
|
||
let analyze_result = analyze_identity(State(state), Json(analyze_req)).await?;
|
||
|
||
let merge_suggestions: Vec<MergeSuggestion> = analyze_result
|
||
.identities
|
||
.iter()
|
||
.filter(|id| id.person_ids.len() > 1)
|
||
.map(|id| {
|
||
let reasons = vec![
|
||
format!(
|
||
"Shared speaker overlap: {:.0}%",
|
||
id.evidence.speaker_overlap * 100.0
|
||
),
|
||
format!(
|
||
"Face similarity: {:.2}",
|
||
id.evidence.face_similarity.unwrap_or(0.0)
|
||
),
|
||
format!("Confidence: {:.2}", id.confidence),
|
||
];
|
||
|
||
MergeSuggestion {
|
||
target_person_id: id.person_ids[0].clone(),
|
||
source_person_ids: id.person_ids[1..].to_vec(),
|
||
confidence: id.confidence,
|
||
reasons,
|
||
action: if id.confidence > 0.8 {
|
||
"auto_apply"
|
||
} else {
|
||
"review_needed"
|
||
}
|
||
.to_string(),
|
||
}
|
||
})
|
||
.collect();
|
||
|
||
Ok(Json(SuggestMergesResponse {
|
||
success: true,
|
||
file_uuid: req.file_uuid,
|
||
merge_suggestions,
|
||
naming_suggestions: vec![],
|
||
}))
|
||
}
|
||
|
||
async fn get_identity_status() -> Result<Json<IdentityStatusResponse>, (StatusCode, String)> {
|
||
Ok(Json(IdentityStatusResponse {
|
||
success: true,
|
||
agent_name: "Identity Agent".to_string(),
|
||
version: "1.0.0".to_string(),
|
||
supported_models: vec!["gemma4".to_string(), "qwen3".to_string()],
|
||
default_thresholds: DefaultThresholds {
|
||
auto_merge_threshold: 0.8,
|
||
llm_threshold: 0.5,
|
||
face_similarity_threshold: 0.3,
|
||
},
|
||
}))
|
||
}
|
||
|
||
fn extract_persons_from_face_data(face_data: &serde_json::Value) -> Vec<PersonData> {
|
||
let mut persons = Vec::new();
|
||
|
||
if let Some(frames) = face_data.get("frames").and_then(|f| f.as_array()) {
|
||
let mut person_frames_map: std::collections::HashMap<String, Vec<i32>> =
|
||
std::collections::HashMap::new();
|
||
|
||
for frame in frames {
|
||
if let Some(frame_num) = frame.get("frame").and_then(|f| f.as_i64()) {
|
||
if let Some(person_id) = frame.get("person_id").and_then(|p| p.as_str()) {
|
||
person_frames_map
|
||
.entry(person_id.to_string())
|
||
.or_insert_with(Vec::new)
|
||
.push(frame_num as i32);
|
||
}
|
||
}
|
||
}
|
||
|
||
for (person_id, frames) in person_frames_map {
|
||
persons.push(PersonData {
|
||
person_id,
|
||
frames,
|
||
avg_embedding: None,
|
||
});
|
||
}
|
||
}
|
||
|
||
persons
|
||
}
|
||
|
||
fn extract_speakers_from_asrx_data(asrx_data: &Option<serde_json::Value>) -> Vec<SpeakerData> {
|
||
let mut speakers = Vec::new();
|
||
if let Some(data) = asrx_data {
|
||
if let Some(segments) = data.get("segments").and_then(|s| s.as_array()) {
|
||
let mut speaker_segments_map: std::collections::HashMap<String, Vec<(f64, f64)>> =
|
||
std::collections::HashMap::new();
|
||
for segment in segments {
|
||
let speaker_id = segment.get("speaker_id").and_then(|s| s.as_str())
|
||
.or_else(|| segment.get("speaker").and_then(|s| s.as_str()));
|
||
if let Some(speaker_id) = speaker_id {
|
||
let start = segment.get("start").or_else(|| segment.get("start_time")).and_then(|s| s.as_f64()).unwrap_or(0.0);
|
||
let end = segment.get("end").or_else(|| segment.get("end_time")).and_then(|e| e.as_f64()).unwrap_or(0.0);
|
||
speaker_segments_map
|
||
.entry(speaker_id.to_string())
|
||
.or_insert_with(Vec::new)
|
||
.push((start, end));
|
||
}
|
||
}
|
||
for (speaker_id, segments) in speaker_segments_map {
|
||
speakers.push(SpeakerData { speaker_id, segments });
|
||
}
|
||
}
|
||
}
|
||
speakers
|
||
}
|
||
|
||
fn analyze_person_speaker_overlap(
|
||
persons: &[PersonData],
|
||
speakers: &[SpeakerData],
|
||
) -> Vec<IdentityResult> {
|
||
let mut identities = Vec::new();
|
||
|
||
for (i, person) in persons.iter().enumerate() {
|
||
let identity_id = format!("identity_{}", i + 1);
|
||
|
||
let mut speaker_ids = Vec::new();
|
||
let mut max_overlap: f64 = 0.0;
|
||
|
||
for speaker in speakers {
|
||
let overlap_frames = calculate_overlap(person, speaker);
|
||
let overlap_ratio = overlap_frames as f64 / person.frames.len() as f64;
|
||
|
||
if overlap_ratio > 0.5 {
|
||
speaker_ids.push(speaker.speaker_id.clone());
|
||
max_overlap = max_overlap.max(overlap_ratio);
|
||
}
|
||
}
|
||
|
||
let confidence = if speaker_ids.len() > 0 {
|
||
0.7 + max_overlap * 0.2
|
||
} else {
|
||
0.5
|
||
};
|
||
|
||
let reasoning = if speaker_ids.len() > 0 {
|
||
format!(
|
||
"Person has high overlap with speakers: {}",
|
||
speaker_ids.join(", ")
|
||
)
|
||
} else {
|
||
"Person has no speaker overlap".to_string()
|
||
};
|
||
|
||
identities.push(IdentityResult {
|
||
identity_id,
|
||
person_ids: vec![person.person_id.clone()],
|
||
speaker_ids,
|
||
confidence,
|
||
evidence: IdentityEvidence {
|
||
face_similarity: None,
|
||
speaker_overlap: max_overlap,
|
||
time_overlap: max_overlap,
|
||
frame_ratio: person.frames.len() as f64 / 1000.0,
|
||
},
|
||
reasoning,
|
||
});
|
||
}
|
||
|
||
identities
|
||
}
|
||
|
||
fn calculate_overlap(person: &PersonData, speaker: &SpeakerData) -> i32 {
|
||
let mut overlap_count = 0;
|
||
for frame_num in &person.frames {
|
||
let frame_time = *frame_num as f64 / 25.0; // default fps=25
|
||
for (start, end) in &speaker.segments {
|
||
if frame_time >= *start && frame_time <= *end {
|
||
overlap_count += 1;
|
||
break;
|
||
}
|
||
}
|
||
}
|
||
overlap_count
|
||
}
|
||
|
||
#[derive(Debug, Deserialize)]
|
||
pub struct SuggestClusteringRequest {
|
||
pub file_uuid: Option<String>,
|
||
pub min_cluster_size: Option<usize>,
|
||
pub similarity_threshold: Option<f64>,
|
||
}
|
||
|
||
#[derive(Debug, Serialize)]
|
||
pub struct SuggestClusteringResponse {
|
||
pub success: bool,
|
||
pub suggestions: Vec<ClusteringSuggestion>,
|
||
pub total_unclustered: usize,
|
||
}
|
||
|
||
#[derive(Debug, Serialize)]
|
||
pub struct ClusteringSuggestion {
|
||
pub cluster_id: String,
|
||
pub face_count: usize,
|
||
pub avg_confidence: f64,
|
||
pub suggested_name: Option<String>,
|
||
pub representative_face: Option<String>,
|
||
}
|
||
|
||
async fn suggest_clustering(
|
||
State(state): State<AppState>,
|
||
Json(req): Json<SuggestClusteringRequest>,
|
||
) -> Result<Json<SuggestClusteringResponse>, (StatusCode, String)> {
|
||
let file_filter = match &req.file_uuid {
|
||
Some(uuid) => format!("AND fd.file_uuid = '{}'", uuid),
|
||
None => String::new(),
|
||
};
|
||
|
||
let query = format!(
|
||
r#"
|
||
SELECT trace_id, file_uuid, COUNT(*) as face_count
|
||
FROM dev.face_detections fd
|
||
WHERE fd.trace_id IS NOT NULL
|
||
AND NOT EXISTS (
|
||
SELECT 1 FROM dev.identities i
|
||
WHERE i.metadata->>'trace_id' = fd.trace_id::text
|
||
)
|
||
{}
|
||
GROUP BY trace_id, file_uuid
|
||
HAVING COUNT(*) >= $1
|
||
ORDER BY face_count DESC
|
||
"#,
|
||
file_filter
|
||
);
|
||
|
||
let pool = state.db.pool();
|
||
let rows = sqlx::query(&query)
|
||
.bind(req.min_cluster_size.unwrap_or(3) as i64)
|
||
.fetch_all(pool)
|
||
.await
|
||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||
|
||
let suggestions: Vec<ClusteringSuggestion> = rows
|
||
.into_iter()
|
||
.map(|row| {
|
||
let trace_id: Option<i32> = row.try_get("trace_id").ok();
|
||
let face_count: i64 = row.get("face_count");
|
||
ClusteringSuggestion {
|
||
cluster_id: format!("trace_{}", trace_id.unwrap_or(0)),
|
||
face_count: face_count as usize,
|
||
avg_confidence: 0.0,
|
||
suggested_name: None,
|
||
representative_face: None,
|
||
}
|
||
})
|
||
.collect();
|
||
|
||
let total_unclustered: i64 = sqlx::query_scalar(
|
||
r#"
|
||
SELECT COUNT(*) FROM face_detections fd
|
||
WHERE fd.identity_id IS NULL
|
||
"#,
|
||
)
|
||
.fetch_one(pool)
|
||
.await
|
||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||
|
||
Ok(Json(SuggestClusteringResponse {
|
||
success: true,
|
||
suggestions,
|
||
total_unclustered: total_unclustered as usize,
|
||
}))
|
||
}
|
||
|
||
#[derive(Debug, Deserialize)]
|
||
pub struct SuggestMergeRequest {
|
||
pub identity_id: Option<String>,
|
||
pub similarity_threshold: Option<f64>,
|
||
}
|
||
|
||
#[derive(Debug, Serialize)]
|
||
pub struct SuggestMergeResponse {
|
||
pub success: bool,
|
||
pub suggestions: Vec<IdentityMergeSuggestion>,
|
||
}
|
||
|
||
#[derive(Debug, Serialize)]
|
||
pub struct IdentityMergeSuggestion {
|
||
pub source_identity_id: String,
|
||
pub target_identity_id: String,
|
||
pub source_name: String,
|
||
pub target_name: String,
|
||
pub similarity_score: f64,
|
||
pub shared_files: usize,
|
||
pub reason: String,
|
||
}
|
||
|
||
async fn suggest_merge(
|
||
State(state): State<AppState>,
|
||
Json(req): Json<SuggestMergeRequest>,
|
||
) -> Result<Json<SuggestMergeResponse>, (StatusCode, String)> {
|
||
let similarity_threshold = req.similarity_threshold.unwrap_or(0.8);
|
||
|
||
let identity_filter = match &req.identity_id {
|
||
Some(id) => format!("AND i1.uuid = '{}' OR i2.uuid = '{}'", id, id),
|
||
None => String::new(),
|
||
};
|
||
|
||
let query = format!(
|
||
r#"
|
||
SELECT
|
||
i1.uuid as source_uuid,
|
||
i2.uuid as target_uuid,
|
||
i1.name as source_name,
|
||
i2.name as target_name,
|
||
COUNT(DISTINCT fd1.file_uuid) as shared_files
|
||
FROM identities i1
|
||
JOIN identities i2 ON i1.id < i2.id
|
||
LEFT JOIN face_detections fd1 ON fd1.identity_id = i1.id
|
||
LEFT JOIN face_detections fd2 ON fd2.identity_id = i2.id AND fd1.file_uuid = fd2.file_uuid
|
||
WHERE i1.identity_type = 'people'
|
||
AND i2.identity_type = 'people'
|
||
AND i1.id != i2.id
|
||
{}
|
||
GROUP BY i1.uuid, i2.uuid, i1.name, i2.name
|
||
HAVING COUNT(DISTINCT fd1.file_uuid) > 0
|
||
ORDER BY shared_files DESC
|
||
LIMIT 50
|
||
"#,
|
||
identity_filter
|
||
);
|
||
|
||
let pool = state.db.pool();
|
||
let rows = sqlx::query(&query)
|
||
.fetch_all(pool)
|
||
.await
|
||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||
|
||
let suggestions: Vec<IdentityMergeSuggestion> = rows
|
||
.into_iter()
|
||
.filter_map(|row| {
|
||
let shared_files: i64 = row.get("shared_files");
|
||
if shared_files > 0 {
|
||
let similarity = (shared_files as f64 / 10.0).min(1.0);
|
||
if similarity >= similarity_threshold {
|
||
Some(IdentityMergeSuggestion {
|
||
source_identity_id: row.get("source_uuid"),
|
||
target_identity_id: row.get("target_uuid"),
|
||
source_name: row.get("source_name"),
|
||
target_name: row.get("target_name"),
|
||
similarity_score: similarity,
|
||
shared_files: shared_files as usize,
|
||
reason: format!(
|
||
"Share {} file(s) - similarity: {:.1}%",
|
||
shared_files,
|
||
similarity * 100.0
|
||
),
|
||
})
|
||
} else {
|
||
None
|
||
}
|
||
} else {
|
||
None
|
||
}
|
||
})
|
||
.collect();
|
||
|
||
Ok(Json(SuggestMergeResponse {
|
||
success: true,
|
||
suggestions,
|
||
}))
|
||
}
|
||
|
||
#[derive(Debug)]
|
||
struct PersonData {
|
||
person_id: String,
|
||
frames: Vec<i32>,
|
||
avg_embedding: Option<Vec<f64>>,
|
||
}
|
||
|
||
#[derive(Debug)]
|
||
struct SpeakerData {
|
||
speaker_id: String,
|
||
segments: Vec<(f64, f64)>,
|
||
}
|
||
|
||
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
|
||
if a.len() != b.len() || a.is_empty() { return 0.0; }
|
||
let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
|
||
let na: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||
let nb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||
if na == 0.0 || nb == 0.0 { 0.0 } else { dot / (na * nb) }
|
||
}
|
||
|
||
/// 迭代多角度 face embedding 比對 + 傳播
|
||
/// Round 1: 用 TMDb seed face_embedding 比對 face_detections (threshold 0.50)
|
||
/// Round 2+: 用已匹配 trace 的所有 face 作為 seed,傳播到未匹配 trace
|
||
async fn match_faces_iterative(pool: &sqlx::PgPool, file_uuid: &str) -> anyhow::Result<usize> {
|
||
// Step 1: 載入 TMDb identities (source='tmdb' 且有 face_embedding)
|
||
let tmdb_rows = sqlx::query_as::<_, (i32, String, Vec<f32>)>(
|
||
"SELECT id, name, face_embedding::real[] FROM dev.identities WHERE source='tmdb' AND face_embedding IS NOT NULL"
|
||
)
|
||
.fetch_all(pool).await?;
|
||
|
||
if tmdb_rows.is_empty() {
|
||
tracing::warn!("[FaceMatch] No TMDb identities with face embeddings");
|
||
return Ok(0);
|
||
}
|
||
tracing::info!("[FaceMatch] Loaded {} TMDb seed identities", tmdb_rows.len());
|
||
|
||
// Step 2: 載入所有 face_detections,按 trace_id 分組
|
||
let fd_rows = sqlx::query_as::<_, (i32, Vec<f32>)>(
|
||
"SELECT trace_id, embedding FROM dev.face_detections \
|
||
WHERE file_uuid=$1 AND trace_id IS NOT NULL AND embedding IS NOT NULL \
|
||
ORDER BY trace_id"
|
||
)
|
||
.bind(file_uuid)
|
||
.fetch_all(pool).await?;
|
||
|
||
if fd_rows.is_empty() {
|
||
tracing::warn!("[FaceMatch] No face detections with embeddings");
|
||
return Ok(0);
|
||
}
|
||
|
||
// 分組:trace_id → Vec<embedding>
|
||
use std::collections::HashMap;
|
||
let mut trace_faces: HashMap<i32, Vec<Vec<f32>>> = HashMap::new();
|
||
for (tid, emb) in &fd_rows {
|
||
trace_faces.entry(*tid).or_insert_with(Vec::new).push(emb.clone());
|
||
}
|
||
|
||
// 去重:同一個 trace 內,embedding 太接近的只留一個
|
||
for faces in trace_faces.values_mut() {
|
||
faces.sort_by(|a, b| b[0].partial_cmp(&a[0]).unwrap_or(std::cmp::Ordering::Equal));
|
||
faces.dedup_by(|a, b| cosine_similarity(a, b) > 0.99);
|
||
}
|
||
|
||
let total_traces = trace_faces.len();
|
||
tracing::info!("[FaceMatch] Loaded {} traces with {} faces", total_traces, fd_rows.len());
|
||
|
||
// Step 3: 建立 TMDb 查找表
|
||
let tmdb_seeds: Vec<(i32, String, Vec<f32>)> = tmdb_rows;
|
||
|
||
// Step 4: 迭代匹配
|
||
const TH: f32 = 0.50;
|
||
let mut matched: HashMap<i32, String> = HashMap::new(); // trace_id → identity_name
|
||
|
||
// Round 1: 直接比對 TMDb
|
||
for (&tid, faces) in &trace_faces {
|
||
let mut best_name = String::new();
|
||
let mut best_sim = 0.0f32;
|
||
for (_, ref name, ref tmdb_emb) in &tmdb_seeds {
|
||
for face_emb in faces {
|
||
let s = cosine_similarity(face_emb, tmdb_emb);
|
||
if s > best_sim { best_sim = s; best_name = name.clone(); }
|
||
}
|
||
}
|
||
if best_sim >= TH {
|
||
matched.insert(tid, best_name);
|
||
}
|
||
}
|
||
tracing::info!("[FaceMatch] Round 1: {} matched ({}%)", matched.len(), matched.len() * 100 / total_traces);
|
||
|
||
// Round 2+: 用已匹配的 face 作為 seed 傳播
|
||
for round_n in 2..=10 {
|
||
let prev = matched.len();
|
||
// 建立 seed pool: name → Vec<embedding>
|
||
let mut seed_pool: HashMap<String, Vec<&Vec<f32>>> = HashMap::new();
|
||
for (&tid, name) in &matched {
|
||
if let Some(faces) = trace_faces.get(&tid) {
|
||
seed_pool.entry(name.clone()).or_default().extend(faces.iter());
|
||
}
|
||
}
|
||
|
||
let mut new_matches: Vec<(i32, String)> = Vec::new();
|
||
for (&tid, faces) in &trace_faces {
|
||
if matched.contains_key(&tid) { continue; }
|
||
let mut best_name = String::new();
|
||
let mut best_sim = 0.0f32;
|
||
if faces.is_empty() { continue; }
|
||
let ref_face = &faces[0];
|
||
for (name, seed_faces) in &seed_pool {
|
||
for seed in seed_faces {
|
||
let s = cosine_similarity(ref_face, seed);
|
||
if s > best_sim { best_sim = s; best_name = name.clone(); }
|
||
}
|
||
}
|
||
if best_sim >= TH {
|
||
new_matches.push((tid, best_name));
|
||
}
|
||
}
|
||
for (tid, name) in new_matches {
|
||
matched.insert(tid, name);
|
||
}
|
||
let new = matched.len() - prev;
|
||
tracing::info!("[FaceMatch] Round {}: +{} matched (total {}, {}%)", round_n, new, matched.len(), matched.len() * 100 / total_traces);
|
||
if new < 5 { break; }
|
||
}
|
||
|
||
// Step 5: 寫入 DB
|
||
let mut updated = 0usize;
|
||
for (tid, name) in &matched {
|
||
let id_opt = sqlx::query_scalar::<_, Option<i32>>(
|
||
"SELECT id FROM dev.identities WHERE name=$1 AND source='tmdb'"
|
||
)
|
||
.bind(name)
|
||
.fetch_optional(pool).await?;
|
||
if let Some(identity_id) = id_opt {
|
||
let _ = sqlx::query(
|
||
"UPDATE dev.face_detections SET identity_id=$1 WHERE file_uuid=$2 AND trace_id=$3"
|
||
)
|
||
.bind(identity_id)
|
||
.bind(file_uuid)
|
||
.bind(tid)
|
||
.execute(pool).await;
|
||
updated += 1;
|
||
}
|
||
}
|
||
|
||
tracing::info!("[FaceMatch] Done: {}/{} traces matched ({}%)", matched.len(), total_traces, matched.len() * 100 / total_traces);
|
||
Ok(updated)
|
||
}
|
||
|
||
/// Bind ASRX speakers to face traces based on temporal overlap.
|
||
/// Reads face_detections (trace_id, identity_id, frame_number) and ASRX
|
||
/// segments (speaker_id, start_time, end_time), computes overlap,
|
||
/// and stores bindings in identity_bindings table.
|
||
pub async fn bind_speakers(pool: &sqlx::PgPool, file_uuid: &str) -> anyhow::Result<usize> {
|
||
// Load face traces with identity_id and frame numbers
|
||
let traces = sqlx::query_as::<_, (i32, Vec<i32>)>(
|
||
"SELECT trace_id, array_agg(frame_number ORDER BY frame_number) \
|
||
FROM dev.face_detections WHERE file_uuid=$1 AND trace_id IS NOT NULL AND identity_id IS NOT NULL \
|
||
GROUP BY trace_id"
|
||
)
|
||
.bind(file_uuid)
|
||
.fetch_all(pool).await?;
|
||
|
||
if traces.is_empty() {
|
||
tracing::info!("[SpeakerBind] No face traces with identities");
|
||
return Ok(0);
|
||
}
|
||
|
||
// Load ASRX speakers from the output JSON
|
||
let output_dir = std::env::var("MOMENTRY_OUTPUT_DIR")
|
||
.unwrap_or_else(|_| "/Users/accusys/momentry/output".to_string());
|
||
let asrx_path = std::path::Path::new(&output_dir).join(format!("{}.asrx.json", file_uuid));
|
||
|
||
let asrx_data: serde_json::Value = match std::fs::read_to_string(&asrx_path) {
|
||
Ok(s) => serde_json::from_str(&s).unwrap_or_default(),
|
||
Err(_) => {
|
||
tracing::info!("[SpeakerBind] No ASRX file found");
|
||
return Ok(0);
|
||
}
|
||
};
|
||
|
||
// Extract speaker segments: speaker_id → [(start_time, end_time)]
|
||
use std::collections::HashMap;
|
||
let mut speakers: HashMap<String, Vec<(f64, f64)>> = HashMap::new();
|
||
if let Some(segments) = asrx_data.get("segments").and_then(|s| s.as_array()) {
|
||
for seg in segments {
|
||
let sid = seg.get("speaker_id").and_then(|s| s.as_str())
|
||
.or_else(|| seg.get("speaker").and_then(|s| s.as_str()));
|
||
if let Some(sid) = sid {
|
||
let start = seg.get("start_time").or_else(|| seg.get("start")).and_then(|v| v.as_f64()).unwrap_or(0.0);
|
||
let end = seg.get("end_time").or_else(|| seg.get("end")).and_then(|v| v.as_f64()).unwrap_or(0.0);
|
||
speakers.entry(sid.to_string()).or_default().push((start, end));
|
||
}
|
||
}
|
||
}
|
||
|
||
if speakers.is_empty() {
|
||
tracing::info!("[SpeakerBind] No speakers found in ASRX data");
|
||
return Ok(0);
|
||
}
|
||
|
||
// Get fps for frame-to-time conversion
|
||
let fps: f64 = 25.0; // default, could also read from DB
|
||
|
||
// For each trace, compute overlap with each speaker
|
||
let mut bindings = 0usize;
|
||
for (trace_id, frames) in &traces {
|
||
if frames.is_empty() { continue; }
|
||
|
||
// Get identity_id for this trace
|
||
let identity_id: Option<i32> = sqlx::query_scalar(
|
||
"SELECT identity_id FROM dev.face_detections WHERE file_uuid=$1 AND trace_id=$2 AND identity_id IS NOT NULL LIMIT 1"
|
||
)
|
||
.bind(file_uuid).bind(trace_id)
|
||
.fetch_optional(pool).await?.flatten();
|
||
|
||
if identity_id.is_none() { continue; }
|
||
let identity_id = identity_id.unwrap();
|
||
|
||
// Compute overlap with each speaker
|
||
let mut best_speaker = String::new();
|
||
let mut best_overlap = 0usize;
|
||
|
||
for (speaker_id, segments) in &speakers {
|
||
let mut overlap = 0usize;
|
||
for &fn_num in frames {
|
||
let frame_time = fn_num as f64 / fps;
|
||
for (start, end) in segments {
|
||
if frame_time >= *start && frame_time <= *end {
|
||
overlap += 1;
|
||
break;
|
||
}
|
||
}
|
||
}
|
||
if overlap > best_overlap {
|
||
best_overlap = overlap;
|
||
best_speaker = speaker_id.clone();
|
||
}
|
||
}
|
||
|
||
// Only bind if meaningful overlap
|
||
let overlap_ratio = best_overlap as f64 / frames.len() as f64;
|
||
if overlap_ratio > 0.3 && !best_speaker.is_empty() {
|
||
let metadata = serde_json::json!({
|
||
"trace_id": trace_id,
|
||
"overlap_frames": best_overlap,
|
||
"total_frames": frames.len(),
|
||
"overlap_ratio": overlap_ratio,
|
||
});
|
||
|
||
let _ = sqlx::query(
|
||
"INSERT INTO dev.identity_bindings (identity_id, identity_type, identity_value, confidence, metadata) \
|
||
VALUES ($1, 'speaker', $2, $3, $4::jsonb) \
|
||
ON CONFLICT (identity_id, identity_type, identity_value) DO UPDATE SET confidence = EXCLUDED.confidence, metadata = EXCLUDED.metadata"
|
||
)
|
||
.bind(identity_id)
|
||
.bind(&best_speaker)
|
||
.bind(overlap_ratio)
|
||
.bind(&metadata)
|
||
.execute(pool).await;
|
||
|
||
bindings += 1;
|
||
}
|
||
}
|
||
|
||
tracing::info!("[SpeakerBind] Created {}/{} speaker bindings", bindings, traces.len());
|
||
Ok(bindings)
|
||
}
|
||
|
||
/// Pipeline-triggered entry point: runs the full identity agent for a file.
|
||
/// Reads face_clustered.json + asrx.json, extracts persons/speakers, creates identities,
|
||
/// runs iterative face matching, and binds speakers.
|
||
pub async fn run_identity_agent(db: &PostgresDb, file_uuid: &str) -> anyhow::Result<()> {
|
||
let output_dir = std::env::var("MOMENTRY_OUTPUT_DIR")
|
||
.unwrap_or_else(|_| "/Users/accusys/momentry/output".to_string());
|
||
|
||
let video_dir = PathBuf::from(&output_dir).join(file_uuid);
|
||
let face_clustered_path = video_dir.join(format!("{}.face_clustered.json", file_uuid));
|
||
let face_clustered_path = if face_clustered_path.exists() {
|
||
face_clustered_path
|
||
} else {
|
||
PathBuf::from(&output_dir).join(format!("{}.face_clustered.json", file_uuid))
|
||
};
|
||
|
||
if !face_clustered_path.exists() {
|
||
tracing::warn!("[IdentityAgent] face_clustered.json not found for {}", file_uuid);
|
||
return Ok(());
|
||
}
|
||
|
||
let face_data: serde_json::Value = std::fs::read_to_string(&face_clustered_path)?.parse()?;
|
||
let asrx_path = video_dir.join(format!("{}.asrx.json", file_uuid));
|
||
let asrx_data: Option<serde_json::Value> = if asrx_path.exists() {
|
||
Some(std::fs::read_to_string(&asrx_path)?.parse()?)
|
||
} else {
|
||
None
|
||
};
|
||
|
||
let persons = extract_persons_from_face_data(&face_data);
|
||
let speakers = extract_speakers_from_asrx_data(&asrx_data);
|
||
let identities = analyze_person_speaker_overlap(&persons, &speakers);
|
||
|
||
let pool = db.pool();
|
||
for id_result in &identities {
|
||
let identity_name = format!("person_{}", id_result.person_ids.first().map(|s| &**s).unwrap_or("unknown"));
|
||
let metadata = serde_json::json!({
|
||
"source": "identity_agent",
|
||
"trace_ids": id_result.person_ids,
|
||
"speaker_ids": id_result.speaker_ids,
|
||
"confidence": id_result.confidence,
|
||
"evidence": {
|
||
"speaker_overlap": id_result.evidence.speaker_overlap,
|
||
"frame_ratio": id_result.evidence.frame_ratio,
|
||
},
|
||
"reasoning": id_result.reasoning,
|
||
});
|
||
let _ = sqlx::query(
|
||
"INSERT INTO dev.identities (name, identity_type, source, metadata, status) VALUES ($1, 'people', 'auto', $2::jsonb, 'pending') ON CONFLICT DO NOTHING"
|
||
)
|
||
.bind(&identity_name)
|
||
.bind(&metadata)
|
||
.execute(pool)
|
||
.await;
|
||
}
|
||
|
||
let matched = match_faces_iterative(pool, file_uuid).await.unwrap_or(0);
|
||
let bound = bind_speakers(pool, file_uuid).await.unwrap_or(0);
|
||
|
||
tracing::info!(
|
||
"[IdentityAgent] Done for {}: {} identities, {} face matches, {} speaker bindings",
|
||
file_uuid, identities.len(), matched, bound
|
||
);
|
||
Ok(())
|
||
}
|