Files
momentry_core/scripts/paligemma_vs_gdino.py
Accusys 39ba5ddf76 feat: Phase 1 handover - schema migration, correction mechanism, API fixes
Schema changes: dev.chunks->dev.chunk, remove old_chunk_id/chunk_index
Correction: asr-1.json format, generate/apply scripts
API: 37/37 endpoints fixed and tested
Docs: HANDOVER_V2.0.md for M4
2026-05-11 07:03:22 +08:00

122 lines
4.8 KiB
Python

#!/opt/homebrew/bin/python3.11
"""
Full comparison: Grounding DINO Base vs PaliGemma 3B mix-224
Tests on 8 known timepoints with gun/stamp prompts.
"""
import json, os, sys, time, cv2, torch, re
from PIL import Image
VIDEO = "/Users/accusys/momentry/var/sftpgo/data/demo/Charade (1963) Cary Grant & Audrey Hepburn \uff5c Comedy Mystery Romance Thriller \uff5c Full Movie.mp4"
OUTPUT_DIR = "/Users/accusys/momentry/output_dev/paligemma_vs_gdino"
os.makedirs(OUTPUT_DIR, exist_ok=True)
TIMEPOINTS = [
(2646, "2646s"), (3188, "3188s"), (3697, "3697s"),
(5341, "5341s"), (5461, "5461s"), (6309, "6309s"),
(6377, "6377s"), (6479, "6479s"),
]
PROMPTS = ["gun", "pistol", "stamp", "envelope", "passport"]
device = "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Device: {device}")
# Load all frames
cap = cv2.VideoCapture(VIDEO)
fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
frames = {}
for t_sec, label in TIMEPOINTS:
cap.set(cv2.CAP_PROP_POS_FRAMES, int(t_sec * fps))
ret, frame = cap.read()
if ret: frames[label] = frame
cap.release()
print(f"Loaded {len(frames)} frames")
all_results = {}
# ===== Grounding DINO Base =====
print("\n" + "="*60)
print("Grounding DINO Base")
print("="*60)
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
t0 = time.time()
gd_proc = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-base")
gd_model = AutoModelForZeroShotObjectDetection.from_pretrained("IDEA-Research/grounding-dino-base").to(device)
gd_dets = {}
for label, frame in frames.items():
img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
for pname in PROMPTS:
inputs = gd_proc(images=img, text=f"{pname}.", return_tensors="pt").to(device)
with torch.no_grad():
outputs = gd_model(**inputs)
target = torch.tensor([img.size[::-1]])
dets = gd_proc.post_process_grounded_object_detection(outputs, threshold=0.1, target_sizes=target)[0]
scores = [round(s.item(), 3) for s in dets["scores"]] if len(dets["boxes"]) > 0 else []
gd_dets[f"{label}_{pname}"] = scores
all_results["grounding-dino-base"] = {"elapsed": round(time.time()-t0, 1), "detections": gd_dets}
print(f" Done: {all_results['grounding-dino-base']['elapsed']}s")
del gd_model; torch.mps.empty_cache()
# ===== PaliGemma 3B mix-224 =====
print("\n" + "="*60)
print("PaliGemma 3B mix-224")
print("="*60)
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
t0 = time.time()
pg_proc = AutoProcessor.from_pretrained("google/paligemma-3b-mix-224")
pg_model = PaliGemmaForConditionalGeneration.from_pretrained(
"google/paligemma-3b-mix-224", dtype=torch.bfloat16
).to(device)
print(f" Model loaded: {sum(p.numel() for p in pg_model.parameters())/1e6:.0f}M params")
pg_dets = {}
for label, frame in frames.items():
img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
for pname in PROMPTS:
t_infer = time.time()
prompt = f"detect {pname}"
inputs = pg_proc(text=prompt, images=img, return_tensors="pt").to(device)
with torch.no_grad():
outputs = pg_model.generate(**inputs, max_new_tokens=100)
result = pg_proc.decode(outputs[0], skip_special_tokens=True)
infer_time = time.time() - t_infer
# Parse bboxes from output
locs = re.findall(r'<loc(\d+)>', result)
n_dets = len(locs) // 4
has_detection = n_dets > 0 or (pname in result.lower() and 'detect' not in result.lower())
scores = []
if has_detection:
for _ in range(n_dets if n_dets > 0 else 1):
scores.append(1.0)
pg_dets[f"{label}_{pname}"] = scores
if has_detection:
print(f" {label} prompt={pname:10s}: {n_dets} det ({infer_time:.1f}s) result={result[:80]}")
all_results["paligemma-3b-mix-224"] = {"elapsed": round(time.time()-t0, 1), "detections": pg_dets}
del pg_model; torch.mps.empty_cache()
# ===== Summary =====
print("\n" + "="*70)
print(f"{'Model':<28} {'Time':>8} {'Params':>8} {'Gun hits':>12} {'Pistol hits':>14} {'Stamp h':>10}")
print("-"*80)
for model_name in ["grounding-dino-base", "paligemma-3b-mix-224"]:
d = all_results[model_name]
dets = d["detections"]
summary = {}
for pname in PROMPTS:
hits = 0
for label, _, _ in TIMEPOINTS:
key = f"{label}_{pname}"
if key in dets and dets[key]:
hits += 1
summary[pname] = hits
params = "232M" if "grounding" in model_name else "2923M"
gun_h = summary.get("gun", 0)
pistol_h = summary.get("pistol", 0)
stamp_h = summary.get("stamp", 0)
print(f"{model_name:<28} {d['elapsed']:>7.1f}s {params:>8} {gun_h:>6d}/8 {pistol_h:>6d}/8 {stamp_h:>6d}/8")
json.dump(all_results, open(os.path.join(OUTPUT_DIR, "comparison.json"), "w"), indent=2)
print(f"\nSaved to {OUTPUT_DIR}/")