diff --git a/scripts/generate_seed_embeddings.py b/scripts/generate_seed_embeddings.py new file mode 100644 index 0000000..a8e88df --- /dev/null +++ b/scripts/generate_seed_embeddings.py @@ -0,0 +1,301 @@ +#!/opt/homebrew/bin/python3.11 +""" +Generate Seed Embeddings - Extract embeddings from TMDb profile photos + +Flow: +1. Query PG identities: source='tmdb' AND tmdb_profile IS NOT NULL +2. Download profile image from TMDb +3. Extract face embedding using CoreML FaceNet +4. Push to Qdrant _seeds collection + +TMDb Image URL format: + https://image.tmdb.org/t/p/original{tmdb_profile_path} + +Usage: + python generate_seed_embeddings.py + python generate_seed_embeddings.py --limit 10 + python generate_seed_embeddings.py --dry-run # Don't push to Qdrant + python generate_seed_embeddings.py --tmdb-api-key YOUR_KEY + +Output: + JSON with generated seed count and status +""" + +import os +import sys +import json +import argparse +import tempfile +import urllib.request +import urllib.error +from typing import Optional, List, Dict + +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) +sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "utils")) + +from qdrant_faces import push_seed_embedding, ensure_seeds_collection + +# Config +DB_URL = os.environ.get("DATABASE_URL", "postgresql://accusys@localhost:5432/momentry") +SCHEMA = os.environ.get("DATABASE_SCHEMA", "dev") +TMDB_API_KEY = os.environ.get("TMDB_API_KEY", "") +TMDB_IMAGE_BASE = "https://image.tmdb.org/t/p/original" + +# CoreML FaceNet +SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) +FACENET_PATH = os.path.join(SCRIPT_DIR, "..", "models", "facenet512.mlpackage") + + +def get_tmdb_identities(limit: int = None) -> List[Dict]: + """Query PG for TMDb identities with profile photos + + Args: + limit: Max identities to process + + Returns: + List of {id, uuid, name, tmdb_id, tmdb_profile} + """ + import psycopg2 + import psycopg2.extras + + conn = psycopg2.connect(DB_URL) + cur = conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) + + if SCHEMA == "public": + table = "identities" + else: + table = f"{SCHEMA}.identities" + + query = f""" + SELECT id, uuid, name, tmdb_id, tmdb_profile + FROM {table} + WHERE source = 'tmdb' AND tmdb_profile IS NOT NULL + ORDER BY id + """ + + if limit: + query += f" LIMIT {limit}" + + cur.execute(query) + rows = cur.fetchall() + cur.close() + conn.close() + + return [dict(row) for row in rows] + + +def download_tmdb_image(tmdb_profile: str, tmdb_id: int) -> Optional[str]: + """Download TMDb profile image to temp file + + Args: + tmdb_profile: TMDb profile URL or path + - Full URL: 'https://image.tmdb.org/t/p/w185/xxx.jpg' + - Path only: '/xxx.jpg' + tmdb_id: TMDb ID for logging + + Returns: + Path to downloaded temp file, or None if failed + """ + if not tmdb_profile: + return None + + # Handle full URL or path + if tmdb_profile.startswith("http"): + url = tmdb_profile + else: + url = f"{TMDB_IMAGE_BASE}{tmdb_profile}" + + # Use 'original' size for better quality + if "/w185" in url: + url = url.replace("/w185", "/original") + + try: + req = urllib.request.Request(url) + with urllib.request.urlopen(req, timeout=30) as resp: + data = resp.read() + + ext = url.split(".")[-1] or "jpg" + tmp_path = tempfile.mktemp(suffix=f".{ext}") + + with open(tmp_path, "wb") as f: + f.write(data) + + print(f"[TMDB] Downloaded: tmdb_id={tmdb_id} -> {tmp_path}") + return tmp_path + except urllib.error.HTTPError as e: + print(f"[TMDB] Download failed (HTTP {e.code}): tmdb_id={tmdb_id}") + return None + except Exception as e: + print(f"[TMDB] Download failed: tmdb_id={tmdb_id} - {e}") + return None + + +def extract_face_embedding(image_path: str) -> Optional[List[float]]: + """Extract 512D face embedding from image using CoreML FaceNet + + Args: + image_path: Path to image file + + Returns: + 512D embedding list, or None if failed + """ + import coremltools as ct + import numpy as np + import cv2 + + # Load CoreML model + try: + model = ct.models.MLModel(FACENET_PATH) + except Exception as e: + print(f"[COREML] Model load failed: {e}") + return None + + # Read image + try: + img = cv2.imread(image_path) + if img is None: + print(f"[COREML] Image read failed: {image_path}") + return None + + # Resize to 160x160 + resized = cv2.resize(img, (160, 160)) + + # Convert HWC to CHW and normalize to [-1, 1] + normalized = (resized.astype(np.float32) / 127.5) - 1.0 + normalized = np.transpose(normalized, (2, 0, 1)) # HWC -> CHW + + # Add batch dim: (1, 3, 160, 160) + input_array = np.expand_dims(normalized, axis=0) + + # Run model + result = model.predict({"input": input_array}) + + # Find output key (var_xxx) + emb_key = [k for k in result.keys() if k.startswith("var_")][0] + embedding = result[emb_key].flatten().tolist() + + return embedding + except Exception as e: + print(f"[COREML] Embedding extraction failed: {e}") + return None + + +def generate_seed_embeddings(limit: int = None, dry_run: bool = False) -> Dict: + """Generate embeddings for all TMDb identities + + Args: + limit: Max identities to process + dry_run: Don't push to Qdrant + + Returns: + Result dict with count and status + """ + result = { + "total": 0, + "processed": 0, + "success": 0, + "failed": 0, + "errors": [], + } + + identities = get_tmdb_identities(limit) + result["total"] = len(identities) + + if not identities: + print("[SEED] No TMDb identities with profile photos") + return result + + print(f"[SEED] Found {len(identities)} TMDb identities") + + if not dry_run: + ensure_seeds_collection() + + for identity in identities: + identity_id = identity["id"] + identity_uuid = str(identity["uuid"]) + name = identity["name"] + tmdb_id = identity.get("tmdb_id") + tmdb_profile = identity.get("tmdb_profile") + + result["processed"] += 1 + + # Download image + tmp_path = download_tmdb_image(tmdb_profile, tmdb_id) + if not tmp_path: + result["failed"] += 1 + result["errors"].append({ + "identity_id": identity_id, + "name": name, + "error": "download_failed", + }) + continue + + # Extract embedding + embedding = extract_face_embedding(tmp_path) + + # Clean up temp file + try: + os.remove(tmp_path) + except: + pass + + if not embedding: + result["failed"] += 1 + result["errors"].append({ + "identity_id": identity_id, + "name": name, + "error": "embedding_failed", + }) + continue + + # Push to Qdrant + if dry_run: + print(f"[SEED] DRY RUN: Would push seed: {name} (id={identity_id})") + else: + try: + push_seed_embedding( + identity_id=identity_id, + identity_uuid=identity_uuid, + name=name, + embedding=embedding, + source="tmdb", + tmdb_id=tmdb_id, + ) + result["success"] += 1 + except Exception as e: + result["failed"] += 1 + result["errors"].append({ + "identity_id": identity_id, + "name": name, + "error": str(e), + }) + + print(f"[SEED] Done: {result['success']} seeds generated, {result['failed']} failed") + return result + + +def main(): + parser = argparse.ArgumentParser(description="Generate Seed Embeddings from TMDb") + parser.add_argument("--limit", type=int, help="Max identities to process") + parser.add_argument("--dry-run", action="store_true", help="Don't push to Qdrant") + parser.add_argument("--tmdb-api-key", help="TMDb API key (optional, for rate limiting)") + parser.add_argument("--output", help="Output JSON file path") + args = parser.parse_args() + + if args.tmdb_api_key: + TMDB_API_KEY = args.tmdb_api_key + + result = generate_seed_embeddings(args.limit, args.dry_run) + + output_json = json.dumps(result, indent=2, ensure_ascii=False) + + if args.output: + with open(args.output, "w") as f: + f.write(output_json) + print(f"[SEED] Output saved to {args.output}") + else: + print(output_json) + + +if __name__ == "__main__": + main() \ No newline at end of file