#!/opt/homebrew/bin/python3.11 """ Identity Clustering Experiment Runner Usage: python runner.py --config configs/exp_001.json Each experiment: 1. Reads config parameters 2. Fetches face trace data from DB 3. Runs clustering algorithm 4. Optionally matches against TMDb 5. Optionally verifies against speakers 6. Saves all results to experiments/identity_clustering/results/exp_{id}/ """ import sys import os import json import argparse import time import numpy as np from datetime import datetime from collections import defaultdict from typing import Dict, List, Tuple, Optional sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../..", "scripts")) # DB connection import psycopg2 import psycopg2.extras DB_URL = os.environ.get("DATABASE_URL", "postgresql://accusys@localhost:5432/momentry") SCHEMA = "dev" EXPERIMENT_DIR = os.path.dirname(os.path.abspath(__file__)) def get_conn(): return psycopg2.connect(DB_URL) def load_experiment_config(config_path: str) -> dict: with open(config_path) as f: return json.load(f) def fetch_trace_data(cur, file_uuid: str, min_frames: int) -> List[dict]: """Fetch trace centroids + metadata from face_detections""" sql = f""" SELECT trace_id, COUNT(*) as frame_count, MIN(frame_number) as start_frame, MAX(frame_number) as end_frame, AVG(x)::float as avg_x, AVG(y)::float as avg_y, AVG(width)::float as avg_w, AVG(height)::float as avg_h, AVG(confidence) as avg_confidence FROM {SCHEMA}.face_detections WHERE file_uuid = %s AND trace_id IS NOT NULL AND embedding IS NOT NULL GROUP BY trace_id HAVING COUNT(*) >= %s ORDER BY trace_id """ cur.execute(sql, (file_uuid, min_frames)) rows = cur.fetchall() traces = [] for row in rows: # Get all embeddings for this trace cur.execute( f"SELECT embedding FROM {SCHEMA}.face_detections WHERE file_uuid=%s AND trace_id=%s AND embedding IS NOT NULL ORDER BY confidence DESC", (file_uuid, row[0]), ) embeddings = [np.array(r[0]) for r in cur.fetchall()] centroid_method = "mean" # default, configurable if centroid_method == "mean": centroid = np.mean(embeddings, axis=0) if embeddings else None elif centroid_method == "median": centroid = np.median(embeddings, axis=0) if embeddings else None else: centroid = embeddings[0] if embeddings else None traces.append( { "trace_id": row[0], "frame_count": row[1], "start_frame": row[2], "end_frame": row[3], "avg_bbox": {"x": row[4], "y": row[5], "w": row[6], "h": row[7]}, "avg_confidence": row[8], "embedding_count": len(embeddings), "centroid": centroid.tolist() if centroid is not None else None, } ) return traces def cosine_similarity(a, b): a, b = np.array(a), np.array(b) return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b) + 1e-10) def cluster_by_threshold( traces: List[dict], threshold: float, adaptive: bool = False ) -> List[dict]: """Simple threshold-based clustering""" clusters = [] assigned = set() for i, t1 in enumerate(traces): if t1["trace_id"] in assigned: continue cluster = [t1] assigned.add(t1["trace_id"]) for j, t2 in enumerate(traces): if t2["trace_id"] in assigned or i == j: continue if t1["centroid"] is None or t2["centroid"] is None: continue sim = cosine_similarity(t1["centroid"], t2["centroid"]) th = threshold if adaptive: # Slightly relax threshold for profile angles fc1, fc2 = t1["frame_count"], t2["frame_count"] if fc1 < 60 or fc2 < 60: th = threshold - 0.05 # relax for short traces if sim >= th: cluster.append(t2) assigned.add(t2["trace_id"]) if len(cluster) >= 1: clusters.append(cluster) return clusters def cluster_dbscan( traces: List[dict], eps: float = 0.3, min_samples: int = 2 ) -> List[dict]: """DBSCAN clustering on embeddings""" from sklearn.cluster import DBSCAN valid = [t for t in traces if t["centroid"] is not None] X = np.array([t["centroid"] for t in valid]) # Cosine distance = 1 - cosine_similarity clustering = DBSCAN(eps=eps, min_samples=min_samples, metric="cosine").fit(X) labels = clustering.labels_ clusters_dict = defaultdict(list) for i, label in enumerate(labels): key = int(label) if label >= 0 else f"noise_{i}" clusters_dict[key].append(valid[i]) return list(clusters_dict.values()) def fetch_tmdb_identities(cur) -> List[dict]: """Get TMDb identities with embeddings""" cur.execute( f"SELECT id, name, face_embedding FROM {SCHEMA}.identities WHERE source='tmdb' AND face_embedding IS NOT NULL" ) return [ {"id": r[0], "name": r[1], "embedding": r[2]} for r in cur.fetchall() if r[2] is not None ] def fetch_speaker_overlaps(cur, file_uuid: str) -> dict: """Get speaker-face trace overlap from TKG edges. Returns {trace_id: {speaker_id: overlap_count}}""" cur.execute( f""" SELECT REPLACE(n.external_id, 'trace_', '')::int as trace_id, n2.external_id as speaker_id, (e.properties->>'overlap_ratio')::float as overlap_ratio FROM {SCHEMA}.tkg_edges e JOIN {SCHEMA}.tkg_nodes n ON e.source_node_id = n.id JOIN {SCHEMA}.tkg_nodes n2 ON e.target_node_id = n2.id WHERE e.edge_type = 'SPEAKS_AS' AND n.node_type = 'face_trace' AND n2.node_type = 'speaker' AND e.file_uuid = %s """, (file_uuid,), ) overlaps = defaultdict(lambda: defaultdict(float)) for row in cur.fetchall(): trace_id, speaker_id, ratio = row[0], row[1], row[2] or 0 if trace_id is None or speaker_id is None: continue overlaps[int(trace_id)][speaker_id] = float(ratio) return dict(overlaps) def verify_with_speakers( clusters: List[dict], speaker_overlaps: dict ) -> List[dict]: """Annotate clusters with dominant speaker from time overlap""" for cluster in clusters: # Collect all speaker overlaps for traces in this cluster speaker_votes = defaultdict(float) trace_ids = cluster.get("trace_ids", []) if not trace_ids: # Raw cluster list trace_ids = [t["trace_id"] for t in cluster] for tid in trace_ids: if tid in speaker_overlaps: for spk, ratio in speaker_overlaps[tid].items(): speaker_votes[spk] += ratio if speaker_votes: best_speaker = max(speaker_votes, key=speaker_votes.get) best_score = speaker_votes[best_speaker] cluster["dominant_speaker"] = best_speaker cluster["speaker_overlap_score"] = round(best_score, 3) cluster["speaker_votes"] = dict(speaker_votes) else: cluster["dominant_speaker"] = None cluster["speaker_overlap_score"] = 0 cluster["speaker_votes"] = {} # Merge clusters that share dominant speaker (high overlap with same speaker) speaker_clusters = defaultdict(list) for i, cluster in enumerate(clusters): spk = cluster.get("dominant_speaker") if spk and cluster.get("speaker_overlap_score", 0) > 0.5: speaker_clusters[spk].append(i) merged = set() new_clusters = [] for spk, indices in speaker_clusters.items(): if len(indices) <= 1: continue # Merge all clusters belonging to same speaker merged_group = [] for idx in indices: merged_group.extend( clusters[idx].get("trace_ids", []) or [t["trace_id"] for t in clusters[idx]] ) merged.add(idx) new_clusters.append({ "merged_from": indices, "trace_ids": list(set(merged_group)), "trace_count": len(set(merged_group)), "dominant_speaker": spk, "merge_reason": "shared_dominant_speaker", }) # Keep unmerged clusters for i, cluster in enumerate(clusters): if i not in merged: new_clusters.append(cluster) return new_clusters def match_tmdb(clusters: List[dict], tmdb_identities: List[dict]) -> List[dict]: """Match each cluster to best TMDb identity""" results = [] for i, cluster in enumerate(clusters): if len(cluster) == 0: continue # Use the trace with most frames as representative best_trace = max(cluster, key=lambda t: t["frame_count"]) centroid = best_trace.get("centroid") if centroid is None: continue matches = [] for t in tmdb_identities: if t["embedding"] is None: continue sim = cosine_similarity(centroid, t["embedding"]) if sim >= 0.55: # TMDb threshold matches.append({"id": t["id"], "name": t["name"], "similarity": float(sim)}) matches.sort(key=lambda m: m["similarity"], reverse=True) cluster_result = { "cluster_id": i, "trace_count": len(cluster), "total_frames": sum(t["frame_count"] for t in cluster), "trace_ids": [t["trace_id"] for t in cluster], "tmdb_matches": matches, "best_match": matches[0]["name"] if matches else None, "best_similarity": matches[0]["similarity"] if matches else 0, } results.append(cluster_result) return results def compute_metrics(clusters: List[dict], total_traces: int) -> dict: clustered = sum(c["trace_count"] for c in clusters) if "trace_count" in clusters[0] else sum(len(c) for c in clusters) return { "total_traces": total_traces, "clustered_traces": clustered, "cluster_count": len(clusters), "coverage": clustered / max(total_traces, 1), "avg_cluster_size": clustered / max(len(clusters), 1), "tmdb_matched": sum(1 for c in clusters if isinstance(c, dict) and c.get("best_match")), "tmdb_coverage": sum(1 for c in clusters if isinstance(c, dict) and c.get("best_match")) / max(len(clusters), 1), } def run_experiment(config: dict) -> dict: """Main experiment flow""" exp_id = config["id"] file_uuid = config.get("file_uuid", "1a04db97be5fa12bd77369831dc141fd") print(f"\n{'='*60}") print(f"Experiment {exp_id}: {config['name']}") print(f"{'='*60}") conn = get_conn() cur = conn.cursor() t0 = time.time() # Step 1: Fetch traces print(f"\n[1] Fetching traces (min_frames={config.get('min_frames', 30)})...") traces = fetch_trace_data(cur, file_uuid, config.get("min_frames", 30)) print(f" {len(traces)} traces loaded") # Step 2: Clustering method = config.get("clustering_method", "threshold") print(f"\n[2] Clustering: method={method}...") if method == "threshold": threshold = config.get("threshold", 0.85) adaptive = config.get("adaptive_threshold", False) clusters = cluster_by_threshold(traces, threshold, adaptive) elif method == "dbscan": eps = config.get("eps", 0.3) min_samples = config.get("min_samples", 2) clusters = cluster_dbscan(traces, eps, min_samples) else: clusters = cluster_by_threshold(traces, 0.85, True) clustered_traces = sum(len(c) for c in clusters) print(f" {len(clusters)} clusters, {clustered_traces} traces clustered") # Step 3: Speaker verification (mandatory — standard step) print(f"\n[3] Speaker verification...") speaker_overlaps = fetch_speaker_overlaps(cur, file_uuid) # Convert raw clusters to label dicts labels = [ { "cluster_id": i, "trace_count": len(c), "trace_ids": [t["trace_id"] for t in c], "tmdb_matches": [], "best_match": None, } for i, c in enumerate(clusters) ] labels = verify_with_speakers(labels, speaker_overlaps) matched_speakers = sum(1 for l in labels if l.get("dominant_speaker")) merged = sum(1 for l in labels if l.get("merge_reason")) print(f" {matched_speakers} clusters have speaker match, {merged} merged by speaker") # Step 4: TMDb matching (optional) if config.get("enable_tmdb", False): print(f"\n[4] TMDb matching...") tmdb = fetch_tmdb_identities(cur) print(f" {len(tmdb)} TMDb identities loaded") labels = match_tmdb(labels if labels else clusters, tmdb) matched = sum(1 for l in labels if l["best_match"]) print(f" {matched} clusters matched to TMDb") # Step 5: Metrics metrics = compute_metrics(labels if labels else clusters, len(traces)) metrics["execution_time_s"] = time.time() - t0 cur.close() conn.close() # Step 5: Save results result_dir = os.path.join(EXPERIMENT_DIR, "results", f"exp_{exp_id}") os.makedirs(result_dir, exist_ok=True) with open(os.path.join(result_dir, "clusters.json"), "w") as f: json.dump(clusters if not labels else labels, f, indent=2, ensure_ascii=False) with open(os.path.join(result_dir, "labels.json"), "w") as f: json.dump(labels, f, indent=2, ensure_ascii=False) with open(os.path.join(result_dir, "metrics.json"), "w") as f: json.dump(metrics, f, indent=2, ensure_ascii=False) with open(os.path.join(result_dir, "config.json"), "w") as f: json.dump(config, f, indent=2, ensure_ascii=False) # Summary summary = f""" Experiment {exp_id}: {config['name']} ==================================== Date: {datetime.now().isoformat()} Config: {json.dumps(config, indent=2)} Results: Traces loaded: {len(traces)} Clusters: {len(clusters)} Clustered traces: {clustered_traces} Coverage: {metrics['coverage']:.1%} Avg cluster size: {metrics['avg_cluster_size']:.1f} TMDb matched: {metrics.get('tmdb_matched', 0)} Execution time: {metrics['execution_time_s']:.1f}s Top clusters: """ sorted_labels = sorted(labels, key=lambda l: l.get("trace_count", 0), reverse=True) for l in sorted_labels[:10]: name = l.get("best_match", "unlabeled") summary += f" Cluster {l['cluster_id']}: {l['trace_count']} traces → {name} (sim={l.get('best_similarity', 0):.3f})\n" with open(os.path.join(result_dir, "summary.txt"), "w") as f: f.write(summary) print(f"\n[✓] Results saved to {result_dir}") print(summary) return metrics def main(): parser = argparse.ArgumentParser(description="Identity Clustering Experiment Runner") parser.add_argument("--config", required=True, help="Experiment config JSON") args = parser.parse_args() config = load_experiment_config(args.config) run_experiment(config) if __name__ == "__main__": main()