201 lines
6.8 KiB
Python
201 lines
6.8 KiB
Python
#!/opt/homebrew/bin/python3.11
|
|
"""
|
|
Match face_detections against TMDb identities via face embedding similarity.
|
|
Port of match_faces_against_tmdb from src/core/tmdb/face_agent.rs
|
|
|
|
Usage: python3 scripts/match_faces_to_tmdb.py <file_uuid> [--schema dev]
|
|
"""
|
|
|
|
import sys
|
|
import psycopg2
|
|
import psycopg2.extras
|
|
import numpy as np
|
|
from collections import defaultdict
|
|
|
|
DATABASE_URL = "postgres://accusys@localhost:5432/momentry"
|
|
THRESHOLD = 0.50
|
|
QC_MIN_FACES = 4 # Minimum faces per trace for QC
|
|
|
|
|
|
def cosine_similarity(a, b):
|
|
a = np.array(a, dtype=np.float64)
|
|
b = np.array(b, dtype=np.float64)
|
|
na = np.linalg.norm(a)
|
|
nb = np.linalg.norm(b)
|
|
if na == 0 or nb == 0:
|
|
return 0.0
|
|
return np.dot(a, b) / (na * nb)
|
|
|
|
|
|
def match_faces_to_tmdb(file_uuid: str, schema: str = "dev"):
|
|
conn = psycopg2.connect(DATABASE_URL)
|
|
cur = conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
|
|
|
|
# Step 1: Load TMDb identities with face embeddings
|
|
cur.execute(f"""
|
|
SELECT id, name, tmdb_id, face_embedding::real[] as embedding
|
|
FROM {schema}.identities
|
|
WHERE source = 'tmdb' AND face_embedding IS NOT NULL
|
|
""")
|
|
tmdb_identities = []
|
|
for row in cur.fetchall():
|
|
emb = row["embedding"]
|
|
if emb and len(emb) > 0:
|
|
tmdb_identities.append({
|
|
"id": row["id"],
|
|
"name": row["name"],
|
|
"tmdb_id": row["tmdb_id"],
|
|
"embedding": emb,
|
|
})
|
|
|
|
print(f"[TMDB-MATCH] Loaded {len(tmdb_identities)} TMDb identities")
|
|
|
|
if not tmdb_identities:
|
|
print("[TMDB-MATCH] No TMDb identities with embeddings")
|
|
cur.close()
|
|
conn.close()
|
|
return 0
|
|
|
|
# Step 2: Load face_detections with trace_id and embedding
|
|
cur.execute(f"""
|
|
SELECT id, trace_id, frame_number, embedding::real[] as embedding, confidence
|
|
FROM {schema}.face_detections
|
|
WHERE file_uuid = %s AND trace_id IS NOT NULL AND embedding IS NOT NULL
|
|
ORDER BY trace_id, frame_number
|
|
""", (file_uuid,))
|
|
|
|
fd_rows = cur.fetchall()
|
|
if not fd_rows:
|
|
print(f"[TMDB-MATCH] No face detections for {file_uuid}")
|
|
cur.close()
|
|
conn.close()
|
|
return 0
|
|
|
|
# Group by trace_id
|
|
trace_faces = defaultdict(list)
|
|
for row in fd_rows:
|
|
trace_id = row["trace_id"]
|
|
emb = row["embedding"]
|
|
if emb:
|
|
trace_faces[trace_id].append({
|
|
"id": row["id"],
|
|
"embedding": emb,
|
|
"frame": row["frame_number"],
|
|
"confidence": row["confidence"],
|
|
})
|
|
|
|
# Dedup near-identical embeddings within trace (sim > 0.99)
|
|
for tid, faces in trace_faces.items():
|
|
faces.sort(key=lambda x: x["embedding"][0])
|
|
unique = []
|
|
for f in faces:
|
|
if not unique or cosine_similarity(f["embedding"], unique[-1]["embedding"]) <= 0.99:
|
|
unique.append(f)
|
|
trace_faces[tid] = unique
|
|
|
|
total_traces = len(trace_faces)
|
|
total_faces = len(fd_rows)
|
|
print(f"[TMDB-MATCH] {total_traces} traces with {total_faces} faces")
|
|
|
|
# Step 3: Single-pass matching (one round only for performance)
|
|
matched = {} # trace_id → (identity_id, name)
|
|
|
|
# Build reference pool from TMDb seeds only
|
|
reference_pool = []
|
|
for tmdb in tmdb_identities:
|
|
reference_pool.append({
|
|
"embedding": tmdb["embedding"],
|
|
"identity_id": tmdb["id"],
|
|
"name": tmdb["name"],
|
|
})
|
|
|
|
print(f"[TMDB-MATCH] Matching {total_traces} traces against {len(reference_pool)} TMDb identities (threshold={THRESHOLD})")
|
|
|
|
# Match each trace against TMDb seeds
|
|
for tid, faces in trace_faces.items():
|
|
trace_scores = defaultdict(list)
|
|
for f in faces:
|
|
for ref in reference_pool:
|
|
sim = cosine_similarity(f["embedding"], ref["embedding"])
|
|
if sim >= THRESHOLD:
|
|
trace_scores[ref["identity_id"]].append((sim, ref["name"]))
|
|
|
|
if not trace_scores:
|
|
continue
|
|
|
|
# Select identity with highest aggregate score
|
|
best_identity = None
|
|
best_score = 0
|
|
best_name = None
|
|
|
|
for identity_id, scores in trace_scores.items():
|
|
avg_sim = np.mean([s[0] for s in scores])
|
|
if avg_sim > best_score:
|
|
best_score = avg_sim
|
|
best_identity = identity_id
|
|
best_name = scores[0][1]
|
|
|
|
if best_identity:
|
|
matched[tid] = (best_identity, best_name, best_score)
|
|
|
|
# Step 4: Quality Control - minimum faces per trace
|
|
qc_removed = 0
|
|
for tid, faces in trace_faces.items():
|
|
if tid in matched and len(faces) < QC_MIN_FACES:
|
|
del matched[tid]
|
|
qc_removed += 1
|
|
|
|
# Step 5: Temporal collision check
|
|
frame_identity_count = defaultdict(lambda: defaultdict(int))
|
|
for tid, faces in trace_faces.items():
|
|
if tid in matched:
|
|
identity_id = matched[tid][0]
|
|
for f in faces:
|
|
frame_identity_count[f["frame"]][identity_id] += 1
|
|
|
|
for frame, identity_counts in frame_identity_count.items():
|
|
for identity_id, count in identity_counts.items():
|
|
if count > 1:
|
|
conflicting = []
|
|
for tid, faces in trace_faces.items():
|
|
if tid in matched and matched[tid][0] == identity_id:
|
|
for f in faces:
|
|
if f["frame"] == frame:
|
|
conflicting.append((tid, f["confidence"]))
|
|
|
|
conflicting.sort(key=lambda x: x[1], reverse=True)
|
|
for tid, _ in conflicting[1:]:
|
|
if tid in matched:
|
|
del matched[tid]
|
|
qc_removed += 1
|
|
|
|
if qc_removed > 0:
|
|
print(f"[TMDB-MATCH] QC removed {qc_removed} traces")
|
|
|
|
# Step 6: Update face_detections.identity_id
|
|
bindings_created = 0
|
|
for tid, (identity_id, name, score) in matched.items():
|
|
for f in trace_faces[tid]:
|
|
cur.execute(f"""
|
|
UPDATE {schema}.face_detections
|
|
SET identity_id = %s
|
|
WHERE id = %s AND identity_id IS NULL
|
|
""", (identity_id, f["id"]))
|
|
bindings_created += cur.rowcount
|
|
|
|
conn.commit()
|
|
cur.close()
|
|
conn.close()
|
|
|
|
print(f"[TMDB-MATCH] {bindings_created} bindings created, {len(matched)} traces matched")
|
|
return bindings_created
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import argparse
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("file_uuid", help="Video file UUID")
|
|
parser.add_argument("--schema", default="dev", help="Database schema")
|
|
args = parser.parse_args()
|
|
|
|
match_faces_to_tmdb(args.file_uuid, args.schema) |