feat: Phase 1 handover - schema migration, correction mechanism, API fixes
Schema changes: dev.chunks->dev.chunk, remove old_chunk_id/chunk_index Correction: asr-1.json format, generate/apply scripts API: 37/37 endpoints fixed and tested Docs: HANDOVER_V2.0.md for M4
This commit is contained in:
343
scripts/gdino_frame_api.py
Normal file
343
scripts/gdino_frame_api.py
Normal file
@@ -0,0 +1,343 @@
|
||||
#!/opt/homebrew/bin/python3.11
|
||||
"""
|
||||
Grounding DINO Frame API v2 — Zero-shot detection + natural language range search.
|
||||
Usage:
|
||||
python3 scripts/gdino_frame_api.py # Start server (port 5051)
|
||||
curl http://localhost:5051/detect -d '{"time":5461,"prompt":"gun"}'
|
||||
curl http://localhost:5051/search -d '{"query":"find the gun","range":"0-6780"}'
|
||||
"""
|
||||
import json, os, sys, time, cv2, torch, re, psycopg2, threading
|
||||
from PIL import Image, ImageDraw
|
||||
from flask import Flask, request, jsonify, send_file
|
||||
from datetime import datetime, timezone
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
RESOURCE_ID = "grounding-dino-v1"
|
||||
RESOURCE_TYPE = "vision_detector"
|
||||
CATEGORY = "zero_shot_detection"
|
||||
MODEL_NAME = "IDEA-Research/grounding-dino-base"
|
||||
DEVICE = "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
BASE_DIR = "/Users/accusys/momentry/output_dev"
|
||||
SHOTS_DIR = os.path.join(BASE_DIR, "api_shots")
|
||||
os.makedirs(SHOTS_DIR, exist_ok=True)
|
||||
DB_URL = "postgresql://accusys@localhost:5432/momentry?host=/tmp"
|
||||
PORT = int(os.environ.get("GDINO_API_PORT", 5051))
|
||||
|
||||
VIDEO_PATHS = {
|
||||
"aeed71342a899fe4b4c57b7d41bcb692":
|
||||
"/Users/accusys/momentry/var/sftpgo/data/demo/Charade (1963) Cary Grant & Audrey Hepburn \uff5c Comedy Mystery Romance Thriller \uff5c Full Movie.mp4",
|
||||
}
|
||||
|
||||
_model = None
|
||||
_processor = None
|
||||
|
||||
def register_resource():
|
||||
"""Register this service as a resource in dev.resources."""
|
||||
try:
|
||||
conn = psycopg2.connect(DB_URL)
|
||||
cur = conn.cursor()
|
||||
cur.execute("""
|
||||
INSERT INTO dev.resources (resource_id, resource_type, category, capabilities, config, metadata, status, last_heartbeat)
|
||||
VALUES (%s, %s, %s, %s::jsonb, %s::jsonb, %s::jsonb, %s, NOW())
|
||||
ON CONFLICT (resource_id)
|
||||
DO UPDATE SET status = %s, last_heartbeat = NOW(), config = %s::jsonb
|
||||
""", (
|
||||
RESOURCE_ID, RESOURCE_TYPE, CATEGORY,
|
||||
json.dumps({
|
||||
"detect": "Single-frame object detection",
|
||||
"search": "Time-range search with natural language query",
|
||||
"target_formats": ["file_uuid:chunk_id", "file_uuid:trace_id", "file_uuid:chunk_index", "range"],
|
||||
}),
|
||||
json.dumps({"port": PORT, "device": DEVICE, "model": MODEL_NAME, "host": "localhost"}),
|
||||
json.dumps({"version": "2.0", "docs": "/health"}),
|
||||
"online", "online", json.dumps({"port": PORT, "device": DEVICE, "model": MODEL_NAME}),
|
||||
))
|
||||
conn.commit()
|
||||
cur.close(); conn.close()
|
||||
print(f"[Resource] Registered as '{RESOURCE_ID}' (type={RESOURCE_TYPE})")
|
||||
except Exception as e:
|
||||
print(f"[Resource] Registration failed: {e}")
|
||||
|
||||
def heartbeat_loop():
|
||||
"""Update heartbeat every 60 seconds."""
|
||||
while True:
|
||||
try:
|
||||
conn = psycopg2.connect(DB_URL)
|
||||
cur = conn.cursor()
|
||||
cur.execute("UPDATE dev.resources SET last_heartbeat = NOW() WHERE resource_id = %s", (RESOURCE_ID,))
|
||||
conn.commit()
|
||||
cur.close(); conn.close()
|
||||
except:
|
||||
pass
|
||||
time.sleep(60)
|
||||
|
||||
def get_model():
|
||||
global _model, _processor
|
||||
if _model is None:
|
||||
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
|
||||
print(f"[GDINO] Loading model on {DEVICE}...")
|
||||
t0 = time.time()
|
||||
_processor = AutoProcessor.from_pretrained(MODEL_NAME)
|
||||
_model = AutoModelForZeroShotObjectDetection.from_pretrained(MODEL_NAME).to(DEVICE)
|
||||
print(f"[GDINO] Loaded in {time.time()-t0:.1f}s")
|
||||
return _model, _processor
|
||||
|
||||
def find_video(uuid):
|
||||
if uuid in VIDEO_PATHS: return VIDEO_PATHS[uuid]
|
||||
import glob
|
||||
base = "/Users/accusys/momentry/var/sftpgo/data/demo"
|
||||
for f in glob.glob(f"{base}/**/Charade*", recursive=True):
|
||||
if f.endswith((".mp4", ".mov", ".avi")): VIDEO_PATHS[uuid] = f; return f
|
||||
for f in glob.glob(f"{base}/**/*{uuid[:8]}*", recursive=True):
|
||||
if f.endswith((".mp4", ".mov", ".avi")): VIDEO_PATHS[uuid] = f; return f
|
||||
return None
|
||||
|
||||
def resolve_target(target_str):
|
||||
"""Resolve 'file_uuid:chunk_id' or 'file_uuid:trace_id' to (file_uuid, start_time, end_time).
|
||||
Returns (uuid, start_sec, end_sec, label) or None.
|
||||
"""
|
||||
if not target_str or ":" not in target_str:
|
||||
return None
|
||||
parts = target_str.split(":", 1)
|
||||
if len(parts) != 2:
|
||||
return None
|
||||
uuid, identifier = parts
|
||||
|
||||
conn = psycopg2.connect(DB_URL)
|
||||
cur = conn.cursor()
|
||||
|
||||
# Try chunk_id first
|
||||
cur.execute("""
|
||||
SELECT start_time, end_time, chunk_id FROM dev.chunks
|
||||
WHERE file_uuid=%s AND chunk_id=%s LIMIT 1
|
||||
""", (uuid, identifier))
|
||||
row = cur.fetchone()
|
||||
if row:
|
||||
cur.close(); conn.close()
|
||||
return (uuid, float(row[0]), float(row[1]), identifier)
|
||||
|
||||
# Try chunk_index
|
||||
if identifier.isdigit():
|
||||
cid = f"{uuid}_{identifier}"
|
||||
cur.execute("""
|
||||
SELECT start_time, end_time, chunk_id FROM dev.chunks
|
||||
WHERE file_uuid=%s AND chunk_id=%s LIMIT 1
|
||||
""", (uuid, cid))
|
||||
row = cur.fetchone()
|
||||
if row:
|
||||
cur.close(); conn.close()
|
||||
return (uuid, float(row[0]), float(row[1]), cid)
|
||||
|
||||
# Try trace_id
|
||||
if identifier.startswith("trace_") or identifier.isdigit():
|
||||
trace_id = identifier.replace("trace_", "")
|
||||
cur.execute("""
|
||||
SELECT MIN(start_time), MAX(end_time), chunk_id FROM dev.chunks
|
||||
WHERE file_uuid=%s AND chunk_type='trace' AND chunk_id LIKE %s
|
||||
GROUP BY chunk_id LIMIT 1
|
||||
""", (uuid, f"%_trace_{trace_id}"))
|
||||
row = cur.fetchone()
|
||||
if row:
|
||||
cur.close(); conn.close()
|
||||
return (uuid, float(row[0]), float(row[1]), f"trace_{trace_id}")
|
||||
|
||||
cur.close(); conn.close()
|
||||
return None
|
||||
|
||||
def parse_query(query):
|
||||
"""Extract search object from natural language query."""
|
||||
query = query.lower().strip()
|
||||
# Direct object name
|
||||
articles = ["a ", "an ", "the ", "some ", "any "]
|
||||
prefixes = ["find ", "show ", "search ", "where is ", "where are ",
|
||||
"looking for ", "detect ", "locate ", "spot ", "scan for "]
|
||||
for p in prefixes:
|
||||
if query.startswith(p):
|
||||
query = query[len(p):]
|
||||
for a in articles:
|
||||
if query.startswith(a):
|
||||
query = query[len(a):]
|
||||
# Remove trailing punctuation and extra words
|
||||
query = query.rstrip(".?!,")
|
||||
for suffix in [" in the image", " in this scene", " in the picture",
|
||||
" being held", " in hand", " in frame", " please"]:
|
||||
if query.endswith(suffix):
|
||||
query = query[: -len(suffix)]
|
||||
return query.strip()
|
||||
|
||||
def infer_frame(img, prompt, threshold=0.1):
|
||||
"""Run Grounding DINO on a PIL image. Returns list of detections."""
|
||||
model, processor = get_model()
|
||||
inputs = processor(images=img, text=f"{prompt}.", return_tensors="pt").to(DEVICE)
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
dets = processor.post_process_grounded_object_detection(
|
||||
outputs, threshold=threshold, target_sizes=[img.size[::-1]])[0]
|
||||
results = []
|
||||
for i in range(len(dets["boxes"])):
|
||||
results.append({
|
||||
"bbox": [round(v, 1) for v in dets["boxes"][i].tolist()],
|
||||
"score": round(dets["scores"][i].item(), 3),
|
||||
"label": prompt,
|
||||
})
|
||||
return results
|
||||
|
||||
@app.route("/detect", methods=["POST"])
|
||||
def detect():
|
||||
"""Detect objects in a single frame.
|
||||
Input: {"uuid","time","prompt","threshold"}
|
||||
"""
|
||||
data = request.json or {}
|
||||
uuid = data.get("uuid", "aeed71342a899fe4b4c57b7d41bcb692")
|
||||
t_sec = data.get("time", 0)
|
||||
prompt = data.get("prompt", "gun")
|
||||
threshold = data.get("threshold", 0.1)
|
||||
|
||||
video = find_video(uuid)
|
||||
if not video: return jsonify({"error": "Video not found"}), 404
|
||||
|
||||
cap = cv2.VideoCapture(video)
|
||||
fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
|
||||
cap.set(cv2.CAP_PROP_POS_FRAMES, int(t_sec * fps))
|
||||
ret, frame = cap.read()
|
||||
cap.release()
|
||||
if not ret: return jsonify({"error": f"Cannot read frame at {t_sec}s"}), 400
|
||||
|
||||
img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
||||
t0 = time.time()
|
||||
detections = infer_frame(img, prompt, threshold)
|
||||
infer_ms = (time.time() - t0) * 1000
|
||||
|
||||
draw = ImageDraw.Draw(img)
|
||||
for d in detections:
|
||||
b = d["bbox"]
|
||||
draw.rectangle(b, outline="lime", width=3)
|
||||
draw.text((b[0], b[1]-18), f"{d['label']} {d['score']:.2f}", fill="lime")
|
||||
|
||||
shot_name = f"{uuid[:8]}_{int(t_sec)}s_{prompt}.jpg"
|
||||
img.save(os.path.join(SHOTS_DIR, shot_name))
|
||||
|
||||
return jsonify({
|
||||
"detections": detections,
|
||||
"time_ms": round(infer_ms, 1),
|
||||
"n_detections": len(detections),
|
||||
"shot_url": f"/shots/{shot_name}",
|
||||
})
|
||||
|
||||
@app.route("/search", methods=["POST"])
|
||||
def search():
|
||||
"""Search across a time range with natural language query.
|
||||
Input: {"uuid","target":"file_uuid:chunk_id","query":"find the gun","range":"0-6780","interval":30,"threshold":0.15}
|
||||
target: 'file_uuid:chunk_id' or 'file_uuid:trace_id' — resolves to time range automatically
|
||||
range: manual time range (used if target not provided)
|
||||
"""
|
||||
data = request.json or {}
|
||||
uuid = data.get("uuid", "aeed71342a899fe4b4c57b7d41bcb692")
|
||||
target_str = data.get("target", "")
|
||||
query = data.get("query", "find the gun")
|
||||
range_str = data.get("range", "0-6780")
|
||||
interval = data.get("interval", 30)
|
||||
threshold = data.get("threshold", 0.15)
|
||||
|
||||
prompt = parse_query(query)
|
||||
if not prompt:
|
||||
return jsonify({"error": f"Cannot parse query: {query}"}), 400
|
||||
|
||||
# Resolve target → time range
|
||||
resolved_label = ""
|
||||
if target_str:
|
||||
resolved = resolve_target(target_str)
|
||||
if resolved:
|
||||
uuid, range_start, range_end, resolved_label = resolved
|
||||
else:
|
||||
return jsonify({"error": f"Cannot resolve target: {target_str}"}), 404
|
||||
else:
|
||||
# Parse manual range
|
||||
if "-" in range_str:
|
||||
parts = range_str.split("-")
|
||||
range_start = float(parts[0])
|
||||
range_end = float(parts[1]) if len(parts) > 1 else 6780
|
||||
else:
|
||||
range_start = 0
|
||||
range_end = 6780
|
||||
|
||||
video = find_video(uuid)
|
||||
if not video: return jsonify({"error": "Video not found"}), 404
|
||||
|
||||
cap = cv2.VideoCapture(video)
|
||||
fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
|
||||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
|
||||
hits = []
|
||||
t_start = time.time()
|
||||
frame_step = int(interval * fps)
|
||||
|
||||
for frame_num in range(int(range_start * fps), min(int(range_end * fps), total_frames), frame_step):
|
||||
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num)
|
||||
ret, frame = cap.read()
|
||||
if not ret: continue
|
||||
|
||||
img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
||||
detections = infer_frame(img, prompt, threshold)
|
||||
|
||||
if detections:
|
||||
ts = frame_num / fps
|
||||
best = max(d["score"] for d in detections)
|
||||
hits.append({
|
||||
"time": round(ts, 1),
|
||||
"time_str": f"{int(ts//60)}:{int(ts%60):02d}.{int((ts%1)*fps):02d}",
|
||||
"frame": frame_num,
|
||||
"detections": detections,
|
||||
"best_score": best,
|
||||
})
|
||||
|
||||
if len(hits) >= 100: # safety limit
|
||||
break
|
||||
|
||||
cap.release()
|
||||
elapsed = time.time() - t_start
|
||||
|
||||
return jsonify({
|
||||
"query": query,
|
||||
"object": prompt,
|
||||
"target": target_str or None,
|
||||
"resolved_target": resolved_label or None,
|
||||
"range": f"{range_start:.0f}-{range_end:.0f}",
|
||||
"interval_secs": interval,
|
||||
"scanned_frames": int((range_end - range_start) / interval) + 1,
|
||||
"hits": hits,
|
||||
"n_hits": len(hits),
|
||||
"elapsed_secs": round(elapsed, 1),
|
||||
})
|
||||
|
||||
@app.route("/shots/<filename>")
|
||||
def serve_shot(filename):
|
||||
path = os.path.join(SHOTS_DIR, filename)
|
||||
if not os.path.exists(path): return jsonify({"error": "Not found"}), 404
|
||||
return send_file(path, mimetype="image/jpeg")
|
||||
|
||||
@app.route("/health")
|
||||
def health():
|
||||
return jsonify({
|
||||
"status": "ok",
|
||||
"resource_id": RESOURCE_ID,
|
||||
"resource_type": RESOURCE_TYPE,
|
||||
"model": MODEL_NAME,
|
||||
"device": DEVICE,
|
||||
"port": PORT,
|
||||
})
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Register as resource
|
||||
register_resource()
|
||||
|
||||
# Start heartbeat thread
|
||||
t = threading.Thread(target=heartbeat_loop, daemon=True)
|
||||
t.start()
|
||||
|
||||
# Load model
|
||||
get_model()
|
||||
print(f"[GDINO] Frame API v2: http://0.0.0.0:{PORT}")
|
||||
print(f"[GDINO] Resource: {RESOURCE_ID} (type={RESOURCE_TYPE})")
|
||||
app.run(host="0.0.0.0", port=PORT, threaded=True)
|
||||
Reference in New Issue
Block a user