Files

447 lines
15 KiB
Python

#!/opt/homebrew/bin/python3.11
"""
Identity Clustering Experiment Runner
Usage:
python runner.py --config configs/exp_001.json
Each experiment:
1. Reads config parameters
2. Fetches face trace data from DB
3. Runs clustering algorithm
4. Optionally matches against TMDb
5. Optionally verifies against speakers
6. Saves all results to experiments/identity_clustering/results/exp_{id}/
"""
import sys
import os
import json
import argparse
import time
import numpy as np
from datetime import datetime
from collections import defaultdict
from typing import Dict, List, Tuple, Optional
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../..", "scripts"))
# DB connection
import psycopg2
import psycopg2.extras
DB_URL = os.environ.get("DATABASE_URL", "postgresql://accusys@localhost:5432/momentry")
SCHEMA = "dev"
EXPERIMENT_DIR = os.path.dirname(os.path.abspath(__file__))
def get_conn():
return psycopg2.connect(DB_URL)
def load_experiment_config(config_path: str) -> dict:
with open(config_path) as f:
return json.load(f)
def fetch_trace_data(cur, file_uuid: str, min_frames: int) -> List[dict]:
"""Fetch trace centroids + metadata from face_detections"""
sql = f"""
SELECT
trace_id,
COUNT(*) as frame_count,
MIN(frame_number) as start_frame,
MAX(frame_number) as end_frame,
AVG(x)::float as avg_x,
AVG(y)::float as avg_y,
AVG(width)::float as avg_w,
AVG(height)::float as avg_h,
AVG(confidence) as avg_confidence
FROM {SCHEMA}.face_detections
WHERE file_uuid = %s AND trace_id IS NOT NULL AND embedding IS NOT NULL
GROUP BY trace_id
HAVING COUNT(*) >= %s
ORDER BY trace_id
"""
cur.execute(sql, (file_uuid, min_frames))
rows = cur.fetchall()
traces = []
for row in rows:
# Get all embeddings for this trace
cur.execute(
f"SELECT embedding FROM {SCHEMA}.face_detections WHERE file_uuid=%s AND trace_id=%s AND embedding IS NOT NULL ORDER BY confidence DESC",
(file_uuid, row[0]),
)
embeddings = [np.array(r[0]) for r in cur.fetchall()]
centroid_method = "mean" # default, configurable
if centroid_method == "mean":
centroid = np.mean(embeddings, axis=0) if embeddings else None
elif centroid_method == "median":
centroid = np.median(embeddings, axis=0) if embeddings else None
else:
centroid = embeddings[0] if embeddings else None
traces.append(
{
"trace_id": row[0],
"frame_count": row[1],
"start_frame": row[2],
"end_frame": row[3],
"avg_bbox": {"x": row[4], "y": row[5], "w": row[6], "h": row[7]},
"avg_confidence": row[8],
"embedding_count": len(embeddings),
"centroid": centroid.tolist() if centroid is not None else None,
}
)
return traces
def cosine_similarity(a, b):
a, b = np.array(a), np.array(b)
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b) + 1e-10)
def cluster_by_threshold(
traces: List[dict], threshold: float, adaptive: bool = False
) -> List[dict]:
"""Simple threshold-based clustering"""
clusters = []
assigned = set()
for i, t1 in enumerate(traces):
if t1["trace_id"] in assigned:
continue
cluster = [t1]
assigned.add(t1["trace_id"])
for j, t2 in enumerate(traces):
if t2["trace_id"] in assigned or i == j:
continue
if t1["centroid"] is None or t2["centroid"] is None:
continue
sim = cosine_similarity(t1["centroid"], t2["centroid"])
th = threshold
if adaptive:
# Slightly relax threshold for profile angles
fc1, fc2 = t1["frame_count"], t2["frame_count"]
if fc1 < 60 or fc2 < 60:
th = threshold - 0.05 # relax for short traces
if sim >= th:
cluster.append(t2)
assigned.add(t2["trace_id"])
if len(cluster) >= 1:
clusters.append(cluster)
return clusters
def cluster_dbscan(
traces: List[dict], eps: float = 0.3, min_samples: int = 2
) -> List[dict]:
"""DBSCAN clustering on embeddings"""
from sklearn.cluster import DBSCAN
valid = [t for t in traces if t["centroid"] is not None]
X = np.array([t["centroid"] for t in valid])
# Cosine distance = 1 - cosine_similarity
clustering = DBSCAN(eps=eps, min_samples=min_samples, metric="cosine").fit(X)
labels = clustering.labels_
clusters_dict = defaultdict(list)
for i, label in enumerate(labels):
key = int(label) if label >= 0 else f"noise_{i}"
clusters_dict[key].append(valid[i])
return list(clusters_dict.values())
def fetch_tmdb_identities(cur) -> List[dict]:
"""Get TMDb identities with embeddings"""
cur.execute(
f"SELECT id, name, face_embedding FROM {SCHEMA}.identities WHERE source='tmdb' AND face_embedding IS NOT NULL"
)
return [
{"id": r[0], "name": r[1], "embedding": r[2]}
for r in cur.fetchall()
if r[2] is not None
]
def fetch_speaker_overlaps(cur, file_uuid: str) -> dict:
"""Get speaker-face trace overlap from TKG edges.
Returns {trace_id: {speaker_id: overlap_count}}"""
cur.execute(
f"""
SELECT
REPLACE(n.external_id, 'trace_', '')::int as trace_id,
n2.external_id as speaker_id,
(e.properties->>'overlap_ratio')::float as overlap_ratio
FROM {SCHEMA}.tkg_edges e
JOIN {SCHEMA}.tkg_nodes n ON e.source_node_id = n.id
JOIN {SCHEMA}.tkg_nodes n2 ON e.target_node_id = n2.id
WHERE e.edge_type = 'SPEAKS_AS'
AND n.node_type = 'face_trace'
AND n2.node_type = 'speaker'
AND e.file_uuid = %s
""",
(file_uuid,),
)
overlaps = defaultdict(lambda: defaultdict(float))
for row in cur.fetchall():
trace_id, speaker_id, ratio = row[0], row[1], row[2] or 0
if trace_id is None or speaker_id is None:
continue
overlaps[int(trace_id)][speaker_id] = float(ratio)
return dict(overlaps)
def verify_with_speakers(
clusters: List[dict], speaker_overlaps: dict
) -> List[dict]:
"""Annotate clusters with dominant speaker from time overlap"""
for cluster in clusters:
# Collect all speaker overlaps for traces in this cluster
speaker_votes = defaultdict(float)
trace_ids = cluster.get("trace_ids", [])
if not trace_ids:
# Raw cluster list
trace_ids = [t["trace_id"] for t in cluster]
for tid in trace_ids:
if tid in speaker_overlaps:
for spk, ratio in speaker_overlaps[tid].items():
speaker_votes[spk] += ratio
if speaker_votes:
best_speaker = max(speaker_votes, key=speaker_votes.get)
best_score = speaker_votes[best_speaker]
cluster["dominant_speaker"] = best_speaker
cluster["speaker_overlap_score"] = round(best_score, 3)
cluster["speaker_votes"] = dict(speaker_votes)
else:
cluster["dominant_speaker"] = None
cluster["speaker_overlap_score"] = 0
cluster["speaker_votes"] = {}
# Merge clusters that share dominant speaker (high overlap with same speaker)
speaker_clusters = defaultdict(list)
for i, cluster in enumerate(clusters):
spk = cluster.get("dominant_speaker")
if spk and cluster.get("speaker_overlap_score", 0) > 0.5:
speaker_clusters[spk].append(i)
merged = set()
new_clusters = []
for spk, indices in speaker_clusters.items():
if len(indices) <= 1:
continue
# Merge all clusters belonging to same speaker
merged_group = []
for idx in indices:
merged_group.extend(
clusters[idx].get("trace_ids", []) or [t["trace_id"] for t in clusters[idx]]
)
merged.add(idx)
new_clusters.append({
"merged_from": indices,
"trace_ids": list(set(merged_group)),
"trace_count": len(set(merged_group)),
"dominant_speaker": spk,
"merge_reason": "shared_dominant_speaker",
})
# Keep unmerged clusters
for i, cluster in enumerate(clusters):
if i not in merged:
new_clusters.append(cluster)
return new_clusters
def match_tmdb(clusters: List[dict], tmdb_identities: List[dict]) -> List[dict]:
"""Match each cluster to best TMDb identity"""
results = []
for i, cluster in enumerate(clusters):
if len(cluster) == 0:
continue
# Use the trace with most frames as representative
best_trace = max(cluster, key=lambda t: t["frame_count"])
centroid = best_trace.get("centroid")
if centroid is None:
continue
matches = []
for t in tmdb_identities:
if t["embedding"] is None:
continue
sim = cosine_similarity(centroid, t["embedding"])
if sim >= 0.55: # TMDb threshold
matches.append({"id": t["id"], "name": t["name"], "similarity": float(sim)})
matches.sort(key=lambda m: m["similarity"], reverse=True)
cluster_result = {
"cluster_id": i,
"trace_count": len(cluster),
"total_frames": sum(t["frame_count"] for t in cluster),
"trace_ids": [t["trace_id"] for t in cluster],
"tmdb_matches": matches,
"best_match": matches[0]["name"] if matches else None,
"best_similarity": matches[0]["similarity"] if matches else 0,
}
results.append(cluster_result)
return results
def compute_metrics(clusters: List[dict], total_traces: int) -> dict:
clustered = sum(c["trace_count"] for c in clusters) if "trace_count" in clusters[0] else sum(len(c) for c in clusters)
return {
"total_traces": total_traces,
"clustered_traces": clustered,
"cluster_count": len(clusters),
"coverage": clustered / max(total_traces, 1),
"avg_cluster_size": clustered / max(len(clusters), 1),
"tmdb_matched": sum(1 for c in clusters if isinstance(c, dict) and c.get("best_match")),
"tmdb_coverage": sum(1 for c in clusters if isinstance(c, dict) and c.get("best_match")) / max(len(clusters), 1),
}
def run_experiment(config: dict) -> dict:
"""Main experiment flow"""
exp_id = config["id"]
file_uuid = config.get("file_uuid", "1a04db97be5fa12bd77369831dc141fd")
print(f"\n{'='*60}")
print(f"Experiment {exp_id}: {config['name']}")
print(f"{'='*60}")
conn = get_conn()
cur = conn.cursor()
t0 = time.time()
# Step 1: Fetch traces
print(f"\n[1] Fetching traces (min_frames={config.get('min_frames', 30)})...")
traces = fetch_trace_data(cur, file_uuid, config.get("min_frames", 30))
print(f" {len(traces)} traces loaded")
# Step 2: Clustering
method = config.get("clustering_method", "threshold")
print(f"\n[2] Clustering: method={method}...")
if method == "threshold":
threshold = config.get("threshold", 0.85)
adaptive = config.get("adaptive_threshold", False)
clusters = cluster_by_threshold(traces, threshold, adaptive)
elif method == "dbscan":
eps = config.get("eps", 0.3)
min_samples = config.get("min_samples", 2)
clusters = cluster_dbscan(traces, eps, min_samples)
else:
clusters = cluster_by_threshold(traces, 0.85, True)
clustered_traces = sum(len(c) for c in clusters)
print(f" {len(clusters)} clusters, {clustered_traces} traces clustered")
# Step 3: Speaker verification (mandatory — standard step)
print(f"\n[3] Speaker verification...")
speaker_overlaps = fetch_speaker_overlaps(cur, file_uuid)
# Convert raw clusters to label dicts
labels = [
{
"cluster_id": i,
"trace_count": len(c),
"trace_ids": [t["trace_id"] for t in c],
"tmdb_matches": [],
"best_match": None,
}
for i, c in enumerate(clusters)
]
labels = verify_with_speakers(labels, speaker_overlaps)
matched_speakers = sum(1 for l in labels if l.get("dominant_speaker"))
merged = sum(1 for l in labels if l.get("merge_reason"))
print(f" {matched_speakers} clusters have speaker match, {merged} merged by speaker")
# Step 4: TMDb matching (optional)
if config.get("enable_tmdb", False):
print(f"\n[4] TMDb matching...")
tmdb = fetch_tmdb_identities(cur)
print(f" {len(tmdb)} TMDb identities loaded")
labels = match_tmdb(labels if labels else clusters, tmdb)
matched = sum(1 for l in labels if l["best_match"])
print(f" {matched} clusters matched to TMDb")
# Step 5: Metrics
metrics = compute_metrics(labels if labels else clusters, len(traces))
metrics["execution_time_s"] = time.time() - t0
cur.close()
conn.close()
# Step 5: Save results
result_dir = os.path.join(EXPERIMENT_DIR, "results", f"exp_{exp_id}")
os.makedirs(result_dir, exist_ok=True)
with open(os.path.join(result_dir, "clusters.json"), "w") as f:
json.dump(clusters if not labels else labels, f, indent=2, ensure_ascii=False)
with open(os.path.join(result_dir, "labels.json"), "w") as f:
json.dump(labels, f, indent=2, ensure_ascii=False)
with open(os.path.join(result_dir, "metrics.json"), "w") as f:
json.dump(metrics, f, indent=2, ensure_ascii=False)
with open(os.path.join(result_dir, "config.json"), "w") as f:
json.dump(config, f, indent=2, ensure_ascii=False)
# Summary
summary = f"""
Experiment {exp_id}: {config['name']}
====================================
Date: {datetime.now().isoformat()}
Config: {json.dumps(config, indent=2)}
Results:
Traces loaded: {len(traces)}
Clusters: {len(clusters)}
Clustered traces: {clustered_traces}
Coverage: {metrics['coverage']:.1%}
Avg cluster size: {metrics['avg_cluster_size']:.1f}
TMDb matched: {metrics.get('tmdb_matched', 0)}
Execution time: {metrics['execution_time_s']:.1f}s
Top clusters:
"""
sorted_labels = sorted(labels, key=lambda l: l.get("trace_count", 0), reverse=True)
for l in sorted_labels[:10]:
name = l.get("best_match", "unlabeled")
summary += f" Cluster {l['cluster_id']}: {l['trace_count']} traces → {name} (sim={l.get('best_similarity', 0):.3f})\n"
with open(os.path.join(result_dir, "summary.txt"), "w") as f:
f.write(summary)
print(f"\n[✓] Results saved to {result_dir}")
print(summary)
return metrics
def main():
parser = argparse.ArgumentParser(description="Identity Clustering Experiment Runner")
parser.add_argument("--config", required=True, help="Experiment config JSON")
args = parser.parse_args()
config = load_experiment_config(args.config)
run_experiment(config)
if __name__ == "__main__":
main()