Files
momentry_core/src/api/n8n_search.rs

265 lines
8.6 KiB
Rust

use crate::core::db::{Bm25Result, PostgresDb};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
#[derive(Debug, Deserialize)]
pub struct SmartSearchRequest {
pub query: String,
pub uuid: Option<String>,
pub limit: Option<usize>,
}
#[derive(Debug, Serialize)]
pub struct SmartSearchResponse {
pub query: String,
pub parsed_dimensions: serde_json::Value,
pub hits: Vec<serde_json::Value>,
pub total: usize,
}
#[derive(Debug, Deserialize, Serialize)]
struct LlmDimensionResponse {
pub who: Option<String>,
pub what: Option<String>,
pub when: Option<String>,
pub r#where: Option<String>,
pub why: Option<String>,
#[serde(default)]
pub keywords: Vec<String>,
}
/// POST /api/v1/n8n/search/smart
pub async fn n8n_search_smart(
db: &PostgresDb,
req: SmartSearchRequest,
) -> Result<SmartSearchResponse, Box<dyn std::error::Error + Send + Sync>> {
let limit = req.limit.unwrap_or(10);
let video_uuid = req.uuid.clone();
// 1. Call LLM to extract 5W1H (Fallback to keywords if LLM fails)
let dimensions = match parse_query_with_llm(&req.query).await {
Some(dims) => dims,
None => LlmDimensionResponse {
who: None,
what: None,
when: None,
r#where: None,
why: None,
keywords: extract_keywords(&req.query),
},
};
// Prepare search terms based on dimensions
let keywords = dimensions.keywords.join(" ");
let semantic_query = [
dimensions.who.clone(),
dimensions.what.clone(),
dimensions.r#where.clone(),
dimensions.why.clone(),
dimensions.when.clone(),
]
.into_iter()
.flatten()
.collect::<Vec<_>>()
.join(" ");
// 2. Multi-dimensional Search
let mut hits: Vec<serde_json::Value> = Vec::new();
let mut seen_chunk_ids: HashSet<String> = HashSet::new();
// Helper function
fn add_hit(
hits: &mut Vec<serde_json::Value>,
seen_chunk_ids: &mut HashSet<String>,
sr: Bm25Result,
boost: f32,
) {
if seen_chunk_ids.insert(sr.chunk_id.clone()) {
let score = sr.bm25_score * boost;
let val = serde_json::json!({
"id": sr.chunk_id,
"vid": sr.uuid,
"start": sr.start_time,
"end": sr.end_time,
"text": sr.text,
"score": score,
"chunk_type": sr.chunk_type
});
hits.push(val);
}
}
// A. Keyword Search (BM25)
if !keywords.is_empty() {
if let Ok(results) = db
.search_bm25(&keywords, video_uuid.as_deref(), limit)
.await
{
for sr in results {
add_hit(&mut hits, &mut seen_chunk_ids, sr, 1.0);
}
}
}
// B. Who Search (Person Matching)
if let Some(who_query) = &dimensions.who {
// 1. Search Person
if let Ok(persons) = db.search_person_candidates(who_query, &video_uuid, 5).await {
if !persons.is_empty() {
let person_id = persons[0]
.get("candidate_id")
.and_then(|v| v.as_str())
.map(String::from);
if let Some(pid) = person_id {
// Heuristic: Search BM25 for the person's ID or Name
let person_name = persons[0]
.get("display_name")
.and_then(|v| v.as_str())
.unwrap_or(who_query);
// Re-run BM25 with person name to find specific chunks and boost them
if let Ok(results) = db
.search_bm25(person_name, video_uuid.as_deref(), limit)
.await
{
for sr in results {
let id = sr.chunk_id.clone();
if seen_chunk_ids.insert(id) {
let score = sr.bm25_score * 1.5;
let val = serde_json::json!({
"id": sr.chunk_id,
"vid": sr.uuid,
"start": sr.start_time,
"end": sr.end_time,
"text": sr.text,
"score": score,
"matched_person": pid
});
hits.push(val);
}
}
}
}
}
}
}
// Sort by score
hits.sort_by(|a, b| {
let score_a = a.get("score").and_then(|v| v.as_f64()).unwrap_or(0.0);
let score_b = b.get("score").and_then(|v| v.as_f64()).unwrap_or(0.0);
score_b.partial_cmp(&score_a).unwrap()
});
// Limit
hits.truncate(limit);
let total = hits.len();
Ok(SmartSearchResponse {
query: req.query,
parsed_dimensions: serde_json::json!(dimensions),
hits,
total,
})
}
fn extract_keywords(query: &str) -> Vec<String> {
// Simple keyword extraction: remove common stop words and punctuation
let stop_words = [
"who", "what", "where", "when", "why", "how", "is", "the", "a", "an", "and", "or", "of",
"in", "to", "for", "with", "by", "on", "at", "from", "up", "about", "into", "over",
"after",
];
query
.to_lowercase()
.chars()
.map(|c| if c.is_alphanumeric() { c } else { ' ' })
.collect::<String>()
.split_whitespace()
.filter(|w| !stop_words.contains(w))
.map(String::from)
.collect()
}
async fn parse_query_with_llm(query: &str) -> Option<LlmDimensionResponse> {
let client = reqwest::Client::new();
// Test connectivity first
if let Ok(resp) = client.get("http://127.0.0.1:8081/health").send().await {
tracing::info!("LLM Health Check: {}", resp.status());
} else {
tracing::error!("LLM Server is unreachable at 127.0.0.1:8081");
}
// We use the OpenAI-compatible endpoint provided by llama.cpp server (default port 8081)
let prompt = format!(
r#"Analyze the user query and extract the following dimensions into a JSON object.
If a dimension is not present, use null.
Dimensions: "who" (person/subject), "what" (action/event), "where" (location), "when" (time), "why" (reason/intent), "keywords" (array of specific keywords).
User Query: "{}"
Output ONLY the JSON object.
"#,
query
);
let payload = serde_json::json!({
"model": "gemma4",
"messages": [
{
"role": "user",
"content": prompt
}
],
"temperature": 0.1,
"stream": false
});
if let Ok(response) = client
.post("http://127.0.0.1:8081/v1/chat/completions")
.json(&payload)
.timeout(std::time::Duration::from_secs(60))
.send()
.await
{
tracing::info!("LLM Response Status: {}", response.status());
if let Ok(json_resp) = response.json::<serde_json::Value>().await {
tracing::info!("LLM Response Body: {}", json_resp);
if let Some(choices) = json_resp.get("choices").and_then(|v| v.as_array()) {
if let Some(choice) = choices.get(0) {
if let Some(content) = choice
.get("message")
.and_then(|m| m.get("content"))
.and_then(|c| c.as_str())
{
if let Some(start) = content.find("{") {
if let Some(end) = content.rfind("}") {
let json_str = &content[start..=end];
if let Ok(dims) =
serde_json::from_str::<LlmDimensionResponse>(json_str)
{
tracing::info!("Parsed LLM Dimensions: {:?}", dims);
return Some(dims);
} else {
tracing::warn!("Failed to parse LLM JSON: {}", json_str);
}
}
}
}
}
}
} else {
tracing::warn!("Failed to parse LLM response JSON");
}
} else {
tracing::warn!("LLM request failed or timed out");
}
None
}