Files
momentry_core/scripts/test_owl_vit_stamps.py
Warren 8f05a7c188 feat: update Python processors and add utility scripts
- Update ASR, face, OCR, pose processors
- Add release pre-flight check script
- Add synonym generation, chunk processing scripts
- Add face recognition, stamp search utilities
2026-04-30 15:07:49 +08:00

115 lines
3.5 KiB
Python

#!/opt/homebrew/bin/python3.11
"""
Test OWL-ViT for "Stamps" Detection
"""
import os
import json
import cv2
import torch
from PIL import Image
from transformers import OwlViTProcessor, OwlViTForObjectDetection
UUID = "384b0ff44aaaa1f1"
VIDEO_PATH = f"output/{UUID}/{UUID}.mp4"
ASR_PATH = f"output/{UUID}/{UUID}.asr.json"
OUTPUT_DIR = f"output/{UUID}/owl_vit_results"
os.makedirs(OUTPUT_DIR, exist_ok=True)
# 1. Find timestamps where "stamp" is mentioned
print("🔍 Analyzing ASR for 'stamp' mentions...")
with open(ASR_PATH) as f:
asr_data = json.load(f)
target_times = []
for seg in asr_data.get("segments", []):
text = seg.get("text", "").lower()
if "stamp" in text:
target_times.append(seg.get("start", 0))
print(f" 🗣️ Found: '{seg['text']}' @ {seg['start']:.2f}s")
if not target_times:
print("❌ No mentions of 'stamp' found.")
exit()
# Prioritize timestamps around the "Stamps" chunk (Chunk 833, ~5851s) and the final confrontation (~6700s+)
# because early mentions might be just dialogue about them without showing them.
priority_times = [5851.6, 5860.4, 6756.6, 6846.0]
print(f"🔥 Prioritizing high-probability timestamps: {priority_times}")
target_times = priority_times
print(f"✅ Found {len(target_times)} candidate timestamps.")
# 2. Load Model (using base for speed, large is more accurate but slower)
print("🧠 Loading OWL-ViT model...")
processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")
model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32")
# 3. Process Frames
cap = cv2.VideoCapture(VIDEO_PATH)
fps = cap.get(cv2.CAP_PROP_FPS)
for i, t in enumerate(target_times): # Check all target times
cap.set(cv2.CAP_PROP_POS_MSEC, t * 1000)
ret, frame = cap.read()
if not ret:
continue
# Convert to PIL for model
image_pil = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
# Define text queries
texts = [["a postage stamp", "a stamp on a letter", "a stamp in an album"]]
inputs = processor(text=texts, images=image_pil, return_tensors="pt")
outputs = model(**inputs)
# Post-process
target_sizes = torch.Tensor([image_pil.size[::-1]])
results = processor.post_process_object_detection(
outputs=outputs, target_sizes=target_sizes, threshold=0.1
)
i = 0
box_found = False
for box, score, label in zip(
results[i]["boxes"], results[i]["scores"], results[i]["labels"]
):
if score > 0.15: # Confidence threshold
box_found = True
x_min, y_min, x_max, y_max = box.int().tolist()
label_text = texts[i][label.item()]
print(f" ✅ Detected '{label_text}' ({score.item():.2f}) at {t:.2f}s")
# Draw
cv2.rectangle(frame, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)
cv2.putText(
frame,
f"{label_text} {score.item():.2f}",
(x_min, y_min - 10),
cv2.FONT_HERSHEY_SIMPLEX,
0.5,
(0, 255, 0),
2,
)
if not box_found:
print(f" ❌ No stamp detected at {t:.2f}s")
cv2.putText(
frame,
"No Stamp Found",
(50, 50),
cv2.FONT_HERSHEY_SIMPLEX,
1,
(0, 0, 255),
2,
)
else:
# Save result
save_path = os.path.join(OUTPUT_DIR, f"stamp_detect_{int(t)}.jpg")
cv2.imwrite(save_path, frame)
print(f" 💾 Saved to {save_path}")
cap.release()
print("🏁 Done.")