use crate::core::db::{Bm25Result, PostgresDb}; use reqwest::Client; use serde::{Deserialize, Serialize}; use std::collections::HashSet; #[derive(Debug, Deserialize)] pub struct SmartSearchRequest { pub query: String, pub uuid: Option, pub limit: Option, } #[derive(Debug, Serialize)] pub struct SmartSearchResponse { pub query: String, pub parsed_dimensions: serde_json::Value, pub hits: Vec, pub total: usize, } #[derive(Debug, Deserialize, Serialize)] struct LlmDimensionResponse { pub who: Option, pub what: Option, pub when: Option, pub r#where: Option, pub why: Option, #[serde(default)] pub keywords: Vec, } /// POST /api/v1/n8n/search/smart pub async fn n8n_search_smart( db: &PostgresDb, req: SmartSearchRequest, ) -> Result> { let limit = req.limit.unwrap_or(10); let video_uuid = req.uuid.clone(); // 1. Call LLM to extract 5W1H (Fallback to keywords if LLM fails) let dimensions = match parse_query_with_llm(&req.query).await { Some(dims) => dims, None => LlmDimensionResponse { who: None, what: None, when: None, r#where: None, why: None, keywords: extract_keywords(&req.query), }, }; // Prepare search terms based on dimensions let keywords = dimensions.keywords.join(" "); let semantic_query = [ dimensions.who.clone(), dimensions.what.clone(), dimensions.r#where.clone(), dimensions.why.clone(), dimensions.when.clone(), ] .into_iter() .flatten() .collect::>() .join(" "); // 2. Multi-dimensional Search let mut hits: Vec = Vec::new(); let mut seen_chunk_ids: HashSet = HashSet::new(); // Helper function fn add_hit( hits: &mut Vec, seen_chunk_ids: &mut HashSet, sr: Bm25Result, boost: f32, ) { if seen_chunk_ids.insert(sr.chunk_id.clone()) { let score = sr.bm25_score * boost; let val = serde_json::json!({ "id": sr.chunk_id, "vid": sr.uuid, "start": sr.start_time, "end": sr.end_time, "text": sr.text, "score": score, "chunk_type": sr.chunk_type }); hits.push(val); } } // A. Keyword Search (BM25) if !keywords.is_empty() { if let Ok(results) = db .search_bm25(&keywords, video_uuid.as_deref(), limit) .await { for sr in results { add_hit(&mut hits, &mut seen_chunk_ids, sr, 1.0); } } } // B. Who Search (Person Matching) if let Some(who_query) = &dimensions.who { // 1. Search Person if let Ok(persons) = db.search_person_candidates(who_query, &video_uuid, 5).await { if !persons.is_empty() { let person_id = persons[0] .get("candidate_id") .and_then(|v| v.as_str()) .map(String::from); if let Some(pid) = person_id { // Heuristic: Search BM25 for the person's ID or Name let person_name = persons[0] .get("display_name") .and_then(|v| v.as_str()) .unwrap_or(who_query); // Re-run BM25 with person name to find specific chunks and boost them if let Ok(results) = db .search_bm25(person_name, video_uuid.as_deref(), limit) .await { for sr in results { let id = sr.chunk_id.clone(); if seen_chunk_ids.insert(id) { let score = sr.bm25_score * 1.5; let val = serde_json::json!({ "id": sr.chunk_id, "vid": sr.uuid, "start": sr.start_time, "end": sr.end_time, "text": sr.text, "score": score, "matched_person": pid }); hits.push(val); } } } } } } } // Sort by score hits.sort_by(|a, b| { let score_a = a.get("score").and_then(|v| v.as_f64()).unwrap_or(0.0); let score_b = b.get("score").and_then(|v| v.as_f64()).unwrap_or(0.0); score_b.partial_cmp(&score_a).unwrap() }); // Limit hits.truncate(limit); let total = hits.len(); Ok(SmartSearchResponse { query: req.query, parsed_dimensions: serde_json::json!(dimensions), hits, total, }) } fn extract_keywords(query: &str) -> Vec { // Simple keyword extraction: remove common stop words and punctuation let stop_words = [ "who", "what", "where", "when", "why", "how", "is", "the", "a", "an", "and", "or", "of", "in", "to", "for", "with", "by", "on", "at", "from", "up", "about", "into", "over", "after", ]; query .to_lowercase() .chars() .map(|c| if c.is_alphanumeric() { c } else { ' ' }) .collect::() .split_whitespace() .filter(|w| !stop_words.contains(w)) .map(String::from) .collect() } async fn parse_query_with_llm(query: &str) -> Option { let client = reqwest::Client::new(); // Test connectivity first if let Ok(resp) = client.get("http://127.0.0.1:8081/health").send().await { tracing::info!("LLM Health Check: {}", resp.status()); } else { tracing::error!("LLM Server is unreachable at 127.0.0.1:8081"); } // We use the OpenAI-compatible endpoint provided by llama.cpp server (default port 8081) let prompt = format!( r#"Analyze the user query and extract the following dimensions into a JSON object. If a dimension is not present, use null. Dimensions: "who" (person/subject), "what" (action/event), "where" (location), "when" (time), "why" (reason/intent), "keywords" (array of specific keywords). User Query: "{}" Output ONLY the JSON object. "#, query ); let payload = serde_json::json!({ "model": "gemma4", "messages": [ { "role": "user", "content": prompt } ], "temperature": 0.1, "stream": false }); if let Ok(response) = client .post("http://127.0.0.1:8081/v1/chat/completions") .json(&payload) .timeout(std::time::Duration::from_secs(60)) .send() .await { tracing::info!("LLM Response Status: {}", response.status()); if let Ok(json_resp) = response.json::().await { tracing::info!("LLM Response Body: {}", json_resp); if let Some(choices) = json_resp.get("choices").and_then(|v| v.as_array()) { if let Some(choice) = choices.get(0) { if let Some(content) = choice .get("message") .and_then(|m| m.get("content")) .and_then(|c| c.as_str()) { if let Some(start) = content.find("{") { if let Some(end) = content.rfind("}") { let json_str = &content[start..=end]; if let Ok(dims) = serde_json::from_str::(json_str) { tracing::info!("Parsed LLM Dimensions: {:?}", dims); return Some(dims); } else { tracing::warn!("Failed to parse LLM JSON: {}", json_str); } } } } } } } else { tracing::warn!("Failed to parse LLM response JSON"); } } else { tracing::warn!("LLM request failed or timed out"); } None }