447 lines
15 KiB
Python
447 lines
15 KiB
Python
#!/opt/homebrew/bin/python3.11
|
|
"""
|
|
Identity Clustering Experiment Runner
|
|
|
|
Usage:
|
|
python runner.py --config configs/exp_001.json
|
|
|
|
Each experiment:
|
|
1. Reads config parameters
|
|
2. Fetches face trace data from DB
|
|
3. Runs clustering algorithm
|
|
4. Optionally matches against TMDb
|
|
5. Optionally verifies against speakers
|
|
6. Saves all results to experiments/identity_clustering/results/exp_{id}/
|
|
"""
|
|
|
|
import sys
|
|
import os
|
|
import json
|
|
import argparse
|
|
import time
|
|
import numpy as np
|
|
from datetime import datetime
|
|
from collections import defaultdict
|
|
from typing import Dict, List, Tuple, Optional
|
|
|
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../..", "scripts"))
|
|
|
|
# DB connection
|
|
import psycopg2
|
|
import psycopg2.extras
|
|
|
|
DB_URL = os.environ.get("DATABASE_URL", "postgresql://accusys@localhost:5432/momentry")
|
|
SCHEMA = "dev"
|
|
EXPERIMENT_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
|
|
|
|
def get_conn():
|
|
return psycopg2.connect(DB_URL)
|
|
|
|
|
|
def load_experiment_config(config_path: str) -> dict:
|
|
with open(config_path) as f:
|
|
return json.load(f)
|
|
|
|
|
|
def fetch_trace_data(cur, file_uuid: str, min_frames: int) -> List[dict]:
|
|
"""Fetch trace centroids + metadata from face_detections"""
|
|
sql = f"""
|
|
SELECT
|
|
trace_id,
|
|
COUNT(*) as frame_count,
|
|
MIN(frame_number) as start_frame,
|
|
MAX(frame_number) as end_frame,
|
|
AVG(x)::float as avg_x,
|
|
AVG(y)::float as avg_y,
|
|
AVG(width)::float as avg_w,
|
|
AVG(height)::float as avg_h,
|
|
AVG(confidence) as avg_confidence
|
|
FROM {SCHEMA}.face_detections
|
|
WHERE file_uuid = %s AND trace_id IS NOT NULL AND embedding IS NOT NULL
|
|
GROUP BY trace_id
|
|
HAVING COUNT(*) >= %s
|
|
ORDER BY trace_id
|
|
"""
|
|
cur.execute(sql, (file_uuid, min_frames))
|
|
rows = cur.fetchall()
|
|
|
|
traces = []
|
|
for row in rows:
|
|
# Get all embeddings for this trace
|
|
cur.execute(
|
|
f"SELECT embedding FROM {SCHEMA}.face_detections WHERE file_uuid=%s AND trace_id=%s AND embedding IS NOT NULL ORDER BY confidence DESC",
|
|
(file_uuid, row[0]),
|
|
)
|
|
embeddings = [np.array(r[0]) for r in cur.fetchall()]
|
|
|
|
centroid_method = "mean" # default, configurable
|
|
if centroid_method == "mean":
|
|
centroid = np.mean(embeddings, axis=0) if embeddings else None
|
|
elif centroid_method == "median":
|
|
centroid = np.median(embeddings, axis=0) if embeddings else None
|
|
else:
|
|
centroid = embeddings[0] if embeddings else None
|
|
|
|
traces.append(
|
|
{
|
|
"trace_id": row[0],
|
|
"frame_count": row[1],
|
|
"start_frame": row[2],
|
|
"end_frame": row[3],
|
|
"avg_bbox": {"x": row[4], "y": row[5], "w": row[6], "h": row[7]},
|
|
"avg_confidence": row[8],
|
|
"embedding_count": len(embeddings),
|
|
"centroid": centroid.tolist() if centroid is not None else None,
|
|
}
|
|
)
|
|
|
|
return traces
|
|
|
|
|
|
def cosine_similarity(a, b):
|
|
a, b = np.array(a), np.array(b)
|
|
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b) + 1e-10)
|
|
|
|
|
|
def cluster_by_threshold(
|
|
traces: List[dict], threshold: float, adaptive: bool = False
|
|
) -> List[dict]:
|
|
"""Simple threshold-based clustering"""
|
|
clusters = []
|
|
assigned = set()
|
|
|
|
for i, t1 in enumerate(traces):
|
|
if t1["trace_id"] in assigned:
|
|
continue
|
|
cluster = [t1]
|
|
assigned.add(t1["trace_id"])
|
|
|
|
for j, t2 in enumerate(traces):
|
|
if t2["trace_id"] in assigned or i == j:
|
|
continue
|
|
if t1["centroid"] is None or t2["centroid"] is None:
|
|
continue
|
|
|
|
sim = cosine_similarity(t1["centroid"], t2["centroid"])
|
|
th = threshold
|
|
if adaptive:
|
|
# Slightly relax threshold for profile angles
|
|
fc1, fc2 = t1["frame_count"], t2["frame_count"]
|
|
if fc1 < 60 or fc2 < 60:
|
|
th = threshold - 0.05 # relax for short traces
|
|
|
|
if sim >= th:
|
|
cluster.append(t2)
|
|
assigned.add(t2["trace_id"])
|
|
|
|
if len(cluster) >= 1:
|
|
clusters.append(cluster)
|
|
|
|
return clusters
|
|
|
|
|
|
def cluster_dbscan(
|
|
traces: List[dict], eps: float = 0.3, min_samples: int = 2
|
|
) -> List[dict]:
|
|
"""DBSCAN clustering on embeddings"""
|
|
from sklearn.cluster import DBSCAN
|
|
|
|
valid = [t for t in traces if t["centroid"] is not None]
|
|
X = np.array([t["centroid"] for t in valid])
|
|
|
|
# Cosine distance = 1 - cosine_similarity
|
|
clustering = DBSCAN(eps=eps, min_samples=min_samples, metric="cosine").fit(X)
|
|
labels = clustering.labels_
|
|
|
|
clusters_dict = defaultdict(list)
|
|
for i, label in enumerate(labels):
|
|
key = int(label) if label >= 0 else f"noise_{i}"
|
|
clusters_dict[key].append(valid[i])
|
|
|
|
return list(clusters_dict.values())
|
|
|
|
|
|
def fetch_tmdb_identities(cur) -> List[dict]:
|
|
"""Get TMDb identities with embeddings"""
|
|
cur.execute(
|
|
f"SELECT id, name, face_embedding FROM {SCHEMA}.identities WHERE source='tmdb' AND face_embedding IS NOT NULL"
|
|
)
|
|
return [
|
|
{"id": r[0], "name": r[1], "embedding": r[2]}
|
|
for r in cur.fetchall()
|
|
if r[2] is not None
|
|
]
|
|
|
|
|
|
def fetch_speaker_overlaps(cur, file_uuid: str) -> dict:
|
|
"""Get speaker-face trace overlap from TKG edges.
|
|
Returns {trace_id: {speaker_id: overlap_count}}"""
|
|
cur.execute(
|
|
f"""
|
|
SELECT
|
|
REPLACE(n.external_id, 'trace_', '')::int as trace_id,
|
|
n2.external_id as speaker_id,
|
|
(e.properties->>'overlap_ratio')::float as overlap_ratio
|
|
FROM {SCHEMA}.tkg_edges e
|
|
JOIN {SCHEMA}.tkg_nodes n ON e.source_node_id = n.id
|
|
JOIN {SCHEMA}.tkg_nodes n2 ON e.target_node_id = n2.id
|
|
WHERE e.edge_type = 'SPEAKS_AS'
|
|
AND n.node_type = 'face_trace'
|
|
AND n2.node_type = 'speaker'
|
|
AND e.file_uuid = %s
|
|
""",
|
|
(file_uuid,),
|
|
)
|
|
overlaps = defaultdict(lambda: defaultdict(float))
|
|
for row in cur.fetchall():
|
|
trace_id, speaker_id, ratio = row[0], row[1], row[2] or 0
|
|
if trace_id is None or speaker_id is None:
|
|
continue
|
|
overlaps[int(trace_id)][speaker_id] = float(ratio)
|
|
return dict(overlaps)
|
|
|
|
|
|
def verify_with_speakers(
|
|
clusters: List[dict], speaker_overlaps: dict
|
|
) -> List[dict]:
|
|
"""Annotate clusters with dominant speaker from time overlap"""
|
|
for cluster in clusters:
|
|
# Collect all speaker overlaps for traces in this cluster
|
|
speaker_votes = defaultdict(float)
|
|
trace_ids = cluster.get("trace_ids", [])
|
|
if not trace_ids:
|
|
# Raw cluster list
|
|
trace_ids = [t["trace_id"] for t in cluster]
|
|
|
|
for tid in trace_ids:
|
|
if tid in speaker_overlaps:
|
|
for spk, ratio in speaker_overlaps[tid].items():
|
|
speaker_votes[spk] += ratio
|
|
|
|
if speaker_votes:
|
|
best_speaker = max(speaker_votes, key=speaker_votes.get)
|
|
best_score = speaker_votes[best_speaker]
|
|
cluster["dominant_speaker"] = best_speaker
|
|
cluster["speaker_overlap_score"] = round(best_score, 3)
|
|
cluster["speaker_votes"] = dict(speaker_votes)
|
|
else:
|
|
cluster["dominant_speaker"] = None
|
|
cluster["speaker_overlap_score"] = 0
|
|
cluster["speaker_votes"] = {}
|
|
|
|
# Merge clusters that share dominant speaker (high overlap with same speaker)
|
|
speaker_clusters = defaultdict(list)
|
|
for i, cluster in enumerate(clusters):
|
|
spk = cluster.get("dominant_speaker")
|
|
if spk and cluster.get("speaker_overlap_score", 0) > 0.5:
|
|
speaker_clusters[spk].append(i)
|
|
|
|
merged = set()
|
|
new_clusters = []
|
|
for spk, indices in speaker_clusters.items():
|
|
if len(indices) <= 1:
|
|
continue
|
|
# Merge all clusters belonging to same speaker
|
|
merged_group = []
|
|
for idx in indices:
|
|
merged_group.extend(
|
|
clusters[idx].get("trace_ids", []) or [t["trace_id"] for t in clusters[idx]]
|
|
)
|
|
merged.add(idx)
|
|
new_clusters.append({
|
|
"merged_from": indices,
|
|
"trace_ids": list(set(merged_group)),
|
|
"trace_count": len(set(merged_group)),
|
|
"dominant_speaker": spk,
|
|
"merge_reason": "shared_dominant_speaker",
|
|
})
|
|
|
|
# Keep unmerged clusters
|
|
for i, cluster in enumerate(clusters):
|
|
if i not in merged:
|
|
new_clusters.append(cluster)
|
|
|
|
return new_clusters
|
|
|
|
|
|
def match_tmdb(clusters: List[dict], tmdb_identities: List[dict]) -> List[dict]:
|
|
"""Match each cluster to best TMDb identity"""
|
|
results = []
|
|
for i, cluster in enumerate(clusters):
|
|
if len(cluster) == 0:
|
|
continue
|
|
# Use the trace with most frames as representative
|
|
best_trace = max(cluster, key=lambda t: t["frame_count"])
|
|
centroid = best_trace.get("centroid")
|
|
if centroid is None:
|
|
continue
|
|
|
|
matches = []
|
|
for t in tmdb_identities:
|
|
if t["embedding"] is None:
|
|
continue
|
|
sim = cosine_similarity(centroid, t["embedding"])
|
|
if sim >= 0.55: # TMDb threshold
|
|
matches.append({"id": t["id"], "name": t["name"], "similarity": float(sim)})
|
|
|
|
matches.sort(key=lambda m: m["similarity"], reverse=True)
|
|
|
|
cluster_result = {
|
|
"cluster_id": i,
|
|
"trace_count": len(cluster),
|
|
"total_frames": sum(t["frame_count"] for t in cluster),
|
|
"trace_ids": [t["trace_id"] for t in cluster],
|
|
"tmdb_matches": matches,
|
|
"best_match": matches[0]["name"] if matches else None,
|
|
"best_similarity": matches[0]["similarity"] if matches else 0,
|
|
}
|
|
results.append(cluster_result)
|
|
|
|
return results
|
|
|
|
|
|
def compute_metrics(clusters: List[dict], total_traces: int) -> dict:
|
|
clustered = sum(c["trace_count"] for c in clusters) if "trace_count" in clusters[0] else sum(len(c) for c in clusters)
|
|
return {
|
|
"total_traces": total_traces,
|
|
"clustered_traces": clustered,
|
|
"cluster_count": len(clusters),
|
|
"coverage": clustered / max(total_traces, 1),
|
|
"avg_cluster_size": clustered / max(len(clusters), 1),
|
|
"tmdb_matched": sum(1 for c in clusters if isinstance(c, dict) and c.get("best_match")),
|
|
"tmdb_coverage": sum(1 for c in clusters if isinstance(c, dict) and c.get("best_match")) / max(len(clusters), 1),
|
|
}
|
|
|
|
|
|
def run_experiment(config: dict) -> dict:
|
|
"""Main experiment flow"""
|
|
exp_id = config["id"]
|
|
file_uuid = config.get("file_uuid", "1a04db97be5fa12bd77369831dc141fd")
|
|
print(f"\n{'='*60}")
|
|
print(f"Experiment {exp_id}: {config['name']}")
|
|
print(f"{'='*60}")
|
|
|
|
conn = get_conn()
|
|
cur = conn.cursor()
|
|
|
|
t0 = time.time()
|
|
|
|
# Step 1: Fetch traces
|
|
print(f"\n[1] Fetching traces (min_frames={config.get('min_frames', 30)})...")
|
|
traces = fetch_trace_data(cur, file_uuid, config.get("min_frames", 30))
|
|
print(f" {len(traces)} traces loaded")
|
|
|
|
# Step 2: Clustering
|
|
method = config.get("clustering_method", "threshold")
|
|
print(f"\n[2] Clustering: method={method}...")
|
|
|
|
if method == "threshold":
|
|
threshold = config.get("threshold", 0.85)
|
|
adaptive = config.get("adaptive_threshold", False)
|
|
clusters = cluster_by_threshold(traces, threshold, adaptive)
|
|
elif method == "dbscan":
|
|
eps = config.get("eps", 0.3)
|
|
min_samples = config.get("min_samples", 2)
|
|
clusters = cluster_dbscan(traces, eps, min_samples)
|
|
else:
|
|
clusters = cluster_by_threshold(traces, 0.85, True)
|
|
|
|
clustered_traces = sum(len(c) for c in clusters)
|
|
print(f" {len(clusters)} clusters, {clustered_traces} traces clustered")
|
|
|
|
# Step 3: Speaker verification (mandatory — standard step)
|
|
print(f"\n[3] Speaker verification...")
|
|
speaker_overlaps = fetch_speaker_overlaps(cur, file_uuid)
|
|
# Convert raw clusters to label dicts
|
|
labels = [
|
|
{
|
|
"cluster_id": i,
|
|
"trace_count": len(c),
|
|
"trace_ids": [t["trace_id"] for t in c],
|
|
"tmdb_matches": [],
|
|
"best_match": None,
|
|
}
|
|
for i, c in enumerate(clusters)
|
|
]
|
|
labels = verify_with_speakers(labels, speaker_overlaps)
|
|
matched_speakers = sum(1 for l in labels if l.get("dominant_speaker"))
|
|
merged = sum(1 for l in labels if l.get("merge_reason"))
|
|
print(f" {matched_speakers} clusters have speaker match, {merged} merged by speaker")
|
|
|
|
# Step 4: TMDb matching (optional)
|
|
if config.get("enable_tmdb", False):
|
|
print(f"\n[4] TMDb matching...")
|
|
tmdb = fetch_tmdb_identities(cur)
|
|
print(f" {len(tmdb)} TMDb identities loaded")
|
|
labels = match_tmdb(labels if labels else clusters, tmdb)
|
|
matched = sum(1 for l in labels if l["best_match"])
|
|
print(f" {matched} clusters matched to TMDb")
|
|
|
|
# Step 5: Metrics
|
|
metrics = compute_metrics(labels if labels else clusters, len(traces))
|
|
metrics["execution_time_s"] = time.time() - t0
|
|
|
|
cur.close()
|
|
conn.close()
|
|
|
|
# Step 5: Save results
|
|
result_dir = os.path.join(EXPERIMENT_DIR, "results", f"exp_{exp_id}")
|
|
os.makedirs(result_dir, exist_ok=True)
|
|
|
|
with open(os.path.join(result_dir, "clusters.json"), "w") as f:
|
|
json.dump(clusters if not labels else labels, f, indent=2, ensure_ascii=False)
|
|
|
|
with open(os.path.join(result_dir, "labels.json"), "w") as f:
|
|
json.dump(labels, f, indent=2, ensure_ascii=False)
|
|
|
|
with open(os.path.join(result_dir, "metrics.json"), "w") as f:
|
|
json.dump(metrics, f, indent=2, ensure_ascii=False)
|
|
|
|
with open(os.path.join(result_dir, "config.json"), "w") as f:
|
|
json.dump(config, f, indent=2, ensure_ascii=False)
|
|
|
|
# Summary
|
|
summary = f"""
|
|
Experiment {exp_id}: {config['name']}
|
|
====================================
|
|
Date: {datetime.now().isoformat()}
|
|
Config: {json.dumps(config, indent=2)}
|
|
|
|
Results:
|
|
Traces loaded: {len(traces)}
|
|
Clusters: {len(clusters)}
|
|
Clustered traces: {clustered_traces}
|
|
Coverage: {metrics['coverage']:.1%}
|
|
Avg cluster size: {metrics['avg_cluster_size']:.1f}
|
|
TMDb matched: {metrics.get('tmdb_matched', 0)}
|
|
Execution time: {metrics['execution_time_s']:.1f}s
|
|
|
|
Top clusters:
|
|
"""
|
|
sorted_labels = sorted(labels, key=lambda l: l.get("trace_count", 0), reverse=True)
|
|
for l in sorted_labels[:10]:
|
|
name = l.get("best_match", "unlabeled")
|
|
summary += f" Cluster {l['cluster_id']}: {l['trace_count']} traces → {name} (sim={l.get('best_similarity', 0):.3f})\n"
|
|
|
|
with open(os.path.join(result_dir, "summary.txt"), "w") as f:
|
|
f.write(summary)
|
|
|
|
print(f"\n[✓] Results saved to {result_dir}")
|
|
print(summary)
|
|
|
|
return metrics
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Identity Clustering Experiment Runner")
|
|
parser.add_argument("--config", required=True, help="Experiment config JSON")
|
|
args = parser.parse_args()
|
|
|
|
config = load_experiment_config(args.config)
|
|
run_experiment(config)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|