diff --git a/docs_v1.0/DESIGN/TKG_QUERY_API_V1.0.md b/docs_v1.0/DESIGN/TKG_QUERY_API_V1.0.md index 821c33b..7fecfac 100644 --- a/docs_v1.0/DESIGN/TKG_QUERY_API_V1.0.md +++ b/docs_v1.0/DESIGN/TKG_QUERY_API_V1.0.md @@ -415,7 +415,96 @@ ORDER BY cooccurrence_count DESC --- -## 6. Phase 計畫 +## 7. TKG 擴充:Pose + Mutual Gaze + +### 7.1 現狀 + +| 元件 | 有 pose? | 有 mutual gaze? | +|------|-----------|-----------------| +| `face.json` | ✅ yaw/pitch/roll | ❌ 未計算 | +| `face_detections.metadata` | ❌ 無 | ❌ 無 | +| `tkg_nodes` (face_trace) | ❌ 無 | — | +| `tkg_edges` (CO_OCCURS_WITH) | ❌ 無 | ❌ 無 | + +### 7.2 目標 + +``` +face_processor → face.json (有 pose) + ↓ +tkg.rs (TKG builder) + ├── 讀取 face.json 的 pose 資料 + ├── 算 avg_yaw/pitch/roll → 寫入 face_trace node + ├── 對同框 face_trace 配對,判斷 mutual_gaze + └── 寫入 CO_OCCURS_WITH edge properties + ↓ +tkg_edges.properties.mutual_gaze = true + ↓ +API 查詢 → 互看 + 面積最大 → 代表 frame +``` + +### 7.3 face_trace node 新增 properties + +```json +{ + "frame_count": 53, + "start_frame": 38165, + "end_frame": 38321, + "avg_bbox": {"x": 731, "y": 215, "width": 228, "height": 228}, + "avg_yaw": 0.014, + "avg_pitch": 0.224, + "avg_roll": -0.069 +} +``` + +### 7.4 CO_OCCURS_WITH edge 新增 properties + +```json +{ + "first_frame": 38187, + "frame_count": 156, + "mutual_gaze": true, + "yaw_a_avg": 0.021, + "yaw_b_avg": -0.421, + "gaze_angle_delta": 0.442 +} +``` + +### 7.5 Mutual Gaze 判斷邏輯 + +```python +GAZE_THRESHOLD = 0.05 # rad + +def detect_mutual_gaze(frame_a, frame_b): + # 判斷 A 和 B 的左右位置關係 + if bbox_a.cx < bbox_b.cx: + # A 在左,B 在右 → A 要看右,B 要看左 + return yaw_a > GAZE_THRESHOLD and yaw_b < -GAZE_THRESHOLD + else: + # A 在右,B 在左 → A 要看左,B 要看右 + return yaw_a < -GAZE_THRESHOLD and yaw_b > GAZE_THRESHOLD +``` + +### 7.6 代表 Frame 選取邏輯(應用端) + +``` +1. 查 TKG: MATCH (a)-[e:CO_OCCURS_WITH]->(b) WHERE e.mutual_gaze = true + → 回傳互看 frame_count 最高的 face_trace 配對 +2. 取該配對的 frame 範圍內,面積×信心最高 frame +3. FFmpeg blurdetect → 選最清晰的作為代表 +``` + +### 7.7 Timeline + +| Phase | 內容 | 工時 | +|-------|------|------| +| 1 | tkg.rs 讀取 face.json 的 pose + 寫入 node | 4-6h | +| 2 | mutual_gaze 判斷 + 寫入 edge | 3-4h | +| 3 | 對 Charade 重跑 TKG + 驗證 | 1h | +| 4 | 代表 frame 選取邏輯 | 2-3h | + +--- + +## 8. Phase 計畫 | Phase | Query Types | 預計工時 | 依賴 | |-------|------------|---------|------| diff --git a/src/api/trace_agent_api.rs b/src/api/trace_agent_api.rs index 269fb42..c1d4dd1 100644 --- a/src/api/trace_agent_api.rs +++ b/src/api/trace_agent_api.rs @@ -29,6 +29,10 @@ pub fn trace_agent_routes() -> Router { "/api/v1/file/:file_uuid/identities/:identity_uuid_a/co-occur-with/:identity_uuid_b", get(get_cooccurrence), ) + .route( + "/api/v1/file/:file_uuid/tkg/rebuild", + post(rebuild_tkg), + ) } #[derive(Debug, Deserialize)] @@ -735,3 +739,47 @@ async fn get_cooccurrence( }, })) } + +use crate::core::config::OUTPUT_DIR; + +#[derive(Serialize)] +struct TkgRebuildResponse { + success: bool, + file_uuid: String, + result: Option, + error: Option, +} + +async fn rebuild_tkg( + State(state): State, + Path(file_uuid): Path, +) -> Json { + let result = crate::core::processor::tkg::build_tkg( + &state.db, + &file_uuid, + &OUTPUT_DIR, + ) + .await; + + match result { + Ok(r) => Json(TkgRebuildResponse { + success: true, + file_uuid, + result: Some(serde_json::json!({ + "face_trace_nodes": r.face_trace_nodes, + "object_nodes": r.object_nodes, + "speaker_nodes": r.speaker_nodes, + "co_occurrence_edges": r.co_occurrence_edges, + "speaker_face_edges": r.speaker_face_edges, + "face_face_edges": r.face_face_edges, + })), + error: None, + }), + Err(e) => Json(TkgRebuildResponse { + success: false, + file_uuid, + result: None, + error: Some(e.to_string()), + }), + } +} diff --git a/src/core/processor/tkg.rs b/src/core/processor/tkg.rs index 39a7626..9cd7d3d 100644 --- a/src/core/processor/tkg.rs +++ b/src/core/processor/tkg.rs @@ -15,6 +15,92 @@ fn t(name: &str) -> String { } } +// ── Pose data from face.json ──────────────────────────────────────── + +#[derive(Debug, Clone)] +struct FacePose { + frame: i64, + x: f64, + y: f64, + w: f64, + h: f64, + yaw: f64, + pitch: f64, + roll: f64, +} + +fn load_face_pose_data(output_dir: &str, file_uuid: &str) -> Result> { + let path = Path::new(output_dir).join(format!("{}.face.json", file_uuid)); + let content = std::fs::read_to_string(&path) + .with_context(|| format!("Failed to read face.json: {}", path.display()))?; + let json: serde_json::Value = serde_json::from_str(&content)?; + + let mut poses = Vec::new(); + if let Some(frames) = json.get("frames").and_then(|v| v.as_array()) { + for frame_entry in frames { + let frame_num = frame_entry.get("frame").and_then(|v| v.as_i64()).unwrap_or(0); + if let Some(faces) = frame_entry.get("faces").and_then(|v| v.as_array()) { + for face in faces { + let bbox = match face.get("bbox") { + Some(b) => b, + None => continue, + }; + let pose = match face.get("pose") { + Some(p) => p, + None => continue, + }; + poses.push(FacePose { + frame: frame_num, + x: bbox.get("x").and_then(|v| v.as_f64()).unwrap_or(0.0), + y: bbox.get("y").and_then(|v| v.as_f64()).unwrap_or(0.0), + w: bbox.get("width").and_then(|v| v.as_f64()).unwrap_or(0.0), + h: bbox.get("height").and_then(|v| v.as_f64()).unwrap_or(0.0), + yaw: pose.get("yaw").and_then(|v| v.as_f64()).unwrap_or(0.0), + pitch: pose.get("pitch").and_then(|v| v.as_f64()).unwrap_or(0.0), + roll: pose.get("roll").and_then(|v| v.as_f64()).unwrap_or(0.0), + }); + } + } + } + } + Ok(poses) +} + +/// Match a face from face_detections (frame, x, y, w, h) to its pose in face.json +/// Uses bbox center distance to find the best match when multiple faces per frame. +fn get_pose_for_face(frame: i64, x: f64, y: f64, w: f64, h: f64, poses: &[FacePose]) -> Option<(f64, f64, f64)> { + let cx = x + w / 2.0; + let cy = y + h / 2.0; + let mut best_dist = f64::MAX; + let mut result = None; + for p in poses.iter().filter(|p| p.frame == frame) { + let pcx = p.x + p.w / 2.0; + let pcy = p.y + p.h / 2.0; + let dist = (cx - pcx).abs() + (cy - pcy).abs(); + if dist < best_dist { + best_dist = dist; + result = Some((p.yaw, p.pitch, p.roll)); + } + } + result +} + +fn detect_mutual_gaze( + bbox_a_x: f64, bbox_a_w: f64, yaw_a: f64, + bbox_b_x: f64, bbox_b_w: f64, yaw_b: f64, + threshold: f64, +) -> bool { + let cx_a = bbox_a_x + bbox_a_w / 2.0; + let cx_b = bbox_b_x + bbox_b_w / 2.0; + if cx_a < cx_b { + // A 在左,B 在右 → A 要看右 (yaw > 0),B 要看左 (yaw < 0) + yaw_a > threshold && yaw_b < -threshold + } else { + // A 在右,B 在左 → A 要看左 (yaw < 0),B 要看右 (yaw > 0) + yaw_a < -threshold && yaw_b > threshold + } +} + // ── Input data structs ──────────────────────────────────────────── #[derive(Debug, Deserialize)] @@ -108,13 +194,16 @@ pub struct TkgResult { pub async fn build_tkg(db: &PostgresDb, file_uuid: &str, output_dir: &str) -> Result { let pool = db.pool(); - let n_face = build_face_trace_nodes(pool, file_uuid).await?; + let pose_data = load_face_pose_data(output_dir, file_uuid).unwrap_or_default(); + tracing::info!("[TKG] Loaded {} pose entries from face.json", pose_data.len()); + + let n_face = build_face_trace_nodes(pool, file_uuid, &pose_data).await?; let n_objects = build_yolo_object_nodes(pool, file_uuid, output_dir).await?; let n_speakers = build_speaker_nodes(pool, file_uuid, output_dir).await?; let e_co = build_co_occurrence_edges(pool, file_uuid, output_dir).await?; let e_sf = build_speaker_face_edges(pool, file_uuid, output_dir).await?; - let e_ff = build_face_face_edges(pool, file_uuid).await?; + let e_ff = build_face_face_edges(pool, file_uuid, &pose_data).await?; Ok(TkgResult { face_trace_nodes: n_face, @@ -128,16 +217,16 @@ pub async fn build_tkg(db: &PostgresDb, file_uuid: &str, output_dir: &str) -> Re // ── Node builders ───────────────────────────────────────────────── -async fn build_face_trace_nodes(pool: &PgPool, file_uuid: &str) -> Result { +async fn build_face_trace_nodes(pool: &PgPool, file_uuid: &str, pose_data: &[FacePose]) -> Result { let face_table = t("face_detections"); let nodes_table = t("tkg_nodes"); let rows = sqlx::query_as::<_, FaceTraceRow>(&format!( r#" - SELECT trace_id, + SELECT trace_id::bigint, COUNT(*)::bigint as frame_count, - MIN(frame_number) as start_f, - MAX(frame_number) as end_f, + MIN(frame_number)::bigint as start_f, + MAX(frame_number)::bigint as end_f, AVG(x::float8) as avg_x, AVG(y::float8) as avg_y, AVG(width::float8) as avg_w, @@ -153,10 +242,53 @@ async fn build_face_trace_nodes(pool: &PgPool, file_uuid: &str) -> Result .fetch_all(pool) .await?; + // Load per-frame data for pose matching + let frame_rows: Vec<(i64, i64, f64, f64, f64, f64)> = sqlx::query_as( + &format!( + "SELECT trace_id::bigint, frame_number::bigint, x::float8, y::float8, width::float8, height::float8 \ + FROM {} WHERE file_uuid = $1 AND trace_id IS NOT NULL ORDER BY trace_id, frame_number", + face_table + ) + ) + .bind(file_uuid) + .fetch_all(pool) + .await?; + + // Group by trace_id: trace_id → Vec<(frame, x, y, w, h)> + let mut trace_frames: HashMap> = HashMap::new(); + for (tid, frame, x, y, w, h) in &frame_rows { + trace_frames.entry(*tid).or_default().push((*frame, *x, *y, *w, *h)); + } + let mut count = 0; for row in &rows { - let external_id = format!("trace_{}", row.trace_id); - let label = format!("Face Trace {}", row.trace_id); + let tid = row.trace_id; + let external_id = format!("trace_{}", tid); + let label = format!("Face Trace {}", tid); + + // Compute average pose for this trace + let mut yaw_sum = 0.0f64; + let mut pitch_sum = 0.0f64; + let mut roll_sum = 0.0f64; + let mut pose_count = 0i64; + + if let Some(frames) = trace_frames.get(&tid) { + for (frame, x, y, w, h) in frames { + if let Some((yaw, pitch, roll)) = get_pose_for_face(*frame, *x, *y, *w, *h, pose_data) { + yaw_sum += yaw; + pitch_sum += pitch; + roll_sum += roll; + pose_count += 1; + } + } + } + + let (avg_yaw, avg_pitch, avg_roll) = if pose_count > 0 { + (yaw_sum / pose_count as f64, pitch_sum / pose_count as f64, roll_sum / pose_count as f64) + } else { + (0.0, 0.0, 0.0) + }; + let props = serde_json::json!({ "frame_count": row.frame_count, "start_frame": row.start_f, @@ -166,7 +298,11 @@ async fn build_face_trace_nodes(pool: &PgPool, file_uuid: &str) -> Result "y": row.avg_y.unwrap_or(0.0).round() as i64, "width": row.avg_w.unwrap_or(0.0).round() as i64, "height": row.avg_h.unwrap_or(0.0).round() as i64, - } + }, + "avg_yaw": (avg_yaw * 1000.0).round() / 1000.0, + "avg_pitch": (avg_pitch * 1000.0).round() / 1000.0, + "avg_roll": (avg_roll * 1000.0).round() / 1000.0, + "pose_count": pose_count, }); sqlx::query(&format!( @@ -312,7 +448,7 @@ async fn build_co_occurrence_edges( let edges_table = t("tkg_edges"); let face_rows = sqlx::query_as::<_, FaceDetectionRow>(&format!( - r#"SELECT trace_id, frame_number, x, y, width, height + r#"SELECT trace_id::bigint, frame_number::bigint, x::float8, y::float8, width::float8, height::float8 FROM {} WHERE file_uuid = $1 AND trace_id IS NOT NULL ORDER BY frame_number"#, face_table @@ -429,7 +565,7 @@ async fn build_speaker_face_edges( let edges_table = t("tkg_edges"); let traces = sqlx::query_as::<_, (i64, i64, i64)>(&format!( - r#"SELECT trace_id, MIN(frame_number) as start_f, MAX(frame_number) as end_f + r#"SELECT trace_id::bigint, MIN(frame_number)::bigint as start_f, MAX(frame_number)::bigint as end_f FROM {} WHERE file_uuid = $1 AND trace_id IS NOT NULL GROUP BY trace_id"#, face_table @@ -533,14 +669,15 @@ async fn build_speaker_face_edges( Ok(edge_count) } -async fn build_face_face_edges(pool: &PgPool, file_uuid: &str) -> Result { +async fn build_face_face_edges(pool: &PgPool, file_uuid: &str, pose_data: &[FacePose]) -> Result { let face_table = t("face_detections"); let nodes_table = t("tkg_nodes"); let edges_table = t("tkg_edges"); + // Use SQL JOIN for fast co-occurrence detection let rows: Vec<(i64, i64, i64)> = sqlx::query_as(&format!( r#" - SELECT a.trace_id AS tid_a, b.trace_id AS tid_b, a.frame_number + SELECT a.trace_id::bigint AS tid_a, b.trace_id::bigint AS tid_b, a.frame_number::bigint FROM {} a JOIN {} b ON a.file_uuid = b.file_uuid @@ -557,53 +694,123 @@ async fn build_face_face_edges(pool: &PgPool, file_uuid: &str) -> Result .fetch_all(pool) .await?; - if rows.is_empty() { - return Ok(0); + // Also load per-frame bbox for mutual_gaze lookups + let bbox_data: Vec<(i64, i64, f64, f64, f64, f64)> = sqlx::query_as( + &format!( + "SELECT trace_id::bigint, frame_number::bigint, x::float8, y::float8, width::float8, height::float8 \ + FROM {} WHERE file_uuid = $1 AND trace_id IS NOT NULL ORDER BY trace_id, frame_number", + face_table + ) + ) + .bind(file_uuid) + .fetch_all(pool) + .await?; + + let mut frame_map: HashMap<(i64, i64), (f64, f64, f64, f64)> = HashMap::new(); // (trace_id, frame) → (x, y, w, h) + for (tid, frame, x, y, w, h) in &bbox_data { + frame_map.insert((*tid, *frame), (*x, *y, *w, *h)); } - // Deduplicate by pair - let mut pair_frames: HashMap<(i64, i64), Vec> = HashMap::new(); + // Group by pair + let mut pair_frames: HashMap<(i64, i64), Vec<(i64, bool)>> = HashMap::new(); for (tid_a, tid_b, frame) in &rows { - let key = if *tid_a < *tid_b { - (*tid_a, *tid_b) - } else { - (*tid_b, *tid_a) + let key = (*tid_a.min(tid_b), *tid_a.max(tid_b)); + let bbox_a = frame_map.get(&(*tid_a, *frame)); + let bbox_b = frame_map.get(&(*tid_b, *frame)); + + let gaze = match (bbox_a, bbox_b) { + (Some(&(xa, ya, wa, ha)), Some(&(xb, yb, wb, hb))) => { + get_pose_for_face(*frame, xa, ya, wa, ha, pose_data) + .and_then(|(yaw_a, _, _)| { + get_pose_for_face(*frame, xb, yb, wb, hb, pose_data) + .map(|(yaw_b, _, _)| detect_mutual_gaze(xa, wa, yaw_a, xb, wb, yaw_b, 0.05)) + }) + .unwrap_or(false) + } + _ => false, }; - pair_frames.entry(key).or_default().push(*frame); + pair_frames.entry(key).or_default().push((*frame, gaze)); } let mut edge_count = 0; - for ((tid_a, tid_b), frames) in &pair_frames { + // Cache node IDs to avoid repeated queries + let mut node_id_cache: HashMap = HashMap::new(); + for ((tid_a, tid_b), frame_data) in &pair_frames { let ext_a = format!("trace_{}", tid_a); let ext_b = format!("trace_{}", tid_b); - let n_a: Option<(i64,)> = sqlx::query_as(&format!( - "SELECT id FROM {} WHERE file_uuid=$1 AND node_type='face_trace' AND external_id=$2", - nodes_table - )) - .bind(file_uuid) - .bind(&ext_a) - .fetch_optional(pool) - .await?; - - let n_b: Option<(i64,)> = sqlx::query_as(&format!( - "SELECT id FROM {} WHERE file_uuid=$1 AND node_type='face_trace' AND external_id=$2", - nodes_table - )) - .bind(file_uuid) - .bind(&ext_b) - .fetch_optional(pool) - .await?; - - let (n_a_id, n_b_id) = match (n_a, n_b) { - (Some((a,)), Some((b,))) => (a, b), - _ => continue, + let n_a_id = match node_id_cache.get(tid_a) { + Some(id) => *id, + None => { + if let Some((id,)) = sqlx::query_as::<_, (i64,)>(&format!( + "SELECT id FROM {} WHERE file_uuid=$1 AND node_type='face_trace' AND external_id=$2", + nodes_table + )) + .bind(file_uuid).bind(&ext_a).fetch_optional(pool).await? + { + node_id_cache.insert(*tid_a, id); + id + } else { continue; } + } }; - let edge_props = serde_json::json!({ - "first_frame": frames[0], - "frame_count": frames.len() as i64, - }); + let n_b_id = match node_id_cache.get(tid_b) { + Some(id) => *id, + None => { + if let Some((id,)) = sqlx::query_as::<_, (i64,)>(&format!( + "SELECT id FROM {} WHERE file_uuid=$1 AND node_type='face_trace' AND external_id=$2", + nodes_table + )) + .bind(file_uuid).bind(&ext_b).fetch_optional(pool).await? + { + node_id_cache.insert(*tid_b, id); + id + } else { continue; } + } + }; + + let frames: Vec = frame_data.iter().map(|(f, _)| *f).collect(); + let gaze_frames: Vec = frame_data.iter().filter(|(_, g)| *g).map(|(f, _)| *f).collect(); + let gaze_count = gaze_frames.len() as i64; + let has_gaze = gaze_count > 0; + + let edge_props = if has_gaze { + // Compute average yaw values for gaze frames + let mut yaw_a_sum = 0.0f64; + let mut yaw_b_sum = 0.0f64; + let mut gaze_sample = 0i64; + for (frame, _) in frame_data.iter().filter(|(_, g)| *g) { + let bbox_a = frame_map.get(&(*tid_a, *frame)); + let bbox_b = frame_map.get(&(*tid_b, *frame)); + if let (Some(&(xa, ya, wa, ha)), Some(&(xb, yb, wb, hb))) = (bbox_a, bbox_b) { + let pose_a = get_pose_for_face(*frame, xa, ya, wa, ha, pose_data); + let pose_b = get_pose_for_face(*frame, xb, yb, wb, hb, pose_data); + if let (Some((ya, _, _)), Some((yb, _, _))) = (pose_a, pose_b) { + yaw_a_sum += ya; + yaw_b_sum += yb; + gaze_sample += 1; + } + } + } + let (avg_ya, avg_yb) = if gaze_sample > 0 { + (yaw_a_sum / gaze_sample as f64, yaw_b_sum / gaze_sample as f64) + } else { (0.0, 0.0) }; + + serde_json::json!({ + "first_frame": frames[0], + "frame_count": frames.len() as i64, + "mutual_gaze": true, + "gaze_frame_count": gaze_count, + "yaw_a_avg": (avg_ya * 1000.0).round() / 1000.0, + "yaw_b_avg": (avg_yb * 1000.0).round() / 1000.0, + }) + } else { + serde_json::json!({ + "first_frame": frames[0], + "frame_count": frames.len() as i64, + "mutual_gaze": false, + }) + }; sqlx::query(&format!( r#"