feat: TKG extension - pose data + mutual gaze detection

This commit is contained in:
Accusys
2026-05-22 07:09:54 +08:00
parent a9e9285032
commit deb9516796
3 changed files with 393 additions and 49 deletions

View File

@@ -415,7 +415,96 @@ ORDER BY cooccurrence_count DESC
---
## 6. Phase 計畫
## 7. TKG 擴充Pose + Mutual Gaze
### 7.1 現狀
| 元件 | 有 pose | 有 mutual gaze |
|------|-----------|-----------------|
| `face.json` | ✅ yaw/pitch/roll | ❌ 未計算 |
| `face_detections.metadata` | ❌ 無 | ❌ 無 |
| `tkg_nodes` (face_trace) | ❌ 無 | — |
| `tkg_edges` (CO_OCCURS_WITH) | ❌ 無 | ❌ 無 |
### 7.2 目標
```
face_processor → face.json (有 pose)
tkg.rs (TKG builder)
├── 讀取 face.json 的 pose 資料
├── 算 avg_yaw/pitch/roll → 寫入 face_trace node
├── 對同框 face_trace 配對,判斷 mutual_gaze
└── 寫入 CO_OCCURS_WITH edge properties
tkg_edges.properties.mutual_gaze = true
API 查詢 → 互看 + 面積最大 → 代表 frame
```
### 7.3 face_trace node 新增 properties
```json
{
"frame_count": 53,
"start_frame": 38165,
"end_frame": 38321,
"avg_bbox": {"x": 731, "y": 215, "width": 228, "height": 228},
"avg_yaw": 0.014,
"avg_pitch": 0.224,
"avg_roll": -0.069
}
```
### 7.4 CO_OCCURS_WITH edge 新增 properties
```json
{
"first_frame": 38187,
"frame_count": 156,
"mutual_gaze": true,
"yaw_a_avg": 0.021,
"yaw_b_avg": -0.421,
"gaze_angle_delta": 0.442
}
```
### 7.5 Mutual Gaze 判斷邏輯
```python
GAZE_THRESHOLD = 0.05 # rad
def detect_mutual_gaze(frame_a, frame_b):
# 判斷 A 和 B 的左右位置關係
if bbox_a.cx < bbox_b.cx:
# A 在左B 在右 → A 要看右B 要看左
return yaw_a > GAZE_THRESHOLD and yaw_b < -GAZE_THRESHOLD
else:
# A 在右B 在左 → A 要看左B 要看右
return yaw_a < -GAZE_THRESHOLD and yaw_b > GAZE_THRESHOLD
```
### 7.6 代表 Frame 選取邏輯(應用端)
```
1. 查 TKG: MATCH (a)-[e:CO_OCCURS_WITH]->(b) WHERE e.mutual_gaze = true
→ 回傳互看 frame_count 最高的 face_trace 配對
2. 取該配對的 frame 範圍內,面積×信心最高 frame
3. FFmpeg blurdetect → 選最清晰的作為代表
```
### 7.7 Timeline
| Phase | 內容 | 工時 |
|-------|------|------|
| 1 | tkg.rs 讀取 face.json 的 pose + 寫入 node | 4-6h |
| 2 | mutual_gaze 判斷 + 寫入 edge | 3-4h |
| 3 | 對 Charade 重跑 TKG + 驗證 | 1h |
| 4 | 代表 frame 選取邏輯 | 2-3h |
---
## 8. Phase 計畫
| Phase | Query Types | 預計工時 | 依賴 |
|-------|------------|---------|------|

View File

@@ -29,6 +29,10 @@ pub fn trace_agent_routes() -> Router<crate::api::types::AppState> {
"/api/v1/file/:file_uuid/identities/:identity_uuid_a/co-occur-with/:identity_uuid_b",
get(get_cooccurrence),
)
.route(
"/api/v1/file/:file_uuid/tkg/rebuild",
post(rebuild_tkg),
)
}
#[derive(Debug, Deserialize)]
@@ -735,3 +739,47 @@ async fn get_cooccurrence(
},
}))
}
use crate::core::config::OUTPUT_DIR;
#[derive(Serialize)]
struct TkgRebuildResponse {
success: bool,
file_uuid: String,
result: Option<serde_json::Value>,
error: Option<String>,
}
async fn rebuild_tkg(
State(state): State<crate::api::types::AppState>,
Path(file_uuid): Path<String>,
) -> Json<TkgRebuildResponse> {
let result = crate::core::processor::tkg::build_tkg(
&state.db,
&file_uuid,
&OUTPUT_DIR,
)
.await;
match result {
Ok(r) => Json(TkgRebuildResponse {
success: true,
file_uuid,
result: Some(serde_json::json!({
"face_trace_nodes": r.face_trace_nodes,
"object_nodes": r.object_nodes,
"speaker_nodes": r.speaker_nodes,
"co_occurrence_edges": r.co_occurrence_edges,
"speaker_face_edges": r.speaker_face_edges,
"face_face_edges": r.face_face_edges,
})),
error: None,
}),
Err(e) => Json(TkgRebuildResponse {
success: false,
file_uuid,
result: None,
error: Some(e.to_string()),
}),
}
}

