Release v1.0.0 candidate
This commit is contained in:
@@ -74,7 +74,7 @@ pub async fn ingest_rule3(pool: &PgPool, file_uuid: &str) -> Result<usize> {
|
||||
let rule1_rows: Vec<(String,)> = sqlx::query_as(
|
||||
r#"
|
||||
SELECT chunk_id FROM chunks
|
||||
WHERE uuid = $1 AND chunk_type = 'sentence' AND rule = 'rule_1'
|
||||
WHERE file_uuid = $1 AND chunk_type = 'sentence'
|
||||
AND start_frame >= $2
|
||||
AND end_frame <= $3
|
||||
"#,
|
||||
@@ -99,7 +99,7 @@ pub async fn ingest_rule3(pool: &PgPool, file_uuid: &str) -> Result<usize> {
|
||||
let texts: Vec<String> = sqlx::query_scalar(
|
||||
r#"
|
||||
SELECT text_content FROM chunks
|
||||
WHERE uuid = $1 AND chunk_type = 'sentence' AND rule = 'rule_1'
|
||||
WHERE file_uuid = $1 AND chunk_type = 'sentence'
|
||||
AND start_frame >= $2
|
||||
AND end_frame <= $3
|
||||
ORDER BY start_frame ASC
|
||||
@@ -135,7 +135,7 @@ pub async fn ingest_rule3(pool: &PgPool, file_uuid: &str) -> Result<usize> {
|
||||
);
|
||||
|
||||
// 4. Insert into dev.chunks
|
||||
let fps_query: Option<f64> = sqlx::query_scalar("SELECT fps FROM videos WHERE uuid = $1")
|
||||
let fps_query: Option<f64> = sqlx::query_scalar("SELECT fps FROM videos WHERE file_uuid = $1")
|
||||
.bind(file_uuid)
|
||||
.fetch_optional(&mut *tx)
|
||||
.await?;
|
||||
@@ -150,11 +150,11 @@ pub async fn ingest_rule3(pool: &PgPool, file_uuid: &str) -> Result<usize> {
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO chunks (
|
||||
uuid, chunk_id, chunk_index, chunk_type,
|
||||
file_uuid, chunk_id, old_chunk_id, chunk_index, chunk_type,
|
||||
start_time, end_time, fps, start_frame, end_frame,
|
||||
content, text_content, summary_text, metadata, child_chunk_ids
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
|
||||
ON CONFLICT (uuid, chunk_id) DO NOTHING
|
||||
) VALUES ($1, $2, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
|
||||
ON CONFLICT (file_uuid, old_chunk_id) DO NOTHING
|
||||
"#,
|
||||
)
|
||||
.bind(file_uuid)
|
||||
|
||||
@@ -1241,7 +1241,7 @@ impl PostgresDb {
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
|
||||
sqlx::query(&format!("DELETE FROM {} WHERE uuid = $1", chunks))
|
||||
sqlx::query(&format!("DELETE FROM {} WHERE file_uuid = $1", chunks))
|
||||
.bind(uuid)
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
@@ -1279,7 +1279,7 @@ impl PostgresDb {
|
||||
pub async fn get_chunk_count(&self, uuid: &str) -> Result<(i64, i64)> {
|
||||
let chunks = schema::table_name("chunks");
|
||||
let sentence_count: i64 = sqlx::query_scalar(&format!(
|
||||
"SELECT COUNT(*) FROM {} WHERE uuid = $1 AND chunk_type = 'sentence'",
|
||||
"SELECT COUNT(*) FROM {} WHERE file_uuid = $1 AND chunk_type = 'sentence'",
|
||||
chunks
|
||||
))
|
||||
.bind(uuid)
|
||||
@@ -1287,7 +1287,7 @@ impl PostgresDb {
|
||||
.await?;
|
||||
|
||||
let time_count: i64 = sqlx::query_scalar(&format!(
|
||||
"SELECT COUNT(*) FROM {} WHERE uuid = $1 AND chunk_type = 'time_based'",
|
||||
"SELECT COUNT(*) FROM {} WHERE file_uuid = $1 AND chunk_type = 'time_based'",
|
||||
chunks
|
||||
))
|
||||
.bind(uuid)
|
||||
@@ -2567,9 +2567,9 @@ impl PostgresDb {
|
||||
|
||||
sqlx::query(&format!(
|
||||
r#"
|
||||
INSERT INTO {} (file_id, file_uuid, chunk_id, chunk_index, chunk_type, start_time, end_time, fps, start_frame, end_frame, text_content, content, metadata, vector_id, frame_count, pre_chunk_ids, parent_chunk_id, child_chunk_ids)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12::jsonb, $13::jsonb, $14, $15, $16, $17, $18)
|
||||
ON CONFLICT (file_uuid, chunk_id) DO UPDATE SET
|
||||
INSERT INTO {} (file_id, file_uuid, chunk_id, old_chunk_id, chunk_index, chunk_type, start_time, end_time, fps, start_frame, end_frame, text_content, content, metadata, vector_id, frame_count, pre_chunk_ids, parent_chunk_id, child_chunk_ids)
|
||||
VALUES ($1, $2, $3, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12::jsonb, $13::jsonb, $14, $15, $16, $17, $18)
|
||||
ON CONFLICT (file_uuid, old_chunk_id) DO UPDATE SET
|
||||
start_time = EXCLUDED.start_time,
|
||||
end_time = EXCLUDED.end_time,
|
||||
fps = EXCLUDED.fps,
|
||||
@@ -2642,9 +2642,9 @@ impl PostgresDb {
|
||||
|
||||
sqlx::query(&format!(
|
||||
r#"
|
||||
INSERT INTO {} (file_id, file_uuid, chunk_id, chunk_index, chunk_type, start_time, end_time, fps, start_frame, end_frame, text_content, content, metadata, vector_id, frame_count, pre_chunk_ids, parent_chunk_id, child_chunk_ids)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12::jsonb, $13::jsonb, $14, $15, $16, $17, $18)
|
||||
ON CONFLICT (file_uuid, chunk_id) DO UPDATE SET
|
||||
INSERT INTO {} (file_id, file_uuid, chunk_id, old_chunk_id, chunk_index, chunk_type, start_time, end_time, fps, start_frame, end_frame, text_content, content, metadata, vector_id, frame_count, pre_chunk_ids, parent_chunk_id, child_chunk_ids)
|
||||
VALUES ($1, $2, $3, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12::jsonb, $13::jsonb, $14, $15, $16, $17, $18)
|
||||
ON CONFLICT (file_uuid, old_chunk_id) DO UPDATE SET
|
||||
start_time = EXCLUDED.start_time,
|
||||
end_time = EXCLUDED.end_time,
|
||||
fps = EXCLUDED.fps,
|
||||
@@ -4453,7 +4453,7 @@ impl PostgresDb {
|
||||
COUNT(*) as chunks_count,
|
||||
COALESCE(SUM(end_frame - start_frame), 0) as chunks_frames
|
||||
FROM {}
|
||||
WHERE uuid = $1
|
||||
WHERE file_uuid = $1
|
||||
"#,
|
||||
chunks_table
|
||||
))
|
||||
@@ -4720,7 +4720,7 @@ impl PostgresDb {
|
||||
1 - (embedding <=> $1::vector) as similarity,
|
||||
bbox
|
||||
FROM {}
|
||||
WHERE file_uuid = $2
|
||||
WHERE uuid = $2
|
||||
AND embedding IS NOT NULL
|
||||
AND 1 - (embedding <=> $1::vector) >= $3
|
||||
ORDER BY embedding <=> $1::vector
|
||||
|
||||
@@ -88,6 +88,44 @@ impl QdrantDb {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 確保指定 collection 存在,不存在則自動建立
|
||||
pub async fn ensure_collection(&self, collection: &str, vector_dim: usize) -> Result<()> {
|
||||
let url = format!("{}/collections/{}", self.base_url, collection);
|
||||
|
||||
let exists = self
|
||||
.client
|
||||
.get(&url)
|
||||
.header("api-key", &self.api_key)
|
||||
.send()
|
||||
.await
|
||||
.map(|r| r.status().is_success())
|
||||
.unwrap_or(false);
|
||||
|
||||
if exists {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let create_url = format!("{}/collections", self.base_url);
|
||||
let body = serde_json::json!({
|
||||
"vectors": {
|
||||
"size": vector_dim,
|
||||
"distance": "Cosine"
|
||||
}
|
||||
});
|
||||
|
||||
self.client
|
||||
.post(&create_url)
|
||||
.header("api-key", &self.api_key)
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&body)
|
||||
.send()
|
||||
.await
|
||||
.context(format!("Failed to create Qdrant collection: {}", collection))?;
|
||||
|
||||
tracing::info!("Created Qdrant collection: {} (dim={})", collection, vector_dim);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 將向量寫入指定 collection(支援多 collection)
|
||||
pub async fn upsert_vector_to_collection(
|
||||
&self,
|
||||
@@ -687,14 +725,13 @@ pub async fn sync_face_embeddings(file_uuid: &str) -> Result<()> {
|
||||
use sqlx::Row;
|
||||
|
||||
let pool = sqlx::PgPool::connect(&DATABASE_URL).await?;
|
||||
let schema = crate::core::config::DATABASE_SCHEMA.as_str();
|
||||
let table = crate::core::db::schema::table_name("face_detections");
|
||||
|
||||
let qdrant: QdrantDb = QdrantDb::new();
|
||||
|
||||
let query = format!(
|
||||
"SELECT id, trace_id, frame_number, embedding FROM {}.{} WHERE file_uuid = $1 AND embedding IS NOT NULL",
|
||||
schema, table
|
||||
"SELECT id, trace_id, frame_number, embedding FROM {} WHERE file_uuid = $1 AND embedding IS NOT NULL",
|
||||
table
|
||||
);
|
||||
let rows = sqlx::query(&query).bind(file_uuid).fetch_all(&pool).await?;
|
||||
|
||||
|
||||
@@ -19,15 +19,34 @@ struct EmbedResponse {
|
||||
embedding: Vec<f32>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
struct OpenAIEmbedResponse {
|
||||
data: Vec<OpenAIEmbedData>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
struct OpenAIEmbedData {
|
||||
embedding: Vec<f32>,
|
||||
}
|
||||
|
||||
impl Embedder {
|
||||
pub fn new(model: String) -> Self {
|
||||
Self::with_url(model, Self::default_url())
|
||||
}
|
||||
|
||||
pub fn with_url(model: String, base_url: String) -> Self {
|
||||
Self {
|
||||
model,
|
||||
client: Client::new(),
|
||||
base_url: "http://localhost:11434".to_string(),
|
||||
base_url,
|
||||
}
|
||||
}
|
||||
|
||||
fn default_url() -> String {
|
||||
std::env::var("MOMENTRY_EMBED_URL")
|
||||
.unwrap_or_else(|_| "http://localhost:11434".to_string())
|
||||
}
|
||||
|
||||
pub async fn embed_text(&self, text: &str) -> Result<Vec<f32>> {
|
||||
self.embed_with_prefix(text, "").await
|
||||
}
|
||||
@@ -41,32 +60,64 @@ impl Embedder {
|
||||
}
|
||||
|
||||
async fn embed_with_prefix(&self, text: &str, prefix: &str) -> Result<Vec<f32>> {
|
||||
let url = format!("{}/api/embeddings", self.base_url);
|
||||
let prompt = format!("{}{}", prefix, text);
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(&url)
|
||||
.json(&EmbedRequest {
|
||||
model: self.model.clone(),
|
||||
prompt,
|
||||
})
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to send embedding request to Ollama")?;
|
||||
// Ollama API: POST {base_url}/api/embeddings with {model, prompt}
|
||||
// OpenAI-compatible: POST {base_url}/v1/embeddings with {input, model}
|
||||
let is_openai = self.base_url.contains(":1143"); // llama.cpp ports: 11436, 11437
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let body = response.text().await.unwrap_or_default();
|
||||
anyhow::bail!("Ollama API error ({}): {}", status, body);
|
||||
if is_openai {
|
||||
let url = format!("{}/v1/embeddings", self.base_url);
|
||||
let body = serde_json::json!({
|
||||
"input": prompt,
|
||||
"model": self.model,
|
||||
});
|
||||
let response = self
|
||||
.client
|
||||
.post(&url)
|
||||
.json(&body)
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to send embedding request")?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let body_text = response.text().await.unwrap_or_default();
|
||||
anyhow::bail!("Embedding API error ({}): {}", status, body_text);
|
||||
}
|
||||
|
||||
let result: OpenAIEmbedResponse = response
|
||||
.json()
|
||||
.await
|
||||
.context("Failed to parse embedding response")?;
|
||||
|
||||
Ok(result.data.into_iter().next().map(|d| d.embedding).unwrap_or_default())
|
||||
} else {
|
||||
let url = format!("{}/api/embeddings", self.base_url);
|
||||
let response = self
|
||||
.client
|
||||
.post(&url)
|
||||
.json(&EmbedRequest {
|
||||
model: self.model.clone(),
|
||||
prompt,
|
||||
})
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to send embedding request to Ollama")?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let body_text = response.text().await.unwrap_or_default();
|
||||
anyhow::bail!("Ollama API error ({}): {}", status, body_text);
|
||||
}
|
||||
|
||||
let result: EmbedResponse = response
|
||||
.json()
|
||||
.await
|
||||
.context("Failed to parse Ollama response")?;
|
||||
|
||||
Ok(result.embedding)
|
||||
}
|
||||
|
||||
let result: EmbedResponse = response
|
||||
.json()
|
||||
.await
|
||||
.context("Failed to parse Ollama response")?;
|
||||
|
||||
Ok(result.embedding)
|
||||
}
|
||||
|
||||
pub async fn embed_chunk_content(&self, chunk: &crate::core::chunk::Chunk) -> Result<Vec<f32>> {
|
||||
|
||||
@@ -233,14 +233,24 @@ impl PythonExecutor {
|
||||
Ok(())
|
||||
};
|
||||
|
||||
// 錯誤時 rename .json.tmp → .json.err
|
||||
// 錯誤時 rename .json.tmp → .json.err(若 .tmp 非有效 JSON)
|
||||
// 若 .tmp 是有效 JSON,保留為 .json(保留部分結果)
|
||||
let mark_failed = || {
|
||||
if let Some(tmp) = &tmp_path {
|
||||
if tmp.exists() {
|
||||
if let Some(out) = &output_path {
|
||||
let mut err_path = out.to_path_buf();
|
||||
err_path.set_extension("json.err");
|
||||
let _ = std::fs::rename(tmp, &err_path);
|
||||
let is_valid = std::fs::read_to_string(tmp)
|
||||
.ok()
|
||||
.and_then(|c| serde_json::from_str::<serde_json::Value>(&c).ok())
|
||||
.is_some();
|
||||
if is_valid {
|
||||
let _ = std::fs::rename(tmp, out);
|
||||
tracing::warn!("[Executor] Partial output preserved: {:?}", out);
|
||||
} else {
|
||||
let mut err_path = out.to_path_buf();
|
||||
err_path.set_extension("json.err");
|
||||
let _ = std::fs::rename(tmp, &err_path);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -65,10 +65,17 @@ pub async fn process_scene_classification(
|
||||
});
|
||||
}
|
||||
|
||||
let coreml_path = "/Users/accusys/models/resnet18_places365.mlpackage";
|
||||
let mut args = vec![video_path, output_path];
|
||||
if std::path::Path::new(coreml_path).exists() {
|
||||
args.push("--model");
|
||||
args.push(coreml_path);
|
||||
}
|
||||
|
||||
executor
|
||||
.run(
|
||||
"scene_classifier.py",
|
||||
&[video_path, output_path],
|
||||
&args,
|
||||
uuid,
|
||||
"SCENE",
|
||||
Some(SCENE_TIMEOUT),
|
||||
|
||||
@@ -1,146 +1,235 @@
|
||||
use anyhow::{Context, Result};
|
||||
use serde::Deserialize;
|
||||
use tracing::{error, info};
|
||||
use std::collections::HashMap;
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
use crate::core::db::PostgresDb;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct FaceDetection {
|
||||
face_id: String,
|
||||
embedding: Vec<f32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct TmdbIdentity {
|
||||
id: i64,
|
||||
id: i32,
|
||||
name: String,
|
||||
face_embedding: Vec<f32>,
|
||||
}
|
||||
|
||||
const MATCH_THRESHOLD: f32 = 0.55;
|
||||
|
||||
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
|
||||
if a.len() != b.len() || a.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
|
||||
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
if norm_a == 0.0 || norm_b == 0.0 {
|
||||
return 0.0;
|
||||
}
|
||||
dot / (norm_a * norm_b)
|
||||
if a.len() != b.len() || a.is_empty() { return 0.0; }
|
||||
let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
|
||||
let na: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
let nb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
if na == 0.0 || nb == 0.0 { 0.0 } else { dot / (na * nb) }
|
||||
}
|
||||
|
||||
/// Match unassigned face detections against TMDb-sourced identities.
|
||||
/// For each face detection with identity_id IS NULL, compute cosine similarity
|
||||
/// against all TMDb identities that have face_embedding set.
|
||||
/// If similarity > MATCH_THRESHOLD, bind the face to the identity.
|
||||
/// Match face detections against TMDb identities using iterative multi-angle propagation.
|
||||
/// Round 1: seed match against TMDb face_embeddings (threshold 0.50)
|
||||
/// Round 2+: propagate to remaining traces using matched faces as reference
|
||||
pub async fn match_faces_against_tmdb(db: &PostgresDb, file_uuid: &str) -> Result<usize> {
|
||||
// Step 1: Fetch unassigned face detections for this file
|
||||
let detections: Vec<FaceDetection> = sqlx::query_as::<_, (String, Vec<f32>)>(
|
||||
"SELECT face_id, embedding FROM dev.face_detections \
|
||||
WHERE file_uuid = $1 AND identity_id IS NULL AND embedding IS NOT NULL",
|
||||
let pool = db.pool();
|
||||
|
||||
// Step 1: Load TMDb identities with face embeddings
|
||||
let tmdb_rows = sqlx::query_as::<_, (i32, String, Vec<f32>)>(
|
||||
"SELECT id, name, face_embedding::real[] FROM dev.identities WHERE source='tmdb' AND face_embedding IS NOT NULL"
|
||||
)
|
||||
.fetch_all(pool).await?;
|
||||
|
||||
if tmdb_rows.is_empty() {
|
||||
info!("[TKG-MATCH] No TMDb identities with face embeddings");
|
||||
return Ok(0);
|
||||
}
|
||||
info!("[TKG-MATCH] {} TMDb seeds loaded", tmdb_rows.len());
|
||||
|
||||
// Step 2: Load face_detections grouped by trace_id
|
||||
let fd_rows = sqlx::query_as::<_, (i32, Vec<f32>)>(
|
||||
"SELECT trace_id, embedding FROM dev.face_detections \
|
||||
WHERE file_uuid=$1 AND trace_id IS NOT NULL AND embedding IS NOT NULL \
|
||||
ORDER BY trace_id"
|
||||
)
|
||||
.bind(file_uuid)
|
||||
.fetch_all(db.pool())
|
||||
.await
|
||||
.context("Failed to fetch unassigned face detections")?
|
||||
.into_iter()
|
||||
.map(|(face_id, embedding)| FaceDetection { face_id, embedding })
|
||||
.collect();
|
||||
.fetch_all(pool).await?;
|
||||
|
||||
if detections.is_empty() {
|
||||
info!(
|
||||
"[TMDB-FACE] No unassigned face detections for {}",
|
||||
file_uuid
|
||||
);
|
||||
if fd_rows.is_empty() {
|
||||
info!("[TKG-MATCH] No face detections for {}", file_uuid);
|
||||
return Ok(0);
|
||||
}
|
||||
|
||||
// Step 2: Fetch TMDb identities with face embeddings
|
||||
let identities: Vec<TmdbIdentity> = sqlx::query_as::<_, (i64, String, Vec<f32>)>(
|
||||
"SELECT id, name, face_embedding::real[] FROM dev.identities \
|
||||
WHERE source = 'tmdb' AND face_embedding IS NOT NULL",
|
||||
)
|
||||
.fetch_all(db.pool())
|
||||
.await
|
||||
.context("Failed to fetch TMDb identities")?
|
||||
.into_iter()
|
||||
.map(|(id, name, emb)| TmdbIdentity {
|
||||
id,
|
||||
name,
|
||||
face_embedding: emb,
|
||||
})
|
||||
.collect();
|
||||
|
||||
if identities.is_empty() {
|
||||
info!("[TMDB-FACE] No TMDb identities with face embeddings for matching");
|
||||
return Ok(0);
|
||||
let mut trace_faces: HashMap<i32, Vec<Vec<f32>>> = HashMap::new();
|
||||
for (tid, emb) in &fd_rows {
|
||||
trace_faces.entry(*tid).or_default().push(emb.clone());
|
||||
}
|
||||
// Dedup near-identical embeddings within trace
|
||||
for faces in trace_faces.values_mut() {
|
||||
faces.sort_by(|a, b| a[0].partial_cmp(&b[0]).unwrap_or(std::cmp::Ordering::Equal));
|
||||
faces.dedup_by(|a, b| cosine_similarity(a, b) > 0.99);
|
||||
}
|
||||
|
||||
info!(
|
||||
"[TMDB-FACE] Matching {} face detections against {} TMDb identities",
|
||||
detections.len(),
|
||||
identities.len()
|
||||
);
|
||||
let total = trace_faces.len();
|
||||
info!("[TKG-MATCH] {} traces with {} faces", total, fd_rows.len());
|
||||
|
||||
// Step 3: For each face detection, find best matching identity
|
||||
let mut bindings_created = 0usize;
|
||||
// Step 3: Iterative matching
|
||||
const TH: f32 = 0.50;
|
||||
let mut matched: HashMap<i32, (i32, String)> = HashMap::new(); // trace_id → (identity_id, name)
|
||||
|
||||
for det in &detections {
|
||||
let mut best_match: Option<(i64, f32)> = None;
|
||||
// Round 1: against TMDb seeds
|
||||
for (&tid, faces) in &trace_faces {
|
||||
let mut best_id = 0i32;
|
||||
let mut best_name = String::new();
|
||||
let mut best_sim = 0.0f32;
|
||||
for (id, name, tmdb_emb) in &tmdb_rows {
|
||||
for face in faces {
|
||||
let s = cosine_similarity(face, tmdb_emb);
|
||||
if s > best_sim { best_sim = s; best_id = *id; best_name = name.clone(); }
|
||||
}
|
||||
}
|
||||
if best_sim >= TH {
|
||||
matched.insert(tid, (best_id, best_name));
|
||||
}
|
||||
}
|
||||
info!("[TKG-MATCH] Round 1: {} ({}/{})", matched.len(), matched.len() * 100 / total, total);
|
||||
|
||||
for identity in &identities {
|
||||
let sim = cosine_similarity(&det.embedding, &identity.face_embedding);
|
||||
if sim > MATCH_THRESHOLD {
|
||||
match best_match {
|
||||
Some((_, best_sim)) if sim > best_sim => {
|
||||
best_match = Some((identity.id, sim));
|
||||
}
|
||||
None => {
|
||||
best_match = Some((identity.id, sim));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
// Round 2+: propagate
|
||||
for round_n in 2..=10 {
|
||||
let prev = matched.len();
|
||||
let mut seed_pool: HashMap<i32, Vec<&Vec<f32>>> = HashMap::new();
|
||||
for (&tid, (id, _)) in &matched {
|
||||
if let Some(faces) = trace_faces.get(&tid) {
|
||||
seed_pool.entry(*id).or_default().extend(faces.iter());
|
||||
}
|
||||
}
|
||||
|
||||
if let Some((identity_id, similarity)) = best_match {
|
||||
// Update face_detection with identity_id
|
||||
let _ = sqlx::query(
|
||||
"UPDATE dev.face_detections SET identity_id = $1, identity_confidence = $2 \
|
||||
WHERE file_uuid = $3 AND face_id = $4",
|
||||
)
|
||||
.bind(identity_id)
|
||||
.bind(similarity as f64)
|
||||
.bind(file_uuid)
|
||||
.bind(&det.face_id)
|
||||
.execute(db.pool())
|
||||
.await
|
||||
.ok();
|
||||
|
||||
// Also create identity_binding
|
||||
let _ = sqlx::query(
|
||||
"INSERT INTO dev.identity_bindings (identity_id, identity_type, identity_value, source, confidence) \
|
||||
VALUES ($1, 'face', $2, 'tmdb_agent', $3) \
|
||||
ON CONFLICT (identity_id, identity_type, identity_value) DO UPDATE SET confidence = EXCLUDED.confidence"
|
||||
)
|
||||
.bind(identity_id)
|
||||
.bind(&det.face_id)
|
||||
.bind(similarity as f64)
|
||||
.execute(db.pool())
|
||||
.await
|
||||
.ok();
|
||||
|
||||
bindings_created += 1;
|
||||
let mut new_matches: Vec<(i32, i32, String)> = Vec::new();
|
||||
for (&tid, faces) in &trace_faces {
|
||||
if matched.contains_key(&tid) || faces.is_empty() { continue; }
|
||||
let ref_face = &faces[0];
|
||||
let mut best_id = 0i32;
|
||||
let mut best_name = String::new();
|
||||
let mut best_sim = 0.0f32;
|
||||
for (&id, seed_faces) in &seed_pool {
|
||||
for seed in seed_faces {
|
||||
let s = cosine_similarity(ref_face, seed);
|
||||
if s > best_sim { best_sim = s; best_id = id; }
|
||||
}
|
||||
}
|
||||
if best_sim >= TH {
|
||||
// Look up name for this id
|
||||
for (id, name, _) in &tmdb_rows {
|
||||
if *id == best_id { best_name = name.clone(); break; }
|
||||
}
|
||||
new_matches.push((tid, best_id, best_name));
|
||||
}
|
||||
}
|
||||
for (tid, id, name) in new_matches {
|
||||
matched.insert(tid, (id, name));
|
||||
}
|
||||
let new = matched.len() - prev;
|
||||
if new < 5 { break; }
|
||||
}
|
||||
|
||||
info!(
|
||||
"[TMDB-FACE] Created {} face-to-TMDb bindings for {}",
|
||||
bindings_created, file_uuid
|
||||
);
|
||||
// Step 4: Quality control
|
||||
// 4a: Remove low-confidence traces (fewer than 4 face detections)
|
||||
let mut after_qc = HashMap::new();
|
||||
for (&tid, &(id, ref name)) in &matched {
|
||||
let cnt: i64 = sqlx::query_scalar(
|
||||
"SELECT COUNT(*) FROM dev.face_detections WHERE file_uuid=$1 AND trace_id=$2"
|
||||
)
|
||||
.bind(file_uuid).bind(tid)
|
||||
.fetch_one(pool).await.unwrap_or(0);
|
||||
if cnt >= 4 {
|
||||
after_qc.insert(tid, (id, name.clone()));
|
||||
} else {
|
||||
info!("[TKG-QC] trace {} removed: only {} face(s), need >= 4", tid, cnt);
|
||||
}
|
||||
}
|
||||
let matched = after_qc;
|
||||
let removed_low = total - matched.len();
|
||||
if removed_low > 0 {
|
||||
info!("[TKG-QC] Removed {} low-confidence traces (< 4 faces)", removed_low);
|
||||
}
|
||||
|
||||
Ok(bindings_created)
|
||||
// 4b: Temporal collision check
|
||||
let removed_collisions = quality_check_temporal_collisions(pool, file_uuid).await?;
|
||||
if removed_collisions > 0 {
|
||||
info!("[TKG-QC] Resolved {} temporal collisions", removed_collisions);
|
||||
}
|
||||
|
||||
// Step 5: Update DB
|
||||
let mut updated = 0usize;
|
||||
for (&tid, &(id, _)) in &matched {
|
||||
let r = sqlx::query(
|
||||
"UPDATE dev.face_detections SET identity_id=$1 WHERE file_uuid=$2 AND trace_id=$3"
|
||||
)
|
||||
.bind(id).bind(file_uuid).bind(tid)
|
||||
.execute(pool).await?;
|
||||
if r.rows_affected() > 0 { updated += 1; }
|
||||
}
|
||||
|
||||
info!("[TKG-MATCH] Done: {}/{} traces matched ({}%)",
|
||||
matched.len(), total, matched.len() * 100 / total);
|
||||
Ok(updated)
|
||||
}
|
||||
|
||||
/// Quality check: detect temporal collisions where two different traces of the same
|
||||
/// identity appear in the same frame (impossible for one person).
|
||||
/// Unbind the lower-confidence trace from the conflicting pair.
|
||||
/// RCA reference: docs_v1.0/API_V1.0.0/INTERNAL/RCA_TRACE39_TRACE45_COLLISION_V1.0.0.md
|
||||
async fn quality_check_temporal_collisions(pool: &sqlx::PgPool, file_uuid: &str) -> Result<usize> {
|
||||
// Find all collision pairs: same identity, same frame, different trace
|
||||
let collisions = sqlx::query_as::<_, (i32, i32, i32, i32)>(
|
||||
r#"
|
||||
SELECT a.identity_id, a.trace_id, b.trace_id, a.frame_number
|
||||
FROM dev.face_detections a
|
||||
JOIN dev.face_detections b
|
||||
ON a.file_uuid = b.file_uuid
|
||||
AND a.frame_number = b.frame_number
|
||||
AND a.trace_id < b.trace_id
|
||||
WHERE a.file_uuid = $1
|
||||
AND a.identity_id IS NOT NULL
|
||||
AND a.identity_id = b.identity_id
|
||||
ORDER BY a.identity_id, a.frame_number
|
||||
"#
|
||||
)
|
||||
.bind(file_uuid)
|
||||
.fetch_all(pool).await?;
|
||||
|
||||
if collisions.is_empty() {
|
||||
return Ok(0);
|
||||
}
|
||||
|
||||
// Group collisions by (identity_id, trace_a, trace_b) and count frames
|
||||
use std::collections::HashMap;
|
||||
let mut collision_groups: HashMap<(i32, i32, i32), usize> = HashMap::new();
|
||||
for (id, ta, tb, _) in &collisions {
|
||||
*collision_groups.entry((*id, *ta, *tb)).or_default() += 1;
|
||||
}
|
||||
|
||||
let mut unbound = 0usize;
|
||||
for ((id, ta, tb), overlap_frames) in &collision_groups {
|
||||
// Get face detection count for each trace
|
||||
let cnt_a: i64 = sqlx::query_scalar(
|
||||
"SELECT COUNT(*) FROM dev.face_detections WHERE file_uuid=$1 AND trace_id=$2 AND identity_id=$3"
|
||||
)
|
||||
.bind(file_uuid).bind(ta).bind(id)
|
||||
.fetch_one(pool).await.unwrap_or(0);
|
||||
|
||||
let cnt_b: i64 = sqlx::query_scalar(
|
||||
"SELECT COUNT(*) FROM dev.face_detections WHERE file_uuid=$1 AND trace_id=$2 AND identity_id=$3"
|
||||
)
|
||||
.bind(file_uuid).bind(tb).bind(id)
|
||||
.fetch_one(pool).await.unwrap_or(0);
|
||||
|
||||
// Unbind the trace with fewer detections (likely the false positive)
|
||||
let victim = if cnt_a <= cnt_b { *ta } else { *tb };
|
||||
let victim_cnt = if cnt_a <= cnt_b { cnt_a } else { cnt_b };
|
||||
|
||||
sqlx::query(
|
||||
"UPDATE dev.face_detections SET identity_id=NULL WHERE file_uuid=$1 AND trace_id=$2"
|
||||
)
|
||||
.bind(file_uuid).bind(victim)
|
||||
.execute(pool).await?;
|
||||
|
||||
unbound += 1;
|
||||
warn!("[TKG-QC] Collision identity={}: trace {} vs trace {} ({} overlap frames). Unbound trace {} ({} detections)",
|
||||
id, ta, tb, overlap_frames, victim, victim_cnt);
|
||||
}
|
||||
|
||||
Ok(unbound)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user