feat: merge traces within same cut — centroid similarity threshold 0.75

This commit is contained in:
Accusys
2026-05-14 03:04:03 +08:00
parent 4e933a554c
commit 64bcfd716e

View File

@@ -19,6 +19,7 @@ import sys
import os import os
import json import json
import argparse import argparse
import numpy as np
import psycopg2 import psycopg2
import psycopg2.extras import psycopg2.extras
from datetime import datetime from datetime import datetime
@@ -36,6 +37,129 @@ def get_conn():
return psycopg2.connect(DB_URL) return psycopg2.connect(DB_URL)
def merge_traces_within_cuts(face_data: dict, cut_scenes: list) -> dict:
"""Merge traces within the same cut if they have similar embeddings (same person re-appeared)."""
frames = face_data.get("frames", {})
if not frames:
return face_data
# Map each frame to its scene/cut number
frame_to_scene = {}
for s in cut_scenes:
for f in range(s["start_frame"], s["end_frame"] + 1):
frame_to_scene[f] = s["scene_number"]
# Collect per-trace data: scene numbers, embeddings, face positions
trace_frames = defaultdict(list)
trace_embeddings = defaultdict(list)
trace_poses = {}
for fnum_str, frm_data in frames.items():
fnum = int(fnum_str)
for face in frm_data.get("faces", []):
tid = face.get("trace_id")
if tid is None:
continue
trace_frames[tid].append(fnum)
emb = face.get("embedding")
if emb is not None:
trace_embeddings[tid].append(emb)
if tid not in trace_poses:
trace_poses[tid] = (face.get("x", 0), face.get("y", 0),
face.get("width", 0), face.get("height", 0))
if len(trace_embeddings) < 2:
return face_data
# Compute centroid per trace
trace_centroids = {}
for tid, embs in trace_embeddings.items():
centroid = np.mean(embs, axis=0)
norm = np.linalg.norm(centroid)
trace_centroids[tid] = centroid / norm if norm > 0 else centroid
# Determine which scene each trace belongs to (majority of frames)
trace_scene = {}
for tid, fns in trace_frames.items():
scene_votes = defaultdict(int)
for fn in fns:
scene = frame_to_scene.get(fn, -1)
scene_votes[scene] += 1
trace_scene[tid] = max(scene_votes, key=scene_votes.get) if scene_votes else -1
# Within each scene, merge traces with similar centroids
scene_traces = defaultdict(list)
for tid, scene in trace_scene.items():
if scene >= 0 and tid in trace_centroids:
scene_traces[scene].append(tid)
merged = 0
next_new_id = max(trace_frames.keys()) + 1 if trace_frames else 0
SIMILARITY_THRESHOLD = 0.75
for scene, tids in scene_traces.items():
if len(tids) < 2:
continue
used = set()
for i in range(len(tids)):
if tids[i] in used:
continue
keep_tid = tids[i]
for j in range(i + 1, len(tids)):
if tids[j] in used:
continue
sim = float(np.dot(trace_centroids[tids[i]], trace_centroids[tids[j]]))
if sim >= SIMILARITY_THRESHOLD:
# Merge tids[j] into keep_tid
for fnum_str, frm_data in frames.items():
for face in frm_data.get("faces", []):
if face.get("trace_id") == tids[j]:
face["trace_id"] = keep_tid
used.add(tids[j])
merged += 1
# If any merges happened, rebuild trace metadata
if merged > 0:
# Rebuild traces dict
new_traces = {}
new_trace_frames = defaultdict(list)
for fnum_str, frm_data in frames.items():
fnum = int(fnum_str)
for face in frm_data.get("faces", []):
tid = face.get("trace_id")
if tid is not None:
new_trace_frames[tid].append({
"frame": fnum,
"face_index": 0,
"bbox": {"x": face.get("x", 0), "y": face.get("y", 0),
"width": face.get("width", 0), "height": face.get("height", 0)},
"confidence": face.get("confidence", 0.0),
})
for tid, path in new_trace_frames.items():
if len(path) >= 1:
frames_sorted = sorted(set(p["frame"] for p in path))
new_traces[str(tid)] = {
"trace_id": tid,
"start_frame": frames_sorted[0],
"end_frame": frames_sorted[-1],
"duration_frames": frames_sorted[-1] - frames_sorted[0] + 1,
"duration_seconds": (frames_sorted[-1] - frames_sorted[0]) / face_data.get("metadata", {}).get("fps", 25.0),
"total_appearances": len(path),
"path": path,
}
face_data["traces"] = new_traces
face_data["metadata"]["trace_stats"] = {
"total_traces": len(new_traces),
"active_traces": len(new_traces),
"long_traces": len([t for t in new_traces.values() if t["duration_frames"] >= 2]),
}
print(f"[TRACE] Post-merge: {merged} traces merged, {len(new_traces)} total traces")
return face_data
def run_face_tracker(face_json_path: str, traced_json_path: str) -> str: def run_face_tracker(face_json_path: str, traced_json_path: str) -> str:
"""Run face_tracker.py on face.json, returns path to face_traced.json""" """Run face_tracker.py on face.json, returns path to face_traced.json"""
from face_tracker import track_faces from face_tracker import track_faces
@@ -115,14 +239,21 @@ def run_face_tracker(face_json_path: str, traced_json_path: str) -> str:
# Load cut boundaries from cut.json (same directory as face.json) # Load cut boundaries from cut.json (same directory as face.json)
cut_boundaries = None cut_boundaries = None
cut_scenes = None
cuts_path = face_json_path.replace("_traced.json", ".cut.json").replace(".face.json", ".cut.json") cuts_path = face_json_path.replace("_traced.json", ".cut.json").replace(".face.json", ".cut.json")
if os.path.exists(cuts_path): if os.path.exists(cuts_path):
with open(cuts_path) as f: with open(cuts_path) as f:
cuts = json.load(f) cuts = json.load(f)
cut_boundaries = {s["start_frame"] for s in cuts.get("scenes", []) if s["start_frame"] > 0} cut_scenes = cuts.get("scenes", [])
cut_boundaries = {s["start_frame"] for s in cut_scenes if s["start_frame"] > 0}
print(f"[TRACE] Loaded {len(cut_boundaries)} cut boundaries") print(f"[TRACE] Loaded {len(cut_boundaries)} cut boundaries")
face_data = track_faces(face_data, use_embedding=True, cut_boundaries=cut_boundaries) face_data = track_faces(face_data, use_embedding=True, cut_boundaries=cut_boundaries)
# Merge traces within same cut (same person re-appearing after occlusion/pose change)
if cut_scenes and len(cut_scenes) > 0:
face_data = merge_traces_within_cuts(face_data, cut_scenes)
metadata = face_data.get("metadata", {}) metadata = face_data.get("metadata", {})
metadata["tracking_method"] = "iou_embedding" metadata["tracking_method"] = "iou_embedding"
metadata["tracked_at"] = datetime.now().isoformat() metadata["tracked_at"] = datetime.now().isoformat()