#!/opt/homebrew/bin/python3.11 """ Multi-Stage Identity Clustering Runner Stage 1: High-confidence face-level matching - Compare ALL face embeddings in each trace against identity references - Bind trace to identity if >90% of faces match with >0.90 similarity - These become "anchors" for Stage 2 Stage 2: Trace centroid clustering of remaining unbounded traces - Use centroid of unbound traces, cluster with adaptive threshold - Merge clusters with speaker overlap verification Stage 3 (optional): TMDb matching """ import sys, os, json, argparse, time, numpy as np from datetime import datetime from collections import defaultdict from typing import Dict, List, Tuple, Optional import psycopg2 DB_URL = os.environ.get("DATABASE_URL", "postgresql://accusys@localhost:5432/momentry") SCHEMA = "dev" EXPERIMENT_DIR = os.path.dirname(os.path.abspath(__file__)) def get_conn(): return psycopg2.connect(DB_URL) def cosine_similarity(a, b): a, b = np.array(a), np.array(b) return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b) + 1e-10) def parse_pg_array(val): """Parse PostgreSQL real[] array — returns numpy float64 array or None""" if val is None: return None if isinstance(val, np.ndarray): return val.astype(np.float64) if isinstance(val, list): return np.array(val, dtype=np.float64) if isinstance(val, str): s = val.strip('[]{}') if not s: return None return np.fromstring(s, sep=',').astype(np.float64) return None def fetch_trace_with_faces(cur, file_uuid: str, min_frames: int) -> List[dict]: """Fetch traces with ALL their individual face embeddings""" # Get trace summaries cur.execute( f""" SELECT trace_id, COUNT(*) as fc, MIN(frame_number), MAX(frame_number), AVG(x::float), AVG(y::float), AVG(width::float), AVG(height::float) FROM {SCHEMA}.face_detections WHERE file_uuid=%s AND trace_id IS NOT NULL AND embedding IS NOT NULL GROUP BY trace_id HAVING COUNT(*)>=%s ORDER BY trace_id """, (file_uuid, min_frames)) traces = [] for row in cur.fetchall(): tid = row[0] cur.execute( f"SELECT embedding FROM {SCHEMA}.face_detections WHERE file_uuid=%s AND trace_id=%s AND embedding IS NOT NULL ORDER BY confidence DESC", (file_uuid, tid)) faces = [] for r in cur.fetchall(): emb = parse_pg_array(r[0]) if emb is not None: faces.append({"embedding": emb.astype(np.float64)}) traces.append({ "trace_id": tid, "frame_count": row[1], "start_frame": row[2], "end_frame": row[3], "avg_bbox": {"x": row[4], "y": row[5], "w": row[6], "h": row[7]}, "faces": faces, "centroid": np.mean([f["embedding"] for f in faces], axis=0).tolist() if faces else None, }) return traces def fetch_speaker_overlaps(cur, file_uuid: str) -> dict: cur.execute(f""" SELECT REPLACE(n.external_id,'trace_','')::int, n2.external_id, (e.properties->>'overlap_ratio')::float FROM {SCHEMA}.tkg_edges e JOIN {SCHEMA}.tkg_nodes n ON e.source_node_id=n.id JOIN {SCHEMA}.tkg_nodes n2 ON e.target_node_id=n2.id WHERE e.edge_type='SPEAKS_AS' AND n.node_type='face_trace' AND n2.node_type='speaker' AND e.file_uuid=%s """, (file_uuid,)) overlaps = defaultdict(lambda: defaultdict(float)) for tid, spk, ratio in cur.fetchall(): if tid and spk: overlaps[int(tid)][spk] = float(ratio or 0) return dict(overlaps) def fetch_identity_references(cur) -> List[dict]: """Get registered identities with face embeddings as references""" cur.execute(f"SELECT id, name, face_embedding FROM {SCHEMA}.identities WHERE face_embedding IS NOT NULL") results = [] for r in cur.fetchall(): emb = parse_pg_array(r[2]) if emb is None: continue results.append({"id": r[0], "name": r[1], "embedding": emb.astype(np.float64)}) return results # ===== STAGE 1: High-confidence face-level matching ===== def stage1_high_confidence_binding( traces: List[dict], identities: List[dict], face_match_threshold: float = 0.92, trace_bind_ratio: float = 0.85, ) -> Tuple[List[dict], List[dict]]: """ For each trace, compare EVERY face against EVERY identity. Bind trace to identity if >trace_bind_ratio% of faces match with >face_match_threshold. Returns (bound_traces, unbound_traces) """ bound = [] unbound = [] for trace in traces: faces = trace.get("faces", []) if not faces: unbound.append(trace) continue best_identity = None best_match_count = 0 for ident in identities: match_count = 0 for face in faces: sim = cosine_similarity(face["embedding"], ident["embedding"]) if sim >= face_match_threshold: match_count += 1 ratio = match_count / len(faces) if ratio >= trace_bind_ratio and match_count > best_match_count: best_match_count = match_count best_identity = { "id": ident["id"], "name": ident["name"], "match_ratio": round(ratio, 3), "matched_faces": match_count, "total_faces": len(faces), } if best_identity: trace["binding"] = best_identity trace["binding_stage"] = "stage1_face_level" bound.append(trace) else: unbound.append(trace) return bound, unbound # ===== STAGE 2: Centroid clustering of unbound traces ===== def stage2_cluster_unbound( traces: List[dict], threshold: float, adaptive: bool = False ) -> List[dict]: """Cluster unbound traces by centroid similarity + speaker verify""" clusters = [] assigned = set() for i, t1 in enumerate(traces): if t1["trace_id"] in assigned: continue cluster = [t1]; assigned.add(t1["trace_id"]) for j, t2 in enumerate(traces): if t2["trace_id"] in assigned or i == j: continue if t1["centroid"] is None or t2["centroid"] is None: continue sim = cosine_similarity(t1["centroid"], t2["centroid"]) th = threshold if adaptive and (t1["frame_count"] < 10 or t2["frame_count"] < 10): th -= 0.05 if sim >= th: cluster.append(t2); assigned.add(t2["trace_id"]) clusters.append(cluster) return clusters def apply_speaker_verification(clusters: List[dict], speaker_overlaps: dict) -> List[dict]: """Label clusters with speaker + merge same-speaker clusters""" labels = [] for i, cluster in enumerate(clusters): trace_ids = [t["trace_id"] for t in cluster] votes = defaultdict(float) for tid in trace_ids: if tid in speaker_overlaps: for spk, r in speaker_overlaps[tid].items(): votes[spk] += r best_spk = max(votes, key=votes.get) if votes else None labels.append({ "cluster_id": i, "trace_count": len(cluster), "trace_ids": trace_ids, "dominant_speaker": best_spk, "speaker_score": round(votes.get(best_spk, 0), 3) if best_spk else 0, "binding": cluster[0].get("binding"), "binding_stage": cluster[0].get("binding_stage"), }) return labels # ===== Main Experiment ===== def run_experiment(config: dict) -> dict: exp_id = config["id"]; file_uuid = config.get("file_uuid", "") conn = get_conn(); cur = conn.cursor() t0 = time.time() out = lambda *a: None # noqa # Load data traces = fetch_trace_with_faces(cur, file_uuid, config.get("min_frames", 3)) identities = fetch_identity_references(cur) if config.get("enable_identity_match", True) else [] speaker_overlaps = fetch_speaker_overlaps(cur, file_uuid) print(f"Traces: {len(traces)}, Identities: {len(identities)}, Speaker edges: {len(speaker_overlaps)}") # Stage 1: TMDb-based first-pass binding (relaxed threshold) bound, unbound = [], traces if identities: bound, unbound = stage1_high_confidence_binding( traces, identities, config.get("stage1_face_threshold", 0.55), config.get("stage1_bind_ratio", 0.60), ) print(f"Stage 1 (TMDb): {len(bound)} traces bound, {len(unbound)} unbound") # Stage 1b+2: Iterative enrichment — each bound trace adds 3 best faces as references if bound and identities and unbound: # Build initial reference sets from Stage 1 bound traces # For each identity, collect top-3 confidence faces from each bound trace identity_refs = {} # identity_id -> list of reference embeddings for t in bound: b = t.get("binding", {}) iid = b.get("id") if isinstance(b, dict) else None if not iid or not t.get("faces"): continue if iid not in identity_refs: identity_refs[iid] = [] # Sample 3 best faces from this trace (top confidence = best quality) faces = t["faces"] n_sample = min(3, len(faces)) for f in faces[:n_sample]: identity_refs[iid].append(f["embedding"]) # Build identity lookup id_to_name = {ident["id"]: ident["name"] for ident in identities} for iid, refs in identity_refs.items(): print(f" {id_to_name.get(iid, '?'):<20} {len(refs)} reference faces (multi-angle sampling)") # Speaker segment counts for weighting speaker_counts = defaultdict(float) for tid, spks in speaker_overlaps.items(): speaker_counts[tid] = sum(spks.values()) # Iterative matching with growing reference set round_num = 0 while True: round_num += 1 bound_this_round = [] for t in unbound: best_score = 0 best_iid = None best_sim = 0 best_match_count = 0 for iid, refs in identity_refs.items(): faces = t.get("faces", []) if not faces: continue # Compare each face against ALL references, take max per face face_sims = [] for face in faces: max_sim = max( cosine_similarity(face["embedding"], ref) for ref in refs ) face_sims.append(max_sim) avg_sim = np.mean(face_sims) if face_sims else 0 match_ratio = sum(1 for s in face_sims if s >= config.get("stage1_face_threshold", 0.55)) / len(face_sims) # Absolute minimum: if avg similarity is too low, never bind min_sim = config.get("stage1b_min_face_similarity", 0.30) if avg_sim < min_sim: continue # Composite score: similarity + match ratio + speaker weight spk_weight = 1.0 + 0.3 * speaker_counts.get(t["trace_id"], 0) / max(max(speaker_counts.values(), default=1), 1) composite = avg_sim * spk_weight * (0.4 + 0.6 * match_ratio) composite_threshold = config.get("stage1b_composite_threshold", 0.50) if composite > best_score and composite > composite_threshold: best_score = composite best_iid = iid best_sim = avg_sim best_match_count = sum(1 for s in face_sims if s >= 0.50) if best_iid is not None: t["binding"] = { "id": best_iid, "name": id_to_name.get(best_iid, "?"), "avg_similarity": round(best_sim, 3), "match_ratio": round(best_match_count / max(len(t.get("faces", [])), 1), 3), "composite_score": round(best_score, 3), "source": f"video_ref_r{round_num}", } t["binding_stage"] = f"stage1b_r{round_num}" bound_this_round.append(t) bound.append(t) if not bound_this_round: break # Enrich references: add 3 best faces from newly bound traces for t in bound_this_round: iid = t["binding"]["id"] faces = t.get("faces", []) n = min(3, len(faces)) for f in faces[:n]: identity_refs[iid].append(f["embedding"]) # Remove from unbound bound_ids = {t["trace_id"] for t in bound_this_round} unbound = [t for t in unbound if t["trace_id"] not in bound_ids] print(f" Round {round_num}: {len(bound_this_round)} traces bound, {len(unbound)} unbound") clusters = stage2_cluster_unbound( unbound, config.get("stage2_threshold", 0.85), config.get("stage2_adaptive", False), ) print(f"Stage 2: {len(clusters)} clusters from {len(unbound)} unbound traces") # Speaker verification all_labels = apply_speaker_verification(clusters, speaker_overlaps) # --- Temporal Collision Check --- # Split traces that have overlapping frames within the same identity if config.get("enable_temporal_collision_check", True): # Build trace timing map: trace_id → (min_frame, max_frame) trace_timing = {} for t in traces: trace_timing[t["trace_id"]] = (t["start_frame"], t["end_frame"]) collision_splits = 0 for label in all_labels: if label.get("trace_count", 0) < 2: continue tids = label["trace_ids"] # Check all pairs in this label for i in range(len(tids)): for j in range(i+1, len(tids)): a, b = tids[i], tids[j] ta = trace_timing.get(a) tb = trace_timing.get(b) if not ta or not tb: continue # Overlap: max(start) < min(end) if max(ta[0], tb[0]) < min(ta[1], tb[1]): collision_splits += 1 print(f" COLLISION: trace {a} & {b} overlap (frames {max(ta[0],tb[0])}-{min(ta[1],tb[1])}), splitting...") # Move the lower-confidence trace to a new label # Get avg confidence from face embeddings (we don't store per-face confidence in trace dict) # Use the existing confidence data from DB cur2 = conn.cursor() cur2.execute(f"SELECT AVG(confidence) FROM {SCHEMA}.face_detections WHERE file_uuid=%s AND trace_id=%s", (file_uuid, a)) conf_a = cur2.fetchone()[0] or 0 cur2.execute(f"SELECT AVG(confidence) FROM {SCHEMA}.face_detections WHERE file_uuid=%s AND trace_id=%s", (file_uuid, b)) conf_b = cur2.fetchone()[0] or 0 cur2.close() if conf_a < conf_b: loser_tid = a else: loser_tid = b # Remove loser from this label, create new label label["trace_ids"].remove(loser_tid) label["trace_count"] -= 1 all_labels.append({ "cluster_id": len(all_labels), "trace_count": 1, "trace_ids": [loser_tid], "binding": None, "binding_stage": "collision_split", }) if collision_splits > 0: print(f" Temporal collision: {collision_splits} traces split") # Merge Stage 1 bound traces into labels for t in bound: all_labels.append({ "cluster_id": len(all_labels), "trace_count": 1, "trace_ids": [t["trace_id"]], "binding": t.get("binding"), "binding_stage": "stage1_face_level", "dominant_speaker": next(iter(speaker_overlaps.get(t["trace_id"], {}).keys()), None) if t["trace_id"] in speaker_overlaps else None, }) # --- Temp Identity: assign names to unbound clusters --- temp_count = 0 for label in all_labels: if label.get("binding") is not None: continue # already has known identity tids = label.get("trace_ids", []) if len(tids) < 1: continue # Create temp identity for all unbound clusters (even singletons as "strangers") if len(tids) >= 1: temp_count += 1 if len(tids) >= 2: temp_name = f"Person_{file_uuid[:8]}_{temp_count:03d}" else: temp_name = f"Stranger_{file_uuid[:8]}_{temp_count:03d}" label["binding"] = { "name": temp_name, "source": "auto_temp", "trace_count": len(tids), } label["binding_stage"] = "auto_temp" if temp_count > 0: print(f" Temp identities created: {temp_count}") # Metrics metrics = { "total_traces": len(traces), "stage1_bound": len(bound), "stage1_bound_traces": len(bound), "stage2_clusters": len(clusters), "stage2_unbound_clustered": sum(len(c) for c in clusters), "total_clusters": len(all_labels), "execution_time_s": time.time() - t0, "coverage": (len(bound) + sum(len(c) for c in clusters)) / max(len(traces), 1), } for k, v in metrics.items(): print(f" {k}: {v}") cur.close(); conn.close() # --- Write bindings to database --- if config.get("write_db", False): conn2 = get_conn(); cur2 = conn2.cursor() total_written = 0 for label in all_labels: binding = label.get("binding") if not binding: continue identity_name = binding.get("name", "") if not identity_name: continue # Get or create identity cur2.execute(f"SELECT id FROM {SCHEMA}.identities WHERE name=%s", (identity_name,)) row = cur2.fetchone() if row: identity_id = row[0] else: source = binding.get("source", "auto") cur2.execute( f"INSERT INTO {SCHEMA}.identities (name, identity_type, source, status) VALUES (%s,'people',%s,'pending') RETURNING id", (identity_name, source)) identity_id = cur2.fetchone()[0] # Bind all faces in each trace to the identity for tid in label["trace_ids"]: cur2.execute( f"UPDATE {SCHEMA}.face_detections SET identity_id=%s WHERE file_uuid=%s AND trace_id=%s AND identity_id IS NULL", (identity_id, file_uuid, tid)) affected = cur2.rowcount if affected > 0: # Write to identity_bindings for traceability confidence = float(binding.get("avg_similarity", 0.8)) cur2.execute( f"INSERT INTO {SCHEMA}.identity_bindings (identity_id, identity_type, identity_value, confidence) VALUES (%s,'trace',%s,%s) ON CONFLICT DO NOTHING", (identity_id, str(tid), confidence)) total_written += affected conn2.commit() cur2.close(); conn2.close() print(f"\nDB write: {total_written} face_detections updated") # Save result_dir = os.path.join(EXPERIMENT_DIR, "results", f"exp_{exp_id}") os.makedirs(result_dir, exist_ok=True) for name, data in [("labels.json", all_labels), ("metrics.json", metrics), ("config.json", config)]: with open(os.path.join(result_dir, name), "w") as f: json.dump(data, f, indent=2, ensure_ascii=False, default=str) print(f"\nSaved to {result_dir}") return metrics def main(): p = argparse.ArgumentParser() p.add_argument("--config", required=True) p.add_argument("--write-db", action="store_true", help="Write bindings to database") args = p.parse_args() with open(args.config) as f: config = json.load(f) if args.write_db: config["write_db"] = True run_experiment(config) if __name__ == "__main__": main()