From 380dd87d8b05edeadb52406d930eeb77ea5fd0bf Mon Sep 17 00:00:00 2001 From: Accusys Date: Fri, 22 May 2026 12:10:37 +0800 Subject: [PATCH] feat: POST /api/v1/agents/search - Gemma4 function calling agent --- src/api/agent_search.rs | 523 +++++++++++++++++++++++++++++++ src/api/mod.rs | 1 + src/api/server.rs | 2 + src/core/llm/function_calling.rs | 189 +++++++++++ src/core/llm/mod.rs | 1 + 5 files changed, 716 insertions(+) create mode 100644 src/api/agent_search.rs create mode 100644 src/core/llm/function_calling.rs diff --git a/src/api/agent_search.rs b/src/api/agent_search.rs new file mode 100644 index 0000000..64cedbc --- /dev/null +++ b/src/api/agent_search.rs @@ -0,0 +1,523 @@ +use axum::{ + extract::State, + http::StatusCode, + response::Json, + routing::post, + Router, +}; +use once_cell::sync::Lazy; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::sync::Mutex; +use std::time::Instant; + +use crate::api::types::AppState; +use crate::core::db::schema; +use crate::core::llm::function_calling::{self, ChatMessage, LlmResponse, ToolCall, ToolDef}; + +// ── Conversation Manager ───────────────────────────────────────── + +struct Conversation { + messages: Vec, + created_at: Instant, + last_active: Instant, +} + +static CONVERSATIONS: Lazy>> = Lazy::new(|| { + // Spawn cleanup task + std::thread::spawn(|| loop { + std::thread::sleep(std::time::Duration::from_secs(60)); + let mut map = CONVERSATIONS.lock().unwrap(); + let now = Instant::now(); + map.retain(|_, conv| now.duration_since(conv.last_active).as_secs() < 1800); + }); + Mutex::new(HashMap::new()) +}); + +fn get_or_create_conv(conv_id: Option<&str>) -> (String, Vec) { + let mut map = CONVERSATIONS.lock().unwrap(); + if let Some(cid) = conv_id { + if let Some(conv) = map.get_mut(cid) { + conv.last_active = Instant::now(); + return (cid.to_string(), conv.messages.clone()); + } + } + let id = uuid::Uuid::new_v4().to_string().replace('-', "")[..16].to_string(); + map.insert(id.clone(), Conversation { + messages: Vec::new(), + created_at: Instant::now(), + last_active: Instant::now(), + }); + (id, Vec::new()) +} + +fn save_messages(conv_id: &str, messages: &[ChatMessage]) { + if let Some(conv) = CONVERSATIONS.lock().unwrap().get_mut(conv_id) { + conv.messages = messages.to_vec(); + conv.last_active = Instant::now(); + } +} + +// ── Request / Response ─────────────────────────────────────────── + +#[derive(Debug, Deserialize)] +pub struct AgentSearchRequest { + pub query: String, + pub conversation_id: Option, + pub file_uuid: Option, +} + +#[derive(Debug, Serialize)] +pub struct AgentSearchResponse { + pub success: bool, + pub conversation_id: String, + pub answer: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub suggestions: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub sources: Option>, +} + +// ── Tool Definitions ────────────────────────────────────────────── + +const SYSTEM_PROMPT: &str = r#"你是 Momentry 影片分析助手。回答用戶關於影片內容的問題。 + +## 工具使用規則 +1. 先確認用戶在問哪部影片 — 使用 find_file 或 list_files +2. 人物問題優先使用 tkg_query +3. 語意/內容問題使用 smart_search 或 universal_search +4. 可以同時呼叫多個工具 + +## 引導規則 +- 如果用戶沒說片名 → 用 find_file 搜尋,如果名稱不明確就反問 +- 反問時提供 suggestions,例如演員名、年代 +- 不要輸出 JSON,用自然語言回答 +- 引用資料時附上具體數字(frame 編號、時間秒數) + +## 回答規則 +- 回答要簡潔但完整 +- 如果找到影片,附上 file_uuid(用戶之後可能需要) +- 對於人物問題,說出角色名和演員名"#; + +fn make_tools(pool: &sqlx::PgPool) -> Vec { + vec![ + function_calling::make_tool( + "find_file", + "透過關鍵字搜尋影片(片名、演員、年份)。回傳符合的影片列表。", + serde_json::json!({ + "query": {"type": "string", "description": "搜尋關鍵字(片名、演員名、年份)"} + }), + vec!["query"], + ), + function_calling::make_tool( + "list_files", + "列出近期註冊的影片。", + serde_json::json!({ + "limit": {"type": "integer", "description": "回傳筆數上限", "default": 10} + }), + vec![], + ), + function_calling::make_tool( + "tkg_query", + "查詢影片的人物互動、配對、同框資料。query_type 包括:top_identities(人物排名)、first_cooccurrence(第一次同框)、identity_details(人物詳細)、mutual_gaze(互看)、interaction_network(互動網絡)、identity_traces(出場片段)、file_info(影片資訊)。", + serde_json::json!({ + "file_uuid": {"type": "string", "description": "影片 UUID"}, + "query_type": { + "type": "string", + "enum": ["top_identities", "first_cooccurrence", "identity_details", "mutual_gaze", "interaction_network", "identity_traces", "file_info"], + "description": "查詢類型" + }, + "identity_name": {"type": "string", "description": "人物名稱(配合 identity_details / identity_traces)"}, + "identity_b": {"type": "string", "description": "第二人物名稱(配合 first_cooccurrence / mutual_gaze)"}, + "limit": {"type": "integer", "default": 5} + }), + vec!["file_uuid", "query_type"], + ), + function_calling::make_tool( + "smart_search", + "語意搜尋 chunk 文字內容。適合需要理解意圖的查詢。", + serde_json::json!({ + "file_uuid": {"type": "string", "description": "限制搜尋範圍(可選)"}, + "query": {"type": "string", "description": "搜尋關鍵字"}, + "limit": {"type": "integer", "default": 5} + }), + vec!["query"], + ), + function_calling::make_tool( + "get_identity_detail", + "查詢單一身份的詳細資料(名字、角色、TMDb 資訊)。", + serde_json::json!({ + "name": {"type": "string", "description": "人物名稱"} + }), + vec!["name"], + ), + function_calling::make_tool( + "get_file_info", + "查詢影片基本資訊(片名、長度、解析度)。", + serde_json::json!({ + "file_uuid": {"type": "string", "description": "影片 UUID"} + }), + vec!["file_uuid"], + ), + function_calling::make_tool( + "get_representative_frame", + "查詢影片最具代表性的 frame 資訊(frame 編號、時間、人物)。", + serde_json::json!({ + "file_uuid": {"type": "string", "description": "影片 UUID"} + }), + vec!["file_uuid"], + ), + ] +} + +// ── Tool Executors ─────────────────────────────────────────────── + +async fn exec_find_file(pool: &sqlx::PgPool, args: &serde_json::Value) -> Result { + let query = args.get("query").and_then(|v| v.as_str()).unwrap_or(""); + let videos = schema::table_name("videos"); + let like = format!("%{}%", query); + let rows: Vec<(String, String)> = sqlx::query_as(&format!( + "SELECT file_uuid::text, file_name FROM {} WHERE file_name ILIKE $1 ORDER BY created_at DESC LIMIT 10", + videos + )) + .bind(&like) + .fetch_all(pool) + .await + .map_err(|e| e.to_string())?; + + if rows.is_empty() { + return Ok(serde_json::json!({"found": false, "message": "No files match the query. Try different keywords."}).to_string()); + } + let files: Vec = rows.into_iter().map(|(u, n)| { + serde_json::json!({"file_uuid": u, "file_name": n}) + }).collect(); + Ok(serde_json::json!({"found": true, "files": files}).to_string()) +} + +async fn exec_list_files(pool: &sqlx::PgPool, args: &serde_json::Value) -> Result { + let limit = args.get("limit").and_then(|v| v.as_i64()).unwrap_or(10); + let videos = schema::table_name("videos"); + let rows: Vec<(String, String)> = sqlx::query_as(&format!( + "SELECT file_uuid::text, file_name FROM {} ORDER BY created_at DESC LIMIT $1", + videos + )) + .bind(limit) + .fetch_all(pool) + .await + .map_err(|e| e.to_string())?; + + let files: Vec = rows.into_iter().map(|(u, n)| { + serde_json::json!({"file_uuid": u, "file_name": n}) + }).collect(); + Ok(serde_json::json!({"files": files}).to_string()) +} + +async fn exec_tkg_query(pool: &sqlx::PgPool, args: &serde_json::Value) -> Result { + let file_uuid = args.get("file_uuid").and_then(|v| v.as_str()).unwrap_or(""); + let query_type = args.get("query_type").and_then(|v| v.as_str()).unwrap_or(""); + let identity_name = args.get("identity_name").and_then(|v| v.as_str()); + let identity_b = args.get("identity_b").and_then(|v| v.as_str()); + let limit = args.get("limit").and_then(|v| v.as_i64()).unwrap_or(5); + + let id_table = schema::table_name("identities"); + let fd_table = schema::table_name("face_detections"); + let videos = schema::table_name("videos"); + let nodes = schema::table_name("tkg_nodes"); + let edges = schema::table_name("tkg_edges"); + + match query_type { + "top_identities" => { + let rows: Vec<(String, String, i64)> = sqlx::query_as(&format!( + "SELECT i.uuid::text, i.name, COUNT(fd.id)::bigint AS face_count \ + FROM {} fd JOIN {} i ON i.id = fd.identity_id \ + WHERE fd.file_uuid = $1 AND fd.identity_id IS NOT NULL AND i.source = 'tmdb' \ + GROUP BY i.uuid, i.name ORDER BY face_count DESC LIMIT $2", + fd_table, id_table + )) + .bind(file_uuid).bind(limit) + .fetch_all(pool) + .await.map_err(|e| e.to_string())?; + Ok(serde_json::json!({"identities": rows}).to_string()) + } + "first_cooccurrence" => { + let name_a = identity_name.unwrap_or(""); + let name_b = identity_b.unwrap_or(""); + let row: Option<(i64, f64)> = sqlx::query_as(&format!( + "SELECT MIN(fd_a.frame_number)::bigint, \ + ROUND(MIN(fd_a.frame_number)::numeric / GREATEST(MAX(v.fps)::numeric, 25.0), 2)::float8 \ + FROM {} fd_a JOIN {} fd_b ON fd_a.frame_number = fd_b.frame_number \ + JOIN {} v ON v.file_uuid = $1 \ + WHERE fd_a.file_uuid = $1 \ + AND fd_a.identity_id = (SELECT id FROM {} WHERE name ILIKE $2 LIMIT 1) \ + AND fd_b.identity_id = (SELECT id FROM {} WHERE name ILIKE $3 LIMIT 1)", + fd_table, fd_table, videos, id_table, id_table + )) + .bind(file_uuid).bind(name_a).bind(name_b) + .fetch_optional(pool) + .await.map_err(|e| e.to_string())?; + Ok(serde_json::json!({"first_cooccurrence": row.map(|(f, t)| serde_json::json!({"frame": f, "timestamp_secs": t}))}).to_string()) + } + "identity_details" => { + let name = identity_name.unwrap_or(""); + let row: Option<(String, String, Option, i64)> = sqlx::query_as(&format!( + "SELECT i.uuid::text, i.name, i.tmdb_id, \ + (SELECT COUNT(*) FROM {} fd WHERE fd.identity_id = i.id AND fd.file_uuid = $1)::bigint \ + FROM {} i WHERE i.name ILIKE $2 LIMIT 1", + fd_table, id_table + )) + .bind(file_uuid).bind(name) + .fetch_optional(pool) + .await.map_err(|e| e.to_string())?; + Ok(serde_json::json!({"identity": row.map(|(u, n, tid, fc)| serde_json::json!({"uuid": u, "name": n, "tmdb_id": tid, "face_count": fc}))}).to_string()) + } + "mutual_gaze" => { + let name_a = identity_name.unwrap_or(""); + let name_b = identity_b.unwrap_or(""); + let row: Option<(i64, i64, f64, f64)> = sqlx::query_as(&format!( + "SELECT (e.properties->>'first_frame')::bigint, \ + (e.properties->>'gaze_frame_count')::int::bigint, \ + (e.properties->>'yaw_a_avg')::float8, \ + (e.properties->>'yaw_b_avg')::float8 \ + FROM {} e \ + JOIN {} a ON a.id = e.source_node_id \ + JOIN {} b ON b.id = e.target_node_id \ + JOIN {} fd_a ON fd_a.file_uuid = $1 AND fd_a.trace_id = REPLACE(a.external_id, 'trace_', '')::int \ + JOIN {} fd_b ON fd_b.file_uuid = $1 AND fd_b.trace_id = REPLACE(b.external_id, 'trace_', '')::int \ + JOIN {} ia ON ia.id = fd_a.identity_id \ + JOIN {} ib ON ib.id = fd_b.identity_id \ + WHERE e.file_uuid = $1 AND ia.name ILIKE $2 AND ib.name ILIKE $3 \ + AND e.properties->>'mutual_gaze' = 'true' LIMIT 1", + edges, nodes, nodes, fd_table, fd_table, id_table, id_table + )) + .bind(file_uuid).bind(name_a).bind(name_b) + .fetch_optional(pool) + .await.map_err(|e| e.to_string())?; + Ok(serde_json::json!({"mutual_gaze": row.map(|(f, gc, ya, yb)| serde_json::json!({"first_frame": f, "gaze_frame_count": gc, "yaw_a": ya, "yaw_b": yb}))}).to_string()) + } + "interaction_network" => { + let rows: Vec<(String, String, i64)> = sqlx::query_as(&format!( + "SELECT ia.name, ib.name, COUNT(*)::bigint \ + FROM {} e \ + JOIN {} a ON a.id = e.source_node_id \ + JOIN {} b ON b.id = e.target_node_id \ + JOIN {} fd_a ON fd_a.trace_id = REPLACE(a.external_id, 'trace_', '')::int AND fd_a.file_uuid = $1 \ + JOIN {} fd_b ON fd_b.trace_id = REPLACE(b.external_id, 'trace_', '')::int AND fd_b.file_uuid = $1 \ + JOIN {} ia ON ia.id = fd_a.identity_id \ + JOIN {} ib ON ib.id = fd_b.identity_id \ + WHERE e.file_uuid = $1 AND e.edge_type = 'CO_OCCURS_WITH' \ + AND ia.name != ib.name AND ia.source = 'tmdb' AND ib.source = 'tmdb' \ + GROUP BY ia.name, ib.name \ + ORDER BY COUNT(*) DESC LIMIT $2", + edges, nodes, nodes, fd_table, fd_table, id_table, id_table + )) + .bind(file_uuid).bind(limit) + .fetch_all(pool) + .await.map_err(|e| e.to_string())?; + Ok(serde_json::json!({"interaction_network": rows}).to_string()) + } + "identity_traces" => { + let name = identity_name.unwrap_or(""); + let rows: Vec<(i32, i64, i32, i32)> = sqlx::query_as(&format!( + "SELECT fd.trace_id, COUNT(*)::bigint, MIN(fd.frame_number)::int, MAX(fd.frame_number)::int \ + FROM {} fd JOIN {} i ON i.id = fd.identity_id \ + WHERE fd.file_uuid = $1 AND i.name ILIKE $2 \ + GROUP BY fd.trace_id ORDER BY COUNT(*) DESC LIMIT $3", + fd_table, id_table + )) + .bind(file_uuid).bind(name).bind(limit) + .fetch_all(pool) + .await.map_err(|e| e.to_string())?; + Ok(serde_json::json!({"traces": rows}).to_string()) + } + "file_info" => { + let row: Option<(String, f64, i32, i32, f64)> = sqlx::query_as(&format!( + "SELECT file_name, duration, width, height, fps FROM {} WHERE file_uuid = $1", + videos + )) + .bind(file_uuid) + .fetch_optional(pool) + .await.map_err(|e| e.to_string())?; + Ok(serde_json::json!({"file_info": row.map(|(n, d, w, h, f)| serde_json::json!({"file_name": n, "duration_sec": d, "width": w, "height": h, "fps": f}))}).to_string()) + } + _ => Ok(serde_json::json!({"error": format!("Unknown query_type: {}", query_type)}).to_string()), + } +} + +async fn exec_smart_search(_pool: &sqlx::PgPool, args: &serde_json::Value) -> Result { + let query = args.get("query").and_then(|v| v.as_str()).unwrap_or(""); + let file_uuid = args.get("file_uuid").and_then(|v| v.as_str()); + let limit = args.get("limit").and_then(|v| v.as_i64()).unwrap_or(5); + + let chunk_table = schema::table_name("chunk"); + let mut sql = format!( + "SELECT chunk_id, text_content, start_frame, end_frame, chunk_type \ + FROM {} WHERE text_content ILIKE $1", chunk_table + ); + if file_uuid.is_some() { + sql.push_str(" AND file_uuid = $2"); + } + sql.push_str(&format!(" ORDER BY start_frame LIMIT {}", limit)); + + if let Some(fuid) = file_uuid { + let like = format!("%{}%", query); + let rows: Vec<(String, Option, i64, i64, String)> = sqlx::query_as(&sql) + .bind(&like).bind(fuid) + .fetch_all(_pool) + .await.map_err(|e| e.to_string())?; + Ok(serde_json::json!({"results": rows}).to_string()) + } else { + let like = format!("%{}%", query); + let rows: Vec<(String, Option, i64, i64, String)> = sqlx::query_as(&sql) + .bind(&like) + .fetch_all(_pool) + .await.map_err(|e| e.to_string())?; + Ok(serde_json::json!({"results": rows}).to_string()) + } +} + +async fn exec_get_identity_detail(pool: &sqlx::PgPool, args: &serde_json::Value) -> Result { + let name = args.get("name").and_then(|v| v.as_str()).unwrap_or(""); + let id_table = schema::table_name("identities"); + let row: Option<(String, String, Option, Option, Option)> = sqlx::query_as(&format!( + "SELECT uuid::text, name, source, tmdb_id, metadata->>'tmdb_character' FROM {} WHERE name ILIKE $1 LIMIT 1", + id_table + )) + .bind(name) + .fetch_optional(pool) + .await.map_err(|e| e.to_string())?; + Ok(serde_json::json!({"identity": row.map(|(u, n, s, t, c)| serde_json::json!({"uuid": u, "name": n, "source": s, "tmdb_id": t, "character": c}))}).to_string()) +} + +async fn exec_get_file_info(pool: &sqlx::PgPool, args: &serde_json::Value) -> Result { + let file_uuid = args.get("file_uuid").and_then(|v| v.as_str()).unwrap_or(""); + let videos = schema::table_name("videos"); + let row: Option<(String, f64, i32, i32, f64)> = sqlx::query_as(&format!( + "SELECT file_name, duration, width, height, fps FROM {} WHERE file_uuid = $1", + videos + )) + .bind(file_uuid) + .fetch_optional(pool) + .await.map_err(|e| e.to_string())?; + Ok(serde_json::json!({"file_info": row.map(|(n, d, w, h, f)| serde_json::json!({"file_name": n, "duration_sec": d, "width": w, "height": h, "fps": f}))}).to_string()) +} + +async fn exec_get_representative_frame(pool: &sqlx::PgPool, args: &serde_json::Value) -> Result { + let file_uuid = args.get("file_uuid").and_then(|v| v.as_str()).unwrap_or(""); + match crate::core::processor::tkg::query_auto_representative_frame(pool, file_uuid).await { + Ok(r) => Ok(serde_json::json!({ + "frame_number": r.frame_number, + "face_quality": r.face_quality, + "main_identities": r.main_identities, + "traces": r.traces, + }).to_string()), + Err(e) => Ok(serde_json::json!({"error": e.to_string()}).to_string()), + } +} + +// ── Tool Router ─────────────────────────────────────────────────── + +async fn execute_tool(pool: &sqlx::PgPool, tool_call: &ToolCall) -> (String, String, String) { + let name = tool_call.function.name.clone(); + let args: serde_json::Value = serde_json::from_str(&tool_call.function.arguments).unwrap_or_default(); + let result = match name.as_str() { + "find_file" => exec_find_file(pool, &args).await, + "list_files" => exec_list_files(pool, &args).await, + "tkg_query" => exec_tkg_query(pool, &args).await, + "smart_search" => exec_smart_search(pool, &args).await, + "get_identity_detail" => exec_get_identity_detail(pool, &args).await, + "get_file_info" => exec_get_file_info(pool, &args).await, + "get_representative_frame" => exec_get_representative_frame(pool, &args).await, + _ => Err(format!("Unknown tool: {}", name)), + }; + let content = match result { + Ok(s) => s, + Err(e) => serde_json::json!({"error": e}).to_string(), + }; + let tool_call_id = tool_call.id.clone().unwrap_or_default(); + (tool_call_id, name, content) +} + +// ── Tool Loop ───────────────────────────────────────────────────── + +const MAX_ROUNDS: u32 = 5; + +async fn run_tool_loop( + pool: &sqlx::PgPool, + system_prompt: &str, + user_query: &str, + history: Vec, +) -> (String, Vec) { + let mut messages = function_calling::build_conversation(system_prompt, user_query, history); + let mut sources = Vec::new(); + + for round in 0..MAX_ROUNDS { + let tools = Some(make_tools(pool)); + match function_calling::call_llm(messages.clone(), tools, 2048, 120).await { + Ok(LlmResponse::Text(text)) => { + return (text, sources); + } + Ok(LlmResponse::ToolCalls(calls)) => { + // Push assistant message with tool_calls so Gemma4 remembers + messages.push(ChatMessage { + role: "assistant".to_string(), + content: None, + tool_calls: Some(calls.clone()), + tool_call_id: None, + name: None, + }); + for call in &calls { + let (tool_call_id, name, content) = execute_tool(pool, call).await; + sources.push(serde_json::json!({"tool": name, "result": content})); + messages.push(function_calling::make_tool_result(&tool_call_id, &name, &content)); + } + } + Err(e) => { + return (format!("系統錯誤:{}", e), sources); + } + } + } + ("已達到最大查詢次數,請縮小問題範圍後重新詢問。".to_string(), sources) +} + +// ── Handler ─────────────────────────────────────────────────────── + +async fn agent_search( + State(state): State, + Json(req): Json, +) -> Result, (StatusCode, Json)> { + let (conv_id, history) = get_or_create_conv(req.conversation_id.as_deref()); + + let (answer, sources) = run_tool_loop( + state.db.pool(), + SYSTEM_PROMPT, + &req.query, + history, + ) + .await; + + // Save updated messages for conversation continuation + let new_msgs = function_calling::build_conversation(SYSTEM_PROMPT, &req.query, vec![]); + save_messages(&conv_id, &new_msgs); + + let needs_input = answer.contains('?') || answer.contains('?'); + let suggestions = if needs_input { + Some(vec!["演員名".to_string(), "電影片名".to_string(), "年份".to_string()]) + } else { + None + }; + + Ok(Json(AgentSearchResponse { + success: true, + conversation_id: conv_id, + answer, + suggestions, + sources: Some(sources), + })) +} + +// ── Routes ───────────────────────────────────────────────────────── + +pub fn agent_search_routes() -> Router { + Router::new() + .route("/api/v1/agents/search", post(agent_search)) +} diff --git a/src/api/mod.rs b/src/api/mod.rs index c561ba3..d5804e3 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -1,4 +1,5 @@ pub mod agent_api; +pub mod agent_search; pub mod auth; pub mod docs; pub mod files; diff --git a/src/api/server.rs b/src/api/server.rs index 2ead26d..907783a 100644 --- a/src/api/server.rs +++ b/src/api/server.rs @@ -6,6 +6,7 @@ use crate::core::db::{Database, PostgresDb}; use crate::Embedder; use super::agent_api; +use super::agent_search; use super::auth; use super::docs; use super::files; @@ -82,6 +83,7 @@ pub async fn start_server(host: &str, port: u16) -> anyhow::Result<()> { .merge(tmdb_api::tmdb_routes()) .merge(identity_api::identity_routes()) .merge(agent_api::agent_routes()) + .merge(agent_search::agent_search_routes()) .merge(identity_agent_api::identity_agent_routes()) .merge(five_w1h_agent_api::five_w1h_agent_routes()) .merge(media_api::bbox_routes()) diff --git a/src/core/llm/function_calling.rs b/src/core/llm/function_calling.rs new file mode 100644 index 0000000..5e1fc18 --- /dev/null +++ b/src/core/llm/function_calling.rs @@ -0,0 +1,189 @@ +use serde::{Deserialize, Serialize}; +use serde_json::{json, Value}; + +/// A tool/function definition for Gemma4 function calling +#[derive(Debug, Clone, Serialize)] +pub struct ToolDef { + #[serde(rename = "type")] + pub tool_type: String, + pub function: ToolFunction, +} + +#[derive(Debug, Clone, Serialize)] +pub struct ToolFunction { + pub name: String, + pub description: String, + pub parameters: Value, +} + +/// A tool call returned by Gemma4 +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolCall { + pub id: Option, + #[serde(rename = "type")] + pub call_type: Option, + pub function: ToolCallFunction, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolCallFunction { + pub name: String, + pub arguments: String, +} + +/// Message in the chat history +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChatMessage { + pub role: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub content: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_calls: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, +} + +/// Full chat request to Gemma4 +#[derive(Debug, Serialize)] +struct ChatRequest { + model: String, + messages: Vec, + temperature: f32, + max_tokens: u32, + stream: bool, + #[serde(skip_serializing_if = "Option::is_none")] + tools: Option>, +} + +/// Response from Gemma4 +#[derive(Debug, Deserialize)] +struct ChatResponse { + pub choices: Vec, +} + +#[derive(Debug, Deserialize)] +struct Choice { + pub finish_reason: Option, + pub message: ChatMessage, +} + +/// Parsed LLM response: either text or tool calls +pub enum LlmResponse { + Text(String), + ToolCalls(Vec), +} + +/// Get the LLM chat URL with fallback chain +pub fn llm_chat_url() -> String { + std::env::var("MOMENTRY_LLM_URL") + .or_else(|_| std::env::var("MOMENTRY_LLM_SUMMARY_URL")) + .unwrap_or_else(|_| "http://localhost:8082/v1/chat/completions".to_string()) +} + +/// Get the LLM model name +pub fn llm_model() -> String { + std::env::var("MOMENTRY_LLM_MODEL") + .or_else(|_| std::env::var("MOMENTRY_LLM_SUMMARY_MODEL")) + .unwrap_or_else(|_| "google_gemma-4-26B-A4B-it-Q5_K_M.gguf".to_string()) +} + +/// Build a tool definition JSON for function calling +pub fn make_tool(name: &str, description: &str, properties: Value, required: Vec<&str>) -> ToolDef { + ToolDef { + tool_type: "function".to_string(), + function: ToolFunction { + name: name.to_string(), + description: description.to_string(), + parameters: json!({ + "type": "object", + "properties": properties, + "required": required, + }), + }, + } +} + +/// Call Gemma4 with messages and optional tools. Returns parsed response. +pub async fn call_llm( + messages: Vec, + tools: Option>, + max_tokens: u32, + timeout_secs: u64, +) -> anyhow::Result { + let client = reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(timeout_secs)) + .build()?; + + let req = ChatRequest { + model: llm_model(), + messages, + temperature: 0.1, + max_tokens, + stream: false, + tools, + }; + + let res = client + .post(&llm_chat_url()) + .json(&req) + .send() + .await?; + + if !res.status().is_success() { + let text = res.text().await.unwrap_or_default(); + anyhow::bail!("LLM API error: {}", text); + } + + let chat_res: ChatResponse = res.json().await?; + let choice = chat_res.choices.into_iter().next() + .ok_or_else(|| anyhow::anyhow!("Empty LLM response"))?; + + match choice.finish_reason.as_deref() { + Some("tool_calls") => { + let calls = choice.message.tool_calls + .ok_or_else(|| anyhow::anyhow!("finish_reason=tool_calls but no tool_calls in message"))?; + Ok(LlmResponse::ToolCalls(calls)) + } + _ => { + let content = choice.message.content.unwrap_or_default(); + Ok(LlmResponse::Text(content.trim().to_string())) + } + } +} + +/// Helper to build the system prompt + user messages +pub fn build_conversation(system_prompt: &str, user_query: &str, history: Vec) -> Vec { + let mut messages = vec![ + ChatMessage { + role: "system".to_string(), + content: Some(system_prompt.to_string()), + tool_calls: None, + tool_call_id: None, + name: None, + }, + ]; + // Add history (user + assistant exchanges) + messages.extend(history); + // Add current user query + messages.push(ChatMessage { + role: "user".to_string(), + content: Some(user_query.to_string()), + tool_calls: None, + tool_call_id: None, + name: None, + }); + messages +} + +/// Build a tool result message to send back to LLM +pub fn make_tool_result(tool_call_id: &str, name: &str, content: &str) -> ChatMessage { + ChatMessage { + role: "tool".to_string(), + content: Some(content.to_string()), + tool_calls: None, + tool_call_id: Some(tool_call_id.to_string()), + name: Some(name.to_string()), + } +} diff --git a/src/core/llm/mod.rs b/src/core/llm/mod.rs index b9babe5..f3bf813 100644 --- a/src/core/llm/mod.rs +++ b/src/core/llm/mod.rs @@ -1 +1,2 @@ pub mod client; +pub mod function_calling;