- Add build_skin_tone_trace_nodes() to tkg.rs (Fitzpatrick I-VI classification) - Add skin_tone_trace_nodes field to TkgResult - Standardize node naming: _trace -> _track (text uses _region) - Add external_id format column to Node Types table - Add storage names to Edge Types table - Create TKG_FORMATION_V1.0.md with Phase 0-4 definition, flow diagram, queries - Add cross-reference from identity_agent_v4.0.md to TKG Formation - Update Python scripts to executable mode
301 lines
8.6 KiB
Python
Executable File
301 lines
8.6 KiB
Python
Executable File
#!/opt/homebrew/bin/python3.11
|
|
"""
|
|
Generate Seed Embeddings - Extract embeddings from TMDb profile photos
|
|
|
|
Flow:
|
|
1. Query PG identities: source='tmdb' AND tmdb_profile IS NOT NULL
|
|
2. Download profile image from TMDb
|
|
3. Extract face embedding using CoreML FaceNet
|
|
4. Push to Qdrant _seeds collection
|
|
|
|
TMDb Image URL format:
|
|
https://image.tmdb.org/t/p/original{tmdb_profile_path}
|
|
|
|
Usage:
|
|
python generate_seed_embeddings.py
|
|
python generate_seed_embeddings.py --limit 10
|
|
python generate_seed_embeddings.py --dry-run # Don't push to Qdrant
|
|
python generate_seed_embeddings.py --tmdb-api-key YOUR_KEY
|
|
|
|
Output:
|
|
JSON with generated seed count and status
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import json
|
|
import argparse
|
|
import tempfile
|
|
import urllib.request
|
|
import urllib.error
|
|
from typing import Optional, List, Dict
|
|
|
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
|
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "utils"))
|
|
|
|
from qdrant_faces import push_seed_embedding, ensure_seeds_collection
|
|
|
|
# Config
|
|
DB_URL = os.environ.get("DATABASE_URL", "postgresql://accusys@localhost:5432/momentry")
|
|
SCHEMA = os.environ.get("DATABASE_SCHEMA", "dev")
|
|
TMDB_API_KEY = os.environ.get("TMDB_API_KEY", "")
|
|
TMDB_IMAGE_BASE = "https://image.tmdb.org/t/p/original"
|
|
|
|
# CoreML FaceNet
|
|
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
FACENET_PATH = os.path.join(SCRIPT_DIR, "..", "models", "facenet512.mlpackage")
|
|
|
|
|
|
def get_tmdb_identities(limit: int = None) -> List[Dict]:
|
|
"""Query PG for TMDb identities with profile photos
|
|
|
|
Args:
|
|
limit: Max identities to process
|
|
|
|
Returns:
|
|
List of {id, uuid, name, tmdb_id, tmdb_profile}
|
|
"""
|
|
import psycopg2
|
|
import psycopg2.extras
|
|
|
|
conn = psycopg2.connect(DB_URL)
|
|
cur = conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
|
|
|
|
if SCHEMA == "public":
|
|
table = "identities"
|
|
else:
|
|
table = f"{SCHEMA}.identities"
|
|
|
|
query = f"""
|
|
SELECT id, uuid, name, tmdb_id, tmdb_profile
|
|
FROM {table}
|
|
WHERE source = 'tmdb' AND tmdb_profile IS NOT NULL
|
|
ORDER BY id
|
|
"""
|
|
|
|
if limit:
|
|
query += f" LIMIT {limit}"
|
|
|
|
cur.execute(query)
|
|
rows = cur.fetchall()
|
|
cur.close()
|
|
conn.close()
|
|
|
|
return [dict(row) for row in rows]
|
|
|
|
|
|
def download_tmdb_image(tmdb_profile: str, tmdb_id: int) -> Optional[str]:
|
|
"""Download TMDb profile image to temp file
|
|
|
|
Args:
|
|
tmdb_profile: TMDb profile URL or path
|
|
- Full URL: 'https://image.tmdb.org/t/p/w185/xxx.jpg'
|
|
- Path only: '/xxx.jpg'
|
|
tmdb_id: TMDb ID for logging
|
|
|
|
Returns:
|
|
Path to downloaded temp file, or None if failed
|
|
"""
|
|
if not tmdb_profile:
|
|
return None
|
|
|
|
# Handle full URL or path
|
|
if tmdb_profile.startswith("http"):
|
|
url = tmdb_profile
|
|
else:
|
|
url = f"{TMDB_IMAGE_BASE}{tmdb_profile}"
|
|
|
|
# Use 'original' size for better quality
|
|
if "/w185" in url:
|
|
url = url.replace("/w185", "/original")
|
|
|
|
try:
|
|
req = urllib.request.Request(url)
|
|
with urllib.request.urlopen(req, timeout=30) as resp:
|
|
data = resp.read()
|
|
|
|
ext = url.split(".")[-1] or "jpg"
|
|
tmp_path = tempfile.mktemp(suffix=f".{ext}")
|
|
|
|
with open(tmp_path, "wb") as f:
|
|
f.write(data)
|
|
|
|
print(f"[TMDB] Downloaded: tmdb_id={tmdb_id} -> {tmp_path}")
|
|
return tmp_path
|
|
except urllib.error.HTTPError as e:
|
|
print(f"[TMDB] Download failed (HTTP {e.code}): tmdb_id={tmdb_id}")
|
|
return None
|
|
except Exception as e:
|
|
print(f"[TMDB] Download failed: tmdb_id={tmdb_id} - {e}")
|
|
return None
|
|
|
|
|
|
def extract_face_embedding(image_path: str) -> Optional[List[float]]:
|
|
"""Extract 512D face embedding from image using CoreML FaceNet
|
|
|
|
Args:
|
|
image_path: Path to image file
|
|
|
|
Returns:
|
|
512D embedding list, or None if failed
|
|
"""
|
|
import coremltools as ct
|
|
import numpy as np
|
|
import cv2
|
|
|
|
# Load CoreML model
|
|
try:
|
|
model = ct.models.MLModel(FACENET_PATH)
|
|
except Exception as e:
|
|
print(f"[COREML] Model load failed: {e}")
|
|
return None
|
|
|
|
# Read image
|
|
try:
|
|
img = cv2.imread(image_path)
|
|
if img is None:
|
|
print(f"[COREML] Image read failed: {image_path}")
|
|
return None
|
|
|
|
# Resize to 160x160
|
|
resized = cv2.resize(img, (160, 160))
|
|
|
|
# Convert HWC to CHW and normalize to [-1, 1]
|
|
normalized = (resized.astype(np.float32) / 127.5) - 1.0
|
|
normalized = np.transpose(normalized, (2, 0, 1)) # HWC -> CHW
|
|
|
|
# Add batch dim: (1, 3, 160, 160)
|
|
input_array = np.expand_dims(normalized, axis=0)
|
|
|
|
# Run model
|
|
result = model.predict({"input": input_array})
|
|
|
|
# Find output key (var_xxx)
|
|
emb_key = [k for k in result.keys() if k.startswith("var_")][0]
|
|
embedding = result[emb_key].flatten().tolist()
|
|
|
|
return embedding
|
|
except Exception as e:
|
|
print(f"[COREML] Embedding extraction failed: {e}")
|
|
return None
|
|
|
|
|
|
def generate_seed_embeddings(limit: int = None, dry_run: bool = False) -> Dict:
|
|
"""Generate embeddings for all TMDb identities
|
|
|
|
Args:
|
|
limit: Max identities to process
|
|
dry_run: Don't push to Qdrant
|
|
|
|
Returns:
|
|
Result dict with count and status
|
|
"""
|
|
result = {
|
|
"total": 0,
|
|
"processed": 0,
|
|
"success": 0,
|
|
"failed": 0,
|
|
"errors": [],
|
|
}
|
|
|
|
identities = get_tmdb_identities(limit)
|
|
result["total"] = len(identities)
|
|
|
|
if not identities:
|
|
print("[SEED] No TMDb identities with profile photos")
|
|
return result
|
|
|
|
print(f"[SEED] Found {len(identities)} TMDb identities")
|
|
|
|
if not dry_run:
|
|
ensure_seeds_collection()
|
|
|
|
for identity in identities:
|
|
identity_id = identity["id"]
|
|
identity_uuid = str(identity["uuid"])
|
|
name = identity["name"]
|
|
tmdb_id = identity.get("tmdb_id")
|
|
tmdb_profile = identity.get("tmdb_profile")
|
|
|
|
result["processed"] += 1
|
|
|
|
# Download image
|
|
tmp_path = download_tmdb_image(tmdb_profile, tmdb_id)
|
|
if not tmp_path:
|
|
result["failed"] += 1
|
|
result["errors"].append({
|
|
"identity_id": identity_id,
|
|
"name": name,
|
|
"error": "download_failed",
|
|
})
|
|
continue
|
|
|
|
# Extract embedding
|
|
embedding = extract_face_embedding(tmp_path)
|
|
|
|
# Clean up temp file
|
|
try:
|
|
os.remove(tmp_path)
|
|
except:
|
|
pass
|
|
|
|
if not embedding:
|
|
result["failed"] += 1
|
|
result["errors"].append({
|
|
"identity_id": identity_id,
|
|
"name": name,
|
|
"error": "embedding_failed",
|
|
})
|
|
continue
|
|
|
|
# Push to Qdrant
|
|
if dry_run:
|
|
print(f"[SEED] DRY RUN: Would push seed: {name} (id={identity_id})")
|
|
else:
|
|
try:
|
|
push_seed_embedding(
|
|
identity_id=identity_id,
|
|
identity_uuid=identity_uuid,
|
|
name=name,
|
|
embedding=embedding,
|
|
source="tmdb",
|
|
tmdb_id=tmdb_id,
|
|
)
|
|
result["success"] += 1
|
|
except Exception as e:
|
|
result["failed"] += 1
|
|
result["errors"].append({
|
|
"identity_id": identity_id,
|
|
"name": name,
|
|
"error": str(e),
|
|
})
|
|
|
|
print(f"[SEED] Done: {result['success']} seeds generated, {result['failed']} failed")
|
|
return result
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Generate Seed Embeddings from TMDb")
|
|
parser.add_argument("--limit", type=int, help="Max identities to process")
|
|
parser.add_argument("--dry-run", action="store_true", help="Don't push to Qdrant")
|
|
parser.add_argument("--tmdb-api-key", help="TMDb API key (optional, for rate limiting)")
|
|
parser.add_argument("--output", help="Output JSON file path")
|
|
args = parser.parse_args()
|
|
|
|
if args.tmdb_api_key:
|
|
TMDB_API_KEY = args.tmdb_api_key
|
|
|
|
result = generate_seed_embeddings(args.limit, args.dry_run)
|
|
|
|
output_json = json.dumps(result, indent=2, ensure_ascii=False)
|
|
|
|
if args.output:
|
|
with open(args.output, "w") as f:
|
|
f.write(output_json)
|
|
print(f"[SEED] Output saved to {args.output}")
|
|
else:
|
|
print(output_json)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |