- Remove session-ses_2f27.md (161KB raw session log) - Remove 49 ROOT_* duplicate files across REFERENCE/ - Remove 14 duplicate files between REFERENCE/ root and history/ - Remove asr_legacy.rs (dead code, replaced by asr.rs) - Remove src/core/worker/ (duplicate JobWorker) - Remove src/core/layers/ (empty directory) - Remove 4 .bak files in src/ - Remove 7 dead private methods in worker/processor.rs - Remove backup directory from git tracking
230 lines
6.5 KiB
Python
230 lines
6.5 KiB
Python
#!/opt/homebrew/bin/python3.11
|
|
"""
|
|
Face Embedding Extractor
|
|
職責:從視頻圖像中提取 Face ID 的人臉向量 (512-dim via ArcFace) 並存入資料庫。
|
|
"""
|
|
|
|
import sys
|
|
import os
|
|
import json
|
|
import numpy as np
|
|
import psycopg2
|
|
import cv2
|
|
|
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
|
|
|
# 嘗試引入 DeepFace
|
|
try:
|
|
from deepface import DeepFace
|
|
|
|
HAS_DEEPFACE = True
|
|
except ImportError:
|
|
HAS_DEEPFACE = False
|
|
print("[Warning] DeepFace not found. Install via: pip install deepface")
|
|
|
|
DB_URL = os.getenv("DATABASE_URL", "postgresql://accusys@localhost:5432/momentry")
|
|
OUTPUT_DIR = os.getenv("MOMENTRY_OUTPUT_DIR", "./output")
|
|
|
|
|
|
def get_db_connection():
|
|
return psycopg2.connect(DB_URL)
|
|
|
|
|
|
def extract_face_embeddings(uuid: str, video_path: str):
|
|
"""
|
|
提取指定視頻中所有 Face 的人臉向量
|
|
"""
|
|
if not HAS_DEEPFACE:
|
|
return {}
|
|
|
|
# 1. 加載 Face JSON 數據
|
|
face_path = os.path.join(OUTPUT_DIR, "quick_preview", "preview.face.json")
|
|
if not os.path.exists(face_path):
|
|
print(f" [Skip] No Face data for {uuid}")
|
|
return {}
|
|
|
|
with open(face_path, "r") as f:
|
|
face_data = json.load(f)
|
|
|
|
frames = face_data.get("frames", [])
|
|
if not frames:
|
|
return {}
|
|
|
|
# 2. 打開視頻文件
|
|
cap = cv2.VideoCapture(video_path)
|
|
if not cap.isOpened():
|
|
print(f" [Error] Cannot open video {video_path}")
|
|
return {}
|
|
|
|
# 3. 收集每個 Face ID 的裁切圖像
|
|
face_crops = {} # { "face_1": [img1, img2], ... }
|
|
|
|
print(f" [Extraction] Processing frames for {uuid}...")
|
|
|
|
# 為了性能,我們可以跳過部分幀,或者只處理前 5 張清晰的臉
|
|
MAX_SAMPLES_PER_FACE = 5
|
|
|
|
for frame_info in frames:
|
|
frame_num = frame_info.get("frame_number", 0)
|
|
|
|
# 定位幀
|
|
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num)
|
|
ret, frame = cap.read()
|
|
if not ret:
|
|
continue
|
|
|
|
# 獲取該幀的臉部數據
|
|
faces_in_frame = frame_info.get("faces", [])
|
|
|
|
for f_info in faces_in_frame:
|
|
fid = f_info.get("id") or f_info.get("face_id") or f"face_{frame_num}"
|
|
bbox = f_info.get("bbox") # [x, y, w, h]
|
|
|
|
# If no bbox but x,y,width,height
|
|
if not bbox and "x" in f_info:
|
|
bbox = [f_info["x"], f_info["y"], f_info["width"], f_info["height"]]
|
|
|
|
if fid and bbox and len(bbox) == 4:
|
|
if fid not in face_crops:
|
|
face_crops[fid] = []
|
|
|
|
if len(face_crops[fid]) < MAX_SAMPLES_PER_FACE:
|
|
x, y, w, h = int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])
|
|
# 邊界檢查
|
|
h_img, w_img = frame.shape[:2]
|
|
x = max(0, x)
|
|
y = max(0, y)
|
|
w = min(w, w_img - x)
|
|
h = min(h, h_img - y)
|
|
|
|
if w > 0 and h > 0:
|
|
crop = frame[y : y + h, x : x + w]
|
|
face_crops[fid].append(crop)
|
|
|
|
cap.release()
|
|
|
|
# 4. 使用 DeepFace 提取 Embedding
|
|
face_embeddings = {}
|
|
|
|
for fid, crops in face_crops.items():
|
|
print(f" [Embedding] Processing {fid} ({len(crops)} crops)...")
|
|
embeddings = []
|
|
|
|
for crop in crops:
|
|
try:
|
|
# DeepFace.represent 返回 embedding
|
|
# model_name='ArcFace' 輸出 512-dim
|
|
result = DeepFace.represent(
|
|
img_path=crop, model_name="ArcFace", enforce_detection=False
|
|
)
|
|
if result:
|
|
embeddings.append(np.array(result[0]["embedding"]))
|
|
except Exception:
|
|
# 忽略無法識別的臉部
|
|
pass
|
|
|
|
if embeddings:
|
|
# 平均池化
|
|
avg_embedding = np.mean(embeddings, axis=0).tolist()
|
|
face_embeddings[fid] = avg_embedding
|
|
else:
|
|
print(f" [Warning] No valid embedding extracted for {fid}")
|
|
|
|
return face_embeddings
|
|
|
|
|
|
def save_embeddings_to_db(uuid: str, embeddings: dict):
|
|
"""
|
|
將提取的人臉向量存入資料庫
|
|
"""
|
|
if not embeddings:
|
|
return
|
|
|
|
conn = get_db_connection()
|
|
cur = conn.cursor()
|
|
|
|
for fid, vector in embeddings.items():
|
|
# 查找是否已綁定
|
|
cur.execute(
|
|
"""
|
|
SELECT t.id FROM talents t
|
|
JOIN identity_bindings b ON t.id = b.talent_id
|
|
WHERE b.binding_type = 'face' AND b.binding_value = %s
|
|
""",
|
|
(fid,),
|
|
)
|
|
|
|
row = cur.fetchone()
|
|
|
|
if row:
|
|
talent_id = row[0]
|
|
# 更新向量
|
|
cur.execute(
|
|
"""
|
|
UPDATE talents SET face_embedding = %s WHERE id = %s
|
|
""",
|
|
(vector, talent_id),
|
|
)
|
|
print(
|
|
f" [DB] Updated embedding for bound Face {fid} (Talent #{talent_id})"
|
|
)
|
|
else:
|
|
# 創建新 Talent
|
|
cur.execute(
|
|
"""
|
|
INSERT INTO talents (real_name, face_embedding)
|
|
VALUES (%s, %s)
|
|
ON CONFLICT (real_name) DO UPDATE SET face_embedding = EXCLUDED.face_embedding
|
|
RETURNING id
|
|
""",
|
|
(f"Face_{fid}", vector),
|
|
)
|
|
|
|
talent_id = cur.fetchone()[0]
|
|
|
|
# 綁定關係
|
|
cur.execute(
|
|
"""
|
|
INSERT INTO identity_bindings (talent_id, binding_type, binding_value, source, confidence)
|
|
VALUES (%s, 'face', %s, 'auto_extracted', 0.9)
|
|
ON CONFLICT (binding_type, binding_value) DO NOTHING
|
|
""",
|
|
(talent_id, fid),
|
|
)
|
|
|
|
print(
|
|
f" [DB] Created new Talent 'Face_{fid}' (#{talent_id}) with embedding"
|
|
)
|
|
|
|
conn.commit()
|
|
cur.close()
|
|
conn.close()
|
|
|
|
|
|
def main():
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser(description="Extract Face Embeddings")
|
|
parser.add_argument("--uuid", required=True, help="Video UUID")
|
|
parser.add_argument("--video-path", required=True, help="Path to video file")
|
|
|
|
args = parser.parse_args()
|
|
|
|
if not os.path.exists(args.video_path):
|
|
print(f"Error: Video file not found at {args.video_path}")
|
|
sys.exit(1)
|
|
|
|
print(f"Starting Face Embedding Extraction for {args.uuid}")
|
|
|
|
# 1. 提取
|
|
embeddings = extract_face_embeddings(args.uuid, args.video_path)
|
|
|
|
# 2. 入庫
|
|
save_embeddings_to_db(args.uuid, embeddings)
|
|
|
|
print("Done.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|