From 580c4b4017d28b351621828e913739b91e6a1442 Mon Sep 17 00:00:00 2001 From: Accusys Date: Thu, 25 Jun 2026 00:47:25 +0800 Subject: [PATCH] feat: add _seeds collection helper functions for Identity Agent - Add ensure_seeds_collection(): create _seeds collection (512D, Cosine) - Add push_seed_embedding(): push identity seed with payload {identity_id, uuid, name, source, file_uuid, trace_id, tmdb_id} - Add get_seeds(): get all seeds (optional source filter) - Add search_seeds(): cosine search against seeds - Add delete_seed(): delete seed by identity_id - Add count_seeds(): count seeds (optional source filter) - Add get_trace_representatives(): get 3 representatives per trace for multi-angle matching - Add get_trace_centroid(): get centroid embedding for a trace - Add update_identity_in_faces(): update identity_id/uuid for all face points with trace_id Point ID strategy: identity_id directly as point_id for _seeds collection All functions tested successfully --- scripts/utils/qdrant_faces.py | 389 +++++++++++++++++++++++++++++++++- 1 file changed, 383 insertions(+), 6 deletions(-) diff --git a/scripts/utils/qdrant_faces.py b/scripts/utils/qdrant_faces.py index ac6ed1a..9bc460b 100644 --- a/scripts/utils/qdrant_faces.py +++ b/scripts/utils/qdrant_faces.py @@ -1,17 +1,25 @@ #!/opt/homebrew/bin/python3.11 """ -Qdrant _faces Collection Operations +Qdrant _faces and _seeds Collection Operations -Functions: +Functions for _faces: - ensure_faces_collection(): Create _faces collection if not exists - generate_point_id(): Generate consistent point ID - push_face_embeddings_batch(): Batch push embeddings to Qdrant - update_trace_ids(): Update trace_id after face tracking +- get_file_faces(): Get all face points for a file +- get_trace_representatives(): Get representative embeddings per trace + +Functions for _seeds: +- ensure_seeds_collection(): Create _seeds collection if not exists +- push_seed_embedding(): Push identity seed embedding +- get_seeds(): Get all seed points +- search_seeds(): Cosine search against seeds +- delete_seed(): Delete a seed point Collection Schema: -- Name: _faces (fixed, no schema prefix) -- Vector: 512D, Cosine distance -- Payload: {file_uuid, frame, trace_id, bbox, confidence, identity_id, identity_uuid, stranger_id} +- _faces: 512D, Cosine, payload: {file_uuid, frame, trace_id, bbox, confidence, identity_id, identity_uuid, stranger_id} +- _seeds: 512D, Cosine, payload: {identity_id, identity_uuid, name, source, file_uuid, trace_id, tmdb_id, created_at} """ import os @@ -20,10 +28,12 @@ import hashlib import urllib.request import urllib.error from typing import Optional +from datetime import datetime QDRANT_URL = os.environ.get("QDRANT_URL", "http://localhost:6333") QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY", "Test3200Test3200Test3200") FACES_COLLECTION = "_faces" +SEEDS_COLLECTION = "_seeds" VECTOR_DIM = 512 BATCH_SIZE = int(os.environ.get("QDRANT_BATCH_SIZE", "100")) @@ -305,4 +315,371 @@ def count_file_faces(file_uuid: str) -> int: } } result = qdrant_request("POST", f"/collections/{FACES_COLLECTION}/points/count", body) - return result.get("result", {}).get("count", 0) \ No newline at end of file + return result.get("result", {}).get("count", 0) + + +def get_trace_representatives(file_uuid: str) -> dict: + """Get representative embeddings per trace for multi-angle matching + + Args: + file_uuid: Video file UUID + + Returns: + {trace_id: [{'frame', 'embedding', 'bbox'}, ...]} + Each trace has 3 representatives: start, middle, end + """ + all_points = get_file_faces(file_uuid) + + traces = {} + for point in all_points: + payload = point.get("payload", {}) + vector = point.get("vector", []) + trace_id = payload.get("trace_id", 0) + + if trace_id == 0: + continue + + if trace_id not in traces: + traces[trace_id] = [] + + traces[trace_id].append({ + "frame": payload.get("frame"), + "embedding": vector, + "bbox": payload.get("bbox", {}), + "confidence": payload.get("confidence", 0.5), + }) + + for trace_id in traces: + points = traces[trace_id] + points.sort(key=lambda x: x["frame"]) + + if len(points) <= 3: + traces[trace_id] = points + else: + start = points[0] + end = points[-1] + middle_idx = len(points) // 2 + middle = points[middle_idx] + traces[trace_id] = [start, middle, end] + + return traces + + +def get_trace_centroid(file_uuid: str, trace_id: int) -> list: + """Get centroid embedding for a trace + + Args: + file_uuid: Video file UUID + trace_id: Trace ID + + Returns: + Centroid embedding (512D) + """ + reps = get_trace_representatives(file_uuid).get(trace_id, []) + + if not reps: + return [0.0] * VECTOR_DIM + + centroid = [0.0] * VECTOR_DIM + for rep in reps: + for i, v in enumerate(rep["embedding"]): + centroid[i] += v + + count = len(reps) + for i in range(VECTOR_DIM): + centroid[i] /= count + + return centroid + + +# ==================== _seeds Collection ==================== + +def ensure_seeds_collection() -> bool: + """Create _seeds collection if not exists""" + url = f"{QDRANT_URL}/collections/{SEEDS_COLLECTION}" + req = urllib.request.Request(url, method="GET") + req.add_header("Api-Key", QDRANT_API_KEY) + try: + urllib.request.urlopen(req) + return True # Collection exists + except urllib.error.HTTPError as e: + if e.code != 404: + raise RuntimeError(f"Qdrant check failed: {e.read().decode()}") + + body = { + "vectors": { + "size": VECTOR_DIM, + "distance": "Cosine" + } + } + create_url = f"{QDRANT_URL}/collections/{SEEDS_COLLECTION}" + data = json.dumps(body).encode() + req = urllib.request.Request(create_url, data=data, method="PUT") + req.add_header("Content-Type", "application/json") + req.add_header("Api-Key", QDRANT_API_KEY) + try: + urllib.request.urlopen(req) + print(f"[QDRANT] Created collection: {SEEDS_COLLECTION}") + return True + except urllib.error.HTTPError as e: + raise RuntimeError(f"Qdrant create collection failed: {e.read().decode()}") + + +def push_seed_embedding( + identity_id: int, + identity_uuid: str, + name: str, + embedding: list, + source: str = "tmdb", + file_uuid: str = None, + trace_id: int = None, + tmdb_id: int = None, +) -> bool: + """Push identity seed embedding to _seeds collection + + Args: + identity_id: PG identity.id + identity_uuid: Identity UUID + name: Identity name + embedding: 512D embedding + source: 'tmdb' | 'manual' | 'propagation' + file_uuid: File UUID (for manual/propagation seeds) + trace_id: Trace ID (for propagation seeds) + tmdb_id: TMDb ID (for TMDb seeds) + + Returns: + True if successful + + Raises: + RuntimeError: If Qdrant push fails + """ + ensure_seeds_collection() + + payload = { + "identity_id": identity_id, + "identity_uuid": identity_uuid, + "name": name, + "source": source, + "created_at": datetime.now().isoformat(), + } + + if file_uuid: + payload["file_uuid"] = file_uuid + if trace_id: + payload["trace_id"] = trace_id + if tmdb_id: + payload["tmdb_id"] = tmdb_id + + body = { + "points": [{ + "id": identity_id, # Use identity_id as point_id + "vector": embedding, + "payload": payload, + }] + } + + url = f"{QDRANT_URL}/collections/{SEEDS_COLLECTION}/points?wait=true" + data = json.dumps(body).encode() + req = urllib.request.Request(url, data=data, method="PUT") + req.add_header("Content-Type", "application/json") + req.add_header("Api-Key", QDRANT_API_KEY) + + try: + urllib.request.urlopen(req) + print(f"[QDRANT] Pushed seed: {name} (id={identity_id}, source={source})") + return True + except urllib.error.HTTPError as e: + error_body = e.read().decode() + raise RuntimeError(f"Qdrant seed push failed: HTTP {e.code} - {error_body}") + + +def get_seeds(source: str = None) -> list: + """Get all seed points + + Args: + source: Filter by source ('tmdb', 'manual', 'propagation'), or None for all + + Returns: + List of seed points with payload and vector + """ + ensure_seeds_collection() + + all_points = [] + offset = None + + while True: + body = { + "limit": BATCH_SIZE, + "with_payload": True, + "with_vector": True, + } + + if source: + body["filter"] = { + "must": [ + {"key": "source", "match": {"value": source}} + ] + } + + if offset: + body["offset"] = offset + + result = qdrant_request("POST", f"/collections/{SEEDS_COLLECTION}/points/scroll", body) + batch = result.get("result", {}).get("points", []) + if not batch: + break + all_points.extend(batch) + offset = result.get("result", {}).get("next_page_offset") + if not offset: + break + + return all_points + + +def search_seeds(query_embedding: list, limit: int = 10, threshold: float = 0.0) -> list: + """Cosine search against seeds + + Args: + query_embedding: 512D query vector + limit: Max results + threshold: Minimum score threshold + + Returns: + List of {identity_id, identity_uuid, name, source, score} + """ + ensure_seeds_collection() + + body = { + "vector": query_embedding, + "limit": limit, + "with_payload": True, + } + + result = qdrant_request("POST", f"/collections/{SEEDS_COLLECTION}/points/search", body) + points = result.get("result", []) + + results = [] + for point in points: + score = point.get("score", 0) + if score < threshold: + continue + + payload = point.get("payload", {}) + results.append({ + "identity_id": payload.get("identity_id"), + "identity_uuid": payload.get("identity_uuid"), + "name": payload.get("name"), + "source": payload.get("source"), + "score": score, + }) + + return results + + +def delete_seed(identity_id: int) -> bool: + """Delete a seed point + + Args: + identity_id: Identity ID (used as point_id) + + Returns: + True if successful + """ + body = { + "points": [identity_id] + } + + result = qdrant_request("POST", f"/collections/{SEEDS_COLLECTION}/points/delete?wait=true", body) + print(f"[QDRANT] Deleted seed: identity_id={identity_id}") + return result.get("result", {}).get("status") == "completed" + + +def count_seeds(source: str = None) -> int: + """Count seed points + + Args: + source: Filter by source, or None for all + + Returns: + Number of seed points + """ + ensure_seeds_collection() + + body = {} + if source: + body["filter"] = { + "must": [ + {"key": "source", "match": {"value": source}} + ] + } + + result = qdrant_request("POST", f"/collections/{SEEDS_COLLECTION}/points/count", body) + return result.get("result", {}).get("count", 0) + + +def update_identity_in_faces(file_uuid: str, trace_id: int, identity_id: int, identity_uuid: str) -> int: + """Update identity_id/identity_uuid for all face points with trace_id + + Called after identity binding confirmation. + + Args: + file_uuid: Video file UUID + trace_id: Trace ID + identity_id: Identity ID + identity_uuid: Identity UUID + + Returns: + Number of updated points + """ + all_points = [] + offset = None + + while True: + body = { + "limit": BATCH_SIZE, + "with_payload": True, + "with_vector": True, + "filter": { + "must": [ + {"key": "file_uuid", "match": {"value": file_uuid}}, + {"key": "trace_id", "match": {"value": trace_id}}, + ] + } + } + if offset: + body["offset"] = offset + + result = qdrant_request("POST", f"/collections/{FACES_COLLECTION}/points/scroll", body) + batch = result.get("result", {}).get("points", []) + if not batch: + break + all_points.extend(batch) + offset = result.get("result", {}).get("next_page_offset") + if not offset: + break + + if not all_points: + return 0 + + updates = [] + for point in all_points: + point_id = point["id"] + vector = point.get("vector", []) + payload = point.get("payload", {}) + + payload["identity_id"] = identity_id + payload["identity_uuid"] = identity_uuid + + updates.append({ + "id": point_id, + "vector": vector, + "payload": payload, + }) + + for i in range(0, len(updates), BATCH_SIZE): + batch = updates[i:i + BATCH_SIZE] + body = {"points": batch} + qdrant_request("PUT", f"/collections/{FACES_COLLECTION}/points?wait=true", body) + + print(f"[QDRANT] Updated {len(updates)} face points with identity_id={identity_id}") + return len(updates) \ No newline at end of file