feat: Phase 1 handover - schema migration, correction mechanism, API fixes
Schema changes: dev.chunks->dev.chunk, remove old_chunk_id/chunk_index Correction: asr-1.json format, generate/apply scripts API: 37/37 endpoints fixed and tested Docs: HANDOVER_V2.0.md for M4
This commit is contained in:
156
scripts/zero_shot_gun_test.py
Normal file
156
scripts/zero_shot_gun_test.py
Normal file
@@ -0,0 +1,156 @@
|
||||
#!/opt/homebrew/bin/python3.11
|
||||
"""
|
||||
Zero-shot Gun Detection Test — OWL-ViT vs Grounding DINO
|
||||
Tests on 8 known timepoints: 5 original pistol frames + 3 ASR gun mentions.
|
||||
"""
|
||||
import json, os, sys, time, cv2
|
||||
import torch
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
VIDEO = "/Users/accusys/momentry/var/sftpgo/data/demo/Charade (1963) Cary Grant & Audrey Hepburn \uff5c Comedy Mystery Romance Thriller \uff5c Full Movie.mp4"
|
||||
OUTPUT_DIR = "/Users/accusys/momentry/output_dev/zero_shot_test"
|
||||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||||
|
||||
TIMEPOINTS = [
|
||||
(2646, "2646s", "ASR: He has a gun"),
|
||||
(3188, "3188s", "Original pistol"),
|
||||
(3697, "3697s", "ASR: Where's your gun"),
|
||||
(5341, "5341s", "ASR: He already killed 3 men"),
|
||||
(5461, "5461s", "Original pistol"),
|
||||
(6309, "6309s", "Original pistol"),
|
||||
(6377, "6377s", "Original gun"),
|
||||
(6479, "6479s", "Original pistol"),
|
||||
]
|
||||
PROMPTS = ["gun", "pistol", "rifle", "weapon"]
|
||||
|
||||
cap = cv2.VideoCapture(VIDEO)
|
||||
fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
|
||||
|
||||
def get_frame(t_sec):
|
||||
cap.set(cv2.CAP_PROP_POS_FRAMES, int(t_sec * fps))
|
||||
ret, frame = cap.read()
|
||||
return frame if ret else None
|
||||
|
||||
def save_annotated(frame, detections, prompt, model_name, label):
|
||||
img = frame.copy()
|
||||
for d in detections:
|
||||
x1, y1, x2, y2 = [int(v) for v in d["bbox"]]
|
||||
conf = d["score"]
|
||||
cls = d["label"]
|
||||
cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
||||
cv2.putText(img, f"{cls} {conf:.2f}", (x1, y1-5),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
|
||||
filename = f"{label}_{model_name}_prompt-{prompt}.jpg"
|
||||
cv2.imwrite(os.path.join(OUTPUT_DIR, filename), img, [cv2.IMWRITE_JPEG_QUALITY, 85])
|
||||
return filename
|
||||
|
||||
all_results = {}
|
||||
|
||||
# ========== OWL-ViT ==========
|
||||
print("=" * 60)
|
||||
print("OWL-ViT (google/owlvit-base-patch32)")
|
||||
print("=" * 60)
|
||||
|
||||
from transformers import OwlViTProcessor, OwlViTForObjectDetection
|
||||
|
||||
owl_proc = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")
|
||||
owl_model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32")
|
||||
device = "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
owl_model.to(device)
|
||||
print(f"Device: {device}")
|
||||
|
||||
owl_dets = {}
|
||||
t0 = time.time()
|
||||
for t_sec, label, desc in TIMEPOINTS:
|
||||
frame = get_frame(t_sec)
|
||||
if frame is None: continue
|
||||
img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
||||
for prompt in PROMPTS:
|
||||
inputs = owl_proc(text=[[prompt]], images=img, return_tensors="pt").to(device)
|
||||
with torch.no_grad():
|
||||
outputs = owl_model(**inputs)
|
||||
target = torch.tensor([img.size[::-1]])
|
||||
dets = owl_proc.post_process_grounded_object_detection(outputs, threshold=0.05, target_sizes=target)[0]
|
||||
det_list = []
|
||||
for i in range(len(dets["boxes"])):
|
||||
det_list.append({
|
||||
"bbox": [round(v, 1) for v in dets["boxes"][i].tolist()],
|
||||
"score": round(dets["scores"][i].item(), 3),
|
||||
"label": prompt,
|
||||
})
|
||||
save_annotated(frame, det_list, prompt, "owlvit", label)
|
||||
key = f"{label}_prompt-{prompt}"
|
||||
owl_dets[key] = det_list
|
||||
if det_list:
|
||||
best = max(d["score"] for d in det_list)
|
||||
print(f" [{desc}] prompt='{prompt}': {len(det_list)} det best={best:.3f}")
|
||||
|
||||
all_results["owlvit"] = {"elapsed": round(time.time()-t0, 1), "detections": owl_dets}
|
||||
|
||||
# ========== Grounding DINO ==========
|
||||
print("\n" + "=" * 60)
|
||||
print("Grounding DINO (IDEA-Research/grounding-dino-base)")
|
||||
print("=" * 60)
|
||||
|
||||
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
|
||||
|
||||
gd_proc = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-base")
|
||||
gd_model = AutoModelForZeroShotObjectDetection.from_pretrained("IDEA-Research/grounding-dino-base")
|
||||
gd_model.to(device)
|
||||
|
||||
gd_dets = {}
|
||||
t0 = time.time()
|
||||
for t_sec, label, desc in TIMEPOINTS:
|
||||
frame = get_frame(t_sec)
|
||||
if frame is None: continue
|
||||
img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
||||
for prompt in PROMPTS:
|
||||
inputs = gd_proc(images=img, text=prompt, return_tensors="pt").to(device)
|
||||
with torch.no_grad():
|
||||
outputs = gd_model(**inputs)
|
||||
target = torch.tensor([img.size[::-1]])
|
||||
dets = gd_proc.post_process_grounded_object_detection(outputs, threshold=0.05, target_sizes=target)[0]
|
||||
det_list = []
|
||||
for i in range(len(dets["boxes"])):
|
||||
det_list.append({
|
||||
"bbox": [round(v, 1) for v in dets["boxes"][i].tolist()],
|
||||
"score": round(dets["scores"][i].item(), 3),
|
||||
"label": prompt,
|
||||
})
|
||||
save_annotated(frame, det_list, prompt, "grounding-dino", label)
|
||||
key = f"{label}_prompt-{prompt}"
|
||||
gd_dets[key] = det_list
|
||||
if det_list:
|
||||
best = max(d["score"] for d in det_list)
|
||||
print(f" [{desc}] prompt='{prompt}': {len(det_list)} det best={best:.3f}")
|
||||
|
||||
all_results["grounding-dino"] = {"elapsed": round(time.time()-t0, 1), "detections": gd_dets}
|
||||
|
||||
cap.release()
|
||||
|
||||
# ========== Summary ==========
|
||||
print("\n" + "=" * 60)
|
||||
print("SUMMARY")
|
||||
print("=" * 60)
|
||||
for model in ["owlvit", "grounding-dino"]:
|
||||
d = all_results[model]
|
||||
dets = d["detections"]
|
||||
hits = sum(1 for v in dets.values() if v)
|
||||
total = sum(len(v) for v in dets.values())
|
||||
print(f"\n{model} ({d['elapsed']}s): {hits}/8 timepoints, {total} total detections")
|
||||
for t_sec, label, desc in TIMEPOINTS:
|
||||
candidates = []
|
||||
for p in PROMPTS:
|
||||
key = f"{label}_prompt-{p}"
|
||||
if key in dets and dets[key]:
|
||||
for dd in dets[key]:
|
||||
candidates.append((p, dd["score"]))
|
||||
if candidates:
|
||||
best = max(candidates, key=lambda x: x[1])
|
||||
print(f" {desc}: best={best[1]:.3f} (prompt='{best[0]}')")
|
||||
else:
|
||||
print(f" {desc}: no detections")
|
||||
|
||||
json.dump(all_results, open(os.path.join(OUTPUT_DIR, "zero_shot_results.json"), "w"), indent=2)
|
||||
print(f"\nSaved to {OUTPUT_DIR}/")
|
||||
Reference in New Issue
Block a user