View File

@@ -15,6 +15,92 @@ fn t(name: &str) -> String {
}
}
// ── Pose data from face.json ────────────────────────────────────────
#[derive(Debug, Clone)]
struct FacePose {
frame: i64,
x: f64,
y: f64,
w: f64,
h: f64,
yaw: f64,
pitch: f64,
roll: f64,
}
fn load_face_pose_data(output_dir: &str, file_uuid: &str) -> Result<Vec<FacePose>> {
let path = Path::new(output_dir).join(format!("{}.face.json", file_uuid));
let content = std::fs::read_to_string(&path)
.with_context(|| format!("Failed to read face.json: {}", path.display()))?;
let json: serde_json::Value = serde_json::from_str(&content)?;
let mut poses = Vec::new();
if let Some(frames) = json.get("frames").and_then(|v| v.as_array()) {
for frame_entry in frames {
let frame_num = frame_entry.get("frame").and_then(|v| v.as_i64()).unwrap_or(0);
if let Some(faces) = frame_entry.get("faces").and_then(|v| v.as_array()) {
for face in faces {
let bbox = match face.get("bbox") {
Some(b) => b,
None => continue,
};
let pose = match face.get("pose") {
Some(p) => p,
None => continue,
};
poses.push(FacePose {
frame: frame_num,
x: bbox.get("x").and_then(|v| v.as_f64()).unwrap_or(0.0),
y: bbox.get("y").and_then(|v| v.as_f64()).unwrap_or(0.0),
w: bbox.get("width").and_then(|v| v.as_f64()).unwrap_or(0.0),
h: bbox.get("height").and_then(|v| v.as_f64()).unwrap_or(0.0),
yaw: pose.get("yaw").and_then(|v| v.as_f64()).unwrap_or(0.0),
pitch: pose.get("pitch").and_then(|v| v.as_f64()).unwrap_or(0.0),
roll: pose.get("roll").and_then(|v| v.as_f64()).unwrap_or(0.0),
});
}
}
}
}
Ok(poses)
}
/// Match a face from face_detections (frame, x, y, w, h) to its pose in face.json
/// Uses bbox center distance to find the best match when multiple faces per frame.
fn get_pose_for_face(frame: i64, x: f64, y: f64, w: f64, h: f64, poses: &[FacePose]) -> Option<(f64, f64, f64)> {
let cx = x + w / 2.0;
let cy = y + h / 2.0;
let mut best_dist = f64::MAX;
let mut result = None;
for p in poses.iter().filter(|p| p.frame == frame) {
let pcx = p.x + p.w / 2.0;
let pcy = p.y + p.h / 2.0;
let dist = (cx - pcx).abs() + (cy - pcy).abs();
if dist < best_dist {
best_dist = dist;
result = Some((p.yaw, p.pitch, p.roll));
}
}
result
}
fn detect_mutual_gaze(
bbox_a_x: f64, bbox_a_w: f64, yaw_a: f64,
bbox_b_x: f64, bbox_b_w: f64, yaw_b: f64,
threshold: f64,
) -> bool {
let cx_a = bbox_a_x + bbox_a_w / 2.0;
let cx_b = bbox_b_x + bbox_b_w / 2.0;
if cx_a < cx_b {
// A 在左B 在右 → A 要看右 (yaw > 0)B 要看左 (yaw < 0)
yaw_a > threshold && yaw_b < -threshold
} else {
// A 在右B 在左 → A 要看左 (yaw < 0)B 要看右 (yaw > 0)
yaw_a < -threshold && yaw_b > threshold
}
}
// ── Input data structs ────────────────────────────────────────────
#[derive(Debug, Deserialize)]
@@ -108,13 +194,16 @@ pub struct TkgResult {
pub async fn build_tkg(db: &PostgresDb, file_uuid: &str, output_dir: &str) -> Result<TkgResult> {
let pool = db.pool();
let n_face = build_face_trace_nodes(pool, file_uuid).await?;
let pose_data = load_face_pose_data(output_dir, file_uuid).unwrap_or_default();
tracing::info!("[TKG] Loaded {} pose entries from face.json", pose_data.len());
let n_face = build_face_trace_nodes(pool, file_uuid, &pose_data).await?;
let n_objects = build_yolo_object_nodes(pool, file_uuid, output_dir).await?;
let n_speakers = build_speaker_nodes(pool, file_uuid, output_dir).await?;
let e_co = build_co_occurrence_edges(pool, file_uuid, output_dir).await?;
let e_sf = build_speaker_face_edges(pool, file_uuid, output_dir).await?;
let e_ff = build_face_face_edges(pool, file_uuid).await?;
let e_ff = build_face_face_edges(pool, file_uuid, &pose_data).await?;
Ok(TkgResult {
face_trace_nodes: n_face,
@@ -128,16 +217,16 @@ pub async fn build_tkg(db: &PostgresDb, file_uuid: &str, output_dir: &str) -> Re
// ── Node builders ─────────────────────────────────────────────────
async fn build_face_trace_nodes(pool: &PgPool, file_uuid: &str) -> Result<usize> {
async fn build_face_trace_nodes(pool: &PgPool, file_uuid: &str, pose_data: &[FacePose]) -> Result<usize> {
let face_table = t("face_detections");
let nodes_table = t("tkg_nodes");
let rows = sqlx::query_as::<_, FaceTraceRow>(&format!(
r#"
SELECT trace_id,
SELECT trace_id::bigint,
COUNT(*)::bigint as frame_count,
MIN(frame_number) as start_f,
MAX(frame_number) as end_f,
MIN(frame_number)::bigint as start_f,
MAX(frame_number)::bigint as end_f,
AVG(x::float8) as avg_x,
AVG(y::float8) as avg_y,
AVG(width::float8) as avg_w,
@@ -153,10 +242,53 @@ async fn build_face_trace_nodes(pool: &PgPool, file_uuid: &str) -> Result<usize>
.fetch_all(pool)
.await?;
// Load per-frame data for pose matching
let frame_rows: Vec<(i64, i64, f64, f64, f64, f64)> = sqlx::query_as(
&format!(
"SELECT trace_id::bigint, frame_number::bigint, x::float8, y::float8, width::float8, height::float8 \
FROM {} WHERE file_uuid = $1 AND trace_id IS NOT NULL ORDER BY trace_id, frame_number",
face_table
)
)
.bind(file_uuid)
.fetch_all(pool)
.await?;
// Group by trace_id: trace_id → Vec<(frame, x, y, w, h)>
let mut trace_frames: HashMap<i64, Vec<(i64, f64, f64, f64, f64)>> = HashMap::new();
for (tid, frame, x, y, w, h) in &frame_rows {
trace_frames.entry(*tid).or_default().push((*frame, *x, *y, *w, *h));
}
let mut count = 0;
for row in &rows {
let external_id = format!("trace_{}", row.trace_id);
let label = format!("Face Trace {}", row.trace_id);
let tid = row.trace_id;
let external_id = format!("trace_{}", tid);
let label = format!("Face Trace {}", tid);
// Compute average pose for this trace
let mut yaw_sum = 0.0f64;
let mut pitch_sum = 0.0f64;
let mut roll_sum = 0.0f64;
let mut pose_count = 0i64;
if let Some(frames) = trace_frames.get(&tid) {
for (frame, x, y, w, h) in frames {
if let Some((yaw, pitch, roll)) = get_pose_for_face(*frame, *x, *y, *w, *h, pose_data) {
yaw_sum += yaw;
pitch_sum += pitch;
roll_sum += roll;
pose_count += 1;
}
}
}
let (avg_yaw, avg_pitch, avg_roll) = if pose_count > 0 {
(yaw_sum / pose_count as f64, pitch_sum / pose_count as f64, roll_sum / pose_count as f64)
} else {
(0.0, 0.0, 0.0)
};
let props = serde_json::json!({
"frame_count": row.frame_count,
"start_frame": row.start_f,
@@ -166,7 +298,11 @@ async fn build_face_trace_nodes(pool: &PgPool, file_uuid: &str) -> Result<usize>
"y": row.avg_y.unwrap_or(0.0).round() as i64,
"width": row.avg_w.unwrap_or(0.0).round() as i64,
"height": row.avg_h.unwrap_or(0.0).round() as i64,
}
},
"avg_yaw": (avg_yaw * 1000.0).round() / 1000.0,
"avg_pitch": (avg_pitch * 1000.0).round() / 1000.0,
"avg_roll": (avg_roll * 1000.0).round() / 1000.0,
"pose_count": pose_count,
});
sqlx::query(&format!(
@@ -312,7 +448,7 @@ async fn build_co_occurrence_edges(
let edges_table = t("tkg_edges");
let face_rows = sqlx::query_as::<_, FaceDetectionRow>(&format!(
r#"SELECT trace_id, frame_number, x, y, width, height
r#"SELECT trace_id::bigint, frame_number::bigint, x::float8, y::float8, width::float8, height::float8
FROM {} WHERE file_uuid = $1 AND trace_id IS NOT NULL
ORDER BY frame_number"#,
face_table
@@ -429,7 +565,7 @@ async fn build_speaker_face_edges(
let edges_table = t("tkg_edges");
let traces = sqlx::query_as::<_, (i64, i64, i64)>(&format!(
r#"SELECT trace_id, MIN(frame_number) as start_f, MAX(frame_number) as end_f
r#"SELECT trace_id::bigint, MIN(frame_number)::bigint as start_f, MAX(frame_number)::bigint as end_f
FROM {} WHERE file_uuid = $1 AND trace_id IS NOT NULL
GROUP BY trace_id"#,
face_table
@@ -533,14 +669,15 @@ async fn build_speaker_face_edges(
Ok(edge_count)
}
async fn build_face_face_edges(pool: &PgPool, file_uuid: &str) -> Result<usize> {
async fn build_face_face_edges(pool: &PgPool, file_uuid: &str, pose_data: &[FacePose]) -> Result<usize> {
let face_table = t("face_detections");
let nodes_table = t("tkg_nodes");
let edges_table = t("tkg_edges");
// Use SQL JOIN for fast co-occurrence detection
let rows: Vec<(i64, i64, i64)> = sqlx::query_as(&format!(
r#"
SELECT a.trace_id AS tid_a, b.trace_id AS tid_b, a.frame_number
SELECT a.trace_id::bigint AS tid_a, b.trace_id::bigint AS tid_b, a.frame_number::bigint
FROM {} a
JOIN {} b
ON a.file_uuid = b.file_uuid
@@ -557,53 +694,123 @@ async fn build_face_face_edges(pool: &PgPool, file_uuid: &str) -> Result<usize>
.fetch_all(pool)
.await?;
if rows.is_empty() {
return Ok(0);
// Also load per-frame bbox for mutual_gaze lookups
let bbox_data: Vec<(i64, i64, f64, f64, f64, f64)> = sqlx::query_as(
&format!(
"SELECT trace_id::bigint, frame_number::bigint, x::float8, y::float8, width::float8, height::float8 \
FROM {} WHERE file_uuid = $1 AND trace_id IS NOT NULL ORDER BY trace_id, frame_number",
face_table
)
)
.bind(file_uuid)
.fetch_all(pool)
.await?;
let mut frame_map: HashMap<(i64, i64), (f64, f64, f64, f64)> = HashMap::new(); // (trace_id, frame) → (x, y, w, h)
for (tid, frame, x, y, w, h) in &bbox_data {
frame_map.insert((*tid, *frame), (*x, *y, *w, *h));
}
// Deduplicate by pair
let mut pair_frames: HashMap<(i64, i64), Vec<i64>> = HashMap::new();
// Group by pair
let mut pair_frames: HashMap<(i64, i64), Vec<(i64, bool)>> = HashMap::new();
for (tid_a, tid_b, frame) in &rows {
let key = if *tid_a < *tid_b {
(*tid_a, *tid_b)
} else {
(*tid_b, *tid_a)
let key = (*tid_a.min(tid_b), *tid_a.max(tid_b));
let bbox_a = frame_map.get(&(*tid_a, *frame));
let bbox_b = frame_map.get(&(*tid_b, *frame));
let gaze = match (bbox_a, bbox_b) {
(Some(&(xa, ya, wa, ha)), Some(&(xb, yb, wb, hb))) => {
get_pose_for_face(*frame, xa, ya, wa, ha, pose_data)
.and_then(|(yaw_a, _, _)| {
get_pose_for_face(*frame, xb, yb, wb, hb, pose_data)
.map(|(yaw_b, _, _)| detect_mutual_gaze(xa, wa, yaw_a, xb, wb, yaw_b, 0.05))
})
.unwrap_or(false)
}
_ => false,
};
pair_frames.entry(key).or_default().push(*frame);
pair_frames.entry(key).or_default().push((*frame, gaze));
}
let mut edge_count = 0;
for ((tid_a, tid_b), frames) in &pair_frames {
// Cache node IDs to avoid repeated queries
let mut node_id_cache: HashMap<i64, i64> = HashMap::new();
for ((tid_a, tid_b), frame_data) in &pair_frames {
let ext_a = format!("trace_{}", tid_a);
let ext_b = format!("trace_{}", tid_b);
let n_a: Option<(i64,)> = sqlx::query_as(&format!(
"SELECT id FROM {} WHERE file_uuid=$1 AND node_type='face_trace' AND external_id=$2",
nodes_table
))
.bind(file_uuid)
.bind(&ext_a)
.fetch_optional(pool)
.await?;
let n_b: Option<(i64,)> = sqlx::query_as(&format!(
"SELECT id FROM {} WHERE file_uuid=$1 AND node_type='face_trace' AND external_id=$2",
nodes_table
))
.bind(file_uuid)
.bind(&ext_b)
.fetch_optional(pool)
.await?;
let (n_a_id, n_b_id) = match (n_a, n_b) {
(Some((a,)), Some((b,))) => (a, b),
_ => continue,
let n_a_id = match node_id_cache.get(tid_a) {
Some(id) => *id,
None => {
if let Some((id,)) = sqlx::query_as::<_, (i64,)>(&format!(
"SELECT id FROM {} WHERE file_uuid=$1 AND node_type='face_trace' AND external_id=$2",
nodes_table
))
.bind(file_uuid).bind(&ext_a).fetch_optional(pool).await?
{
node_id_cache.insert(*tid_a, id);
id
} else { continue; }
}
};
let edge_props = serde_json::json!({
"first_frame": frames[0],
"frame_count": frames.len() as i64,
});
let n_b_id = match node_id_cache.get(tid_b) {
Some(id) => *id,
None => {
if let Some((id,)) = sqlx::query_as::<_, (i64,)>(&format!(
"SELECT id FROM {} WHERE file_uuid=$1 AND node_type='face_trace' AND external_id=$2",
nodes_table
))
.bind(file_uuid).bind(&ext_b).fetch_optional(pool).await?
{
node_id_cache.insert(*tid_b, id);
id
} else { continue; }
}
};
let frames: Vec<i64> = frame_data.iter().map(|(f, _)| *f).collect();
let gaze_frames: Vec<i64> = frame_data.iter().filter(|(_, g)| *g).map(|(f, _)| *f).collect();
let gaze_count = gaze_frames.len() as i64;
let has_gaze = gaze_count > 0;
let edge_props = if has_gaze {
// Compute average yaw values for gaze frames
let mut yaw_a_sum = 0.0f64;
let mut yaw_b_sum = 0.0f64;
let mut gaze_sample = 0i64;
for (frame, _) in frame_data.iter().filter(|(_, g)| *g) {
let bbox_a = frame_map.get(&(*tid_a, *frame));
let bbox_b = frame_map.get(&(*tid_b, *frame));
if let (Some(&(xa, ya, wa, ha)), Some(&(xb, yb, wb, hb))) = (bbox_a, bbox_b) {
let pose_a = get_pose_for_face(*frame, xa, ya, wa, ha, pose_data);
let pose_b = get_pose_for_face(*frame, xb, yb, wb, hb, pose_data);
if let (Some((ya, _, _)), Some((yb, _, _))) = (pose_a, pose_b) {
yaw_a_sum += ya;
yaw_b_sum += yb;
gaze_sample += 1;
}
}
}
let (avg_ya, avg_yb) = if gaze_sample > 0 {
(yaw_a_sum / gaze_sample as f64, yaw_b_sum / gaze_sample as f64)
} else { (0.0, 0.0) };
serde_json::json!({
"first_frame": frames[0],
"frame_count": frames.len() as i64,
"mutual_gaze": true,
"gaze_frame_count": gaze_count,
"yaw_a_avg": (avg_ya * 1000.0).round() / 1000.0,
"yaw_b_avg": (avg_yb * 1000.0).round() / 1000.0,
})
} else {
serde_json::json!({
"first_frame": frames[0],
"frame_count": frames.len() as i64,
"mutual_gaze": false,
})
};
sqlx::query(&format!(
r#"