169 lines
5.9 KiB
Python
169 lines
5.9 KiB
Python
"""Executor: Search API, download trace video, extract key frames"""
|
|
import json, subprocess, os, cv2, sys
|
|
from PIL import Image
|
|
|
|
API = "http://localhost:3003"
|
|
KEY = "muser_68600856036340bcafc01930eb4bd839_1774418104_97221b69"
|
|
|
|
FRAME_OUTPUT = "/tmp/qa"
|
|
os.makedirs(FRAME_OUTPUT, exist_ok=True)
|
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
|
sys.path.insert(0, os.path.dirname(__file__))
|
|
|
|
|
|
def find_trace_by_identity(actor_name, file_uuid):
|
|
"""Find a trace_id for a TMDB actor from the DB."""
|
|
import psycopg2
|
|
conn = psycopg2.connect("postgresql://accusys@localhost:5432/momentry")
|
|
cur = conn.cursor()
|
|
cur.execute("""
|
|
SELECT fd.trace_id, COUNT(*) as faces,
|
|
MIN(fd.frame_number) as start_f, MAX(fd.frame_number) as end_f
|
|
FROM dev.face_detections fd
|
|
JOIN dev.identities i ON i.id = fd.identity_id
|
|
WHERE i.name = %s AND fd.file_uuid = %s AND fd.trace_id IS NOT NULL
|
|
GROUP BY fd.trace_id
|
|
ORDER BY faces DESC LIMIT 1
|
|
""", (actor_name, file_uuid))
|
|
row = cur.fetchone()
|
|
cur.close()
|
|
conn.close()
|
|
if row:
|
|
return {"trace_id": row[0], "faces": row[1], "start_frame": row[2], "end_frame": row[3]}
|
|
return None
|
|
|
|
|
|
def find_trace_in_frame_range(start_frame, end_frame, file_uuid):
|
|
"""Find a trace that appears in the given frame range."""
|
|
import psycopg2
|
|
conn = psycopg2.connect("postgresql://accusys@localhost:5432/momentry")
|
|
cur = conn.cursor()
|
|
cur.execute("""
|
|
SELECT trace_id, COUNT(*) as faces,
|
|
MIN(frame_number) as start_f, MAX(frame_number) as end_f
|
|
FROM dev.face_detections
|
|
WHERE file_uuid = %s AND trace_id IS NOT NULL
|
|
AND frame_number BETWEEN %s AND %s
|
|
GROUP BY trace_id
|
|
ORDER BY faces DESC LIMIT 1
|
|
""", (file_uuid, start_frame, end_frame))
|
|
row = cur.fetchone()
|
|
cur.close()
|
|
conn.close()
|
|
if row:
|
|
return {"trace_id": row[0], "faces": row[1], "start_frame": row[2], "end_frame": row[3]}
|
|
return None
|
|
|
|
|
|
def find_trace_by_object(object_name, file_uuid):
|
|
"""Find a trace in a frame range where YOLO detects the object."""
|
|
import json, os
|
|
yolo_path = os.path.join("/Users/accusys/momentry/output_dev", f"{file_uuid}.yolo.json")
|
|
if not os.path.exists(yolo_path):
|
|
return find_trace_in_frame_range(0, 1000000, file_uuid)
|
|
|
|
with open(yolo_path) as f:
|
|
yolo = json.load(f)
|
|
|
|
# Find first frame with the object
|
|
for fnum_str, frm in yolo.get("frames", {}).items():
|
|
for det in frm.get("detections", []):
|
|
cls = det.get("class_name", "").lower()
|
|
if object_name.lower() in cls.lower():
|
|
target_frame = int(fnum_str)
|
|
return find_trace_in_frame_range(
|
|
max(0, target_frame - 50),
|
|
target_frame + 50,
|
|
file_uuid
|
|
)
|
|
return None
|
|
|
|
|
|
def download_trace_video(file_uuid, trace_id, output_path):
|
|
"""Download trace video in normal mode (no overlay)."""
|
|
cmd = [
|
|
"curl", "-sk", "-H", "X-API-Key: " + KEY,
|
|
"-o", output_path,
|
|
f"{API}/api/v1/file/{file_uuid}/trace/{trace_id}/video?mode=normal&padding=1"
|
|
]
|
|
result = subprocess.run(cmd, capture_output=True, timeout=60)
|
|
return os.path.exists(output_path)
|
|
|
|
|
|
def extract_frames(video_path, n_frames=1):
|
|
"""Extract N evenly-spaced frames from video."""
|
|
cap = cv2.VideoCapture(video_path)
|
|
total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
if total == 0:
|
|
cap.release()
|
|
return []
|
|
|
|
positions = [int(total * 0.5)] # just middle frame
|
|
if n_frames > 1:
|
|
positions = [int(total * p) for p in [0.2, 0.5, 0.8]]
|
|
|
|
positions = [max(0, min(p, total - 1)) for p in positions]
|
|
|
|
frames = []
|
|
for pos in positions:
|
|
cap.set(cv2.CAP_PROP_POS_FRAMES, pos)
|
|
ret, frame = cap.read()
|
|
if ret:
|
|
frames.append(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
|
|
cap.release()
|
|
return frames
|
|
|
|
|
|
def execute(query, file_uuid):
|
|
"""Full execute: type-specific search → download → extract frames."""
|
|
qid = query["id"]
|
|
qtype = query["type"]
|
|
print(f" [{qid}] ({qtype}) {query['prompt'][:55]}...", end="", flush=True)
|
|
|
|
# Type-specific search
|
|
trace_info = None
|
|
if qtype == "identity":
|
|
actor = query.get("expected_identity")
|
|
if actor:
|
|
trace_info = find_trace_by_identity(actor, file_uuid)
|
|
elif qtype == "scene":
|
|
start = query.get("cut_start", 0)
|
|
end = query.get("cut_end", 1000000)
|
|
trace_info = find_trace_in_frame_range(start, end, file_uuid)
|
|
elif qtype == "object":
|
|
obj_name = query.get("expected_object", "")
|
|
trace_info = find_trace_by_object(obj_name, file_uuid)
|
|
|
|
if trace_info is None:
|
|
print(" ❌ no trace found")
|
|
return {"query": query, "status": "no_trace", "frames": []}
|
|
|
|
trace_id = trace_info["trace_id"] if isinstance(trace_info, dict) else trace_info
|
|
start_frame = trace_info.get("start_frame", 0) if isinstance(trace_info, dict) else 0
|
|
end_frame = trace_info.get("end_frame", 0) if isinstance(trace_info, dict) else 0
|
|
trace_start = start_frame / 25.0 if start_frame > 0 else 0
|
|
trace_end = end_frame / 25.0 if end_frame > 0 else trace_start + 30
|
|
|
|
# Download video
|
|
vid_path = f"{FRAME_OUTPUT}/{qid}_video.mp4"
|
|
if download_trace_video(file_uuid, trace_id, vid_path):
|
|
size = os.path.getsize(vid_path)
|
|
print(f" ({size//1024}KB)", end="", flush=True)
|
|
else:
|
|
print(" ❌ video dl failed")
|
|
return {"query": query, "status": "no_video", "frames": []}
|
|
|
|
# Extract frames
|
|
frames = extract_frames(vid_path)
|
|
print(f" {len(frames)} frames")
|
|
|
|
return {
|
|
"query": query,
|
|
"status": "ok",
|
|
"trace_id": trace_id,
|
|
"video_path": vid_path,
|
|
"frames": frames,
|
|
"trace_start": trace_start,
|
|
"trace_end": trace_end,
|
|
}
|