82 lines
3.0 KiB
Python
82 lines
3.0 KiB
Python
"""Query Generator: Generate 15 test prompts from DB data"""
|
|
import random, psycopg2, json
|
|
|
|
DB_URL = "postgresql://accusys@localhost:5432/momentry"
|
|
|
|
def generate(file_uuid):
|
|
conn = psycopg2.connect(DB_URL)
|
|
cur = conn.cursor()
|
|
queries = []
|
|
|
|
# 1. Identity queries (5) — top TMDB actors by face count
|
|
cur.execute("""
|
|
SELECT i.name, fd.trace_id, COUNT(*) as faces
|
|
FROM dev.face_detections fd
|
|
JOIN dev.identities i ON i.id = fd.identity_id
|
|
WHERE fd.file_uuid = %s AND i.source = 'tmdb'
|
|
GROUP BY i.name, fd.trace_id
|
|
ORDER BY faces DESC LIMIT 5
|
|
""", (file_uuid,))
|
|
for i, (name, tid, cnt) in enumerate(cur.fetchall()):
|
|
scene_hints = ["indoor", "outdoor", "in a conversation", "walking", "talking"]
|
|
hint = scene_hints[i % len(scene_hints)]
|
|
queries.append({
|
|
"id": f"Q{i+1:02d}", "type": "identity",
|
|
"prompt": f"Show {name} {hint}",
|
|
"expected_identity": name,
|
|
"expected_trace_id": tid,
|
|
"face_count_gt": cnt
|
|
})
|
|
|
|
# 2. Scene queries (5) — from cut.json file
|
|
import json, os
|
|
cut_path = os.path.join("/Users/accusys/momentry/output_dev", f"{file_uuid}.cut.json")
|
|
if os.path.exists(cut_path):
|
|
with open(cut_path) as f:
|
|
cuts = json.load(f).get("scenes", [])
|
|
else:
|
|
cuts = []
|
|
|
|
scene_labels = ["restaurant", "hotel_room", "office", "street",
|
|
"bedroom", "park", "kitchen", "car_interior", "bar", "living_room"]
|
|
import random
|
|
random.shuffle(cuts)
|
|
for i in range(min(5, len(cuts))):
|
|
label = scene_labels[i % len(scene_labels)]
|
|
queries.append({
|
|
"id": f"Q{i+6:02d}", "type": "scene",
|
|
"prompt": f"Show the scene in a {label.replace('_', ' ')}",
|
|
"expected_scene": label,
|
|
"cut_start": cuts[i]["start_frame"],
|
|
"cut_end": cuts[i]["end_frame"],
|
|
})
|
|
|
|
# 3. Object queries (5) — from yolo.json
|
|
yolo_path = os.path.join("/Users/accusys/momentry/output_dev", f"{file_uuid}.yolo.json")
|
|
if os.path.exists(yolo_path):
|
|
with open(yolo_path) as f:
|
|
yolo_data = json.load(f)
|
|
from collections import Counter
|
|
class_counts = Counter()
|
|
for _, frm in yolo_data.get("frames", {}).items():
|
|
for det in frm.get("detections", []):
|
|
cls = det.get("class_name", det.get("class", ""))
|
|
if cls not in ("person", "tie"):
|
|
class_counts[cls] += 1
|
|
top_classes = [c for c, _ in class_counts.most_common(10)]
|
|
else:
|
|
top_classes = ["chair", "car", "bottle", "book", "tvmonitor", "cell phone", "cup", "diningtable"]
|
|
|
|
random.shuffle(top_classes)
|
|
for i in range(min(5, len(top_classes))):
|
|
cls = top_classes[i]
|
|
queries.append({
|
|
"id": f"Q{i+11:02d}", "type": "object",
|
|
"prompt": f"Find scenes containing a {cls}",
|
|
"expected_object": cls,
|
|
})
|
|
|
|
cur.close()
|
|
conn.close()
|
|
return queries
|