feat: POST /api/v1/agents/search - Gemma4 function calling agent
This commit is contained in:
189
src/core/llm/function_calling.rs
Normal file
189
src/core/llm/function_calling.rs
Normal file
@@ -0,0 +1,189 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
|
||||
/// A tool/function definition for Gemma4 function calling
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct ToolDef {
|
||||
#[serde(rename = "type")]
|
||||
pub tool_type: String,
|
||||
pub function: ToolFunction,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct ToolFunction {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub parameters: Value,
|
||||
}
|
||||
|
||||
/// A tool call returned by Gemma4
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolCall {
|
||||
pub id: Option<String>,
|
||||
#[serde(rename = "type")]
|
||||
pub call_type: Option<String>,
|
||||
pub function: ToolCallFunction,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolCallFunction {
|
||||
pub name: String,
|
||||
pub arguments: String,
|
||||
}
|
||||
|
||||
/// Message in the chat history
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatMessage {
|
||||
pub role: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub content: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<Vec<ToolCall>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_call_id: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub name: Option<String>,
|
||||
}
|
||||
|
||||
/// Full chat request to Gemma4
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ChatRequest {
|
||||
model: String,
|
||||
messages: Vec<ChatMessage>,
|
||||
temperature: f32,
|
||||
max_tokens: u32,
|
||||
stream: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tools: Option<Vec<ToolDef>>,
|
||||
}
|
||||
|
||||
/// Response from Gemma4
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ChatResponse {
|
||||
pub choices: Vec<Choice>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct Choice {
|
||||
pub finish_reason: Option<String>,
|
||||
pub message: ChatMessage,
|
||||
}
|
||||
|
||||
/// Parsed LLM response: either text or tool calls
|
||||
pub enum LlmResponse {
|
||||
Text(String),
|
||||
ToolCalls(Vec<ToolCall>),
|
||||
}
|
||||
|
||||
/// Get the LLM chat URL with fallback chain
|
||||
pub fn llm_chat_url() -> String {
|
||||
std::env::var("MOMENTRY_LLM_URL")
|
||||
.or_else(|_| std::env::var("MOMENTRY_LLM_SUMMARY_URL"))
|
||||
.unwrap_or_else(|_| "http://localhost:8082/v1/chat/completions".to_string())
|
||||
}
|
||||
|
||||
/// Get the LLM model name
|
||||
pub fn llm_model() -> String {
|
||||
std::env::var("MOMENTRY_LLM_MODEL")
|
||||
.or_else(|_| std::env::var("MOMENTRY_LLM_SUMMARY_MODEL"))
|
||||
.unwrap_or_else(|_| "google_gemma-4-26B-A4B-it-Q5_K_M.gguf".to_string())
|
||||
}
|
||||
|
||||
/// Build a tool definition JSON for function calling
|
||||
pub fn make_tool(name: &str, description: &str, properties: Value, required: Vec<&str>) -> ToolDef {
|
||||
ToolDef {
|
||||
tool_type: "function".to_string(),
|
||||
function: ToolFunction {
|
||||
name: name.to_string(),
|
||||
description: description.to_string(),
|
||||
parameters: json!({
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required,
|
||||
}),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Call Gemma4 with messages and optional tools. Returns parsed response.
|
||||
pub async fn call_llm(
|
||||
messages: Vec<ChatMessage>,
|
||||
tools: Option<Vec<ToolDef>>,
|
||||
max_tokens: u32,
|
||||
timeout_secs: u64,
|
||||
) -> anyhow::Result<LlmResponse> {
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(timeout_secs))
|
||||
.build()?;
|
||||
|
||||
let req = ChatRequest {
|
||||
model: llm_model(),
|
||||
messages,
|
||||
temperature: 0.1,
|
||||
max_tokens,
|
||||
stream: false,
|
||||
tools,
|
||||
};
|
||||
|
||||
let res = client
|
||||
.post(&llm_chat_url())
|
||||
.json(&req)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !res.status().is_success() {
|
||||
let text = res.text().await.unwrap_or_default();
|
||||
anyhow::bail!("LLM API error: {}", text);
|
||||
}
|
||||
|
||||
let chat_res: ChatResponse = res.json().await?;
|
||||
let choice = chat_res.choices.into_iter().next()
|
||||
.ok_or_else(|| anyhow::anyhow!("Empty LLM response"))?;
|
||||
|
||||
match choice.finish_reason.as_deref() {
|
||||
Some("tool_calls") => {
|
||||
let calls = choice.message.tool_calls
|
||||
.ok_or_else(|| anyhow::anyhow!("finish_reason=tool_calls but no tool_calls in message"))?;
|
||||
Ok(LlmResponse::ToolCalls(calls))
|
||||
}
|
||||
_ => {
|
||||
let content = choice.message.content.unwrap_or_default();
|
||||
Ok(LlmResponse::Text(content.trim().to_string()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to build the system prompt + user messages
|
||||
pub fn build_conversation(system_prompt: &str, user_query: &str, history: Vec<ChatMessage>) -> Vec<ChatMessage> {
|
||||
let mut messages = vec![
|
||||
ChatMessage {
|
||||
role: "system".to_string(),
|
||||
content: Some(system_prompt.to_string()),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
name: None,
|
||||
},
|
||||
];
|
||||
// Add history (user + assistant exchanges)
|
||||
messages.extend(history);
|
||||
// Add current user query
|
||||
messages.push(ChatMessage {
|
||||
role: "user".to_string(),
|
||||
content: Some(user_query.to_string()),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
name: None,
|
||||
});
|
||||
messages
|
||||
}
|
||||
|
||||
/// Build a tool result message to send back to LLM
|
||||
pub fn make_tool_result(tool_call_id: &str, name: &str, content: &str) -> ChatMessage {
|
||||
ChatMessage {
|
||||
role: "tool".to_string(),
|
||||
content: Some(content.to_string()),
|
||||
tool_calls: None,
|
||||
tool_call_id: Some(tool_call_id.to_string()),
|
||||
name: Some(name.to_string()),
|
||||
}
|
||||
}
|
||||
@@ -1 +1,2 @@
|
||||
pub mod client;
|
||||
pub mod function_calling;
|
||||
|
||||
Reference in New Issue
Block a user