Files
momentry_core/experiments/identity_clustering/runner_v2.py

432 lines
16 KiB
Python

#!/opt/homebrew/bin/python3.11
"""
Multi-Stage Identity Clustering Runner
Stage 1: High-confidence face-level matching
- Compare ALL face embeddings in each trace against identity references
- Bind trace to identity if >90% of faces match with >0.90 similarity
- These become "anchors" for Stage 2
Stage 2: Trace centroid clustering of remaining unbounded traces
- Use centroid of unbound traces, cluster with adaptive threshold
- Merge clusters with speaker overlap verification
Stage 3 (optional): TMDb matching
"""
import sys, os, json, argparse, time, numpy as np
from datetime import datetime
from collections import defaultdict
from typing import Dict, List, Tuple, Optional
import psycopg2
DB_URL = os.environ.get("DATABASE_URL", "postgresql://accusys@localhost:5432/momentry")
SCHEMA = "dev"
EXPERIMENT_DIR = os.path.dirname(os.path.abspath(__file__))
def get_conn(): return psycopg2.connect(DB_URL)
def cosine_similarity(a, b):
a, b = np.array(a), np.array(b)
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b) + 1e-10)
def parse_pg_array(val):
"""Parse PostgreSQL real[] array — returns numpy float64 array or None"""
if val is None: return None
if isinstance(val, np.ndarray): return val.astype(np.float64)
if isinstance(val, list): return np.array(val, dtype=np.float64)
if isinstance(val, str):
s = val.strip('[]{}')
if not s: return None
return np.fromstring(s, sep=',').astype(np.float64)
return None
def fetch_trace_with_faces(cur, file_uuid: str, min_frames: int) -> List[dict]:
"""Fetch traces with ALL their individual face embeddings"""
# Get trace summaries
cur.execute(
f"""
SELECT trace_id, COUNT(*) as fc, MIN(frame_number), MAX(frame_number),
AVG(x::float), AVG(y::float), AVG(width::float), AVG(height::float)
FROM {SCHEMA}.face_detections
WHERE file_uuid=%s AND trace_id IS NOT NULL AND embedding IS NOT NULL
GROUP BY trace_id HAVING COUNT(*)>=%s ORDER BY trace_id
""", (file_uuid, min_frames))
traces = []
for row in cur.fetchall():
tid = row[0]
cur.execute(
f"SELECT embedding FROM {SCHEMA}.face_detections WHERE file_uuid=%s AND trace_id=%s AND embedding IS NOT NULL ORDER BY confidence DESC",
(file_uuid, tid))
faces = []
for r in cur.fetchall():
emb = parse_pg_array(r[0])
if emb is not None:
faces.append({"embedding": emb.astype(np.float64)})
traces.append({
"trace_id": tid, "frame_count": row[1],
"start_frame": row[2], "end_frame": row[3],
"avg_bbox": {"x": row[4], "y": row[5], "w": row[6], "h": row[7]},
"faces": faces,
"centroid": np.mean([f["embedding"] for f in faces], axis=0).tolist() if faces else None,
})
return traces
def fetch_speaker_overlaps(cur, file_uuid: str) -> dict:
cur.execute(f"""
SELECT REPLACE(n.external_id,'trace_','')::int, n2.external_id,
(e.properties->>'overlap_ratio')::float
FROM {SCHEMA}.tkg_edges e
JOIN {SCHEMA}.tkg_nodes n ON e.source_node_id=n.id
JOIN {SCHEMA}.tkg_nodes n2 ON e.target_node_id=n2.id
WHERE e.edge_type='SPEAKS_AS' AND n.node_type='face_trace' AND n2.node_type='speaker' AND e.file_uuid=%s
""", (file_uuid,))
overlaps = defaultdict(lambda: defaultdict(float))
for tid, spk, ratio in cur.fetchall():
if tid and spk: overlaps[int(tid)][spk] = float(ratio or 0)
return dict(overlaps)
def fetch_identity_references(cur) -> List[dict]:
"""Get registered identities with face embeddings as references"""
cur.execute(f"SELECT id, name, face_embedding FROM {SCHEMA}.identities WHERE face_embedding IS NOT NULL")
results = []
for r in cur.fetchall():
emb = parse_pg_array(r[2])
if emb is None: continue
results.append({"id": r[0], "name": r[1], "embedding": emb.astype(np.float64)})
return results
# ===== STAGE 1: High-confidence face-level matching =====
def stage1_high_confidence_binding(
traces: List[dict], identities: List[dict],
face_match_threshold: float = 0.92,
trace_bind_ratio: float = 0.85,
) -> Tuple[List[dict], List[dict]]:
"""
For each trace, compare EVERY face against EVERY identity.
Bind trace to identity if >trace_bind_ratio% of faces match with >face_match_threshold.
Returns (bound_traces, unbound_traces)
"""
bound = []
unbound = []
for trace in traces:
faces = trace.get("faces", [])
if not faces:
unbound.append(trace)
continue
best_identity = None
best_match_count = 0
for ident in identities:
match_count = 0
for face in faces:
sim = cosine_similarity(face["embedding"], ident["embedding"])
if sim >= face_match_threshold:
match_count += 1
ratio = match_count / len(faces)
if ratio >= trace_bind_ratio and match_count > best_match_count:
best_match_count = match_count
best_identity = {
"id": ident["id"],
"name": ident["name"],
"match_ratio": round(ratio, 3),
"matched_faces": match_count,
"total_faces": len(faces),
}
if best_identity:
trace["binding"] = best_identity
trace["binding_stage"] = "stage1_face_level"
bound.append(trace)
else:
unbound.append(trace)
return bound, unbound
# ===== STAGE 2: Centroid clustering of unbound traces =====
def stage2_cluster_unbound(
traces: List[dict], threshold: float, adaptive: bool = False
) -> List[dict]:
"""Cluster unbound traces by centroid similarity + speaker verify"""
clusters = []
assigned = set()
for i, t1 in enumerate(traces):
if t1["trace_id"] in assigned: continue
cluster = [t1]; assigned.add(t1["trace_id"])
for j, t2 in enumerate(traces):
if t2["trace_id"] in assigned or i == j: continue
if t1["centroid"] is None or t2["centroid"] is None: continue
sim = cosine_similarity(t1["centroid"], t2["centroid"])
th = threshold
if adaptive and (t1["frame_count"] < 10 or t2["frame_count"] < 10):
th -= 0.05
if sim >= th:
cluster.append(t2); assigned.add(t2["trace_id"])
clusters.append(cluster)
return clusters
def apply_speaker_verification(clusters: List[dict], speaker_overlaps: dict) -> List[dict]:
"""Label clusters with speaker + merge same-speaker clusters"""
labels = []
for i, cluster in enumerate(clusters):
trace_ids = [t["trace_id"] for t in cluster]
votes = defaultdict(float)
for tid in trace_ids:
if tid in speaker_overlaps:
for spk, r in speaker_overlaps[tid].items():
votes[spk] += r
best_spk = max(votes, key=votes.get) if votes else None
labels.append({
"cluster_id": i, "trace_count": len(cluster),
"trace_ids": trace_ids,
"dominant_speaker": best_spk,
"speaker_score": round(votes.get(best_spk, 0), 3) if best_spk else 0,
"binding": cluster[0].get("binding"),
"binding_stage": cluster[0].get("binding_stage"),
})
return labels
# ===== Main Experiment =====
def run_experiment(config: dict) -> dict:
exp_id = config["id"]; file_uuid = config.get("file_uuid", "")
conn = get_conn(); cur = conn.cursor()
t0 = time.time()
out = lambda *a: None # noqa
# Load data
traces = fetch_trace_with_faces(cur, file_uuid, config.get("min_frames", 3))
identities = fetch_identity_references(cur) if config.get("enable_identity_match", True) else []
speaker_overlaps = fetch_speaker_overlaps(cur, file_uuid)
print(f"Traces: {len(traces)}, Identities: {len(identities)}, Speaker edges: {len(speaker_overlaps)}")
# Stage 1: TMDb-based first-pass binding (relaxed threshold)
bound, unbound = [], traces
if identities:
bound, unbound = stage1_high_confidence_binding(
traces, identities,
config.get("stage1_face_threshold", 0.55),
config.get("stage1_bind_ratio", 0.60),
)
print(f"Stage 1 (TMDb): {len(bound)} traces bound, {len(unbound)} unbound")
# Stage 1b+2: Iterative enrichment — each bound trace adds 3 best faces as references
if bound and identities and unbound:
# Build initial reference sets from Stage 1 bound traces
# For each identity, collect top-3 confidence faces from each bound trace
identity_refs = {} # identity_id -> list of reference embeddings
for t in bound:
b = t.get("binding", {})
iid = b.get("id") if isinstance(b, dict) else None
if not iid or not t.get("faces"): continue
if iid not in identity_refs:
identity_refs[iid] = []
# Sample 3 best faces from this trace (top confidence = best quality)
faces = t["faces"]
n_sample = min(3, len(faces))
for f in faces[:n_sample]:
identity_refs[iid].append(f["embedding"])
# Build identity lookup
id_to_name = {ident["id"]: ident["name"] for ident in identities}
for iid, refs in identity_refs.items():
print(f" {id_to_name.get(iid, '?'):<20} {len(refs)} reference faces (multi-angle sampling)")
# Speaker segment counts for weighting
speaker_counts = defaultdict(float)
for tid, spks in speaker_overlaps.items():
speaker_counts[tid] = sum(spks.values())
# Iterative matching with growing reference set
round_num = 0
while True:
round_num += 1
bound_this_round = []
for t in unbound:
best_score = 0
best_iid = None
best_sim = 0
best_match_count = 0
for iid, refs in identity_refs.items():
faces = t.get("faces", [])
if not faces: continue
# Compare each face against ALL references, take max per face
face_sims = []
for face in faces:
max_sim = max(
cosine_similarity(face["embedding"], ref) for ref in refs
)
face_sims.append(max_sim)
avg_sim = np.mean(face_sims) if face_sims else 0
match_ratio = sum(1 for s in face_sims if s >= config.get("stage1_face_threshold", 0.55)) / len(face_sims)
# Composite score: similarity + match ratio + speaker weight
spk_weight = 1.0 + 0.3 * speaker_counts.get(t["trace_id"], 0) / max(max(speaker_counts.values(), default=1), 1)
composite = avg_sim * spk_weight * (0.4 + 0.6 * match_ratio)
if composite > best_score and composite > 0.35:
best_score = composite
best_iid = iid
best_sim = avg_sim
best_match_count = sum(1 for s in face_sims if s >= 0.50)
if best_iid is not None:
t["binding"] = {
"id": best_iid, "name": id_to_name.get(best_iid, "?"),
"avg_similarity": round(best_sim, 3),
"match_ratio": round(best_match_count / max(len(t.get("faces", [])), 1), 3),
"composite_score": round(best_score, 3),
"source": f"video_ref_r{round_num}",
}
t["binding_stage"] = f"stage1b_r{round_num}"
bound_this_round.append(t)
bound.append(t)
if not bound_this_round:
break
# Enrich references: add 3 best faces from newly bound traces
for t in bound_this_round:
iid = t["binding"]["id"]
faces = t.get("faces", [])
n = min(3, len(faces))
for f in faces[:n]:
identity_refs[iid].append(f["embedding"])
# Remove from unbound
bound_ids = {t["trace_id"] for t in bound_this_round}
unbound = [t for t in unbound if t["trace_id"] not in bound_ids]
print(f" Round {round_num}: {len(bound_this_round)} traces bound, {len(unbound)} unbound")
clusters = stage2_cluster_unbound(
unbound,
config.get("stage2_threshold", 0.85),
config.get("stage2_adaptive", False),
)
print(f"Stage 2: {len(clusters)} clusters from {len(unbound)} unbound traces")
# Speaker verification
all_labels = apply_speaker_verification(clusters, speaker_overlaps)
# Merge Stage 1 bound traces into labels
for t in bound:
all_labels.append({
"cluster_id": len(all_labels),
"trace_count": 1,
"trace_ids": [t["trace_id"]],
"binding": t.get("binding"),
"binding_stage": "stage1_face_level",
"dominant_speaker": next(iter(speaker_overlaps.get(t["trace_id"], {}).keys()), None) if t["trace_id"] in speaker_overlaps else None,
})
# Metrics
metrics = {
"total_traces": len(traces),
"stage1_bound": len(bound),
"stage1_bound_traces": len(bound),
"stage2_clusters": len(clusters),
"stage2_unbound_clustered": sum(len(c) for c in clusters),
"total_clusters": len(all_labels),
"execution_time_s": time.time() - t0,
"coverage": (len(bound) + sum(len(c) for c in clusters)) / max(len(traces), 1),
}
for k, v in metrics.items():
print(f" {k}: {v}")
cur.close(); conn.close()
# --- Write bindings to database ---
if config.get("write_db", False):
conn2 = get_conn(); cur2 = conn2.cursor()
total_written = 0
for label in all_labels:
binding = label.get("binding")
if not binding: continue
identity_name = binding.get("name", "")
if not identity_name: continue
# Get or create identity
cur2.execute(f"SELECT id FROM {SCHEMA}.identities WHERE name=%s", (identity_name,))
row = cur2.fetchone()
if row:
identity_id = row[0]
else:
cur2.execute(
f"INSERT INTO {SCHEMA}.identities (name, identity_type, source, status) VALUES (%s,'people','auto','pending') RETURNING id",
(identity_name,))
identity_id = cur2.fetchone()[0]
# Bind all faces in each trace to the identity
for tid in label["trace_ids"]:
cur2.execute(
f"UPDATE {SCHEMA}.face_detections SET identity_id=%s WHERE file_uuid=%s AND trace_id=%s AND identity_id IS NULL",
(identity_id, file_uuid, tid))
affected = cur2.rowcount
if affected > 0:
# Write to identity_bindings for traceability
confidence = float(binding.get("avg_similarity", 0.8))
cur2.execute(
f"INSERT INTO {SCHEMA}.identity_bindings (identity_id, identity_type, identity_value, confidence) VALUES (%s,'trace',%s,%s) ON CONFLICT DO NOTHING",
(identity_id, str(tid), confidence))
total_written += affected
conn2.commit()
cur2.close(); conn2.close()
print(f"\nDB write: {total_written} face_detections updated")
# Save
result_dir = os.path.join(EXPERIMENT_DIR, "results", f"exp_{exp_id}")
os.makedirs(result_dir, exist_ok=True)
for name, data in [("labels.json", all_labels), ("metrics.json", metrics), ("config.json", config)]:
with open(os.path.join(result_dir, name), "w") as f:
json.dump(data, f, indent=2, ensure_ascii=False, default=str)
print(f"\nSaved to {result_dir}")
return metrics
def main():
p = argparse.ArgumentParser()
p.add_argument("--config", required=True)
p.add_argument("--write-db", action="store_true", help="Write bindings to database")
args = p.parse_args()
with open(args.config) as f: config = json.load(f)
if args.write_db:
config["write_db"] = True
run_experiment(config)
if __name__ == "__main__":
main()