Schema changes: dev.chunks->dev.chunk, remove old_chunk_id/chunk_index Correction: asr-1.json format, generate/apply scripts API: 37/37 endpoints fixed and tested Docs: HANDOVER_V2.0.md for M4
325 lines
11 KiB
Python
325 lines
11 KiB
Python
#!/opt/homebrew/bin/python3.11
|
|
"""
|
|
Object Search Agent — searches across YOLO, OCR, ASR, and TKG.
|
|
Usage: python3 scripts/object_search_agent.py --keyword stamp [--uuid <UUID>]
|
|
"""
|
|
import json, sys, argparse
|
|
from collections import defaultdict
|
|
import psycopg2
|
|
|
|
UUID = "aeed71342a899fe4b4c57b7d41bcb692"
|
|
DB_URL = "postgresql://accusys@localhost:5432/momentry?host=/tmp"
|
|
FPS = 25.0
|
|
|
|
# YOLO class aliases for common search terms
|
|
ALIASES = {
|
|
"stamp": ["stamp"],
|
|
"gun": ["knife", "pistol", "rifle", "grenade"],
|
|
"weapon": ["knife", "pistol", "rifle", "grenade"],
|
|
"knife": ["knife"],
|
|
"person": ["person"],
|
|
"letter": ["book"],
|
|
"envelope": ["book"],
|
|
"car": ["car"],
|
|
"tie": ["tie"],
|
|
"phone": ["cell phone"],
|
|
"bottle": ["bottle", "wine glass", "cup"],
|
|
"chair": ["chair"],
|
|
"umbrella": ["umbrella"],
|
|
}
|
|
|
|
def search_yolo(cur, keyword, uuid):
|
|
"""Search YOLO detections for matching object classes."""
|
|
classes = ALIASES.get(keyword, [keyword])
|
|
results = []
|
|
for cls in classes:
|
|
cur.execute("""
|
|
SELECT start_frame, end_frame, data
|
|
FROM dev.pre_chunks
|
|
WHERE file_uuid=%s AND processor_type='yolo'
|
|
AND data->'objects' IS NOT NULL
|
|
AND data->'objects' @> jsonb_build_array(
|
|
jsonb_build_object('class_name', %s)
|
|
)
|
|
ORDER BY start_frame
|
|
LIMIT 100
|
|
""", (uuid, cls))
|
|
for r in cur.fetchall():
|
|
sf, ef, data = r
|
|
objects = [o for o in data.get("objects", []) if o.get("class_name") == cls]
|
|
top_conf = max((o.get("confidence", 0) for o in objects), default=0)
|
|
if top_conf > 0.3:
|
|
ts = sf / FPS
|
|
results.append({
|
|
"frame": int(sf),
|
|
"timestamp": ts,
|
|
"time_str": f"{int(ts//60)}:{int(ts%60):02d}.{int((ts%1)*25):02d}",
|
|
"class": cls,
|
|
"confidence": round(top_conf, 3),
|
|
"source": "yolo",
|
|
})
|
|
return results
|
|
|
|
def search_ocr(cur, keyword, uuid):
|
|
"""Search OCR text for keyword."""
|
|
cur.execute("""
|
|
SELECT start_frame, end_frame, data
|
|
FROM dev.pre_chunks
|
|
WHERE file_uuid=%s AND processor_type='ocr'
|
|
AND data->>'text' ILIKE %s
|
|
ORDER BY start_frame
|
|
LIMIT 50
|
|
""", (uuid, f"%{keyword}%"))
|
|
results = []
|
|
for r in cur.fetchall():
|
|
sf, ef, data = r
|
|
results.append({
|
|
"frame": sf,
|
|
"timestamp": sf / FPS,
|
|
"time_str": f"{int(sf//FPS//60)}:{sf//FPS%60:02d}.{sf%FPS:02.0f}",
|
|
"text": data.get("text", "")[:100],
|
|
"source": "ocr",
|
|
})
|
|
return results
|
|
|
|
def search_asr(cur, keyword, uuid):
|
|
"""Search ASR/sentence text for keyword."""
|
|
cur.execute("""
|
|
SELECT chunk_index, start_time, end_time, text_content
|
|
FROM dev.chunks
|
|
WHERE file_uuid=%s AND chunk_type='sentence'
|
|
AND text_content ILIKE %s
|
|
ORDER BY start_time
|
|
LIMIT 100
|
|
""", (uuid, f"%{keyword}%"))
|
|
results = []
|
|
for r in cur.fetchall():
|
|
idx, st, et, text = r
|
|
results.append({
|
|
"chunk_index": idx,
|
|
"timestamp": st,
|
|
"time_str": f"{int(st//60)}:{st%60:05.2f}",
|
|
"text": (text or "")[:120],
|
|
"source": "asr",
|
|
})
|
|
return results
|
|
|
|
GUN_MODEL_PATH = "/Users/accusys/momentry_core_0.1/models/gun/gun_detector/weights/best.pt"
|
|
GUN_CLASSES = {0: "grenade", 1: "knife", 2: "pistol", 3: "rifle"}
|
|
|
|
# Grounding DINO — Zero-shot gun detector (Large: 7 datasets, confirmed best on Charade)
|
|
GDINO_MODEL_NAME = "/Users/accusys/momentry_core_0.1/models/gun/grounding-dino-large-hf"
|
|
GDINO_PROMPTS = ["gun", "pistol", "rifle", "weapon", "firearm"]
|
|
|
|
_gdino_processor = None
|
|
_gdino_model = None
|
|
_gdino_device = None
|
|
|
|
def init_gdino():
|
|
global _gdino_processor, _gdino_model, _gdino_device
|
|
if _gdino_model is not None:
|
|
return
|
|
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
|
|
import torch
|
|
_gdino_processor = AutoProcessor.from_pretrained(GDINO_MODEL_NAME)
|
|
_gdino_model = AutoModelForZeroShotObjectDetection.from_pretrained(GDINO_MODEL_NAME)
|
|
_gdino_device = "mps" if torch.backends.mps.is_available() else "cpu"
|
|
_gdino_model.to(_gdino_device)
|
|
|
|
def search_zero_shot(video_path, keyword, threshold=0.05):
|
|
"""Search for objects using Grounding DINO zero-shot detection."""
|
|
import cv2
|
|
from PIL import Image
|
|
import torch
|
|
|
|
# Determine prompts based on keyword
|
|
if keyword in ("gun", "weapon", "pistol", "rifle", "firearm"):
|
|
prompts = GDINO_PROMPTS
|
|
else:
|
|
prompts = [keyword]
|
|
|
|
init_gdino()
|
|
|
|
cap = cv2.VideoCapture(video_path)
|
|
fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
|
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
|
|
results = []
|
|
for frame_num in range(0, total_frames, 1500): # every ~60s
|
|
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num)
|
|
ret, frame = cap.read()
|
|
if not ret: break
|
|
|
|
img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
|
|
|
for prompt in prompts:
|
|
inputs = _gdino_processor(images=img, text=prompt, return_tensors="pt").to(_gdino_device)
|
|
with torch.no_grad():
|
|
outputs = _gdino_model(**inputs)
|
|
target = torch.tensor([img.size[::-1]])
|
|
dets = _gdino_processor.post_process_grounded_object_detection(
|
|
outputs, threshold=threshold, target_sizes=target)[0]
|
|
|
|
for i in range(len(dets["boxes"])):
|
|
score = dets["scores"][i].item()
|
|
ts = frame_num / fps
|
|
results.append({
|
|
"frame": frame_num,
|
|
"timestamp": ts,
|
|
"time_str": f"{int(ts//60)}:{int(ts%60):02d}",
|
|
"class": prompt,
|
|
"confidence": round(score, 3),
|
|
"source": "grounding-dino",
|
|
})
|
|
|
|
if len(results) >= 50:
|
|
break
|
|
|
|
cap.release()
|
|
return results
|
|
|
|
def search_gun_detector(video_path, keyword, frame_step=150, confidence=0.25):
|
|
"""Run custom gun detector model on keyframes."""
|
|
classes = ALIASES.get(keyword, [])
|
|
target_ids = [cid for cid, cname in GUN_CLASSES.items() if cname in classes]
|
|
if not target_ids:
|
|
return []
|
|
|
|
try:
|
|
from ultralytics import YOLO
|
|
import cv2
|
|
except ImportError:
|
|
return [{"error": "ultralytics or cv2 not available"}]
|
|
|
|
model = YOLO(GUN_MODEL_PATH)
|
|
cap = cv2.VideoCapture(video_path)
|
|
fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
|
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
|
|
results = []
|
|
for frame_num in range(0, total_frames, frame_step):
|
|
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num)
|
|
ret, frame = cap.read()
|
|
if not ret:
|
|
break
|
|
|
|
dets = model(frame, conf=confidence, verbose=False)[0]
|
|
for det in dets.boxes.data:
|
|
cls_id = int(det[5])
|
|
if cls_id in target_ids:
|
|
conf_val = float(det[4])
|
|
ts = frame_num / fps
|
|
results.append({
|
|
"frame": frame_num,
|
|
"timestamp": ts,
|
|
"time_str": f"{int(ts//60)}:{int(ts%60):02d}.{int((ts%1)*fps):02d}",
|
|
"class": GUN_CLASSES[cls_id],
|
|
"confidence": round(conf_val, 3),
|
|
"source": "gun_detector",
|
|
})
|
|
|
|
if len(results) >= 50:
|
|
break
|
|
|
|
cap.release()
|
|
return results
|
|
|
|
def search_tkg(cur, keyword, uuid):
|
|
"""Search TKG for related entities."""
|
|
cur.execute("""
|
|
SELECT node_type, external_id, label, properties
|
|
FROM dev.tkg_nodes
|
|
WHERE file_uuid=%s
|
|
AND (label ILIKE %s OR external_id ILIKE %s)
|
|
LIMIT 20
|
|
""", (uuid, f"%{keyword}%", f"%{keyword}%"))
|
|
results = []
|
|
for r in cur.fetchall():
|
|
node_type, ext_id, label, props = r
|
|
results.append({
|
|
"type": node_type,
|
|
"id": ext_id,
|
|
"label": label,
|
|
"properties": props,
|
|
"source": "tkg",
|
|
})
|
|
return results
|
|
|
|
def find_video(uuid):
|
|
"""Find Charade video file."""
|
|
import glob
|
|
base = "/Users/accusys/momentry/var/sftpgo/data/demo"
|
|
# Find Charade by name
|
|
for f in glob.glob(f"{base}/**/Charade*", recursive=True):
|
|
if f.endswith((".mp4", ".mov", ".avi")):
|
|
return f
|
|
# Fallback: search by uuid pattern
|
|
for f in glob.glob(f"{base}/**/*{uuid[:8]}*", recursive=True):
|
|
if f.endswith((".mp4", ".mov", ".avi")):
|
|
return f
|
|
return None
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Movie Object Search Agent")
|
|
parser.add_argument("--keyword", required=True, help="Object to search for")
|
|
parser.add_argument("--uuid", default=UUID)
|
|
parser.add_argument("--sources", default="all", help="yolo,ocr,asr,tkg,gun_custom,all")
|
|
parser.add_argument("--video", help="Path to video file (for gun detector)")
|
|
args = parser.parse_args()
|
|
|
|
kw = args.keyword.lower()
|
|
src = args.sources.split(",") if args.sources != "all" else ["yolo","ocr","asr","tkg"]
|
|
|
|
conn = psycopg2.connect(DB_URL)
|
|
cur = conn.cursor()
|
|
|
|
results = {}
|
|
|
|
if "yolo" in src:
|
|
r = search_yolo(cur, kw, args.uuid)
|
|
results["yolo"] = {"count": len(r), "results": r[:30]}
|
|
|
|
if "ocr" in src:
|
|
r = search_ocr(cur, kw, args.uuid)
|
|
results["ocr"] = {"count": len(r), "results": r[:20]}
|
|
|
|
if "asr" in src:
|
|
r = search_asr(cur, kw, args.uuid)
|
|
results["asr"] = {"count": len(r), "results": r[:20]}
|
|
|
|
if "tkg" in src:
|
|
r = search_tkg(cur, kw, args.uuid)
|
|
results["tkg"] = {"count": len(r), "results": r[:10]}
|
|
|
|
if "zero_shot" in src or kw in ("gun", "weapon", "pistol", "rifle", "firearm"):
|
|
video_path = args.video or find_video(args.uuid)
|
|
if video_path:
|
|
print(" Running Grounding DINO zero-shot search...")
|
|
r = search_zero_shot(video_path, kw)
|
|
results["zero_shot"] = {"count": len(r), "results": r[:20]}
|
|
else:
|
|
results["zero_shot"] = {"count": 0, "results": [], "error": "Video not found"}
|
|
|
|
conn.close()
|
|
|
|
# Print summary
|
|
print(f"\n=== Object Search: \"{args.keyword}\" ===\n")
|
|
for src_name, data in results.items():
|
|
print(f"[{src_name.upper()}] {data['count']} matches" + (" — top results:" if data['results'] else ""))
|
|
for i, r in enumerate(data['results'][:5]):
|
|
if src_name == "yolo":
|
|
print(f" {i+1}. {r['time_str']} frame={r['frame']} \"{r['class']}\" conf={r['confidence']}")
|
|
elif src_name == "ocr":
|
|
print(f" {i+1}. {r['time_str']} frame={r['frame']} \"{r['text'][:60]}\"")
|
|
elif src_name == "asr":
|
|
print(f" {i+1}. {r['time_str']} \"{r['text'][:60]}\"")
|
|
elif src_name == "tkg":
|
|
print(f" {i+1}. {r['type']}: {r['label']} ({r.get('properties',{}).get('total_detections','?')} detections)")
|
|
print()
|
|
|
|
# Output as JSON for machine parsing
|
|
print(json.dumps({"keyword": args.keyword, "sources": results}, indent=2))
|
|
|
|
if __name__ == "__main__":
|
|
main()
|