Release v1.0.0 candidate
This commit is contained in:
@@ -1,146 +1,235 @@
|
||||
use anyhow::{Context, Result};
|
||||
use serde::Deserialize;
|
||||
use tracing::{error, info};
|
||||
use std::collections::HashMap;
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
use crate::core::db::PostgresDb;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct FaceDetection {
|
||||
face_id: String,
|
||||
embedding: Vec<f32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct TmdbIdentity {
|
||||
id: i64,
|
||||
id: i32,
|
||||
name: String,
|
||||
face_embedding: Vec<f32>,
|
||||
}
|
||||
|
||||
const MATCH_THRESHOLD: f32 = 0.55;
|
||||
|
||||
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
|
||||
if a.len() != b.len() || a.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
|
||||
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
if norm_a == 0.0 || norm_b == 0.0 {
|
||||
return 0.0;
|
||||
}
|
||||
dot / (norm_a * norm_b)
|
||||
if a.len() != b.len() || a.is_empty() { return 0.0; }
|
||||
let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
|
||||
let na: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
let nb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
if na == 0.0 || nb == 0.0 { 0.0 } else { dot / (na * nb) }
|
||||
}
|
||||
|
||||
/// Match unassigned face detections against TMDb-sourced identities.
|
||||
/// For each face detection with identity_id IS NULL, compute cosine similarity
|
||||
/// against all TMDb identities that have face_embedding set.
|
||||
/// If similarity > MATCH_THRESHOLD, bind the face to the identity.
|
||||
/// Match face detections against TMDb identities using iterative multi-angle propagation.
|
||||
/// Round 1: seed match against TMDb face_embeddings (threshold 0.50)
|
||||
/// Round 2+: propagate to remaining traces using matched faces as reference
|
||||
pub async fn match_faces_against_tmdb(db: &PostgresDb, file_uuid: &str) -> Result<usize> {
|
||||
// Step 1: Fetch unassigned face detections for this file
|
||||
let detections: Vec<FaceDetection> = sqlx::query_as::<_, (String, Vec<f32>)>(
|
||||
"SELECT face_id, embedding FROM dev.face_detections \
|
||||
WHERE file_uuid = $1 AND identity_id IS NULL AND embedding IS NOT NULL",
|
||||
let pool = db.pool();
|
||||
|
||||
// Step 1: Load TMDb identities with face embeddings
|
||||
let tmdb_rows = sqlx::query_as::<_, (i32, String, Vec<f32>)>(
|
||||
"SELECT id, name, face_embedding::real[] FROM dev.identities WHERE source='tmdb' AND face_embedding IS NOT NULL"
|
||||
)
|
||||
.fetch_all(pool).await?;
|
||||
|
||||
if tmdb_rows.is_empty() {
|
||||
info!("[TKG-MATCH] No TMDb identities with face embeddings");
|
||||
return Ok(0);
|
||||
}
|
||||
info!("[TKG-MATCH] {} TMDb seeds loaded", tmdb_rows.len());
|
||||
|
||||
// Step 2: Load face_detections grouped by trace_id
|
||||
let fd_rows = sqlx::query_as::<_, (i32, Vec<f32>)>(
|
||||
"SELECT trace_id, embedding FROM dev.face_detections \
|
||||
WHERE file_uuid=$1 AND trace_id IS NOT NULL AND embedding IS NOT NULL \
|
||||
ORDER BY trace_id"
|
||||
)
|
||||
.bind(file_uuid)
|
||||
.fetch_all(db.pool())
|
||||
.await
|
||||
.context("Failed to fetch unassigned face detections")?
|
||||
.into_iter()
|
||||
.map(|(face_id, embedding)| FaceDetection { face_id, embedding })
|
||||
.collect();
|
||||
.fetch_all(pool).await?;
|
||||
|
||||
if detections.is_empty() {
|
||||
info!(
|
||||
"[TMDB-FACE] No unassigned face detections for {}",
|
||||
file_uuid
|
||||
);
|
||||
if fd_rows.is_empty() {
|
||||
info!("[TKG-MATCH] No face detections for {}", file_uuid);
|
||||
return Ok(0);
|
||||
}
|
||||
|
||||
// Step 2: Fetch TMDb identities with face embeddings
|
||||
let identities: Vec<TmdbIdentity> = sqlx::query_as::<_, (i64, String, Vec<f32>)>(
|
||||
"SELECT id, name, face_embedding::real[] FROM dev.identities \
|
||||
WHERE source = 'tmdb' AND face_embedding IS NOT NULL",
|
||||
)
|
||||
.fetch_all(db.pool())
|
||||
.await
|
||||
.context("Failed to fetch TMDb identities")?
|
||||
.into_iter()
|
||||
.map(|(id, name, emb)| TmdbIdentity {
|
||||
id,
|
||||
name,
|
||||
face_embedding: emb,
|
||||
})
|
||||
.collect();
|
||||
|
||||
if identities.is_empty() {
|
||||
info!("[TMDB-FACE] No TMDb identities with face embeddings for matching");
|
||||
return Ok(0);
|
||||
let mut trace_faces: HashMap<i32, Vec<Vec<f32>>> = HashMap::new();
|
||||
for (tid, emb) in &fd_rows {
|
||||
trace_faces.entry(*tid).or_default().push(emb.clone());
|
||||
}
|
||||
// Dedup near-identical embeddings within trace
|
||||
for faces in trace_faces.values_mut() {
|
||||
faces.sort_by(|a, b| a[0].partial_cmp(&b[0]).unwrap_or(std::cmp::Ordering::Equal));
|
||||
faces.dedup_by(|a, b| cosine_similarity(a, b) > 0.99);
|
||||
}
|
||||
|
||||
info!(
|
||||
"[TMDB-FACE] Matching {} face detections against {} TMDb identities",
|
||||
detections.len(),
|
||||
identities.len()
|
||||
);
|
||||
let total = trace_faces.len();
|
||||
info!("[TKG-MATCH] {} traces with {} faces", total, fd_rows.len());
|
||||
|
||||
// Step 3: For each face detection, find best matching identity
|
||||
let mut bindings_created = 0usize;
|
||||
// Step 3: Iterative matching
|
||||
const TH: f32 = 0.50;
|
||||
let mut matched: HashMap<i32, (i32, String)> = HashMap::new(); // trace_id → (identity_id, name)
|
||||
|
||||
for det in &detections {
|
||||
let mut best_match: Option<(i64, f32)> = None;
|
||||
// Round 1: against TMDb seeds
|
||||
for (&tid, faces) in &trace_faces {
|
||||
let mut best_id = 0i32;
|
||||
let mut best_name = String::new();
|
||||
let mut best_sim = 0.0f32;
|
||||
for (id, name, tmdb_emb) in &tmdb_rows {
|
||||
for face in faces {
|
||||
let s = cosine_similarity(face, tmdb_emb);
|
||||
if s > best_sim { best_sim = s; best_id = *id; best_name = name.clone(); }
|
||||
}
|
||||
}
|
||||
if best_sim >= TH {
|
||||
matched.insert(tid, (best_id, best_name));
|
||||
}
|
||||
}
|
||||
info!("[TKG-MATCH] Round 1: {} ({}/{})", matched.len(), matched.len() * 100 / total, total);
|
||||
|
||||
for identity in &identities {
|
||||
let sim = cosine_similarity(&det.embedding, &identity.face_embedding);
|
||||
if sim > MATCH_THRESHOLD {
|
||||
match best_match {
|
||||
Some((_, best_sim)) if sim > best_sim => {
|
||||
best_match = Some((identity.id, sim));
|
||||
}
|
||||
None => {
|
||||
best_match = Some((identity.id, sim));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
// Round 2+: propagate
|
||||
for round_n in 2..=10 {
|
||||
let prev = matched.len();
|
||||
let mut seed_pool: HashMap<i32, Vec<&Vec<f32>>> = HashMap::new();
|
||||
for (&tid, (id, _)) in &matched {
|
||||
if let Some(faces) = trace_faces.get(&tid) {
|
||||
seed_pool.entry(*id).or_default().extend(faces.iter());
|
||||
}
|
||||
}
|
||||
|
||||
if let Some((identity_id, similarity)) = best_match {
|
||||
// Update face_detection with identity_id
|
||||
let _ = sqlx::query(
|
||||
"UPDATE dev.face_detections SET identity_id = $1, identity_confidence = $2 \
|
||||
WHERE file_uuid = $3 AND face_id = $4",
|
||||
)
|
||||
.bind(identity_id)
|
||||
.bind(similarity as f64)
|
||||
.bind(file_uuid)
|
||||
.bind(&det.face_id)
|
||||
.execute(db.pool())
|
||||
.await
|
||||
.ok();
|
||||
|
||||
// Also create identity_binding
|
||||
let _ = sqlx::query(
|
||||
"INSERT INTO dev.identity_bindings (identity_id, identity_type, identity_value, source, confidence) \
|
||||
VALUES ($1, 'face', $2, 'tmdb_agent', $3) \
|
||||
ON CONFLICT (identity_id, identity_type, identity_value) DO UPDATE SET confidence = EXCLUDED.confidence"
|
||||
)
|
||||
.bind(identity_id)
|
||||
.bind(&det.face_id)
|
||||
.bind(similarity as f64)
|
||||
.execute(db.pool())
|
||||
.await
|
||||
.ok();
|
||||
|
||||
bindings_created += 1;
|
||||
let mut new_matches: Vec<(i32, i32, String)> = Vec::new();
|
||||
for (&tid, faces) in &trace_faces {
|
||||
if matched.contains_key(&tid) || faces.is_empty() { continue; }
|
||||
let ref_face = &faces[0];
|
||||
let mut best_id = 0i32;
|
||||
let mut best_name = String::new();
|
||||
let mut best_sim = 0.0f32;
|
||||
for (&id, seed_faces) in &seed_pool {
|
||||
for seed in seed_faces {
|
||||
let s = cosine_similarity(ref_face, seed);
|
||||
if s > best_sim { best_sim = s; best_id = id; }
|
||||
}
|
||||
}
|
||||
if best_sim >= TH {
|
||||
// Look up name for this id
|
||||
for (id, name, _) in &tmdb_rows {
|
||||
if *id == best_id { best_name = name.clone(); break; }
|
||||
}
|
||||
new_matches.push((tid, best_id, best_name));
|
||||
}
|
||||
}
|
||||
for (tid, id, name) in new_matches {
|
||||
matched.insert(tid, (id, name));
|
||||
}
|
||||
let new = matched.len() - prev;
|
||||
if new < 5 { break; }
|
||||
}
|
||||
|
||||
info!(
|
||||
"[TMDB-FACE] Created {} face-to-TMDb bindings for {}",
|
||||
bindings_created, file_uuid
|
||||
);
|
||||
// Step 4: Quality control
|
||||
// 4a: Remove low-confidence traces (fewer than 4 face detections)
|
||||
let mut after_qc = HashMap::new();
|
||||
for (&tid, &(id, ref name)) in &matched {
|
||||
let cnt: i64 = sqlx::query_scalar(
|
||||
"SELECT COUNT(*) FROM dev.face_detections WHERE file_uuid=$1 AND trace_id=$2"
|
||||
)
|
||||
.bind(file_uuid).bind(tid)
|
||||
.fetch_one(pool).await.unwrap_or(0);
|
||||
if cnt >= 4 {
|
||||
after_qc.insert(tid, (id, name.clone()));
|
||||
} else {
|
||||
info!("[TKG-QC] trace {} removed: only {} face(s), need >= 4", tid, cnt);
|
||||
}
|
||||
}
|
||||
let matched = after_qc;
|
||||
let removed_low = total - matched.len();
|
||||
if removed_low > 0 {
|
||||
info!("[TKG-QC] Removed {} low-confidence traces (< 4 faces)", removed_low);
|
||||
}
|
||||
|
||||
Ok(bindings_created)
|
||||
// 4b: Temporal collision check
|
||||
let removed_collisions = quality_check_temporal_collisions(pool, file_uuid).await?;
|
||||
if removed_collisions > 0 {
|
||||
info!("[TKG-QC] Resolved {} temporal collisions", removed_collisions);
|
||||
}
|
||||
|
||||
// Step 5: Update DB
|
||||
let mut updated = 0usize;
|
||||
for (&tid, &(id, _)) in &matched {
|
||||
let r = sqlx::query(
|
||||
"UPDATE dev.face_detections SET identity_id=$1 WHERE file_uuid=$2 AND trace_id=$3"
|
||||
)
|
||||
.bind(id).bind(file_uuid).bind(tid)
|
||||
.execute(pool).await?;
|
||||
if r.rows_affected() > 0 { updated += 1; }
|
||||
}
|
||||
|
||||
info!("[TKG-MATCH] Done: {}/{} traces matched ({}%)",
|
||||
matched.len(), total, matched.len() * 100 / total);
|
||||
Ok(updated)
|
||||
}
|
||||
|
||||
/// Quality check: detect temporal collisions where two different traces of the same
|
||||
/// identity appear in the same frame (impossible for one person).
|
||||
/// Unbind the lower-confidence trace from the conflicting pair.
|
||||
/// RCA reference: docs_v1.0/API_V1.0.0/INTERNAL/RCA_TRACE39_TRACE45_COLLISION_V1.0.0.md
|
||||
async fn quality_check_temporal_collisions(pool: &sqlx::PgPool, file_uuid: &str) -> Result<usize> {
|
||||
// Find all collision pairs: same identity, same frame, different trace
|
||||
let collisions = sqlx::query_as::<_, (i32, i32, i32, i32)>(
|
||||
r#"
|
||||
SELECT a.identity_id, a.trace_id, b.trace_id, a.frame_number
|
||||
FROM dev.face_detections a
|
||||
JOIN dev.face_detections b
|
||||
ON a.file_uuid = b.file_uuid
|
||||
AND a.frame_number = b.frame_number
|
||||
AND a.trace_id < b.trace_id
|
||||
WHERE a.file_uuid = $1
|
||||
AND a.identity_id IS NOT NULL
|
||||
AND a.identity_id = b.identity_id
|
||||
ORDER BY a.identity_id, a.frame_number
|
||||
"#
|
||||
)
|
||||
.bind(file_uuid)
|
||||
.fetch_all(pool).await?;
|
||||
|
||||
if collisions.is_empty() {
|
||||
return Ok(0);
|
||||
}
|
||||
|
||||
// Group collisions by (identity_id, trace_a, trace_b) and count frames
|
||||
use std::collections::HashMap;
|
||||
let mut collision_groups: HashMap<(i32, i32, i32), usize> = HashMap::new();
|
||||
for (id, ta, tb, _) in &collisions {
|
||||
*collision_groups.entry((*id, *ta, *tb)).or_default() += 1;
|
||||
}
|
||||
|
||||
let mut unbound = 0usize;
|
||||
for ((id, ta, tb), overlap_frames) in &collision_groups {
|
||||
// Get face detection count for each trace
|
||||
let cnt_a: i64 = sqlx::query_scalar(
|
||||
"SELECT COUNT(*) FROM dev.face_detections WHERE file_uuid=$1 AND trace_id=$2 AND identity_id=$3"
|
||||
)
|
||||
.bind(file_uuid).bind(ta).bind(id)
|
||||
.fetch_one(pool).await.unwrap_or(0);
|
||||
|
||||
let cnt_b: i64 = sqlx::query_scalar(
|
||||
"SELECT COUNT(*) FROM dev.face_detections WHERE file_uuid=$1 AND trace_id=$2 AND identity_id=$3"
|
||||
)
|
||||
.bind(file_uuid).bind(tb).bind(id)
|
||||
.fetch_one(pool).await.unwrap_or(0);
|
||||
|
||||
// Unbind the trace with fewer detections (likely the false positive)
|
||||
let victim = if cnt_a <= cnt_b { *ta } else { *tb };
|
||||
let victim_cnt = if cnt_a <= cnt_b { cnt_a } else { cnt_b };
|
||||
|
||||
sqlx::query(
|
||||
"UPDATE dev.face_detections SET identity_id=NULL WHERE file_uuid=$1 AND trace_id=$2"
|
||||
)
|
||||
.bind(file_uuid).bind(victim)
|
||||
.execute(pool).await?;
|
||||
|
||||
unbound += 1;
|
||||
warn!("[TKG-QC] Collision identity={}: trace {} vs trace {} ({} overlap frames). Unbound trace {} ({} detections)",
|
||||
id, ta, tb, overlap_frames, victim, victim_cnt);
|
||||
}
|
||||
|
||||
Ok(unbound)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user