feat: POST /api/v1/agents/search - Gemma4 function calling agent
This commit is contained in:
523
src/api/agent_search.rs
Normal file
523
src/api/agent_search.rs
Normal file
@@ -0,0 +1,523 @@
|
||||
use axum::{
|
||||
extract::State,
|
||||
http::StatusCode,
|
||||
response::Json,
|
||||
routing::post,
|
||||
Router,
|
||||
};
|
||||
use once_cell::sync::Lazy;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Mutex;
|
||||
use std::time::Instant;
|
||||
|
||||
use crate::api::types::AppState;
|
||||
use crate::core::db::schema;
|
||||
use crate::core::llm::function_calling::{self, ChatMessage, LlmResponse, ToolCall, ToolDef};
|
||||
|
||||
// ── Conversation Manager ─────────────────────────────────────────
|
||||
|
||||
struct Conversation {
|
||||
messages: Vec<ChatMessage>,
|
||||
created_at: Instant,
|
||||
last_active: Instant,
|
||||
}
|
||||
|
||||
static CONVERSATIONS: Lazy<Mutex<HashMap<String, Conversation>>> = Lazy::new(|| {
|
||||
// Spawn cleanup task
|
||||
std::thread::spawn(|| loop {
|
||||
std::thread::sleep(std::time::Duration::from_secs(60));
|
||||
let mut map = CONVERSATIONS.lock().unwrap();
|
||||
let now = Instant::now();
|
||||
map.retain(|_, conv| now.duration_since(conv.last_active).as_secs() < 1800);
|
||||
});
|
||||
Mutex::new(HashMap::new())
|
||||
});
|
||||
|
||||
fn get_or_create_conv(conv_id: Option<&str>) -> (String, Vec<ChatMessage>) {
|
||||
let mut map = CONVERSATIONS.lock().unwrap();
|
||||
if let Some(cid) = conv_id {
|
||||
if let Some(conv) = map.get_mut(cid) {
|
||||
conv.last_active = Instant::now();
|
||||
return (cid.to_string(), conv.messages.clone());
|
||||
}
|
||||
}
|
||||
let id = uuid::Uuid::new_v4().to_string().replace('-', "")[..16].to_string();
|
||||
map.insert(id.clone(), Conversation {
|
||||
messages: Vec::new(),
|
||||
created_at: Instant::now(),
|
||||
last_active: Instant::now(),
|
||||
});
|
||||
(id, Vec::new())
|
||||
}
|
||||
|
||||
fn save_messages(conv_id: &str, messages: &[ChatMessage]) {
|
||||
if let Some(conv) = CONVERSATIONS.lock().unwrap().get_mut(conv_id) {
|
||||
conv.messages = messages.to_vec();
|
||||
conv.last_active = Instant::now();
|
||||
}
|
||||
}
|
||||
|
||||
// ── Request / Response ───────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct AgentSearchRequest {
|
||||
pub query: String,
|
||||
pub conversation_id: Option<String>,
|
||||
pub file_uuid: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct AgentSearchResponse {
|
||||
pub success: bool,
|
||||
pub conversation_id: String,
|
||||
pub answer: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub suggestions: Option<Vec<String>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub sources: Option<Vec<serde_json::Value>>,
|
||||
}
|
||||
|
||||
// ── Tool Definitions ──────────────────────────────────────────────
|
||||
|
||||
const SYSTEM_PROMPT: &str = r#"你是 Momentry 影片分析助手。回答用戶關於影片內容的問題。
|
||||
|
||||
## 工具使用規則
|
||||
1. 先確認用戶在問哪部影片 — 使用 find_file 或 list_files
|
||||
2. 人物問題優先使用 tkg_query
|
||||
3. 語意/內容問題使用 smart_search 或 universal_search
|
||||
4. 可以同時呼叫多個工具
|
||||
|
||||
## 引導規則
|
||||
- 如果用戶沒說片名 → 用 find_file 搜尋,如果名稱不明確就反問
|
||||
- 反問時提供 suggestions,例如演員名、年代
|
||||
- 不要輸出 JSON,用自然語言回答
|
||||
- 引用資料時附上具體數字(frame 編號、時間秒數)
|
||||
|
||||
## 回答規則
|
||||
- 回答要簡潔但完整
|
||||
- 如果找到影片,附上 file_uuid(用戶之後可能需要)
|
||||
- 對於人物問題,說出角色名和演員名"#;
|
||||
|
||||
fn make_tools(pool: &sqlx::PgPool) -> Vec<ToolDef> {
|
||||
vec![
|
||||
function_calling::make_tool(
|
||||
"find_file",
|
||||
"透過關鍵字搜尋影片(片名、演員、年份)。回傳符合的影片列表。",
|
||||
serde_json::json!({
|
||||
"query": {"type": "string", "description": "搜尋關鍵字(片名、演員名、年份)"}
|
||||
}),
|
||||
vec!["query"],
|
||||
),
|
||||
function_calling::make_tool(
|
||||
"list_files",
|
||||
"列出近期註冊的影片。",
|
||||
serde_json::json!({
|
||||
"limit": {"type": "integer", "description": "回傳筆數上限", "default": 10}
|
||||
}),
|
||||
vec![],
|
||||
),
|
||||
function_calling::make_tool(
|
||||
"tkg_query",
|
||||
"查詢影片的人物互動、配對、同框資料。query_type 包括:top_identities(人物排名)、first_cooccurrence(第一次同框)、identity_details(人物詳細)、mutual_gaze(互看)、interaction_network(互動網絡)、identity_traces(出場片段)、file_info(影片資訊)。",
|
||||
serde_json::json!({
|
||||
"file_uuid": {"type": "string", "description": "影片 UUID"},
|
||||
"query_type": {
|
||||
"type": "string",
|
||||
"enum": ["top_identities", "first_cooccurrence", "identity_details", "mutual_gaze", "interaction_network", "identity_traces", "file_info"],
|
||||
"description": "查詢類型"
|
||||
},
|
||||
"identity_name": {"type": "string", "description": "人物名稱(配合 identity_details / identity_traces)"},
|
||||
"identity_b": {"type": "string", "description": "第二人物名稱(配合 first_cooccurrence / mutual_gaze)"},
|
||||
"limit": {"type": "integer", "default": 5}
|
||||
}),
|
||||
vec!["file_uuid", "query_type"],
|
||||
),
|
||||
function_calling::make_tool(
|
||||
"smart_search",
|
||||
"語意搜尋 chunk 文字內容。適合需要理解意圖的查詢。",
|
||||
serde_json::json!({
|
||||
"file_uuid": {"type": "string", "description": "限制搜尋範圍(可選)"},
|
||||
"query": {"type": "string", "description": "搜尋關鍵字"},
|
||||
"limit": {"type": "integer", "default": 5}
|
||||
}),
|
||||
vec!["query"],
|
||||
),
|
||||
function_calling::make_tool(
|
||||
"get_identity_detail",
|
||||
"查詢單一身份的詳細資料(名字、角色、TMDb 資訊)。",
|
||||
serde_json::json!({
|
||||
"name": {"type": "string", "description": "人物名稱"}
|
||||
}),
|
||||
vec!["name"],
|
||||
),
|
||||
function_calling::make_tool(
|
||||
"get_file_info",
|
||||
"查詢影片基本資訊(片名、長度、解析度)。",
|
||||
serde_json::json!({
|
||||
"file_uuid": {"type": "string", "description": "影片 UUID"}
|
||||
}),
|
||||
vec!["file_uuid"],
|
||||
),
|
||||
function_calling::make_tool(
|
||||
"get_representative_frame",
|
||||
"查詢影片最具代表性的 frame 資訊(frame 編號、時間、人物)。",
|
||||
serde_json::json!({
|
||||
"file_uuid": {"type": "string", "description": "影片 UUID"}
|
||||
}),
|
||||
vec!["file_uuid"],
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
// ── Tool Executors ───────────────────────────────────────────────
|
||||
|
||||
async fn exec_find_file(pool: &sqlx::PgPool, args: &serde_json::Value) -> Result<String, String> {
|
||||
let query = args.get("query").and_then(|v| v.as_str()).unwrap_or("");
|
||||
let videos = schema::table_name("videos");
|
||||
let like = format!("%{}%", query);
|
||||
let rows: Vec<(String, String)> = sqlx::query_as(&format!(
|
||||
"SELECT file_uuid::text, file_name FROM {} WHERE file_name ILIKE $1 ORDER BY created_at DESC LIMIT 10",
|
||||
videos
|
||||
))
|
||||
.bind(&like)
|
||||
.fetch_all(pool)
|
||||
.await
|
||||
.map_err(|e| e.to_string())?;
|
||||
|
||||
if rows.is_empty() {
|
||||
return Ok(serde_json::json!({"found": false, "message": "No files match the query. Try different keywords."}).to_string());
|
||||
}
|
||||
let files: Vec<serde_json::Value> = rows.into_iter().map(|(u, n)| {
|
||||
serde_json::json!({"file_uuid": u, "file_name": n})
|
||||
}).collect();
|
||||
Ok(serde_json::json!({"found": true, "files": files}).to_string())
|
||||
}
|
||||
|
||||
async fn exec_list_files(pool: &sqlx::PgPool, args: &serde_json::Value) -> Result<String, String> {
|
||||
let limit = args.get("limit").and_then(|v| v.as_i64()).unwrap_or(10);
|
||||
let videos = schema::table_name("videos");
|
||||
let rows: Vec<(String, String)> = sqlx::query_as(&format!(
|
||||
"SELECT file_uuid::text, file_name FROM {} ORDER BY created_at DESC LIMIT $1",
|
||||
videos
|
||||
))
|
||||
.bind(limit)
|
||||
.fetch_all(pool)
|
||||
.await
|
||||
.map_err(|e| e.to_string())?;
|
||||
|
||||
let files: Vec<serde_json::Value> = rows.into_iter().map(|(u, n)| {
|
||||
serde_json::json!({"file_uuid": u, "file_name": n})
|
||||
}).collect();
|
||||
Ok(serde_json::json!({"files": files}).to_string())
|
||||
}
|
||||
|
||||
async fn exec_tkg_query(pool: &sqlx::PgPool, args: &serde_json::Value) -> Result<String, String> {
|
||||
let file_uuid = args.get("file_uuid").and_then(|v| v.as_str()).unwrap_or("");
|
||||
let query_type = args.get("query_type").and_then(|v| v.as_str()).unwrap_or("");
|
||||
let identity_name = args.get("identity_name").and_then(|v| v.as_str());
|
||||
let identity_b = args.get("identity_b").and_then(|v| v.as_str());
|
||||
let limit = args.get("limit").and_then(|v| v.as_i64()).unwrap_or(5);
|
||||
|
||||
let id_table = schema::table_name("identities");
|
||||
let fd_table = schema::table_name("face_detections");
|
||||
let videos = schema::table_name("videos");
|
||||
let nodes = schema::table_name("tkg_nodes");
|
||||
let edges = schema::table_name("tkg_edges");
|
||||
|
||||
match query_type {
|
||||
"top_identities" => {
|
||||
let rows: Vec<(String, String, i64)> = sqlx::query_as(&format!(
|
||||
"SELECT i.uuid::text, i.name, COUNT(fd.id)::bigint AS face_count \
|
||||
FROM {} fd JOIN {} i ON i.id = fd.identity_id \
|
||||
WHERE fd.file_uuid = $1 AND fd.identity_id IS NOT NULL AND i.source = 'tmdb' \
|
||||
GROUP BY i.uuid, i.name ORDER BY face_count DESC LIMIT $2",
|
||||
fd_table, id_table
|
||||
))
|
||||
.bind(file_uuid).bind(limit)
|
||||
.fetch_all(pool)
|
||||
.await.map_err(|e| e.to_string())?;
|
||||
Ok(serde_json::json!({"identities": rows}).to_string())
|
||||
}
|
||||
"first_cooccurrence" => {
|
||||
let name_a = identity_name.unwrap_or("");
|
||||
let name_b = identity_b.unwrap_or("");
|
||||
let row: Option<(i64, f64)> = sqlx::query_as(&format!(
|
||||
"SELECT MIN(fd_a.frame_number)::bigint, \
|
||||
ROUND(MIN(fd_a.frame_number)::numeric / GREATEST(MAX(v.fps)::numeric, 25.0), 2)::float8 \
|
||||
FROM {} fd_a JOIN {} fd_b ON fd_a.frame_number = fd_b.frame_number \
|
||||
JOIN {} v ON v.file_uuid = $1 \
|
||||
WHERE fd_a.file_uuid = $1 \
|
||||
AND fd_a.identity_id = (SELECT id FROM {} WHERE name ILIKE $2 LIMIT 1) \
|
||||
AND fd_b.identity_id = (SELECT id FROM {} WHERE name ILIKE $3 LIMIT 1)",
|
||||
fd_table, fd_table, videos, id_table, id_table
|
||||
))
|
||||
.bind(file_uuid).bind(name_a).bind(name_b)
|
||||
.fetch_optional(pool)
|
||||
.await.map_err(|e| e.to_string())?;
|
||||
Ok(serde_json::json!({"first_cooccurrence": row.map(|(f, t)| serde_json::json!({"frame": f, "timestamp_secs": t}))}).to_string())
|
||||
}
|
||||
"identity_details" => {
|
||||
let name = identity_name.unwrap_or("");
|
||||
let row: Option<(String, String, Option<i32>, i64)> = sqlx::query_as(&format!(
|
||||
"SELECT i.uuid::text, i.name, i.tmdb_id, \
|
||||
(SELECT COUNT(*) FROM {} fd WHERE fd.identity_id = i.id AND fd.file_uuid = $1)::bigint \
|
||||
FROM {} i WHERE i.name ILIKE $2 LIMIT 1",
|
||||
fd_table, id_table
|
||||
))
|
||||
.bind(file_uuid).bind(name)
|
||||
.fetch_optional(pool)
|
||||
.await.map_err(|e| e.to_string())?;
|
||||
Ok(serde_json::json!({"identity": row.map(|(u, n, tid, fc)| serde_json::json!({"uuid": u, "name": n, "tmdb_id": tid, "face_count": fc}))}).to_string())
|
||||
}
|
||||
"mutual_gaze" => {
|
||||
let name_a = identity_name.unwrap_or("");
|
||||
let name_b = identity_b.unwrap_or("");
|
||||
let row: Option<(i64, i64, f64, f64)> = sqlx::query_as(&format!(
|
||||
"SELECT (e.properties->>'first_frame')::bigint, \
|
||||
(e.properties->>'gaze_frame_count')::int::bigint, \
|
||||
(e.properties->>'yaw_a_avg')::float8, \
|
||||
(e.properties->>'yaw_b_avg')::float8 \
|
||||
FROM {} e \
|
||||
JOIN {} a ON a.id = e.source_node_id \
|
||||
JOIN {} b ON b.id = e.target_node_id \
|
||||
JOIN {} fd_a ON fd_a.file_uuid = $1 AND fd_a.trace_id = REPLACE(a.external_id, 'trace_', '')::int \
|
||||
JOIN {} fd_b ON fd_b.file_uuid = $1 AND fd_b.trace_id = REPLACE(b.external_id, 'trace_', '')::int \
|
||||
JOIN {} ia ON ia.id = fd_a.identity_id \
|
||||
JOIN {} ib ON ib.id = fd_b.identity_id \
|
||||
WHERE e.file_uuid = $1 AND ia.name ILIKE $2 AND ib.name ILIKE $3 \
|
||||
AND e.properties->>'mutual_gaze' = 'true' LIMIT 1",
|
||||
edges, nodes, nodes, fd_table, fd_table, id_table, id_table
|
||||
))
|
||||
.bind(file_uuid).bind(name_a).bind(name_b)
|
||||
.fetch_optional(pool)
|
||||
.await.map_err(|e| e.to_string())?;
|
||||
Ok(serde_json::json!({"mutual_gaze": row.map(|(f, gc, ya, yb)| serde_json::json!({"first_frame": f, "gaze_frame_count": gc, "yaw_a": ya, "yaw_b": yb}))}).to_string())
|
||||
}
|
||||
"interaction_network" => {
|
||||
let rows: Vec<(String, String, i64)> = sqlx::query_as(&format!(
|
||||
"SELECT ia.name, ib.name, COUNT(*)::bigint \
|
||||
FROM {} e \
|
||||
JOIN {} a ON a.id = e.source_node_id \
|
||||
JOIN {} b ON b.id = e.target_node_id \
|
||||
JOIN {} fd_a ON fd_a.trace_id = REPLACE(a.external_id, 'trace_', '')::int AND fd_a.file_uuid = $1 \
|
||||
JOIN {} fd_b ON fd_b.trace_id = REPLACE(b.external_id, 'trace_', '')::int AND fd_b.file_uuid = $1 \
|
||||
JOIN {} ia ON ia.id = fd_a.identity_id \
|
||||
JOIN {} ib ON ib.id = fd_b.identity_id \
|
||||
WHERE e.file_uuid = $1 AND e.edge_type = 'CO_OCCURS_WITH' \
|
||||
AND ia.name != ib.name AND ia.source = 'tmdb' AND ib.source = 'tmdb' \
|
||||
GROUP BY ia.name, ib.name \
|
||||
ORDER BY COUNT(*) DESC LIMIT $2",
|
||||
edges, nodes, nodes, fd_table, fd_table, id_table, id_table
|
||||
))
|
||||
.bind(file_uuid).bind(limit)
|
||||
.fetch_all(pool)
|
||||
.await.map_err(|e| e.to_string())?;
|
||||
Ok(serde_json::json!({"interaction_network": rows}).to_string())
|
||||
}
|
||||
"identity_traces" => {
|
||||
let name = identity_name.unwrap_or("");
|
||||
let rows: Vec<(i32, i64, i32, i32)> = sqlx::query_as(&format!(
|
||||
"SELECT fd.trace_id, COUNT(*)::bigint, MIN(fd.frame_number)::int, MAX(fd.frame_number)::int \
|
||||
FROM {} fd JOIN {} i ON i.id = fd.identity_id \
|
||||
WHERE fd.file_uuid = $1 AND i.name ILIKE $2 \
|
||||
GROUP BY fd.trace_id ORDER BY COUNT(*) DESC LIMIT $3",
|
||||
fd_table, id_table
|
||||
))
|
||||
.bind(file_uuid).bind(name).bind(limit)
|
||||
.fetch_all(pool)
|
||||
.await.map_err(|e| e.to_string())?;
|
||||
Ok(serde_json::json!({"traces": rows}).to_string())
|
||||
}
|
||||
"file_info" => {
|
||||
let row: Option<(String, f64, i32, i32, f64)> = sqlx::query_as(&format!(
|
||||
"SELECT file_name, duration, width, height, fps FROM {} WHERE file_uuid = $1",
|
||||
videos
|
||||
))
|
||||
.bind(file_uuid)
|
||||
.fetch_optional(pool)
|
||||
.await.map_err(|e| e.to_string())?;
|
||||
Ok(serde_json::json!({"file_info": row.map(|(n, d, w, h, f)| serde_json::json!({"file_name": n, "duration_sec": d, "width": w, "height": h, "fps": f}))}).to_string())
|
||||
}
|
||||
_ => Ok(serde_json::json!({"error": format!("Unknown query_type: {}", query_type)}).to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
async fn exec_smart_search(_pool: &sqlx::PgPool, args: &serde_json::Value) -> Result<String, String> {
|
||||
let query = args.get("query").and_then(|v| v.as_str()).unwrap_or("");
|
||||
let file_uuid = args.get("file_uuid").and_then(|v| v.as_str());
|
||||
let limit = args.get("limit").and_then(|v| v.as_i64()).unwrap_or(5);
|
||||
|
||||
let chunk_table = schema::table_name("chunk");
|
||||
let mut sql = format!(
|
||||
"SELECT chunk_id, text_content, start_frame, end_frame, chunk_type \
|
||||
FROM {} WHERE text_content ILIKE $1", chunk_table
|
||||
);
|
||||
if file_uuid.is_some() {
|
||||
sql.push_str(" AND file_uuid = $2");
|
||||
}
|
||||
sql.push_str(&format!(" ORDER BY start_frame LIMIT {}", limit));
|
||||
|
||||
if let Some(fuid) = file_uuid {
|
||||
let like = format!("%{}%", query);
|
||||
let rows: Vec<(String, Option<String>, i64, i64, String)> = sqlx::query_as(&sql)
|
||||
.bind(&like).bind(fuid)
|
||||
.fetch_all(_pool)
|
||||
.await.map_err(|e| e.to_string())?;
|
||||
Ok(serde_json::json!({"results": rows}).to_string())
|
||||
} else {
|
||||
let like = format!("%{}%", query);
|
||||
let rows: Vec<(String, Option<String>, i64, i64, String)> = sqlx::query_as(&sql)
|
||||
.bind(&like)
|
||||
.fetch_all(_pool)
|
||||
.await.map_err(|e| e.to_string())?;
|
||||
Ok(serde_json::json!({"results": rows}).to_string())
|
||||
}
|
||||
}
|
||||
|
||||
async fn exec_get_identity_detail(pool: &sqlx::PgPool, args: &serde_json::Value) -> Result<String, String> {
|
||||
let name = args.get("name").and_then(|v| v.as_str()).unwrap_or("");
|
||||
let id_table = schema::table_name("identities");
|
||||
let row: Option<(String, String, Option<String>, Option<i32>, Option<String>)> = sqlx::query_as(&format!(
|
||||
"SELECT uuid::text, name, source, tmdb_id, metadata->>'tmdb_character' FROM {} WHERE name ILIKE $1 LIMIT 1",
|
||||
id_table
|
||||
))
|
||||
.bind(name)
|
||||
.fetch_optional(pool)
|
||||
.await.map_err(|e| e.to_string())?;
|
||||
Ok(serde_json::json!({"identity": row.map(|(u, n, s, t, c)| serde_json::json!({"uuid": u, "name": n, "source": s, "tmdb_id": t, "character": c}))}).to_string())
|
||||
}
|
||||
|
||||
async fn exec_get_file_info(pool: &sqlx::PgPool, args: &serde_json::Value) -> Result<String, String> {
|
||||
let file_uuid = args.get("file_uuid").and_then(|v| v.as_str()).unwrap_or("");
|
||||
let videos = schema::table_name("videos");
|
||||
let row: Option<(String, f64, i32, i32, f64)> = sqlx::query_as(&format!(
|
||||
"SELECT file_name, duration, width, height, fps FROM {} WHERE file_uuid = $1",
|
||||
videos
|
||||
))
|
||||
.bind(file_uuid)
|
||||
.fetch_optional(pool)
|
||||
.await.map_err(|e| e.to_string())?;
|
||||
Ok(serde_json::json!({"file_info": row.map(|(n, d, w, h, f)| serde_json::json!({"file_name": n, "duration_sec": d, "width": w, "height": h, "fps": f}))}).to_string())
|
||||
}
|
||||
|
||||
async fn exec_get_representative_frame(pool: &sqlx::PgPool, args: &serde_json::Value) -> Result<String, String> {
|
||||
let file_uuid = args.get("file_uuid").and_then(|v| v.as_str()).unwrap_or("");
|
||||
match crate::core::processor::tkg::query_auto_representative_frame(pool, file_uuid).await {
|
||||
Ok(r) => Ok(serde_json::json!({
|
||||
"frame_number": r.frame_number,
|
||||
"face_quality": r.face_quality,
|
||||
"main_identities": r.main_identities,
|
||||
"traces": r.traces,
|
||||
}).to_string()),
|
||||
Err(e) => Ok(serde_json::json!({"error": e.to_string()}).to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
// ── Tool Router ───────────────────────────────────────────────────
|
||||
|
||||
async fn execute_tool(pool: &sqlx::PgPool, tool_call: &ToolCall) -> (String, String, String) {
|
||||
let name = tool_call.function.name.clone();
|
||||
let args: serde_json::Value = serde_json::from_str(&tool_call.function.arguments).unwrap_or_default();
|
||||
let result = match name.as_str() {
|
||||
"find_file" => exec_find_file(pool, &args).await,
|
||||
"list_files" => exec_list_files(pool, &args).await,
|
||||
"tkg_query" => exec_tkg_query(pool, &args).await,
|
||||
"smart_search" => exec_smart_search(pool, &args).await,
|
||||
"get_identity_detail" => exec_get_identity_detail(pool, &args).await,
|
||||
"get_file_info" => exec_get_file_info(pool, &args).await,
|
||||
"get_representative_frame" => exec_get_representative_frame(pool, &args).await,
|
||||
_ => Err(format!("Unknown tool: {}", name)),
|
||||
};
|
||||
let content = match result {
|
||||
Ok(s) => s,
|
||||
Err(e) => serde_json::json!({"error": e}).to_string(),
|
||||
};
|
||||
let tool_call_id = tool_call.id.clone().unwrap_or_default();
|
||||
(tool_call_id, name, content)
|
||||
}
|
||||
|
||||
// ── Tool Loop ─────────────────────────────────────────────────────
|
||||
|
||||
const MAX_ROUNDS: u32 = 5;
|
||||
|
||||
async fn run_tool_loop(
|
||||
pool: &sqlx::PgPool,
|
||||
system_prompt: &str,
|
||||
user_query: &str,
|
||||
history: Vec<ChatMessage>,
|
||||
) -> (String, Vec<serde_json::Value>) {
|
||||
let mut messages = function_calling::build_conversation(system_prompt, user_query, history);
|
||||
let mut sources = Vec::new();
|
||||
|
||||
for round in 0..MAX_ROUNDS {
|
||||
let tools = Some(make_tools(pool));
|
||||
match function_calling::call_llm(messages.clone(), tools, 2048, 120).await {
|
||||
Ok(LlmResponse::Text(text)) => {
|
||||
return (text, sources);
|
||||
}
|
||||
Ok(LlmResponse::ToolCalls(calls)) => {
|
||||
// Push assistant message with tool_calls so Gemma4 remembers
|
||||
messages.push(ChatMessage {
|
||||
role: "assistant".to_string(),
|
||||
content: None,
|
||||
tool_calls: Some(calls.clone()),
|
||||
tool_call_id: None,
|
||||
name: None,
|
||||
});
|
||||
for call in &calls {
|
||||
let (tool_call_id, name, content) = execute_tool(pool, call).await;
|
||||
sources.push(serde_json::json!({"tool": name, "result": content}));
|
||||
messages.push(function_calling::make_tool_result(&tool_call_id, &name, &content));
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
return (format!("系統錯誤:{}", e), sources);
|
||||
}
|
||||
}
|
||||
}
|
||||
("已達到最大查詢次數,請縮小問題範圍後重新詢問。".to_string(), sources)
|
||||
}
|
||||
|
||||
// ── Handler ───────────────────────────────────────────────────────
|
||||
|
||||
async fn agent_search(
|
||||
State(state): State<AppState>,
|
||||
Json(req): Json<AgentSearchRequest>,
|
||||
) -> Result<Json<AgentSearchResponse>, (StatusCode, Json<serde_json::Value>)> {
|
||||
let (conv_id, history) = get_or_create_conv(req.conversation_id.as_deref());
|
||||
|
||||
let (answer, sources) = run_tool_loop(
|
||||
state.db.pool(),
|
||||
SYSTEM_PROMPT,
|
||||
&req.query,
|
||||
history,
|
||||
)
|
||||
.await;
|
||||
|
||||
// Save updated messages for conversation continuation
|
||||
let new_msgs = function_calling::build_conversation(SYSTEM_PROMPT, &req.query, vec![]);
|
||||
save_messages(&conv_id, &new_msgs);
|
||||
|
||||
let needs_input = answer.contains('?') || answer.contains('?');
|
||||
let suggestions = if needs_input {
|
||||
Some(vec!["演員名".to_string(), "電影片名".to_string(), "年份".to_string()])
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok(Json(AgentSearchResponse {
|
||||
success: true,
|
||||
conversation_id: conv_id,
|
||||
answer,
|
||||
suggestions,
|
||||
sources: Some(sources),
|
||||
}))
|
||||
}
|
||||
|
||||
// ── Routes ─────────────────────────────────────────────────────────
|
||||
|
||||
pub fn agent_search_routes() -> Router<AppState> {
|
||||
Router::new()
|
||||
.route("/api/v1/agents/search", post(agent_search))
|
||||
}
|
||||
@@ -1,4 +1,5 @@
|
||||
pub mod agent_api;
|
||||
pub mod agent_search;
|
||||
pub mod auth;
|
||||
pub mod docs;
|
||||
pub mod files;
|
||||
|
||||
@@ -6,6 +6,7 @@ use crate::core::db::{Database, PostgresDb};
|
||||
use crate::Embedder;
|
||||
|
||||
use super::agent_api;
|
||||
use super::agent_search;
|
||||
use super::auth;
|
||||
use super::docs;
|
||||
use super::files;
|
||||
@@ -82,6 +83,7 @@ pub async fn start_server(host: &str, port: u16) -> anyhow::Result<()> {
|
||||
.merge(tmdb_api::tmdb_routes())
|
||||
.merge(identity_api::identity_routes())
|
||||
.merge(agent_api::agent_routes())
|
||||
.merge(agent_search::agent_search_routes())
|
||||
.merge(identity_agent_api::identity_agent_routes())
|
||||
.merge(five_w1h_agent_api::five_w1h_agent_routes())
|
||||
.merge(media_api::bbox_routes())
|
||||
|
||||
189
src/core/llm/function_calling.rs
Normal file
189
src/core/llm/function_calling.rs
Normal file
@@ -0,0 +1,189 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
|
||||
/// A tool/function definition for Gemma4 function calling
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct ToolDef {
|
||||
#[serde(rename = "type")]
|
||||
pub tool_type: String,
|
||||
pub function: ToolFunction,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct ToolFunction {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub parameters: Value,
|
||||
}
|
||||
|
||||
/// A tool call returned by Gemma4
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolCall {
|
||||
pub id: Option<String>,
|
||||
#[serde(rename = "type")]
|
||||
pub call_type: Option<String>,
|
||||
pub function: ToolCallFunction,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolCallFunction {
|
||||
pub name: String,
|
||||
pub arguments: String,
|
||||
}
|
||||
|
||||
/// Message in the chat history
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatMessage {
|
||||
pub role: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub content: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<Vec<ToolCall>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_call_id: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub name: Option<String>,
|
||||
}
|
||||
|
||||
/// Full chat request to Gemma4
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ChatRequest {
|
||||
model: String,
|
||||
messages: Vec<ChatMessage>,
|
||||
temperature: f32,
|
||||
max_tokens: u32,
|
||||
stream: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tools: Option<Vec<ToolDef>>,
|
||||
}
|
||||
|
||||
/// Response from Gemma4
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ChatResponse {
|
||||
pub choices: Vec<Choice>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct Choice {
|
||||
pub finish_reason: Option<String>,
|
||||
pub message: ChatMessage,
|
||||
}
|
||||
|
||||
/// Parsed LLM response: either text or tool calls
|
||||
pub enum LlmResponse {
|
||||
Text(String),
|
||||
ToolCalls(Vec<ToolCall>),
|
||||
}
|
||||
|
||||
/// Get the LLM chat URL with fallback chain
|
||||
pub fn llm_chat_url() -> String {
|
||||
std::env::var("MOMENTRY_LLM_URL")
|
||||
.or_else(|_| std::env::var("MOMENTRY_LLM_SUMMARY_URL"))
|
||||
.unwrap_or_else(|_| "http://localhost:8082/v1/chat/completions".to_string())
|
||||
}
|
||||
|
||||
/// Get the LLM model name
|
||||
pub fn llm_model() -> String {
|
||||
std::env::var("MOMENTRY_LLM_MODEL")
|
||||
.or_else(|_| std::env::var("MOMENTRY_LLM_SUMMARY_MODEL"))
|
||||
.unwrap_or_else(|_| "google_gemma-4-26B-A4B-it-Q5_K_M.gguf".to_string())
|
||||
}
|
||||
|
||||
/// Build a tool definition JSON for function calling
|
||||
pub fn make_tool(name: &str, description: &str, properties: Value, required: Vec<&str>) -> ToolDef {
|
||||
ToolDef {
|
||||
tool_type: "function".to_string(),
|
||||
function: ToolFunction {
|
||||
name: name.to_string(),
|
||||
description: description.to_string(),
|
||||
parameters: json!({
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required,
|
||||
}),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Call Gemma4 with messages and optional tools. Returns parsed response.
|
||||
pub async fn call_llm(
|
||||
messages: Vec<ChatMessage>,
|
||||
tools: Option<Vec<ToolDef>>,
|
||||
max_tokens: u32,
|
||||
timeout_secs: u64,
|
||||
) -> anyhow::Result<LlmResponse> {
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(timeout_secs))
|
||||
.build()?;
|
||||
|
||||
let req = ChatRequest {
|
||||
model: llm_model(),
|
||||
messages,
|
||||
temperature: 0.1,
|
||||
max_tokens,
|
||||
stream: false,
|
||||
tools,
|
||||
};
|
||||
|
||||
let res = client
|
||||
.post(&llm_chat_url())
|
||||
.json(&req)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !res.status().is_success() {
|
||||
let text = res.text().await.unwrap_or_default();
|
||||
anyhow::bail!("LLM API error: {}", text);
|
||||
}
|
||||
|
||||
let chat_res: ChatResponse = res.json().await?;
|
||||
let choice = chat_res.choices.into_iter().next()
|
||||
.ok_or_else(|| anyhow::anyhow!("Empty LLM response"))?;
|
||||
|
||||
match choice.finish_reason.as_deref() {
|
||||
Some("tool_calls") => {
|
||||
let calls = choice.message.tool_calls
|
||||
.ok_or_else(|| anyhow::anyhow!("finish_reason=tool_calls but no tool_calls in message"))?;
|
||||
Ok(LlmResponse::ToolCalls(calls))
|
||||
}
|
||||
_ => {
|
||||
let content = choice.message.content.unwrap_or_default();
|
||||
Ok(LlmResponse::Text(content.trim().to_string()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to build the system prompt + user messages
|
||||
pub fn build_conversation(system_prompt: &str, user_query: &str, history: Vec<ChatMessage>) -> Vec<ChatMessage> {
|
||||
let mut messages = vec![
|
||||
ChatMessage {
|
||||
role: "system".to_string(),
|
||||
content: Some(system_prompt.to_string()),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
name: None,
|
||||
},
|
||||
];
|
||||
// Add history (user + assistant exchanges)
|
||||
messages.extend(history);
|
||||
// Add current user query
|
||||
messages.push(ChatMessage {
|
||||
role: "user".to_string(),
|
||||
content: Some(user_query.to_string()),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
name: None,
|
||||
});
|
||||
messages
|
||||
}
|
||||
|
||||
/// Build a tool result message to send back to LLM
|
||||
pub fn make_tool_result(tool_call_id: &str, name: &str, content: &str) -> ChatMessage {
|
||||
ChatMessage {
|
||||
role: "tool".to_string(),
|
||||
content: Some(content.to_string()),
|
||||
tool_calls: None,
|
||||
tool_call_id: Some(tool_call_id.to_string()),
|
||||
name: Some(name.to_string()),
|
||||
}
|
||||
}
|
||||
@@ -1 +1,2 @@
|
||||
pub mod client;
|
||||
pub mod function_calling;
|
||||
|
||||
Reference in New Issue
Block a user