Fix qdrant_request() to properly handle empty dict {} as body.
Python's 'if body' evaluates to False for empty dict, causing EOF error.
Changed:
- data = json.dumps(body).encode() if body is not None else None
Also cleaned up count_seeds() to use consistent body passing.
685 lines
20 KiB
Python
685 lines
20 KiB
Python
#!/opt/homebrew/bin/python3.11
|
|
"""
|
|
Qdrant _faces and _seeds Collection Operations
|
|
|
|
Functions for _faces:
|
|
- ensure_faces_collection(): Create _faces collection if not exists
|
|
- generate_point_id(): Generate consistent point ID
|
|
- push_face_embeddings_batch(): Batch push embeddings to Qdrant
|
|
- update_trace_ids(): Update trace_id after face tracking
|
|
- get_file_faces(): Get all face points for a file
|
|
- get_trace_representatives(): Get representative embeddings per trace
|
|
|
|
Functions for _seeds:
|
|
- ensure_seeds_collection(): Create _seeds collection if not exists
|
|
- push_seed_embedding(): Push identity seed embedding
|
|
- get_seeds(): Get all seed points
|
|
- search_seeds(): Cosine search against seeds
|
|
- delete_seed(): Delete a seed point
|
|
|
|
Collection Schema:
|
|
- _faces: 512D, Cosine, payload: {file_uuid, frame, trace_id, bbox, confidence, identity_id, identity_uuid, stranger_id}
|
|
- _seeds: 512D, Cosine, payload: {identity_id, identity_uuid, name, source, file_uuid, trace_id, tmdb_id, created_at}
|
|
"""
|
|
|
|
import os
|
|
import json
|
|
import hashlib
|
|
import urllib.request
|
|
import urllib.error
|
|
from typing import Optional
|
|
from datetime import datetime
|
|
|
|
QDRANT_URL = os.environ.get("QDRANT_URL", "http://localhost:6333")
|
|
QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY", "Test3200Test3200Test3200")
|
|
FACES_COLLECTION = "_faces"
|
|
SEEDS_COLLECTION = "_seeds"
|
|
VECTOR_DIM = 512
|
|
BATCH_SIZE = int(os.environ.get("QDRANT_BATCH_SIZE", "100"))
|
|
|
|
|
|
def qdrant_request(method: str, path: str, body: dict = None) -> dict:
|
|
"""Make HTTP request to Qdrant"""
|
|
url = f"{QDRANT_URL}{path}"
|
|
data = json.dumps(body).encode() if body is not None else None
|
|
req = urllib.request.Request(url, data=data, method=method)
|
|
req.add_header("Content-Type", "application/json")
|
|
req.add_header("Api-Key", QDRANT_API_KEY)
|
|
try:
|
|
with urllib.request.urlopen(req) as resp:
|
|
return json.loads(resp.read())
|
|
except urllib.error.HTTPError as e:
|
|
error_body = e.read().decode()
|
|
raise RuntimeError(f"Qdrant HTTP {e.code}: {error_body}")
|
|
|
|
|
|
def ensure_faces_collection() -> bool:
|
|
"""Create _faces collection if not exists"""
|
|
url = f"{QDRANT_URL}/collections/{FACES_COLLECTION}"
|
|
req = urllib.request.Request(url, method="GET")
|
|
req.add_header("Api-Key", QDRANT_API_KEY)
|
|
try:
|
|
urllib.request.urlopen(req)
|
|
return True # Collection exists
|
|
except urllib.error.HTTPError as e:
|
|
if e.code != 404:
|
|
raise RuntimeError(f"Qdrant check failed: {e.read().decode()}")
|
|
|
|
# Create collection
|
|
body = {
|
|
"vectors": {
|
|
"size": VECTOR_DIM,
|
|
"distance": "Cosine"
|
|
}
|
|
}
|
|
create_url = f"{QDRANT_URL}/collections/{FACES_COLLECTION}"
|
|
data = json.dumps(body).encode()
|
|
req = urllib.request.Request(create_url, data=data, method="PUT")
|
|
req.add_header("Content-Type", "application/json")
|
|
req.add_header("Api-Key", QDRANT_API_KEY)
|
|
try:
|
|
urllib.request.urlopen(req)
|
|
print(f"[QDRANT] Created collection: {FACES_COLLECTION}")
|
|
return True
|
|
except urllib.error.HTTPError as e:
|
|
raise RuntimeError(f"Qdrant create collection failed: {e.read().decode()}")
|
|
|
|
|
|
def generate_point_id(file_uuid: str, frame: int, trace_id: int = 0) -> int:
|
|
"""Generate consistent point ID from file_uuid + frame + trace_id"""
|
|
key = f"{file_uuid}_{frame}_{trace_id}"
|
|
return int(hashlib.md5(key.encode()).hexdigest()[:16], 16)
|
|
|
|
|
|
def push_face_embeddings_batch(
|
|
file_uuid: str,
|
|
faces: list,
|
|
publisher=None
|
|
) -> int:
|
|
"""Batch push face embeddings to _faces collection
|
|
|
|
Args:
|
|
file_uuid: Video file UUID
|
|
faces: List of {frame, trace_id, bbox, confidence, embedding}
|
|
publisher: RedisPublisher for progress reporting (optional)
|
|
|
|
Returns:
|
|
Number of successfully pushed embeddings
|
|
|
|
Raises:
|
|
RuntimeError: If Qdrant push fails
|
|
"""
|
|
if not faces:
|
|
return 0
|
|
|
|
ensure_faces_collection()
|
|
|
|
total = len(faces)
|
|
pushed = 0
|
|
|
|
for i in range(0, total, BATCH_SIZE):
|
|
batch = faces[i:i + BATCH_SIZE]
|
|
|
|
points = []
|
|
for face in batch:
|
|
point_id = generate_point_id(
|
|
file_uuid,
|
|
face["frame"],
|
|
face.get("trace_id", 0)
|
|
)
|
|
points.append({
|
|
"id": point_id,
|
|
"vector": face["embedding"],
|
|
"payload": {
|
|
"file_uuid": file_uuid,
|
|
"frame": face["frame"],
|
|
"trace_id": face.get("trace_id", 0),
|
|
"bbox": face["bbox"],
|
|
"confidence": face.get("confidence", 0.5),
|
|
"identity_id": None,
|
|
"identity_uuid": None,
|
|
"stranger_id": None,
|
|
}
|
|
})
|
|
|
|
body = {"points": points}
|
|
url = f"{QDRANT_URL}/collections/{FACES_COLLECTION}/points?wait=true"
|
|
data = json.dumps(body).encode()
|
|
req = urllib.request.Request(url, data=data, method="PUT")
|
|
req.add_header("Content-Type", "application/json")
|
|
req.add_header("Api-Key", QDRANT_API_KEY)
|
|
|
|
try:
|
|
urllib.request.urlopen(req)
|
|
pushed += len(batch)
|
|
except urllib.error.HTTPError as e:
|
|
error_body = e.read().decode()
|
|
raise RuntimeError(
|
|
f"Qdrant push failed (batch {i//BATCH_SIZE}): HTTP {e.code} - {error_body}"
|
|
)
|
|
|
|
if publisher:
|
|
pct = int((i + len(batch)) * 100 / total)
|
|
publisher.progress("face", i + len(batch), total, f"Qdrant push {pct}%")
|
|
|
|
print(f"[QDRANT] Pushed {pushed} embeddings to {FACES_COLLECTION}")
|
|
return pushed
|
|
|
|
|
|
def update_trace_ids(file_uuid: str, trace_mapping: dict) -> int:
|
|
"""Update trace_id for all face points in a file
|
|
|
|
Called by store_traced_faces.py after face tracking.
|
|
|
|
Args:
|
|
file_uuid: Video file UUID
|
|
trace_mapping: {frame: {bbox_key: trace_id}}
|
|
bbox_key = f"{x}_{y}_{width}_{height}"
|
|
|
|
Returns:
|
|
Number of updated points
|
|
"""
|
|
all_points = []
|
|
offset = None
|
|
|
|
while True:
|
|
body = {
|
|
"limit": BATCH_SIZE,
|
|
"with_payload": True,
|
|
"with_vector": True,
|
|
"filter": {
|
|
"must": [
|
|
{"key": "file_uuid", "match": {"value": file_uuid}}
|
|
]
|
|
}
|
|
}
|
|
if offset:
|
|
body["offset"] = offset
|
|
|
|
result = qdrant_request("POST", f"/collections/{FACES_COLLECTION}/points/scroll", body)
|
|
batch = result.get("result", {}).get("points", [])
|
|
if not batch:
|
|
break
|
|
all_points.extend(batch)
|
|
offset = result.get("result", {}).get("next_page_offset")
|
|
if not offset:
|
|
break
|
|
|
|
updates = []
|
|
for point in all_points:
|
|
point_id = point["id"]
|
|
payload = point.get("payload", {})
|
|
vector = point.get("vector", [])
|
|
|
|
frame = payload.get("frame")
|
|
bbox = payload.get("bbox", {})
|
|
bbox_key = f"{bbox.get('x')}_{bbox.get('y')}_{bbox.get('width')}_{bbox.get('height')}"
|
|
|
|
trace_id = trace_mapping.get(frame, {}).get(bbox_key)
|
|
if trace_id is None:
|
|
continue
|
|
|
|
payload["trace_id"] = trace_id
|
|
updates.append({
|
|
"id": point_id,
|
|
"vector": vector,
|
|
"payload": payload,
|
|
})
|
|
|
|
if not updates:
|
|
return 0
|
|
|
|
for i in range(0, len(updates), BATCH_SIZE):
|
|
batch = updates[i:i + BATCH_SIZE]
|
|
body = {"points": batch}
|
|
qdrant_request("PUT", f"/collections/{FACES_COLLECTION}/points?wait=true", body)
|
|
|
|
print(f"[QDRANT] Updated {len(updates)} trace_ids in {FACES_COLLECTION}")
|
|
return len(updates)
|
|
|
|
|
|
def delete_file_faces(file_uuid: str) -> int:
|
|
"""Delete all face points for a file
|
|
|
|
Args:
|
|
file_uuid: Video file UUID
|
|
|
|
Returns:
|
|
Number of deleted points
|
|
"""
|
|
body = {
|
|
"filter": {
|
|
"must": [
|
|
{"key": "file_uuid", "match": {"value": file_uuid}}
|
|
]
|
|
}
|
|
}
|
|
result = qdrant_request("POST", f"/collections/{FACES_COLLECTION}/points/delete", body)
|
|
deleted = result.get("result", {}).get("operation_id", 0)
|
|
print(f"[QDRANT] Deleted faces for file_uuid={file_uuid}")
|
|
return deleted
|
|
|
|
|
|
def get_file_faces(file_uuid: str) -> list:
|
|
"""Get all face points for a file
|
|
|
|
Args:
|
|
file_uuid: Video file UUID
|
|
|
|
Returns:
|
|
List of points with payload and vector
|
|
"""
|
|
all_points = []
|
|
offset = None
|
|
|
|
while True:
|
|
body = {
|
|
"limit": BATCH_SIZE,
|
|
"with_payload": True,
|
|
"with_vector": True,
|
|
"filter": {
|
|
"must": [
|
|
{"key": "file_uuid", "match": {"value": file_uuid}}
|
|
]
|
|
}
|
|
}
|
|
if offset:
|
|
body["offset"] = offset
|
|
|
|
result = qdrant_request("POST", f"/collections/{FACES_COLLECTION}/points/scroll", body)
|
|
batch = result.get("result", {}).get("points", [])
|
|
if not batch:
|
|
break
|
|
all_points.extend(batch)
|
|
offset = result.get("result", {}).get("next_page_offset")
|
|
if not offset:
|
|
break
|
|
|
|
return all_points
|
|
|
|
|
|
def count_file_faces(file_uuid: str) -> int:
|
|
"""Count face points for a file
|
|
|
|
Args:
|
|
file_uuid: Video file UUID
|
|
|
|
Returns:
|
|
Number of face points
|
|
"""
|
|
body = {
|
|
"filter": {
|
|
"must": [
|
|
{"key": "file_uuid", "match": {"value": file_uuid}}
|
|
]
|
|
}
|
|
}
|
|
result = qdrant_request("POST", f"/collections/{FACES_COLLECTION}/points/count", body)
|
|
return result.get("result", {}).get("count", 0)
|
|
|
|
|
|
def get_trace_representatives(file_uuid: str) -> dict:
|
|
"""Get representative embeddings per trace for multi-angle matching
|
|
|
|
Args:
|
|
file_uuid: Video file UUID
|
|
|
|
Returns:
|
|
{trace_id: [{'frame', 'embedding', 'bbox'}, ...]}
|
|
Each trace has 3 representatives: start, middle, end
|
|
"""
|
|
all_points = get_file_faces(file_uuid)
|
|
|
|
traces = {}
|
|
for point in all_points:
|
|
payload = point.get("payload", {})
|
|
vector = point.get("vector", [])
|
|
trace_id = payload.get("trace_id", 0)
|
|
|
|
if trace_id == 0:
|
|
continue
|
|
|
|
if trace_id not in traces:
|
|
traces[trace_id] = []
|
|
|
|
traces[trace_id].append({
|
|
"frame": payload.get("frame"),
|
|
"embedding": vector,
|
|
"bbox": payload.get("bbox", {}),
|
|
"confidence": payload.get("confidence", 0.5),
|
|
})
|
|
|
|
for trace_id in traces:
|
|
points = traces[trace_id]
|
|
points.sort(key=lambda x: x["frame"])
|
|
|
|
if len(points) <= 3:
|
|
traces[trace_id] = points
|
|
else:
|
|
start = points[0]
|
|
end = points[-1]
|
|
middle_idx = len(points) // 2
|
|
middle = points[middle_idx]
|
|
traces[trace_id] = [start, middle, end]
|
|
|
|
return traces
|
|
|
|
|
|
def get_trace_centroid(file_uuid: str, trace_id: int) -> list:
|
|
"""Get centroid embedding for a trace
|
|
|
|
Args:
|
|
file_uuid: Video file UUID
|
|
trace_id: Trace ID
|
|
|
|
Returns:
|
|
Centroid embedding (512D)
|
|
"""
|
|
reps = get_trace_representatives(file_uuid).get(trace_id, [])
|
|
|
|
if not reps:
|
|
return [0.0] * VECTOR_DIM
|
|
|
|
centroid = [0.0] * VECTOR_DIM
|
|
for rep in reps:
|
|
for i, v in enumerate(rep["embedding"]):
|
|
centroid[i] += v
|
|
|
|
count = len(reps)
|
|
for i in range(VECTOR_DIM):
|
|
centroid[i] /= count
|
|
|
|
return centroid
|
|
|
|
|
|
# ==================== _seeds Collection ====================
|
|
|
|
def ensure_seeds_collection() -> bool:
|
|
"""Create _seeds collection if not exists"""
|
|
url = f"{QDRANT_URL}/collections/{SEEDS_COLLECTION}"
|
|
req = urllib.request.Request(url, method="GET")
|
|
req.add_header("Api-Key", QDRANT_API_KEY)
|
|
try:
|
|
urllib.request.urlopen(req)
|
|
return True # Collection exists
|
|
except urllib.error.HTTPError as e:
|
|
if e.code != 404:
|
|
raise RuntimeError(f"Qdrant check failed: {e.read().decode()}")
|
|
|
|
body = {
|
|
"vectors": {
|
|
"size": VECTOR_DIM,
|
|
"distance": "Cosine"
|
|
}
|
|
}
|
|
create_url = f"{QDRANT_URL}/collections/{SEEDS_COLLECTION}"
|
|
data = json.dumps(body).encode()
|
|
req = urllib.request.Request(create_url, data=data, method="PUT")
|
|
req.add_header("Content-Type", "application/json")
|
|
req.add_header("Api-Key", QDRANT_API_KEY)
|
|
try:
|
|
urllib.request.urlopen(req)
|
|
print(f"[QDRANT] Created collection: {SEEDS_COLLECTION}")
|
|
return True
|
|
except urllib.error.HTTPError as e:
|
|
raise RuntimeError(f"Qdrant create collection failed: {e.read().decode()}")
|
|
|
|
|
|
def push_seed_embedding(
|
|
identity_id: int,
|
|
identity_uuid: str,
|
|
name: str,
|
|
embedding: list,
|
|
source: str = "tmdb",
|
|
file_uuid: str = None,
|
|
trace_id: int = None,
|
|
tmdb_id: int = None,
|
|
) -> bool:
|
|
"""Push identity seed embedding to _seeds collection
|
|
|
|
Args:
|
|
identity_id: PG identity.id
|
|
identity_uuid: Identity UUID
|
|
name: Identity name
|
|
embedding: 512D embedding
|
|
source: 'tmdb' | 'manual' | 'propagation'
|
|
file_uuid: File UUID (for manual/propagation seeds)
|
|
trace_id: Trace ID (for propagation seeds)
|
|
tmdb_id: TMDb ID (for TMDb seeds)
|
|
|
|
Returns:
|
|
True if successful
|
|
|
|
Raises:
|
|
RuntimeError: If Qdrant push fails
|
|
"""
|
|
ensure_seeds_collection()
|
|
|
|
payload = {
|
|
"identity_id": identity_id,
|
|
"identity_uuid": identity_uuid,
|
|
"name": name,
|
|
"source": source,
|
|
"created_at": datetime.now().isoformat(),
|
|
}
|
|
|
|
if file_uuid:
|
|
payload["file_uuid"] = file_uuid
|
|
if trace_id:
|
|
payload["trace_id"] = trace_id
|
|
if tmdb_id:
|
|
payload["tmdb_id"] = tmdb_id
|
|
|
|
body = {
|
|
"points": [{
|
|
"id": identity_id, # Use identity_id as point_id
|
|
"vector": embedding,
|
|
"payload": payload,
|
|
}]
|
|
}
|
|
|
|
url = f"{QDRANT_URL}/collections/{SEEDS_COLLECTION}/points?wait=true"
|
|
data = json.dumps(body).encode()
|
|
req = urllib.request.Request(url, data=data, method="PUT")
|
|
req.add_header("Content-Type", "application/json")
|
|
req.add_header("Api-Key", QDRANT_API_KEY)
|
|
|
|
try:
|
|
urllib.request.urlopen(req)
|
|
print(f"[QDRANT] Pushed seed: {name} (id={identity_id}, source={source})")
|
|
return True
|
|
except urllib.error.HTTPError as e:
|
|
error_body = e.read().decode()
|
|
raise RuntimeError(f"Qdrant seed push failed: HTTP {e.code} - {error_body}")
|
|
|
|
|
|
def get_seeds(source: str = None) -> list:
|
|
"""Get all seed points
|
|
|
|
Args:
|
|
source: Filter by source ('tmdb', 'manual', 'propagation'), or None for all
|
|
|
|
Returns:
|
|
List of seed points with payload and vector
|
|
"""
|
|
ensure_seeds_collection()
|
|
|
|
all_points = []
|
|
offset = None
|
|
|
|
while True:
|
|
body = {
|
|
"limit": BATCH_SIZE,
|
|
"with_payload": True,
|
|
"with_vector": True,
|
|
}
|
|
|
|
if source:
|
|
body["filter"] = {
|
|
"must": [
|
|
{"key": "source", "match": {"value": source}}
|
|
]
|
|
}
|
|
|
|
if offset:
|
|
body["offset"] = offset
|
|
|
|
result = qdrant_request("POST", f"/collections/{SEEDS_COLLECTION}/points/scroll", body)
|
|
batch = result.get("result", {}).get("points", [])
|
|
if not batch:
|
|
break
|
|
all_points.extend(batch)
|
|
offset = result.get("result", {}).get("next_page_offset")
|
|
if not offset:
|
|
break
|
|
|
|
return all_points
|
|
|
|
|
|
def search_seeds(query_embedding: list, limit: int = 10, threshold: float = 0.0) -> list:
|
|
"""Cosine search against seeds
|
|
|
|
Args:
|
|
query_embedding: 512D query vector
|
|
limit: Max results
|
|
threshold: Minimum score threshold
|
|
|
|
Returns:
|
|
List of {identity_id, identity_uuid, name, source, score}
|
|
"""
|
|
ensure_seeds_collection()
|
|
|
|
body = {
|
|
"vector": query_embedding,
|
|
"limit": limit,
|
|
"with_payload": True,
|
|
}
|
|
|
|
result = qdrant_request("POST", f"/collections/{SEEDS_COLLECTION}/points/search", body)
|
|
points = result.get("result", [])
|
|
|
|
results = []
|
|
for point in points:
|
|
score = point.get("score", 0)
|
|
if score < threshold:
|
|
continue
|
|
|
|
payload = point.get("payload", {})
|
|
results.append({
|
|
"identity_id": payload.get("identity_id"),
|
|
"identity_uuid": payload.get("identity_uuid"),
|
|
"name": payload.get("name"),
|
|
"source": payload.get("source"),
|
|
"score": score,
|
|
})
|
|
|
|
return results
|
|
|
|
|
|
def delete_seed(identity_id: int) -> bool:
|
|
"""Delete a seed point
|
|
|
|
Args:
|
|
identity_id: Identity ID (used as point_id)
|
|
|
|
Returns:
|
|
True if successful
|
|
"""
|
|
body = {
|
|
"points": [identity_id]
|
|
}
|
|
|
|
result = qdrant_request("POST", f"/collections/{SEEDS_COLLECTION}/points/delete?wait=true", body)
|
|
print(f"[QDRANT] Deleted seed: identity_id={identity_id}")
|
|
return result.get("result", {}).get("status") == "completed"
|
|
|
|
|
|
def count_seeds(source: str = None) -> int:
|
|
"""Count seed points
|
|
|
|
Args:
|
|
source: Filter by source, or None for all
|
|
|
|
Returns:
|
|
Number of seed points
|
|
"""
|
|
ensure_seeds_collection()
|
|
|
|
body = {}
|
|
if source:
|
|
body["filter"] = {
|
|
"must": [
|
|
{"key": "source", "match": {"value": source}}
|
|
]
|
|
}
|
|
|
|
result = qdrant_request("POST", f"/collections/{SEEDS_COLLECTION}/points/count", body)
|
|
return result.get("result", {}).get("count", 0)
|
|
|
|
|
|
def update_identity_in_faces(file_uuid: str, trace_id: int, identity_id: int, identity_uuid: str) -> int:
|
|
"""Update identity_id/identity_uuid for all face points with trace_id
|
|
|
|
Called after identity binding confirmation.
|
|
|
|
Args:
|
|
file_uuid: Video file UUID
|
|
trace_id: Trace ID
|
|
identity_id: Identity ID
|
|
identity_uuid: Identity UUID
|
|
|
|
Returns:
|
|
Number of updated points
|
|
"""
|
|
all_points = []
|
|
offset = None
|
|
|
|
while True:
|
|
body = {
|
|
"limit": BATCH_SIZE,
|
|
"with_payload": True,
|
|
"with_vector": True,
|
|
"filter": {
|
|
"must": [
|
|
{"key": "file_uuid", "match": {"value": file_uuid}},
|
|
{"key": "trace_id", "match": {"value": trace_id}},
|
|
]
|
|
}
|
|
}
|
|
if offset:
|
|
body["offset"] = offset
|
|
|
|
result = qdrant_request("POST", f"/collections/{FACES_COLLECTION}/points/scroll", body)
|
|
batch = result.get("result", {}).get("points", [])
|
|
if not batch:
|
|
break
|
|
all_points.extend(batch)
|
|
offset = result.get("result", {}).get("next_page_offset")
|
|
if not offset:
|
|
break
|
|
|
|
if not all_points:
|
|
return 0
|
|
|
|
updates = []
|
|
for point in all_points:
|
|
point_id = point["id"]
|
|
vector = point.get("vector", [])
|
|
payload = point.get("payload", {})
|
|
|
|
payload["identity_id"] = identity_id
|
|
payload["identity_uuid"] = identity_uuid
|
|
|
|
updates.append({
|
|
"id": point_id,
|
|
"vector": vector,
|
|
"payload": payload,
|
|
})
|
|
|
|
for i in range(0, len(updates), BATCH_SIZE):
|
|
batch = updates[i:i + BATCH_SIZE]
|
|
body = {"points": batch}
|
|
qdrant_request("PUT", f"/collections/{FACES_COLLECTION}/points?wait=true", body)
|
|
|
|
print(f"[QDRANT] Updated {len(updates)} face points with identity_id={identity_id}")
|
|
return len(updates) |