feat: add _seeds collection helper functions for Identity Agent
- Add ensure_seeds_collection(): create _seeds collection (512D, Cosine)
- Add push_seed_embedding(): push identity seed with payload {identity_id, uuid, name, source, file_uuid, trace_id, tmdb_id}
- Add get_seeds(): get all seeds (optional source filter)
- Add search_seeds(): cosine search against seeds
- Add delete_seed(): delete seed by identity_id
- Add count_seeds(): count seeds (optional source filter)
- Add get_trace_representatives(): get 3 representatives per trace for multi-angle matching
- Add get_trace_centroid(): get centroid embedding for a trace
- Add update_identity_in_faces(): update identity_id/uuid for all face points with trace_id
Point ID strategy: identity_id directly as point_id for _seeds collection
All functions tested successfully
This commit is contained in:
@@ -1,17 +1,25 @@
|
||||
#!/opt/homebrew/bin/python3.11
|
||||
"""
|
||||
Qdrant _faces Collection Operations
|
||||
Qdrant _faces and _seeds Collection Operations
|
||||
|
||||
Functions:
|
||||
Functions for _faces:
|
||||
- ensure_faces_collection(): Create _faces collection if not exists
|
||||
- generate_point_id(): Generate consistent point ID
|
||||
- push_face_embeddings_batch(): Batch push embeddings to Qdrant
|
||||
- update_trace_ids(): Update trace_id after face tracking
|
||||
- get_file_faces(): Get all face points for a file
|
||||
- get_trace_representatives(): Get representative embeddings per trace
|
||||
|
||||
Functions for _seeds:
|
||||
- ensure_seeds_collection(): Create _seeds collection if not exists
|
||||
- push_seed_embedding(): Push identity seed embedding
|
||||
- get_seeds(): Get all seed points
|
||||
- search_seeds(): Cosine search against seeds
|
||||
- delete_seed(): Delete a seed point
|
||||
|
||||
Collection Schema:
|
||||
- Name: _faces (fixed, no schema prefix)
|
||||
- Vector: 512D, Cosine distance
|
||||
- Payload: {file_uuid, frame, trace_id, bbox, confidence, identity_id, identity_uuid, stranger_id}
|
||||
- _faces: 512D, Cosine, payload: {file_uuid, frame, trace_id, bbox, confidence, identity_id, identity_uuid, stranger_id}
|
||||
- _seeds: 512D, Cosine, payload: {identity_id, identity_uuid, name, source, file_uuid, trace_id, tmdb_id, created_at}
|
||||
"""
|
||||
|
||||
import os
|
||||
@@ -20,10 +28,12 @@ import hashlib
|
||||
import urllib.request
|
||||
import urllib.error
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
|
||||
QDRANT_URL = os.environ.get("QDRANT_URL", "http://localhost:6333")
|
||||
QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY", "Test3200Test3200Test3200")
|
||||
FACES_COLLECTION = "_faces"
|
||||
SEEDS_COLLECTION = "_seeds"
|
||||
VECTOR_DIM = 512
|
||||
BATCH_SIZE = int(os.environ.get("QDRANT_BATCH_SIZE", "100"))
|
||||
|
||||
@@ -305,4 +315,371 @@ def count_file_faces(file_uuid: str) -> int:
|
||||
}
|
||||
}
|
||||
result = qdrant_request("POST", f"/collections/{FACES_COLLECTION}/points/count", body)
|
||||
return result.get("result", {}).get("count", 0)
|
||||
return result.get("result", {}).get("count", 0)
|
||||
|
||||
|
||||
def get_trace_representatives(file_uuid: str) -> dict:
|
||||
"""Get representative embeddings per trace for multi-angle matching
|
||||
|
||||
Args:
|
||||
file_uuid: Video file UUID
|
||||
|
||||
Returns:
|
||||
{trace_id: [{'frame', 'embedding', 'bbox'}, ...]}
|
||||
Each trace has 3 representatives: start, middle, end
|
||||
"""
|
||||
all_points = get_file_faces(file_uuid)
|
||||
|
||||
traces = {}
|
||||
for point in all_points:
|
||||
payload = point.get("payload", {})
|
||||
vector = point.get("vector", [])
|
||||
trace_id = payload.get("trace_id", 0)
|
||||
|
||||
if trace_id == 0:
|
||||
continue
|
||||
|
||||
if trace_id not in traces:
|
||||
traces[trace_id] = []
|
||||
|
||||
traces[trace_id].append({
|
||||
"frame": payload.get("frame"),
|
||||
"embedding": vector,
|
||||
"bbox": payload.get("bbox", {}),
|
||||
"confidence": payload.get("confidence", 0.5),
|
||||
})
|
||||
|
||||
for trace_id in traces:
|
||||
points = traces[trace_id]
|
||||
points.sort(key=lambda x: x["frame"])
|
||||
|
||||
if len(points) <= 3:
|
||||
traces[trace_id] = points
|
||||
else:
|
||||
start = points[0]
|
||||
end = points[-1]
|
||||
middle_idx = len(points) // 2
|
||||
middle = points[middle_idx]
|
||||
traces[trace_id] = [start, middle, end]
|
||||
|
||||
return traces
|
||||
|
||||
|
||||
def get_trace_centroid(file_uuid: str, trace_id: int) -> list:
|
||||
"""Get centroid embedding for a trace
|
||||
|
||||
Args:
|
||||
file_uuid: Video file UUID
|
||||
trace_id: Trace ID
|
||||
|
||||
Returns:
|
||||
Centroid embedding (512D)
|
||||
"""
|
||||
reps = get_trace_representatives(file_uuid).get(trace_id, [])
|
||||
|
||||
if not reps:
|
||||
return [0.0] * VECTOR_DIM
|
||||
|
||||
centroid = [0.0] * VECTOR_DIM
|
||||
for rep in reps:
|
||||
for i, v in enumerate(rep["embedding"]):
|
||||
centroid[i] += v
|
||||
|
||||
count = len(reps)
|
||||
for i in range(VECTOR_DIM):
|
||||
centroid[i] /= count
|
||||
|
||||
return centroid
|
||||
|
||||
|
||||
# ==================== _seeds Collection ====================
|
||||
|
||||
def ensure_seeds_collection() -> bool:
|
||||
"""Create _seeds collection if not exists"""
|
||||
url = f"{QDRANT_URL}/collections/{SEEDS_COLLECTION}"
|
||||
req = urllib.request.Request(url, method="GET")
|
||||
req.add_header("Api-Key", QDRANT_API_KEY)
|
||||
try:
|
||||
urllib.request.urlopen(req)
|
||||
return True # Collection exists
|
||||
except urllib.error.HTTPError as e:
|
||||
if e.code != 404:
|
||||
raise RuntimeError(f"Qdrant check failed: {e.read().decode()}")
|
||||
|
||||
body = {
|
||||
"vectors": {
|
||||
"size": VECTOR_DIM,
|
||||
"distance": "Cosine"
|
||||
}
|
||||
}
|
||||
create_url = f"{QDRANT_URL}/collections/{SEEDS_COLLECTION}"
|
||||
data = json.dumps(body).encode()
|
||||
req = urllib.request.Request(create_url, data=data, method="PUT")
|
||||
req.add_header("Content-Type", "application/json")
|
||||
req.add_header("Api-Key", QDRANT_API_KEY)
|
||||
try:
|
||||
urllib.request.urlopen(req)
|
||||
print(f"[QDRANT] Created collection: {SEEDS_COLLECTION}")
|
||||
return True
|
||||
except urllib.error.HTTPError as e:
|
||||
raise RuntimeError(f"Qdrant create collection failed: {e.read().decode()}")
|
||||
|
||||
|
||||
def push_seed_embedding(
|
||||
identity_id: int,
|
||||
identity_uuid: str,
|
||||
name: str,
|
||||
embedding: list,
|
||||
source: str = "tmdb",
|
||||
file_uuid: str = None,
|
||||
trace_id: int = None,
|
||||
tmdb_id: int = None,
|
||||
) -> bool:
|
||||
"""Push identity seed embedding to _seeds collection
|
||||
|
||||
Args:
|
||||
identity_id: PG identity.id
|
||||
identity_uuid: Identity UUID
|
||||
name: Identity name
|
||||
embedding: 512D embedding
|
||||
source: 'tmdb' | 'manual' | 'propagation'
|
||||
file_uuid: File UUID (for manual/propagation seeds)
|
||||
trace_id: Trace ID (for propagation seeds)
|
||||
tmdb_id: TMDb ID (for TMDb seeds)
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
|
||||
Raises:
|
||||
RuntimeError: If Qdrant push fails
|
||||
"""
|
||||
ensure_seeds_collection()
|
||||
|
||||
payload = {
|
||||
"identity_id": identity_id,
|
||||
"identity_uuid": identity_uuid,
|
||||
"name": name,
|
||||
"source": source,
|
||||
"created_at": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
if file_uuid:
|
||||
payload["file_uuid"] = file_uuid
|
||||
if trace_id:
|
||||
payload["trace_id"] = trace_id
|
||||
if tmdb_id:
|
||||
payload["tmdb_id"] = tmdb_id
|
||||
|
||||
body = {
|
||||
"points": [{
|
||||
"id": identity_id, # Use identity_id as point_id
|
||||
"vector": embedding,
|
||||
"payload": payload,
|
||||
}]
|
||||
}
|
||||
|
||||
url = f"{QDRANT_URL}/collections/{SEEDS_COLLECTION}/points?wait=true"
|
||||
data = json.dumps(body).encode()
|
||||
req = urllib.request.Request(url, data=data, method="PUT")
|
||||
req.add_header("Content-Type", "application/json")
|
||||
req.add_header("Api-Key", QDRANT_API_KEY)
|
||||
|
||||
try:
|
||||
urllib.request.urlopen(req)
|
||||
print(f"[QDRANT] Pushed seed: {name} (id={identity_id}, source={source})")
|
||||
return True
|
||||
except urllib.error.HTTPError as e:
|
||||
error_body = e.read().decode()
|
||||
raise RuntimeError(f"Qdrant seed push failed: HTTP {e.code} - {error_body}")
|
||||
|
||||
|
||||
def get_seeds(source: str = None) -> list:
|
||||
"""Get all seed points
|
||||
|
||||
Args:
|
||||
source: Filter by source ('tmdb', 'manual', 'propagation'), or None for all
|
||||
|
||||
Returns:
|
||||
List of seed points with payload and vector
|
||||
"""
|
||||
ensure_seeds_collection()
|
||||
|
||||
all_points = []
|
||||
offset = None
|
||||
|
||||
while True:
|
||||
body = {
|
||||
"limit": BATCH_SIZE,
|
||||
"with_payload": True,
|
||||
"with_vector": True,
|
||||
}
|
||||
|
||||
if source:
|
||||
body["filter"] = {
|
||||
"must": [
|
||||
{"key": "source", "match": {"value": source}}
|
||||
]
|
||||
}
|
||||
|
||||
if offset:
|
||||
body["offset"] = offset
|
||||
|
||||
result = qdrant_request("POST", f"/collections/{SEEDS_COLLECTION}/points/scroll", body)
|
||||
batch = result.get("result", {}).get("points", [])
|
||||
if not batch:
|
||||
break
|
||||
all_points.extend(batch)
|
||||
offset = result.get("result", {}).get("next_page_offset")
|
||||
if not offset:
|
||||
break
|
||||
|
||||
return all_points
|
||||
|
||||
|
||||
def search_seeds(query_embedding: list, limit: int = 10, threshold: float = 0.0) -> list:
|
||||
"""Cosine search against seeds
|
||||
|
||||
Args:
|
||||
query_embedding: 512D query vector
|
||||
limit: Max results
|
||||
threshold: Minimum score threshold
|
||||
|
||||
Returns:
|
||||
List of {identity_id, identity_uuid, name, source, score}
|
||||
"""
|
||||
ensure_seeds_collection()
|
||||
|
||||
body = {
|
||||
"vector": query_embedding,
|
||||
"limit": limit,
|
||||
"with_payload": True,
|
||||
}
|
||||
|
||||
result = qdrant_request("POST", f"/collections/{SEEDS_COLLECTION}/points/search", body)
|
||||
points = result.get("result", [])
|
||||
|
||||
results = []
|
||||
for point in points:
|
||||
score = point.get("score", 0)
|
||||
if score < threshold:
|
||||
continue
|
||||
|
||||
payload = point.get("payload", {})
|
||||
results.append({
|
||||
"identity_id": payload.get("identity_id"),
|
||||
"identity_uuid": payload.get("identity_uuid"),
|
||||
"name": payload.get("name"),
|
||||
"source": payload.get("source"),
|
||||
"score": score,
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def delete_seed(identity_id: int) -> bool:
|
||||
"""Delete a seed point
|
||||
|
||||
Args:
|
||||
identity_id: Identity ID (used as point_id)
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
body = {
|
||||
"points": [identity_id]
|
||||
}
|
||||
|
||||
result = qdrant_request("POST", f"/collections/{SEEDS_COLLECTION}/points/delete?wait=true", body)
|
||||
print(f"[QDRANT] Deleted seed: identity_id={identity_id}")
|
||||
return result.get("result", {}).get("status") == "completed"
|
||||
|
||||
|
||||
def count_seeds(source: str = None) -> int:
|
||||
"""Count seed points
|
||||
|
||||
Args:
|
||||
source: Filter by source, or None for all
|
||||
|
||||
Returns:
|
||||
Number of seed points
|
||||
"""
|
||||
ensure_seeds_collection()
|
||||
|
||||
body = {}
|
||||
if source:
|
||||
body["filter"] = {
|
||||
"must": [
|
||||
{"key": "source", "match": {"value": source}}
|
||||
]
|
||||
}
|
||||
|
||||
result = qdrant_request("POST", f"/collections/{SEEDS_COLLECTION}/points/count", body)
|
||||
return result.get("result", {}).get("count", 0)
|
||||
|
||||
|
||||
def update_identity_in_faces(file_uuid: str, trace_id: int, identity_id: int, identity_uuid: str) -> int:
|
||||
"""Update identity_id/identity_uuid for all face points with trace_id
|
||||
|
||||
Called after identity binding confirmation.
|
||||
|
||||
Args:
|
||||
file_uuid: Video file UUID
|
||||
trace_id: Trace ID
|
||||
identity_id: Identity ID
|
||||
identity_uuid: Identity UUID
|
||||
|
||||
Returns:
|
||||
Number of updated points
|
||||
"""
|
||||
all_points = []
|
||||
offset = None
|
||||
|
||||
while True:
|
||||
body = {
|
||||
"limit": BATCH_SIZE,
|
||||
"with_payload": True,
|
||||
"with_vector": True,
|
||||
"filter": {
|
||||
"must": [
|
||||
{"key": "file_uuid", "match": {"value": file_uuid}},
|
||||
{"key": "trace_id", "match": {"value": trace_id}},
|
||||
]
|
||||
}
|
||||
}
|
||||
if offset:
|
||||
body["offset"] = offset
|
||||
|
||||
result = qdrant_request("POST", f"/collections/{FACES_COLLECTION}/points/scroll", body)
|
||||
batch = result.get("result", {}).get("points", [])
|
||||
if not batch:
|
||||
break
|
||||
all_points.extend(batch)
|
||||
offset = result.get("result", {}).get("next_page_offset")
|
||||
if not offset:
|
||||
break
|
||||
|
||||
if not all_points:
|
||||
return 0
|
||||
|
||||
updates = []
|
||||
for point in all_points:
|
||||
point_id = point["id"]
|
||||
vector = point.get("vector", [])
|
||||
payload = point.get("payload", {})
|
||||
|
||||
payload["identity_id"] = identity_id
|
||||
payload["identity_uuid"] = identity_uuid
|
||||
|
||||
updates.append({
|
||||
"id": point_id,
|
||||
"vector": vector,
|
||||
"payload": payload,
|
||||
})
|
||||
|
||||
for i in range(0, len(updates), BATCH_SIZE):
|
||||
batch = updates[i:i + BATCH_SIZE]
|
||||
body = {"points": batch}
|
||||
qdrant_request("PUT", f"/collections/{FACES_COLLECTION}/points?wait=true", body)
|
||||
|
||||
print(f"[QDRANT] Updated {len(updates)} face points with identity_id={identity_id}")
|
||||
return len(updates)
|
||||
Reference in New Issue
Block a user