diff --git a/scripts/identity_matcher.py b/scripts/identity_matcher.py index 1e73e1c..a7d3858 100644 --- a/scripts/identity_matcher.py +++ b/scripts/identity_matcher.py @@ -38,6 +38,7 @@ from qdrant_faces import ( search_seeds, get_trace_centroid, ) +from tkg_helper import batch_mark_suggestions, batch_mark_strangers TH_ROUND_1 = 0.55 TH_ROUND_2 = 0.55 @@ -355,6 +356,7 @@ def main(): parser.add_argument("--identity-map", help="JSON file with {trace_id: {identity_id, uuid, name}} (for Round 2+)") parser.add_argument("--output", help="Output JSON file path") parser.add_argument("--stranger", action="store_true", help="Also run stranger clustering") + parser.add_argument("--mark-tkg", action="store_true", help="Mark TKG face_track nodes with suggestions") args = parser.parse_args() if args.round == 1: @@ -396,6 +398,15 @@ def main(): stranger_clusters = cluster_strangers(args.file_uuid, matched_traces) result["stranger_clusters"] = stranger_clusters + # Mark TKG nodes if requested + if args.mark_tkg: + tkg_updated = batch_mark_suggestions(args.file_uuid, suggestions) + result["tkg_nodes_updated"] = tkg_updated + + if args.stranger and stranger_clusters: + tkg_strangers = batch_mark_strangers(args.file_uuid, stranger_clusters) + result["tkg_strangers_updated"] = tkg_strangers + output_json = json.dumps(result, indent=2, ensure_ascii=False) if args.output: diff --git a/scripts/utils/tkg_helper.py b/scripts/utils/tkg_helper.py new file mode 100644 index 0000000..adcaa15 --- /dev/null +++ b/scripts/utils/tkg_helper.py @@ -0,0 +1,421 @@ +#!/opt/homebrew/bin/python3.11 +""" +TKG Helper - PostgreSQL TKG node operations for Identity Agent + +Functions: +- mark_face_track_suggested(): Mark face_track node as 'suggested' +- mark_face_track_confirmed(): Mark face_track node as 'confirmed' +- mark_face_track_stranger(): Mark face_track node as 'stranger' +- get_face_track_nodes(): Get all face_track nodes for a file +- get_pending_face_tracks(): Get face_track nodes with status='pending' + +TKG face_track node properties schema: +{ + "trace_id": int, + "frame_count": int, + "start_frame": int, + "end_frame": int, + "avg_bbox": {...}, + "avg_pose": {...}, + + // Identity binding states + "status": "pending" | "suggested" | "confirmed" | "stranger", + "pending_identity_name": str | null, + "pending_identity_uuid": str | null, + "suggested_by": "tmdb" | "propagation" | "manual" | null, + "confidence": float, + + // Confirmed fields + "identity_uuid": str | null, + "identity_ref": str | null, + "stranger_ref": str | null +} +""" + +import os +import sys +import json +import psycopg2 +import psycopg2.extras +from typing import Optional, Dict, List + +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + +DB_URL = os.environ.get("DATABASE_URL", "postgresql://accusys@localhost:5432/momentry") +SCHEMA = os.environ.get("DATABASE_SCHEMA", "dev") + + +def get_conn(): + """Get PostgreSQL connection""" + return psycopg2.connect(DB_URL) + + +def table_name(table: str) -> str: + """Get schema-prefixed table name""" + if SCHEMA == "public": + return table + return f"{SCHEMA}.{table}" + + +def mark_face_track_suggested( + file_uuid: str, + trace_id: int, + identity_id: int, + identity_uuid: str, + name: str, + confidence: float, + suggested_by: str = "tmdb", +) -> bool: + """Mark face_track node as 'suggested' + + Args: + file_uuid: Video file UUID + trace_id: Face trace ID + identity_id: PG identity.id + identity_uuid: Identity UUID + name: Identity name + confidence: Matching confidence score + suggested_by: 'tmdb' | 'propagation' | 'manual' + + Returns: + True if successful + """ + conn = get_conn() + cur = conn.cursor() + + tkg_table = table_name("tkg_nodes") + external_id = f"face_track_{trace_id}" + + props = { + "status": "suggested", + "pending_identity_name": name, + "pending_identity_uuid": identity_uuid, + "pending_identity_id": identity_id, + "suggested_by": suggested_by, + "confidence": round(confidence, 4), + } + + try: + cur.execute( + f""" + UPDATE {tkg_table} + SET properties = properties || %s::jsonb + WHERE file_uuid = %s AND node_type = 'face_track' AND external_id = %s + """, + (json.dumps(props), file_uuid, external_id), + ) + conn.commit() + updated = cur.rowcount > 0 + if updated: + print(f"[TKG] Marked trace {trace_id} as suggested: {name} (confidence={confidence:.4f})") + return updated + except Exception as e: + print(f"[TKG] Error marking trace {trace_id}: {e}") + conn.rollback() + return False + finally: + cur.close() + conn.close() + + +def mark_face_track_confirmed( + file_uuid: str, + trace_id: int, + identity_id: int, + identity_uuid: str, + name: str, +) -> bool: + """Mark face_track node as 'confirmed' + + Args: + file_uuid: Video file UUID + trace_id: Face trace ID + identity_id: PG identity.id + identity_uuid: Identity UUID + name: Identity name + + Returns: + True if successful + """ + conn = get_conn() + cur = conn.cursor() + + tkg_table = table_name("tkg_nodes") + external_id = f"face_track_{trace_id}" + identity_ref = f"{file_uuid}:identity_{identity_id}" + + props = { + "status": "confirmed", + "identity_uuid": identity_uuid, + "identity_id": identity_id, + "identity_ref": identity_ref, + "identity_name": name, + } + + # Remove pending fields + remove_keys = ["pending_identity_name", "pending_identity_uuid", "pending_identity_id", "suggested_by", "confidence"] + + try: + # Build JSONB update: add new props, remove pending fields + props_json = json.dumps(props) + cur.execute( + f""" + UPDATE {tkg_table} + SET properties = (properties || %s::jsonb) + - 'pending_identity_name' - 'pending_identity_uuid' + - 'pending_identity_id' - 'suggested_by' - 'confidence' - 'stranger_ref' + WHERE file_uuid = %s AND node_type = 'face_track' AND external_id = %s + """, + (props_json, file_uuid, external_id), + ) + conn.commit() + updated = cur.rowcount > 0 + if updated: + print(f"[TKG] Marked trace {trace_id} as confirmed: {name}") + return updated + except Exception as e: + print(f"[TKG] Error confirming trace {trace_id}: {e}") + conn.rollback() + return False + finally: + cur.close() + conn.close() + + +def mark_face_track_stranger( + file_uuid: str, + trace_id: int, + stranger_cluster_id: int, +) -> bool: + """Mark face_track node as 'stranger' + + Args: + file_uuid: Video file UUID + trace_id: Face trace ID + stranger_cluster_id: Stranger cluster ID + + Returns: + True if successful + """ + conn = get_conn() + cur = conn.cursor() + + tkg_table = table_name("tkg_nodes") + external_id = f"face_track_{trace_id}" + stranger_ref = f"stranger_{stranger_cluster_id}" + + props = { + "status": "stranger", + "stranger_id": stranger_cluster_id, + "stranger_ref": stranger_ref, + } + + try: + cur.execute( + f""" + UPDATE {tkg_table} + SET properties = (properties || %s::jsonb) + - 'pending_identity_name' - 'pending_identity_uuid' + - 'pending_identity_id' - 'suggested_by' - 'confidence' + - 'identity_uuid' - 'identity_ref' + WHERE file_uuid = %s AND node_type = 'face_track' AND external_id = %s + """, + (json.dumps(props), file_uuid, external_id), + ) + conn.commit() + updated = cur.rowcount > 0 + if updated: + print(f"[TKG] Marked trace {trace_id} as stranger cluster {stranger_cluster_id}") + return updated + except Exception as e: + print(f"[TKG] Error marking stranger trace {trace_id}: {e}") + conn.rollback() + return False + finally: + cur.close() + conn.close() + + +def get_face_track_nodes(file_uuid: str) -> List[Dict]: + """Get all face_track nodes for a file + + Args: + file_uuid: Video file UUID + + Returns: + List of face_track nodes with properties + """ + conn = get_conn() + cur = conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) + + tkg_table = table_name("tkg_nodes") + + try: + cur.execute( + f""" + SELECT id, external_id, label, properties, created_at + FROM {tkg_table} + WHERE file_uuid = %s AND node_type = 'face_track' + ORDER BY external_id + """, + (file_uuid,), + ) + rows = cur.fetchall() + return [dict(row) for row in rows] + finally: + cur.close() + conn.close() + + +def get_pending_face_tracks(file_uuid: str) -> List[Dict]: + """Get face_track nodes with status='pending' or NULL status + + Args: + file_uuid: Video file UUID + + Returns: + List of pending face_track nodes + """ + conn = get_conn() + cur = conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) + + tkg_table = table_name("tkg_nodes") + + try: + cur.execute( + f""" + SELECT id, external_id, label, properties, created_at + FROM {tkg_table} + WHERE file_uuid = %s AND node_type = 'face_track' + AND (properties->>'status' IS NULL OR properties->>'status' = 'pending') + ORDER BY external_id + """, + (file_uuid,), + ) + rows = cur.fetchall() + return [dict(row) for row in rows] + finally: + cur.close() + conn.close() + + +def get_suggested_face_tracks(file_uuid: str) -> List[Dict]: + """Get face_track nodes with status='suggested' + + Args: + file_uuid: Video file UUID + + Returns: + List of suggested face_track nodes + """ + conn = get_conn() + cur = conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) + + tkg_table = table_name("tkg_nodes") + + try: + cur.execute( + f""" + SELECT id, external_id, label, properties, created_at + FROM {tkg_table} + WHERE file_uuid = %s AND node_type = 'face_track' + AND properties->>'status' = 'suggested' + ORDER BY external_id + """, + (file_uuid,), + ) + rows = cur.fetchall() + return [dict(row) for row in rows] + finally: + cur.close() + conn.close() + + +def clear_face_track_status(file_uuid: str, trace_id: int) -> bool: + """Clear identity binding status from face_track node + + Args: + file_uuid: Video file UUID + trace_id: Face trace ID + + Returns: + True if successful + """ + conn = get_conn() + cur = conn.cursor() + + tkg_table = table_name("tkg_nodes") + external_id = f"face_track_{trace_id}" + + try: + cur.execute( + f""" + UPDATE {tkg_table} + SET properties = properties + - 'status' - 'pending_identity_name' - 'pending_identity_uuid' + - 'pending_identity_id' - 'suggested_by' - 'confidence' + - 'identity_uuid' - 'identity_ref' - 'identity_id' - 'identity_name' + - 'stranger_id' - 'stranger_ref' + WHERE file_uuid = %s AND node_type = 'face_track' AND external_id = %s + """, + (file_uuid, external_id), + ) + conn.commit() + return cur.rowcount > 0 + except Exception as e: + print(f"[TKG] Error clearing trace {trace_id}: {e}") + conn.rollback() + return False + finally: + cur.close() + conn.close() + + +def batch_mark_suggestions(file_uuid: str, suggestions: Dict) -> int: + """Batch mark multiple face_track nodes as 'suggested' + + Args: + file_uuid: Video file UUID + suggestions: {trace_id: {identity_id, identity_uuid, name, score, suggested_by}} + + Returns: + Number of nodes updated + """ + updated = 0 + for trace_id_str, suggestion in suggestions.items(): + trace_id = int(trace_id_str) + success = mark_face_track_suggested( + file_uuid, + trace_id, + suggestion.get("identity_id"), + suggestion.get("identity_uuid"), + suggestion.get("name"), + suggestion.get("score", 0.0), + suggestion.get("suggested_by", "tmdb"), + ) + if success: + updated += 1 + + print(f"[TKG] Batch marked {updated}/{len(suggestions)} traces as suggested") + return updated + + +def batch_mark_strangers(file_uuid: str, stranger_clusters: Dict) -> int: + """Batch mark multiple face_track nodes as 'stranger' + + Args: + file_uuid: Video file UUID + stranger_clusters: {cluster_id: [trace_ids]} + + Returns: + Number of nodes updated + """ + updated = 0 + for cluster_id, trace_ids in stranger_clusters.items(): + for trace_id in trace_ids: + success = mark_face_track_stranger(file_uuid, trace_id, cluster_id) + if success: + updated += 1 + + print(f"[TKG] Batch marked {updated} traces as strangers in {len(stranger_clusters)} clusters") + return updated \ No newline at end of file