174 lines
5.4 KiB
Python
174 lines
5.4 KiB
Python
#!/opt/homebrew/bin/python3.11
|
|
"""
|
|
Extract face embeddings for a video file using InsightFace + CoreML FaceNet.
|
|
Updates face_detections.embedding in PostgreSQL.
|
|
|
|
Usage: python3 scripts/extract_video_embeddings.py --file-uuid <uuid> --video-path <path>
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
import sys
|
|
import io
|
|
import warnings
|
|
import cv2
|
|
import numpy as np
|
|
import psycopg2
|
|
from psycopg2.extras import execute_values
|
|
|
|
warnings.filterwarnings("ignore")
|
|
|
|
DATABASE_URL = os.getenv("DATABASE_URL", "postgres://accusys@localhost:5432/momentry")
|
|
MODELS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "models")
|
|
FACENET_PATH = os.path.join(MODELS_DIR, "facenet512.mlpackage")
|
|
|
|
|
|
def get_schema():
|
|
"""Get schema from DATABASE_URL options"""
|
|
db_url = os.getenv("DATABASE_URL", "")
|
|
if "search_path=dev" in db_url or "DATABASE_SCHEMA=dev" in os.environ:
|
|
return "dev"
|
|
return "public"
|
|
|
|
|
|
def extract_video_embeddings(file_uuid: str, video_path: str, schema: str = "dev"):
|
|
"""Extract face embeddings from video frames"""
|
|
|
|
# Suppress InsightFace verbose output
|
|
old_stdout = sys.stdout
|
|
sys.stdout = io.StringIO()
|
|
try:
|
|
import insightface
|
|
from insightface.app import FaceAnalysis
|
|
import coremltools as ct
|
|
|
|
app = FaceAnalysis(name="buffalo_l", providers=["CPUExecutionProvider"])
|
|
app.prepare(ctx_id=0, det_thresh=0.5)
|
|
coreml_model = ct.models.MLModel(FACENET_PATH)
|
|
finally:
|
|
sys.stdout = old_stdout
|
|
|
|
# Open video
|
|
cap = cv2.VideoCapture(video_path)
|
|
fps = cap.get(cv2.CAP_PROP_FPS)
|
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
|
|
print(f"[EMBED] Video: {total_frames} frames, {fps} fps")
|
|
|
|
# Get face detections from DB (without embeddings)
|
|
conn = psycopg2.connect(DATABASE_URL)
|
|
cur = conn.cursor()
|
|
|
|
cur.execute(f"""
|
|
SELECT id, frame_number, x, y, width, height
|
|
FROM {schema}.face_detections
|
|
WHERE file_uuid = %s AND embedding IS NULL
|
|
ORDER BY frame_number
|
|
""", (file_uuid,))
|
|
|
|
face_records = cur.fetchall()
|
|
print(f"[EMBED] Faces without embedding: {len(face_records)}")
|
|
|
|
if len(face_records) == 0:
|
|
print("[EMBED] All faces have embeddings")
|
|
cur.close()
|
|
conn.close()
|
|
return
|
|
|
|
# Build frame -> faces mapping
|
|
frame_faces = {}
|
|
for face_id, frame_num, x, y, w, h in face_records:
|
|
if frame_num not in frame_faces:
|
|
frame_faces[frame_num] = []
|
|
frame_faces[frame_num].append((face_id, x, y, w, h))
|
|
|
|
# Extract embeddings
|
|
batch_updates = []
|
|
processed_frames = 0
|
|
|
|
for frame_num in sorted(frame_faces.keys()):
|
|
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num)
|
|
ret, frame = cap.read()
|
|
|
|
if not ret:
|
|
continue
|
|
|
|
faces_data = frame_faces[frame_num]
|
|
|
|
# Detect faces in this frame
|
|
faces = app.get(frame)
|
|
|
|
for face_id, x, y, w, h in faces_data:
|
|
# Find matching detected face
|
|
best_face = None
|
|
best_iou = 0
|
|
|
|
for det_face in faces:
|
|
fx1, fy1, fx2, fy2 = det_face.bbox
|
|
fw, fh = fx2 - fx1, fy2 - fy1
|
|
|
|
# Calculate IoU
|
|
xi1, yi1 = max(x, fx1), max(y, fy1)
|
|
xi2, yi2 = min(x + w, fx2), min(y + h, fy2)
|
|
inter_w, inter_h = max(0, xi2 - xi1), max(0, yi2 - yi1)
|
|
inter = inter_w * inter_h
|
|
union = w * h + fw * fh - inter
|
|
|
|
iou = inter / union if union > 0 else 0
|
|
|
|
if iou > best_iou:
|
|
best_iou = iou
|
|
best_face = det_face
|
|
|
|
if best_face and best_iou > 0.3:
|
|
# Get embedding from InsightFace
|
|
embedding = best_face.embedding
|
|
|
|
if embedding is not None and len(embedding) > 0:
|
|
batch_updates.append((embedding.tolist(), face_id))
|
|
|
|
processed_frames += 1
|
|
if processed_frames % 100 == 0:
|
|
print(f"[EMBED] Progress: {processed_frames} frames, {len(batch_updates)} embeddings")
|
|
|
|
cap.release()
|
|
|
|
# Update embeddings in DB
|
|
if batch_updates:
|
|
print(f"[EMBED] Updating {len(batch_updates)} embeddings...")
|
|
|
|
for emb, face_id in batch_updates:
|
|
cur.execute(f"""
|
|
UPDATE {schema}.face_detections
|
|
SET embedding = %s
|
|
WHERE id = %s
|
|
""", (emb, face_id))
|
|
|
|
conn.commit()
|
|
|
|
# Verify
|
|
cur.execute(f"""
|
|
SELECT COUNT(embedding) FROM {schema}.face_detections
|
|
WHERE file_uuid = %s
|
|
""", (file_uuid,))
|
|
embed_count = cur.fetchone()[0]
|
|
|
|
print(f"[EMBED] Done: {embed_count} faces with embeddings")
|
|
|
|
cur.close()
|
|
conn.close()
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Extract face embeddings from video")
|
|
parser.add_argument("--file-uuid", required=True, help="Video file UUID")
|
|
parser.add_argument("--video-path", required=True, help="Video file path")
|
|
parser.add_argument("--schema", default=get_schema(), help="Database schema")
|
|
args = parser.parse_args()
|
|
|
|
extract_video_embeddings(args.file_uuid, args.video_path, args.schema)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |