Merge branch 'main' of http://192.168.110.200:3000/admin/momentry_core
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
use axum::{
|
||||
body::Body,
|
||||
extract::{Path, Query, State},
|
||||
http::StatusCode,
|
||||
response::Json,
|
||||
http::{header, StatusCode},
|
||||
response::{IntoResponse, Json, Response},
|
||||
routing::{get, post},
|
||||
Router,
|
||||
};
|
||||
@@ -16,6 +17,22 @@ pub fn trace_agent_routes() -> Router<crate::api::types::AppState> {
|
||||
"/api/v1/file/:file_uuid/trace/:trace_id/faces",
|
||||
get(list_trace_faces),
|
||||
)
|
||||
.route(
|
||||
"/api/v1/file/:file_uuid/trace/:trace_id/representative-face",
|
||||
get(get_representative_face),
|
||||
)
|
||||
.route(
|
||||
"/api/v1/file/:file_uuid/trace/:trace_id/thumbnail",
|
||||
get(get_trace_thumbnail),
|
||||
)
|
||||
.route(
|
||||
"/api/v1/file/:file_uuid/identities/:identity_uuid_a/co-occur-with/:identity_uuid_b",
|
||||
get(get_cooccurrence),
|
||||
)
|
||||
.route(
|
||||
"/api/v1/file/:file_uuid/tkg/rebuild",
|
||||
post(rebuild_tkg),
|
||||
)
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
@@ -328,3 +345,441 @@ async fn list_trace_faces(
|
||||
faces,
|
||||
}))
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct RepFaceBbox {
|
||||
x: i32,
|
||||
y: i32,
|
||||
width: i32,
|
||||
height: i32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct RepFaceResult {
|
||||
frame_number: i64,
|
||||
timestamp_secs: f64,
|
||||
bbox: RepFaceBbox,
|
||||
confidence: f64,
|
||||
quality_score: f64,
|
||||
blur_score: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct RepFaceResponse {
|
||||
success: bool,
|
||||
file_uuid: String,
|
||||
trace_id: i32,
|
||||
face_count: i64,
|
||||
representative: RepFaceResult,
|
||||
}
|
||||
|
||||
struct RepFaceSelection {
|
||||
frame: i64,
|
||||
x: i32,
|
||||
y: i32,
|
||||
w: i32,
|
||||
h: i32,
|
||||
conf: f64,
|
||||
blur: f64,
|
||||
score: f64,
|
||||
video_path: String,
|
||||
fps: f64,
|
||||
face_count: i64,
|
||||
}
|
||||
|
||||
async fn select_rep_face<F, T>(
|
||||
pool: &sqlx::PgPool,
|
||||
file_uuid: &str,
|
||||
trace_id: i32,
|
||||
err_fn: F,
|
||||
) -> Result<RepFaceSelection, T>
|
||||
where
|
||||
F: Fn(anyhow::Error) -> T,
|
||||
{
|
||||
use crate::core::db::schema;
|
||||
let fd_table = schema::table_name("face_detections");
|
||||
let video_table = schema::table_name("videos");
|
||||
|
||||
let fps: f64 = sqlx::query_scalar(&format!(
|
||||
"SELECT COALESCE(fps, 25.0) FROM {} WHERE file_uuid = $1", video_table
|
||||
))
|
||||
.bind(file_uuid)
|
||||
.fetch_optional(pool)
|
||||
.await
|
||||
.map_err(|e| err_fn(anyhow::anyhow!("{}", e)))?
|
||||
.unwrap_or(25.0);
|
||||
|
||||
let face_count: (i64,) = sqlx::query_as(&format!(
|
||||
"SELECT COUNT(*) FROM {} WHERE file_uuid = $1 AND trace_id = $2", fd_table
|
||||
))
|
||||
.bind(file_uuid)
|
||||
.bind(trace_id)
|
||||
.fetch_one(pool)
|
||||
.await
|
||||
.map_err(|e| err_fn(anyhow::anyhow!("{}", e)))?;
|
||||
|
||||
struct Candidate { frame: i64, x: i32, y: i32, w: i32, h: i32, conf: f64, score: f64 }
|
||||
|
||||
let rows = sqlx::query_as::<_, (i64, i32, i32, i32, i32, f64)>(&format!(
|
||||
"SELECT frame_number::bigint, x, y, width, height, confidence::float8 \
|
||||
FROM {} WHERE file_uuid = $1 AND trace_id = $2 AND confidence > 0.7 \
|
||||
AND ((metadata->>'qc_ok')::boolean IS NULL OR (metadata->>'qc_ok')::boolean = true) \
|
||||
ORDER BY (width::float8 * height::float8) * confidence::float8 DESC LIMIT 10",
|
||||
fd_table
|
||||
))
|
||||
.bind(file_uuid).bind(trace_id)
|
||||
.fetch_all(pool)
|
||||
.await
|
||||
.map_err(|e| err_fn(anyhow::anyhow!("{}", e)))?;
|
||||
|
||||
if rows.is_empty() {
|
||||
return Err(err_fn(anyhow::anyhow!("No suitable face found")));
|
||||
}
|
||||
|
||||
let candidates: Vec<Candidate> = rows.into_iter()
|
||||
.map(|(frame, x, y, w, h, conf)| {
|
||||
let score = (w as f64 * h as f64) * conf;
|
||||
Candidate { frame, x, y, w, h, conf, score }
|
||||
})
|
||||
.collect();
|
||||
|
||||
let video_path: String = sqlx::query_scalar(&format!(
|
||||
"SELECT file_path FROM {} WHERE file_uuid = $1", video_table
|
||||
))
|
||||
.bind(file_uuid)
|
||||
.fetch_optional(pool)
|
||||
.await
|
||||
.map_err(|e| err_fn(anyhow::anyhow!("{}", e)))?
|
||||
.ok_or_else(|| err_fn(anyhow::anyhow!("Video not found")))?;
|
||||
|
||||
let mut best = candidates[0].frame;
|
||||
let mut best_blur = f64::MAX;
|
||||
let mut best_idx = 0usize;
|
||||
|
||||
for (i, c) in candidates.iter().enumerate() {
|
||||
let seek = c.frame as f64 / fps;
|
||||
if let Ok(output) = tokio::process::Command::new("ffmpeg")
|
||||
.args(["-ss", &format!("{:.2}", seek), "-i", &video_path,
|
||||
"-vframes", "1", "-vf", &format!("crop={}:{}:{}:{},blurdetect", c.w, c.h, c.x, c.y),
|
||||
"-f", "null", "-"])
|
||||
.output().await
|
||||
{
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
for line in stderr.lines() {
|
||||
if let Some(blur_str) = line.split("blur mean: ").nth(1) {
|
||||
if let Ok(blur) = blur_str.trim().parse::<f64>() {
|
||||
if blur < best_blur { best_blur = blur; best = c.frame; best_idx = i; }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let chosen = &candidates[best_idx];
|
||||
Ok(RepFaceSelection {
|
||||
frame: chosen.frame, x: chosen.x, y: chosen.y, w: chosen.w, h: chosen.h,
|
||||
conf: chosen.conf, blur: best_blur, score: chosen.score,
|
||||
video_path, fps, face_count: face_count.0,
|
||||
})
|
||||
}
|
||||
|
||||
async fn get_representative_face(
|
||||
State(state): State<crate::api::types::AppState>,
|
||||
Path((file_uuid, trace_id)): Path<(String, i32)>,
|
||||
) -> Result<Json<RepFaceResponse>, (StatusCode, Json<serde_json::Value>)> {
|
||||
let sel = select_rep_face(state.db.pool(), &file_uuid, trace_id, |e| {
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error": e.to_string()})))
|
||||
}).await?;
|
||||
|
||||
Ok(Json(RepFaceResponse {
|
||||
success: true,
|
||||
file_uuid,
|
||||
trace_id,
|
||||
face_count: sel.face_count,
|
||||
representative: RepFaceResult {
|
||||
frame_number: sel.frame,
|
||||
timestamp_secs: sel.frame as f64 / sel.fps,
|
||||
bbox: RepFaceBbox { x: sel.x, y: sel.y, width: sel.w, height: sel.h },
|
||||
confidence: sel.conf,
|
||||
quality_score: sel.score,
|
||||
blur_score: sel.blur,
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
async fn get_trace_thumbnail(
|
||||
State(state): State<crate::api::types::AppState>,
|
||||
Path((file_uuid, trace_id)): Path<(String, i32)>,
|
||||
) -> Result<Response, (StatusCode, Json<serde_json::Value>)> {
|
||||
let sel = select_rep_face(state.db.pool(), &file_uuid, trace_id, |e| {
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error": e.to_string()})))
|
||||
}).await?;
|
||||
|
||||
let seek = sel.frame as f64 / sel.fps;
|
||||
let tmp = std::env::temp_dir().join(format!("trace_{}_{}.jpg", file_uuid, trace_id));
|
||||
|
||||
let status = tokio::process::Command::new("ffmpeg")
|
||||
.args([
|
||||
"-ss", &format!("{:.2}", seek),
|
||||
"-i", &sel.video_path,
|
||||
"-vframes", "1",
|
||||
"-vf", &format!("crop={}:{}:{}:{},scale=320:320", sel.w, sel.h, sel.x, sel.y),
|
||||
"-q:v", "2",
|
||||
"-y", &tmp.to_string_lossy().to_string(),
|
||||
])
|
||||
.output()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error": e.to_string()})))
|
||||
})?;
|
||||
|
||||
if !status.status.success() {
|
||||
return Err((StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error": "FFmpeg failed"}))));
|
||||
}
|
||||
|
||||
let bytes = tokio::fs::read(&tmp).await.map_err(|e| {
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error": e.to_string()})))
|
||||
})?;
|
||||
|
||||
let _ = tokio::fs::remove_file(&tmp).await;
|
||||
|
||||
Ok(Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header(header::CONTENT_TYPE, "image/jpeg")
|
||||
.header(header::CACHE_CONTROL, "public, max-age=86400")
|
||||
.body(Body::from(bytes))
|
||||
.unwrap())
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct CoOccurIdentity {
|
||||
identity_uuid: String,
|
||||
name: String,
|
||||
trace_id: i32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct CoOccurRepFace {
|
||||
frame_number: i64,
|
||||
bbox: RepFaceBbox,
|
||||
confidence: f64,
|
||||
thumbnail_url: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct CoOccurrence {
|
||||
frame_number: i64,
|
||||
timestamp_secs: f64,
|
||||
total_cooccurrence_frames: i64,
|
||||
representative_face_a: Option<CoOccurRepFace>,
|
||||
representative_face_b: Option<CoOccurRepFace>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct CoOccurResponse {
|
||||
success: bool,
|
||||
file_uuid: String,
|
||||
identity_a: CoOccurIdentity,
|
||||
identity_b: CoOccurIdentity,
|
||||
first_cooccurrence: CoOccurrence,
|
||||
}
|
||||
|
||||
async fn get_cooccurrence(
|
||||
State(state): State<crate::api::types::AppState>,
|
||||
Path((file_uuid, identity_uuid_a, identity_uuid_b)): Path<(String, String, String)>,
|
||||
) -> Result<Json<CoOccurResponse>, (StatusCode, Json<serde_json::Value>)> {
|
||||
use crate::core::db::schema;
|
||||
let id_table = schema::table_name("identities");
|
||||
let fd_table = schema::table_name("face_detections");
|
||||
|
||||
// Stage 1: Get identity names and IDs
|
||||
let id_a = sqlx::query_as::<_, (i32, String)>(&format!(
|
||||
"SELECT id, name FROM {} WHERE uuid::text = $1 OR REPLACE(uuid::text, '-', '') = $1",
|
||||
id_table
|
||||
))
|
||||
.bind(&identity_uuid_a)
|
||||
.fetch_optional(state.db.pool())
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error": e.to_string()})))
|
||||
})?
|
||||
.ok_or_else(|| {
|
||||
(StatusCode::NOT_FOUND, Json(serde_json::json!({"error": "Identity A not found"})))
|
||||
})?;
|
||||
|
||||
let id_b = sqlx::query_as::<_, (i32, String)>(&format!(
|
||||
"SELECT id, name FROM {} WHERE uuid::text = $1 OR REPLACE(uuid::text, '-', '') = $1",
|
||||
id_table
|
||||
))
|
||||
.bind(&identity_uuid_b)
|
||||
.fetch_optional(state.db.pool())
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error": e.to_string()})))
|
||||
})?
|
||||
.ok_or_else(|| {
|
||||
(StatusCode::NOT_FOUND, Json(serde_json::json!({"error": "Identity B not found"})))
|
||||
})?;
|
||||
|
||||
// Stage 2: Find first frame where both identity_ids appear
|
||||
let cooccur: Option<(i64,)> = sqlx::query_as(
|
||||
&format!(
|
||||
"SELECT MIN(fd.frame_number)::bigint FROM {} fd \
|
||||
WHERE fd.file_uuid = $1 AND fd.identity_id = $2 \
|
||||
AND fd.frame_number IN ( \
|
||||
SELECT frame_number FROM {} \
|
||||
WHERE file_uuid = $1 AND identity_id = $3 \
|
||||
)",
|
||||
fd_table, fd_table
|
||||
)
|
||||
)
|
||||
.bind(&file_uuid)
|
||||
.bind(id_a.0)
|
||||
.bind(id_b.0)
|
||||
.fetch_optional(state.db.pool())
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error": e.to_string()})))
|
||||
})?;
|
||||
|
||||
let (first_frame,) = cooccur.ok_or_else(|| {
|
||||
(StatusCode::NOT_FOUND, Json(serde_json::json!({"error": "These two identities never appear together in this file"})))
|
||||
})?;
|
||||
|
||||
// Get fps for timestamp
|
||||
let video_table = schema::table_name("videos");
|
||||
let fps: f64 = sqlx::query_scalar(&format!(
|
||||
"SELECT COALESCE(fps, 25.0) FROM {} WHERE file_uuid = $1", video_table
|
||||
))
|
||||
.bind(&file_uuid)
|
||||
.fetch_optional(state.db.pool())
|
||||
.await
|
||||
.map_err(|e| {
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error": e.to_string()})))
|
||||
})?
|
||||
.unwrap_or(25.0);
|
||||
|
||||
// Stage 3: Get trace_ids for both at this frame
|
||||
let trace_a: Option<(i32,)> = sqlx::query_as(
|
||||
&format!("SELECT trace_id FROM {} WHERE file_uuid = $1 AND frame_number = $2 AND identity_id = $3 AND trace_id IS NOT NULL LIMIT 1", fd_table)
|
||||
)
|
||||
.bind(&file_uuid).bind(first_frame).bind(id_a.0)
|
||||
.fetch_optional(state.db.pool()).await
|
||||
.map_err(|e| {
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error": e.to_string()})))
|
||||
})?;
|
||||
|
||||
let trace_b: Option<(i32,)> = sqlx::query_as(
|
||||
&format!("SELECT trace_id FROM {} WHERE file_uuid = $1 AND frame_number = $2 AND identity_id = $3 AND trace_id IS NOT NULL LIMIT 1", fd_table)
|
||||
)
|
||||
.bind(&file_uuid).bind(first_frame).bind(id_b.0)
|
||||
.fetch_optional(state.db.pool()).await
|
||||
.map_err(|e| {
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error": e.to_string()})))
|
||||
})?;
|
||||
|
||||
// Stage 4: Get representative faces for both traces (reusing select_rep_face)
|
||||
let rep_a = if let Some((tid,)) = trace_a {
|
||||
select_rep_face(state.db.pool(), &file_uuid, tid, |e| {
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error": e.to_string()})))
|
||||
}).await.ok().map(|sel| CoOccurRepFace {
|
||||
frame_number: sel.frame,
|
||||
bbox: RepFaceBbox { x: sel.x, y: sel.y, width: sel.w, height: sel.h },
|
||||
confidence: sel.conf,
|
||||
thumbnail_url: format!("/api/v1/file/{}/trace/{}/thumbnail", file_uuid, tid),
|
||||
})
|
||||
} else { None };
|
||||
|
||||
let rep_b = if let Some((tid,)) = trace_b {
|
||||
select_rep_face(state.db.pool(), &file_uuid, tid, |e| {
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, Json(serde_json::json!({"error": e.to_string()})))
|
||||
}).await.ok().map(|sel| CoOccurRepFace {
|
||||
frame_number: sel.frame,
|
||||
bbox: RepFaceBbox { x: sel.x, y: sel.y, width: sel.w, height: sel.h },
|
||||
confidence: sel.conf,
|
||||
thumbnail_url: format!("/api/v1/file/{}/trace/{}/thumbnail", file_uuid, tid),
|
||||
})
|
||||
} else { None };
|
||||
|
||||
// Total co-occurrence frames (from TKG if available, otherwise from face_detections)
|
||||
let total_cooccurrence_frames: i64 = sqlx::query_scalar(
|
||||
&format!(
|
||||
"SELECT COUNT(DISTINCT fd.frame_number)::bigint FROM {} fd \
|
||||
WHERE fd.file_uuid = $1 AND fd.identity_id = $2 \
|
||||
AND fd.frame_number IN ( \
|
||||
SELECT frame_number FROM {} \
|
||||
WHERE file_uuid = $1 AND identity_id = $3 \
|
||||
)",
|
||||
fd_table, fd_table
|
||||
)
|
||||
)
|
||||
.bind(&file_uuid).bind(id_a.0).bind(id_b.0)
|
||||
.fetch_one(state.db.pool()).await
|
||||
.unwrap_or(0);
|
||||
|
||||
Ok(Json(CoOccurResponse {
|
||||
success: true,
|
||||
file_uuid,
|
||||
identity_a: CoOccurIdentity {
|
||||
identity_uuid: identity_uuid_a,
|
||||
name: id_a.1,
|
||||
trace_id: trace_a.map(|t| t.0).unwrap_or(0),
|
||||
},
|
||||
identity_b: CoOccurIdentity {
|
||||
identity_uuid: identity_uuid_b,
|
||||
name: id_b.1,
|
||||
trace_id: trace_b.map(|t| t.0).unwrap_or(0),
|
||||
},
|
||||
first_cooccurrence: CoOccurrence {
|
||||
frame_number: first_frame,
|
||||
timestamp_secs: first_frame as f64 / fps,
|
||||
total_cooccurrence_frames,
|
||||
representative_face_a: rep_a,
|
||||
representative_face_b: rep_b,
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
use crate::core::config::OUTPUT_DIR;
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct TkgRebuildResponse {
|
||||
success: bool,
|
||||
file_uuid: String,
|
||||
result: Option<serde_json::Value>,
|
||||
error: Option<String>,
|
||||
}
|
||||
|
||||
async fn rebuild_tkg(
|
||||
State(state): State<crate::api::types::AppState>,
|
||||
Path(file_uuid): Path<String>,
|
||||
) -> Json<TkgRebuildResponse> {
|
||||
let result = crate::core::processor::tkg::build_tkg(
|
||||
&state.db,
|
||||
&file_uuid,
|
||||
&OUTPUT_DIR,
|
||||
)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(r) => Json(TkgRebuildResponse {
|
||||
success: true,
|
||||
file_uuid,
|
||||
result: Some(serde_json::json!({
|
||||
"face_trace_nodes": r.face_trace_nodes,
|
||||
"object_nodes": r.object_nodes,
|
||||
"speaker_nodes": r.speaker_nodes,
|
||||
"co_occurrence_edges": r.co_occurrence_edges,
|
||||
"speaker_face_edges": r.speaker_face_edges,
|
||||
"face_face_edges": r.face_face_edges,
|
||||
})),
|
||||
error: None,
|
||||
}),
|
||||
Err(e) => Json(TkgRebuildResponse {
|
||||
success: false,
|
||||
file_uuid,
|
||||
result: None,
|
||||
error: Some(e.to_string()),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,6 +15,92 @@ fn t(name: &str) -> String {
|
||||
}
|
||||
}
|
||||
|
||||
// ── Pose data from face.json ────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct FacePose {
|
||||
frame: i64,
|
||||
x: f64,
|
||||
y: f64,
|
||||
w: f64,
|
||||
h: f64,
|
||||
yaw: f64,
|
||||
pitch: f64,
|
||||
roll: f64,
|
||||
}
|
||||
|
||||
fn load_face_pose_data(output_dir: &str, file_uuid: &str) -> Result<Vec<FacePose>> {
|
||||
let path = Path::new(output_dir).join(format!("{}.face.json", file_uuid));
|
||||
let content = std::fs::read_to_string(&path)
|
||||
.with_context(|| format!("Failed to read face.json: {}", path.display()))?;
|
||||
let json: serde_json::Value = serde_json::from_str(&content)?;
|
||||
|
||||
let mut poses = Vec::new();
|
||||
if let Some(frames) = json.get("frames").and_then(|v| v.as_array()) {
|
||||
for frame_entry in frames {
|
||||
let frame_num = frame_entry.get("frame").and_then(|v| v.as_i64()).unwrap_or(0);
|
||||
if let Some(faces) = frame_entry.get("faces").and_then(|v| v.as_array()) {
|
||||
for face in faces {
|
||||
let bbox = match face.get("bbox") {
|
||||
Some(b) => b,
|
||||
None => continue,
|
||||
};
|
||||
let pose = match face.get("pose") {
|
||||
Some(p) => p,
|
||||
None => continue,
|
||||
};
|
||||
poses.push(FacePose {
|
||||
frame: frame_num,
|
||||
x: bbox.get("x").and_then(|v| v.as_f64()).unwrap_or(0.0),
|
||||
y: bbox.get("y").and_then(|v| v.as_f64()).unwrap_or(0.0),
|
||||
w: bbox.get("width").and_then(|v| v.as_f64()).unwrap_or(0.0),
|
||||
h: bbox.get("height").and_then(|v| v.as_f64()).unwrap_or(0.0),
|
||||
yaw: pose.get("yaw").and_then(|v| v.as_f64()).unwrap_or(0.0),
|
||||
pitch: pose.get("pitch").and_then(|v| v.as_f64()).unwrap_or(0.0),
|
||||
roll: pose.get("roll").and_then(|v| v.as_f64()).unwrap_or(0.0),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(poses)
|
||||
}
|
||||
|
||||
/// Match a face from face_detections (frame, x, y, w, h) to its pose in face.json
|
||||
/// Uses bbox center distance to find the best match when multiple faces per frame.
|
||||
fn get_pose_for_face(frame: i64, x: f64, y: f64, w: f64, h: f64, poses: &[FacePose]) -> Option<(f64, f64, f64)> {
|
||||
let cx = x + w / 2.0;
|
||||
let cy = y + h / 2.0;
|
||||
let mut best_dist = f64::MAX;
|
||||
let mut result = None;
|
||||
for p in poses.iter().filter(|p| p.frame == frame) {
|
||||
let pcx = p.x + p.w / 2.0;
|
||||
let pcy = p.y + p.h / 2.0;
|
||||
let dist = (cx - pcx).abs() + (cy - pcy).abs();
|
||||
if dist < best_dist {
|
||||
best_dist = dist;
|
||||
result = Some((p.yaw, p.pitch, p.roll));
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
fn detect_mutual_gaze(
|
||||
bbox_a_x: f64, bbox_a_w: f64, yaw_a: f64,
|
||||
bbox_b_x: f64, bbox_b_w: f64, yaw_b: f64,
|
||||
threshold: f64,
|
||||
) -> bool {
|
||||
let cx_a = bbox_a_x + bbox_a_w / 2.0;
|
||||
let cx_b = bbox_b_x + bbox_b_w / 2.0;
|
||||
if cx_a < cx_b {
|
||||
// A 在左,B 在右 → A 要看右 (yaw > 0),B 要看左 (yaw < 0)
|
||||
yaw_a > threshold && yaw_b < -threshold
|
||||
} else {
|
||||
// A 在右,B 在左 → A 要看左 (yaw < 0),B 要看右 (yaw > 0)
|
||||
yaw_a < -threshold && yaw_b > threshold
|
||||
}
|
||||
}
|
||||
|
||||
// ── Input data structs ────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
@@ -108,13 +194,16 @@ pub struct TkgResult {
|
||||
|
||||
pub async fn build_tkg(db: &PostgresDb, file_uuid: &str, output_dir: &str) -> Result<TkgResult> {
|
||||
let pool = db.pool();
|
||||
let n_face = build_face_trace_nodes(pool, file_uuid).await?;
|
||||
let pose_data = load_face_pose_data(output_dir, file_uuid).unwrap_or_default();
|
||||
tracing::info!("[TKG] Loaded {} pose entries from face.json", pose_data.len());
|
||||
|
||||
let n_face = build_face_trace_nodes(pool, file_uuid, &pose_data).await?;
|
||||
let n_objects = build_yolo_object_nodes(pool, file_uuid, output_dir).await?;
|
||||
let n_speakers = build_speaker_nodes(pool, file_uuid, output_dir).await?;
|
||||
|
||||
let e_co = build_co_occurrence_edges(pool, file_uuid, output_dir).await?;
|
||||
let e_sf = build_speaker_face_edges(pool, file_uuid, output_dir).await?;
|
||||
let e_ff = build_face_face_edges(pool, file_uuid).await?;
|
||||
let e_ff = build_face_face_edges(pool, file_uuid, &pose_data).await?;
|
||||
|
||||
Ok(TkgResult {
|
||||
face_trace_nodes: n_face,
|
||||
@@ -128,16 +217,16 @@ pub async fn build_tkg(db: &PostgresDb, file_uuid: &str, output_dir: &str) -> Re
|
||||
|
||||
// ── Node builders ─────────────────────────────────────────────────
|
||||
|
||||
async fn build_face_trace_nodes(pool: &PgPool, file_uuid: &str) -> Result<usize> {
|
||||
async fn build_face_trace_nodes(pool: &PgPool, file_uuid: &str, pose_data: &[FacePose]) -> Result<usize> {
|
||||
let face_table = t("face_detections");
|
||||
let nodes_table = t("tkg_nodes");
|
||||
|
||||
let rows = sqlx::query_as::<_, FaceTraceRow>(&format!(
|
||||
r#"
|
||||
SELECT trace_id,
|
||||
SELECT trace_id::bigint,
|
||||
COUNT(*)::bigint as frame_count,
|
||||
MIN(frame_number) as start_f,
|
||||
MAX(frame_number) as end_f,
|
||||
MIN(frame_number)::bigint as start_f,
|
||||
MAX(frame_number)::bigint as end_f,
|
||||
AVG(x::float8) as avg_x,
|
||||
AVG(y::float8) as avg_y,
|
||||
AVG(width::float8) as avg_w,
|
||||
@@ -153,10 +242,53 @@ async fn build_face_trace_nodes(pool: &PgPool, file_uuid: &str) -> Result<usize>
|
||||
.fetch_all(pool)
|
||||
.await?;
|
||||
|
||||
// Load per-frame data for pose matching
|
||||
let frame_rows: Vec<(i64, i64, f64, f64, f64, f64)> = sqlx::query_as(
|
||||
&format!(
|
||||
"SELECT trace_id::bigint, frame_number::bigint, x::float8, y::float8, width::float8, height::float8 \
|
||||
FROM {} WHERE file_uuid = $1 AND trace_id IS NOT NULL ORDER BY trace_id, frame_number",
|
||||
face_table
|
||||
)
|
||||
)
|
||||
.bind(file_uuid)
|
||||
.fetch_all(pool)
|
||||
.await?;
|
||||
|
||||
// Group by trace_id: trace_id → Vec<(frame, x, y, w, h)>
|
||||
let mut trace_frames: HashMap<i64, Vec<(i64, f64, f64, f64, f64)>> = HashMap::new();
|
||||
for (tid, frame, x, y, w, h) in &frame_rows {
|
||||
trace_frames.entry(*tid).or_default().push((*frame, *x, *y, *w, *h));
|
||||
}
|
||||
|
||||
let mut count = 0;
|
||||
for row in &rows {
|
||||
let external_id = format!("trace_{}", row.trace_id);
|
||||
let label = format!("Face Trace {}", row.trace_id);
|
||||
let tid = row.trace_id;
|
||||
let external_id = format!("trace_{}", tid);
|
||||
let label = format!("Face Trace {}", tid);
|
||||
|
||||
// Compute average pose for this trace
|
||||
let mut yaw_sum = 0.0f64;
|
||||
let mut pitch_sum = 0.0f64;
|
||||
let mut roll_sum = 0.0f64;
|
||||
let mut pose_count = 0i64;
|
||||
|
||||
if let Some(frames) = trace_frames.get(&tid) {
|
||||
for (frame, x, y, w, h) in frames {
|
||||
if let Some((yaw, pitch, roll)) = get_pose_for_face(*frame, *x, *y, *w, *h, pose_data) {
|
||||
yaw_sum += yaw;
|
||||
pitch_sum += pitch;
|
||||
roll_sum += roll;
|
||||
pose_count += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let (avg_yaw, avg_pitch, avg_roll) = if pose_count > 0 {
|
||||
(yaw_sum / pose_count as f64, pitch_sum / pose_count as f64, roll_sum / pose_count as f64)
|
||||
} else {
|
||||
(0.0, 0.0, 0.0)
|
||||
};
|
||||
|
||||
let props = serde_json::json!({
|
||||
"frame_count": row.frame_count,
|
||||
"start_frame": row.start_f,
|
||||
@@ -166,7 +298,11 @@ async fn build_face_trace_nodes(pool: &PgPool, file_uuid: &str) -> Result<usize>
|
||||
"y": row.avg_y.unwrap_or(0.0).round() as i64,
|
||||
"width": row.avg_w.unwrap_or(0.0).round() as i64,
|
||||
"height": row.avg_h.unwrap_or(0.0).round() as i64,
|
||||
}
|
||||
},
|
||||
"avg_yaw": (avg_yaw * 1000.0).round() / 1000.0,
|
||||
"avg_pitch": (avg_pitch * 1000.0).round() / 1000.0,
|
||||
"avg_roll": (avg_roll * 1000.0).round() / 1000.0,
|
||||
"pose_count": pose_count,
|
||||
});
|
||||
|
||||
sqlx::query(&format!(
|
||||
@@ -312,7 +448,7 @@ async fn build_co_occurrence_edges(
|
||||
let edges_table = t("tkg_edges");
|
||||
|
||||
let face_rows = sqlx::query_as::<_, FaceDetectionRow>(&format!(
|
||||
r#"SELECT trace_id, frame_number, x, y, width, height
|
||||
r#"SELECT trace_id::bigint, frame_number::bigint, x::float8, y::float8, width::float8, height::float8
|
||||
FROM {} WHERE file_uuid = $1 AND trace_id IS NOT NULL
|
||||
ORDER BY frame_number"#,
|
||||
face_table
|
||||
@@ -429,7 +565,7 @@ async fn build_speaker_face_edges(
|
||||
let edges_table = t("tkg_edges");
|
||||
|
||||
let traces = sqlx::query_as::<_, (i64, i64, i64)>(&format!(
|
||||
r#"SELECT trace_id, MIN(frame_number) as start_f, MAX(frame_number) as end_f
|
||||
r#"SELECT trace_id::bigint, MIN(frame_number)::bigint as start_f, MAX(frame_number)::bigint as end_f
|
||||
FROM {} WHERE file_uuid = $1 AND trace_id IS NOT NULL
|
||||
GROUP BY trace_id"#,
|
||||
face_table
|
||||
@@ -533,14 +669,15 @@ async fn build_speaker_face_edges(
|
||||
Ok(edge_count)
|
||||
}
|
||||
|
||||
async fn build_face_face_edges(pool: &PgPool, file_uuid: &str) -> Result<usize> {
|
||||
async fn build_face_face_edges(pool: &PgPool, file_uuid: &str, pose_data: &[FacePose]) -> Result<usize> {
|
||||
let face_table = t("face_detections");
|
||||
let nodes_table = t("tkg_nodes");
|
||||
let edges_table = t("tkg_edges");
|
||||
|
||||
// Use SQL JOIN for fast co-occurrence detection
|
||||
let rows: Vec<(i64, i64, i64)> = sqlx::query_as(&format!(
|
||||
r#"
|
||||
SELECT a.trace_id AS tid_a, b.trace_id AS tid_b, a.frame_number
|
||||
SELECT a.trace_id::bigint AS tid_a, b.trace_id::bigint AS tid_b, a.frame_number::bigint
|
||||
FROM {} a
|
||||
JOIN {} b
|
||||
ON a.file_uuid = b.file_uuid
|
||||
@@ -557,53 +694,123 @@ async fn build_face_face_edges(pool: &PgPool, file_uuid: &str) -> Result<usize>
|
||||
.fetch_all(pool)
|
||||
.await?;
|
||||
|
||||
if rows.is_empty() {
|
||||
return Ok(0);
|
||||
// Also load per-frame bbox for mutual_gaze lookups
|
||||
let bbox_data: Vec<(i64, i64, f64, f64, f64, f64)> = sqlx::query_as(
|
||||
&format!(
|
||||
"SELECT trace_id::bigint, frame_number::bigint, x::float8, y::float8, width::float8, height::float8 \
|
||||
FROM {} WHERE file_uuid = $1 AND trace_id IS NOT NULL ORDER BY trace_id, frame_number",
|
||||
face_table
|
||||
)
|
||||
)
|
||||
.bind(file_uuid)
|
||||
.fetch_all(pool)
|
||||
.await?;
|
||||
|
||||
let mut frame_map: HashMap<(i64, i64), (f64, f64, f64, f64)> = HashMap::new(); // (trace_id, frame) → (x, y, w, h)
|
||||
for (tid, frame, x, y, w, h) in &bbox_data {
|
||||
frame_map.insert((*tid, *frame), (*x, *y, *w, *h));
|
||||
}
|
||||
|
||||
// Deduplicate by pair
|
||||
let mut pair_frames: HashMap<(i64, i64), Vec<i64>> = HashMap::new();
|
||||
// Group by pair
|
||||
let mut pair_frames: HashMap<(i64, i64), Vec<(i64, bool)>> = HashMap::new();
|
||||
for (tid_a, tid_b, frame) in &rows {
|
||||
let key = if *tid_a < *tid_b {
|
||||
(*tid_a, *tid_b)
|
||||
} else {
|
||||
(*tid_b, *tid_a)
|
||||
let key = (*tid_a.min(tid_b), *tid_a.max(tid_b));
|
||||
let bbox_a = frame_map.get(&(*tid_a, *frame));
|
||||
let bbox_b = frame_map.get(&(*tid_b, *frame));
|
||||
|
||||
let gaze = match (bbox_a, bbox_b) {
|
||||
(Some(&(xa, ya, wa, ha)), Some(&(xb, yb, wb, hb))) => {
|
||||
get_pose_for_face(*frame, xa, ya, wa, ha, pose_data)
|
||||
.and_then(|(yaw_a, _, _)| {
|
||||
get_pose_for_face(*frame, xb, yb, wb, hb, pose_data)
|
||||
.map(|(yaw_b, _, _)| detect_mutual_gaze(xa, wa, yaw_a, xb, wb, yaw_b, 0.05))
|
||||
})
|
||||
.unwrap_or(false)
|
||||
}
|
||||
_ => false,
|
||||
};
|
||||
pair_frames.entry(key).or_default().push(*frame);
|
||||
pair_frames.entry(key).or_default().push((*frame, gaze));
|
||||
}
|
||||
|
||||
let mut edge_count = 0;
|
||||
for ((tid_a, tid_b), frames) in &pair_frames {
|
||||
// Cache node IDs to avoid repeated queries
|
||||
let mut node_id_cache: HashMap<i64, i64> = HashMap::new();
|
||||
for ((tid_a, tid_b), frame_data) in &pair_frames {
|
||||
let ext_a = format!("trace_{}", tid_a);
|
||||
let ext_b = format!("trace_{}", tid_b);
|
||||
|
||||
let n_a: Option<(i64,)> = sqlx::query_as(&format!(
|
||||
"SELECT id FROM {} WHERE file_uuid=$1 AND node_type='face_trace' AND external_id=$2",
|
||||
nodes_table
|
||||
))
|
||||
.bind(file_uuid)
|
||||
.bind(&ext_a)
|
||||
.fetch_optional(pool)
|
||||
.await?;
|
||||
|
||||
let n_b: Option<(i64,)> = sqlx::query_as(&format!(
|
||||
"SELECT id FROM {} WHERE file_uuid=$1 AND node_type='face_trace' AND external_id=$2",
|
||||
nodes_table
|
||||
))
|
||||
.bind(file_uuid)
|
||||
.bind(&ext_b)
|
||||
.fetch_optional(pool)
|
||||
.await?;
|
||||
|
||||
let (n_a_id, n_b_id) = match (n_a, n_b) {
|
||||
(Some((a,)), Some((b,))) => (a, b),
|
||||
_ => continue,
|
||||
let n_a_id = match node_id_cache.get(tid_a) {
|
||||
Some(id) => *id,
|
||||
None => {
|
||||
if let Some((id,)) = sqlx::query_as::<_, (i64,)>(&format!(
|
||||
"SELECT id FROM {} WHERE file_uuid=$1 AND node_type='face_trace' AND external_id=$2",
|
||||
nodes_table
|
||||
))
|
||||
.bind(file_uuid).bind(&ext_a).fetch_optional(pool).await?
|
||||
{
|
||||
node_id_cache.insert(*tid_a, id);
|
||||
id
|
||||
} else { continue; }
|
||||
}
|
||||
};
|
||||
|
||||
let edge_props = serde_json::json!({
|
||||
"first_frame": frames[0],
|
||||
"frame_count": frames.len() as i64,
|
||||
});
|
||||
let n_b_id = match node_id_cache.get(tid_b) {
|
||||
Some(id) => *id,
|
||||
None => {
|
||||
if let Some((id,)) = sqlx::query_as::<_, (i64,)>(&format!(
|
||||
"SELECT id FROM {} WHERE file_uuid=$1 AND node_type='face_trace' AND external_id=$2",
|
||||
nodes_table
|
||||
))
|
||||
.bind(file_uuid).bind(&ext_b).fetch_optional(pool).await?
|
||||
{
|
||||
node_id_cache.insert(*tid_b, id);
|
||||
id
|
||||
} else { continue; }
|
||||
}
|
||||
};
|
||||
|
||||
let frames: Vec<i64> = frame_data.iter().map(|(f, _)| *f).collect();
|
||||
let gaze_frames: Vec<i64> = frame_data.iter().filter(|(_, g)| *g).map(|(f, _)| *f).collect();
|
||||
let gaze_count = gaze_frames.len() as i64;
|
||||
let has_gaze = gaze_count > 0;
|
||||
|
||||
let edge_props = if has_gaze {
|
||||
// Compute average yaw values for gaze frames
|
||||
let mut yaw_a_sum = 0.0f64;
|
||||
let mut yaw_b_sum = 0.0f64;
|
||||
let mut gaze_sample = 0i64;
|
||||
for (frame, _) in frame_data.iter().filter(|(_, g)| *g) {
|
||||
let bbox_a = frame_map.get(&(*tid_a, *frame));
|
||||
let bbox_b = frame_map.get(&(*tid_b, *frame));
|
||||
if let (Some(&(xa, ya, wa, ha)), Some(&(xb, yb, wb, hb))) = (bbox_a, bbox_b) {
|
||||
let pose_a = get_pose_for_face(*frame, xa, ya, wa, ha, pose_data);
|
||||
let pose_b = get_pose_for_face(*frame, xb, yb, wb, hb, pose_data);
|
||||
if let (Some((ya, _, _)), Some((yb, _, _))) = (pose_a, pose_b) {
|
||||
yaw_a_sum += ya;
|
||||
yaw_b_sum += yb;
|
||||
gaze_sample += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
let (avg_ya, avg_yb) = if gaze_sample > 0 {
|
||||
(yaw_a_sum / gaze_sample as f64, yaw_b_sum / gaze_sample as f64)
|
||||
} else { (0.0, 0.0) };
|
||||
|
||||
serde_json::json!({
|
||||
"first_frame": frames[0],
|
||||
"frame_count": frames.len() as i64,
|
||||
"mutual_gaze": true,
|
||||
"gaze_frame_count": gaze_count,
|
||||
"yaw_a_avg": (avg_ya * 1000.0).round() / 1000.0,
|
||||
"yaw_b_avg": (avg_yb * 1000.0).round() / 1000.0,
|
||||
})
|
||||
} else {
|
||||
serde_json::json!({
|
||||
"first_frame": frames[0],
|
||||
"frame_count": frames.len() as i64,
|
||||
"mutual_gaze": false,
|
||||
})
|
||||
};
|
||||
|
||||
sqlx::query(&format!(
|
||||
r#"
|
||||
|
||||
Reference in New Issue
Block a user