diff --git a/src/api/identity_agent_api.rs b/src/api/identity_agent_api.rs index 30495f1..7236d60 100644 --- a/src/api/identity_agent_api.rs +++ b/src/api/identity_agent_api.rs @@ -619,10 +619,13 @@ fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 { } } -/// 迭代多角度 face embedding 比對 + 傳播 -/// Round 1: 用 TMDb seed face_embedding 比對 face_detections (threshold 0.50) +/// 迭代多角度 face embedding 比對 + 傳播 (Qdrant version) +/// Round 1: 用 TMDb seed face_embedding 比對 Qdrant embeddings (threshold 0.50) /// Round 2+: 用已匹配 trace 的所有 face 作為 seed,傳播到未匹配 trace async fn match_faces_iterative(pool: &sqlx::PgPool, file_uuid: &str) -> anyhow::Result { + use crate::core::db::face_embedding_db::FaceEmbeddingDb; + use std::collections::HashMap; + // Step 1: 載入 TMDb identities (source='tmdb' 且有 face_embedding) let identities_table = schema::table_name("identities"); let tmdb_rows = sqlx::query_as::<_, (i32, String, Vec)>( @@ -635,12 +638,167 @@ async fn match_faces_iterative(pool: &sqlx::PgPool, file_uuid: &str) -> anyhow:: return Ok(0); } tracing::info!( - "[FaceMatch] Loaded {} TMDb seed identities", + "[FaceMatch-Qdrant] Loaded {} TMDb seed identities", + tmdb_rows.len() + ); + + // Step 2: Load embeddings from Qdrant + let face_db = FaceEmbeddingDb::new(); + let qdrant_embeddings = face_db.get_all_embeddings_for_file(file_uuid).await?; + + if qdrant_embeddings.is_empty() { + tracing::warn!("[FaceMatch-Qdrant] No face embeddings in Qdrant for {}", file_uuid); + return match_faces_iterative_pg(pool, file_uuid).await; // Fallback to PG + } + + // Group: trace_id → Vec<(frame, embedding)> + let mut trace_faces_raw: HashMap)>> = HashMap::new(); + for (_, emb, payload) in &qdrant_embeddings { + trace_faces_raw + .entry(payload.trace_id) + .or_default() + .push((payload.frame, emb.clone())); + } + + // Sample 3 embeddings per trace (front, mid, back) + let mut trace_samples: HashMap>> = HashMap::new(); + for (tid, mut faces) in trace_faces_raw { + faces.sort_by_key(|(frame, _)| *frame); + let n = faces.len(); + let indices = if n <= 3 { + (0..n).collect::>() + } else { + vec![0, n / 2, n - 1] + }; + let samples: Vec> = indices.iter().map(|&i| faces[i].1.clone()).collect(); + trace_samples.insert(tid, samples); + } + + let total_traces = trace_samples.len(); + let sample_count: usize = trace_samples.values().map(|v| v.len()).sum(); + tracing::info!( + "[FaceMatch-Qdrant] Loaded {} traces, sampled {} embeddings", + total_traces, + sample_count + ); + + // Step 3: Match against TMDb seeds + const TH: f32 = 0.50; + let tmdb_seeds: Vec<(i32, String, Vec)> = tmdb_rows; + let mut matched: HashMap = HashMap::new(); + + for (&tid, samples) in &trace_samples { + let mut best_name = String::new(); + let mut best_sim = 0.0f32; + for (_, ref name, ref tmdb_emb) in &tmdb_seeds { + for face_emb in samples { + let s = cosine_similarity(face_emb, tmdb_emb); + if s > best_sim { + best_sim = s; + best_name = name.clone(); + } + } + } + if best_sim >= TH { + matched.insert(tid, best_name); + } + } + tracing::info!( + "[FaceMatch-Qdrant] Round 1: matched {} traces (threshold={})", + matched.len(), + TH + ); + + // Round 2+: Propagate + let mut round = 2; + while matched.len() < trace_samples.len() { + let prev_count = matched.len(); + + // Collect new matches in separate HashMap + let mut new_matches: HashMap = HashMap::new(); + + for (&tid, samples) in &trace_samples { + if matched.contains_key(&tid) { + continue; + } + + for (matched_tid, matched_name) in &matched { + if let Some(matched_embs) = trace_samples.get(matched_tid) { + for face_emb in samples { + for ref_emb in matched_embs { + let s = cosine_similarity(face_emb, ref_emb); + if s >= TH { + new_matches.insert(tid, matched_name.clone()); + break; + } + } + } + } + } + } + + // Merge new matches + matched.extend(new_matches); + + if matched.len() == prev_count { + break; + } + tracing::info!( + "[FaceMatch-Qdrant] Round {}: matched {} total", + round, + matched.len() + ); + round += 1; + } + + // Update face_detections.identity_id + let fd_table = schema::table_name("face_detections"); + let identities_map: HashMap = tmdb_seeds + .iter() + .map(|(id, name, _)| (name.clone(), *id)) + .collect(); + + let mut updated = 0usize; + for (tid, name) in &matched { + let identity_id = identities_map.get(name); + if let Some(id) = identity_id { + let rows = sqlx::query(&format!( + "UPDATE {} SET identity_id = $1 WHERE file_uuid = $2 AND trace_id = $3", + fd_table + )) + .bind(*id) + .bind(file_uuid) + .bind(*tid) + .execute(pool) + .await? + .rows_affected(); + updated += rows as usize; + } + } + + tracing::info!("[FaceMatch-Qdrant] Updated {} face_detections", updated); + Ok(updated) +} + +/// Fallback: PostgreSQL-based matching (original implementation) +async fn match_faces_iterative_pg(pool: &sqlx::PgPool, file_uuid: &str) -> anyhow::Result { + // Step 1: 載入 TMDb identities (source='tmdb' 且有 face_embedding) + let identities_table = schema::table_name("identities"); + let tmdb_rows = sqlx::query_as::<_, (i32, String, Vec)>( + &format!("SELECT id, name, face_embedding::real[] FROM {} WHERE source='tmdb' AND face_embedding IS NOT NULL", identities_table) + ) + .fetch_all(pool).await?; + + if tmdb_rows.is_empty() { + tracing::warn!("[FaceMatch-PG] No TMDb identities with face embeddings"); + return Ok(0); + } + tracing::info!( + "[FaceMatch-PG] Loaded {} TMDb seed identities", tmdb_rows.len() ); // Step 2: 載入所有 face_detections(含 frame_number),按 trace_id 分組 - // frame_number is BIGINT (i64) in database let fd_table = schema::table_name("face_detections"); let fd_rows = sqlx::query_as::<_, (i32, i64, Vec)>(&format!( "SELECT trace_id, frame_number, embedding FROM {} \ @@ -653,7 +811,7 @@ async fn match_faces_iterative(pool: &sqlx::PgPool, file_uuid: &str) -> anyhow:: .await?; if fd_rows.is_empty() { - tracing::warn!("[FaceMatch] No face detections with embeddings"); + tracing::warn!("[FaceMatch-PG] No face detections with embeddings"); return Ok(0); } @@ -668,7 +826,6 @@ async fn match_faces_iterative(pool: &sqlx::PgPool, file_uuid: &str) -> anyhow:: } // 從每個 trace 選取不同角度的 3 個 face embedding - // 策略:按 frame_number 排序,取前中後各 1 個 let mut trace_samples: HashMap>> = HashMap::new(); for (tid, mut faces) in trace_faces_raw { faces.sort_by_key(|(frame, _)| *frame); @@ -686,7 +843,7 @@ async fn match_faces_iterative(pool: &sqlx::PgPool, file_uuid: &str) -> anyhow:: let total_traces = trace_samples.len(); let sample_count: usize = trace_samples.values().map(|v| v.len()).sum(); tracing::info!( - "[FaceMatch] Loaded {} traces, sampled {} embeddings (3-angle)", + "[FaceMatch-PG] Loaded {} traces, sampled {} embeddings (3-angle)", total_traces, sample_count ); @@ -699,7 +856,6 @@ async fn match_faces_iterative(pool: &sqlx::PgPool, file_uuid: &str) -> anyhow:: let mut matched: HashMap = HashMap::new(); // trace_id → identity_name // Round 1: 用 3-angle samples 比對 TMDb - // 每個 trace 選 3 個不同角度 face,取最高 similarity for (&tid, samples) in &trace_samples { let mut best_name = String::new(); let mut best_sim = 0.0f32;