Files
momentry_core/src/api/identity_binding.rs

413 lines
13 KiB
Rust

use axum::{
extract::{Path, Query},
http::StatusCode,
response::Json,
routing::{get, post},
Router,
};
use serde::{Deserialize, Serialize};
use crate::core::db::{Database, PostgresDb};
use crate::core::person_identity::{BindIdentityRequest, Identity, UnbindIdentityRequest};
#[derive(Debug, Clone, Serialize)]
pub struct ApiResponse<T: Serialize> {
pub success: bool,
pub message: String,
pub data: Option<T>,
}
// ============================================================================
// API Handlers
// ============================================================================
async fn get_db() -> Result<PostgresDb, (StatusCode, Json<serde_json::Value>)> {
PostgresDb::init().await.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": format!("DB init failed: {}", e) })),
)
})
}
/// 獲取 Identity (人物) 列表
pub async fn list_identities(
Query(params): Query<ListIdentitiesParams>,
) -> Result<Json<ApiResponse<Vec<Identity>>>, (StatusCode, Json<serde_json::Value>)> {
let db = get_db().await?;
let limit = params.limit.unwrap_or(100);
let offset = params.offset.unwrap_or(0);
let search = params.search.unwrap_or_default();
let identities = db
.list_identities(&search, limit, offset)
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": e.to_string() })),
)
})?;
Ok(Json(ApiResponse {
success: true,
message: format!("Found {} identities", identities.len()),
data: Some(identities),
}))
}
/// 綁定身份 (Face/Speaker -> Identity)
pub async fn bind_identity(
Json(req): Json<BindIdentityRequest>,
) -> Result<Json<ApiResponse<()>>, (StatusCode, Json<serde_json::Value>)> {
let db = get_db().await?;
let identity = if let Some(id_id) = req.identity_id {
db.get_identity_by_id(id_id).await.ok().flatten()
} else if let Some(name) = &req.name {
db.get_or_create_identity(name).await.ok()
} else {
None
};
let identity = match identity {
Some(t) => t,
None => {
return Err((
StatusCode::BAD_REQUEST,
Json(serde_json::json!({ "error": "Identity not found or name required" })),
));
}
};
let source = req.source.unwrap_or("manual".to_string());
db.bind_identity(
identity.id as i64,
&req.binding_type,
&req.binding_value,
&source,
1.0,
)
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": e.to_string() })),
)
})?;
Ok(Json(ApiResponse {
success: true,
message: format!(
"Bound {} '{}' to Identity '{}'",
req.binding_type, req.binding_value, identity.name
),
data: None,
}))
}
/// 解綁身份
pub async fn unbind_identity(
Json(req): Json<UnbindIdentityRequest>,
) -> Result<Json<ApiResponse<()>>, (StatusCode, Json<serde_json::Value>)> {
let db = get_db().await?;
db.unbind_identity(&req.binding_type, &req.binding_value)
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": e.to_string() })),
)
})?;
Ok(Json(ApiResponse {
success: true,
message: format!("Unbound {} '{}'", req.binding_type, req.binding_value),
data: None,
}))
}
/// 查詢機器 ID 對應的 Identity (人物)
pub async fn get_identity_info(
Path((binding_type, binding_value)): Path<(String, String)>,
) -> Result<Json<ApiResponse<Identity>>, (StatusCode, Json<serde_json::Value>)> {
let db = get_db().await?;
let identity = db
.get_identity_by_binding(&binding_type, &binding_value)
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": e.to_string() })),
)
})?;
let identity = identity.ok_or_else(|| {
(
StatusCode::NOT_FOUND,
Json(serde_json::json!({ "error": "Identity not found" })),
)
})?;
Ok(Json(ApiResponse {
success: true,
message: "Identity info retrieved".to_string(),
data: Some(identity),
}))
}
/// 列出未綁定的信號 (待標註列表)
pub async fn list_unbound_signals(
Query(params): Query<ListSignalsParams>,
) -> Result<Json<ApiResponse<Vec<String>>>, (StatusCode, Json<serde_json::Value>)> {
let db = get_db().await?;
let signals = db
.list_unbound_signals(&params.uuid, &params.binding_type)
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": e.to_string() })),
)
})?;
Ok(Json(ApiResponse {
success: true,
message: format!(
"Found {} unbound {} signals",
signals.len(),
params.binding_type
),
data: Some(signals),
}))
}
/// 獲取特定信號 (Face ID 或 Speaker ID) 出現的所有 Chunk (時間軸)
pub async fn get_signal_timeline(
Path((uuid, binding_type, binding_value)): Path<(String, String, String)>,
) -> Result<Json<ApiResponse<Vec<serde_json::Value>>>, (StatusCode, Json<serde_json::Value>)> {
let db = get_db().await?;
let chunks = db
.get_chunks_by_signal(&uuid, &binding_type, &binding_value)
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": e.to_string() })),
)
})?;
Ok(Json(ApiResponse {
success: true,
message: format!(
"Found {} chunks for {} '{}'",
chunks.len(),
binding_type,
binding_value
),
data: Some(chunks),
}))
}
#[derive(Debug, Deserialize)]
pub struct AVSuggestRequest {
pub video_uuid: String,
pub overlap_threshold: Option<f64>, // default 0.6
}
#[derive(Debug, Serialize)]
pub struct AVSuggestion {
pub face_id: String,
pub speaker_id: String,
pub overlap_score: f64,
pub face_talent_id: Option<i32>,
pub speaker_talent_id: Option<i32>,
}
/// Suggests Face-Speaker bindings based on temporal overlap
pub async fn suggest_audio_visual_bindings(
Json(req): Json<AVSuggestRequest>,
) -> Result<Json<ApiResponse<Vec<AVSuggestion>>>, (StatusCode, Json<serde_json::Value>)> {
let db = get_db().await?;
let threshold = req.overlap_threshold.unwrap_or(0.6);
// 1. Get Face signals and their time ranges
let face_signals = db
.list_unbound_signals(&req.video_uuid, "face")
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": format!("Face signals: {}", e) })),
)
})?;
let speaker_signals = db
.list_unbound_signals(&req.video_uuid, "speaker")
.await
.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(serde_json::json!({ "error": format!("Speaker signals: {}", e) })),
)
})?;
let mut suggestions = Vec::new();
for face_id in &face_signals {
for speaker_id in &speaker_signals {
// Calculate overlap
// In a real implementation, we would query the exact timestamps from DB.
// For now, we'll use a placeholder or simple heuristic if timestamps are available in signal timeline.
// Let's assume we fetch timelines.
// Placeholder: Calculate overlap by fetching timelines
let face_timeline = db
.get_chunks_by_signal(&req.video_uuid, "face", face_id)
.await
.unwrap_or_default();
let speaker_timeline = db
.get_chunks_by_signal(&req.video_uuid, "speaker", speaker_id)
.await
.unwrap_or_default();
// Simplified overlap calculation based on chunk count/ids (assuming chunk_id contains time info or we have ranges)
// Since chunk IDs are generic, we rely on content JSON having 'start_time'
let overlap = calculate_overlap(&face_timeline, &speaker_timeline);
if overlap >= threshold {
// Check if they are already bound to the same identity
let face_identity = db
.get_identity_by_binding("face", face_id)
.await
.ok()
.flatten();
let speaker_identity = db
.get_identity_by_binding("speaker", speaker_id)
.await
.ok()
.flatten();
// If both bound to different identities, don't suggest (conflict)
if let (Some(fi), Some(si)) = (&face_identity, &speaker_identity) {
if fi.id != si.id {
continue;
}
}
suggestions.push(AVSuggestion {
face_id: face_id.clone(),
speaker_id: speaker_id.clone(),
overlap_score: overlap,
face_talent_id: face_identity.as_ref().map(|i| i.id as i32),
speaker_talent_id: speaker_identity.as_ref().map(|i| i.id as i32),
});
}
}
}
suggestions.sort_by(|a, b| b.overlap_score.partial_cmp(&a.overlap_score).unwrap());
Ok(Json(ApiResponse {
success: true,
message: format!("Found {} AV suggestions", suggestions.len()),
data: Some(suggestions),
}))
}
fn calculate_overlap(
face_chunks: &[serde_json::Value],
speaker_chunks: &[serde_json::Value],
) -> f64 {
// Simplified: Extract start/end times and calculate intersection over union
// Assuming chunks have start_frame or start_time in content JSON
// If content is raw string, we might need to parse it.
// In our schema, content is JSONB.
let mut face_ranges: Vec<(f64, f64)> = Vec::new();
for c in face_chunks {
if let Some(content) = c.get("content") {
if let (Some(start), Some(end)) = (
content.get("start_time").and_then(|v| v.as_f64()),
content.get("end_time").and_then(|v| v.as_f64()),
) {
face_ranges.push((start, end));
}
}
}
let mut speaker_ranges: Vec<(f64, f64)> = Vec::new();
for c in speaker_chunks {
if let Some(content) = c.get("content") {
if let (Some(start), Some(end)) = (
content.get("start_time").and_then(|v| v.as_f64()),
content.get("end_time").and_then(|v| v.as_f64()),
) {
speaker_ranges.push((start, end));
}
}
}
let mut overlap_duration = 0.0;
for (fs, fe) in &face_ranges {
for (ss, se) in &speaker_ranges {
let start = fs.max(*ss);
let end = fe.min(*se);
if start < end {
overlap_duration += end - start;
}
}
}
// Return normalized overlap (0.0 to 1.0+), simple version: overlap / min_duration
let min_duration = face_ranges
.iter()
.map(|(_, e)| e)
.sum::<f64>()
.min(speaker_ranges.iter().map(|(_, e)| e).sum::<f64>());
if min_duration > 0.0 {
(overlap_duration / min_duration).min(1.0)
} else {
0.0
}
}
// ============================================================================
// Router Setup
// ============================================================================
#[derive(Debug, Deserialize)]
pub struct ListIdentitiesParams {
pub search: Option<String>,
pub limit: Option<i32>,
pub offset: Option<i32>,
}
#[derive(Debug, Deserialize)]
pub struct ListSignalsParams {
pub uuid: String,
pub binding_type: String, // "face" or "speaker"
}
pub fn identity_binding_routes() -> Router<crate::api::server::AppState> {
Router::new()
.route("/api/v1/identities/bind", post(bind_identity))
.route("/api/v1/identities/unbind", post(unbind_identity))
.route(
"/api/v1/identity/:binding_type/:binding_value",
get(get_identity_info),
)
// 信號發現 (Discovery)
.route("/api/v1/signals/unbound", get(list_unbound_signals))
// 信號時間軸 (Timeline)
.route(
"/api/v1/signals/:uuid/:binding_type/:binding_value/timeline",
get(get_signal_timeline),
)
.route(
"/api/v1/identities/suggest-av",
post(suggest_audio_visual_bindings),
)
}