265 lines
8.6 KiB
Rust
265 lines
8.6 KiB
Rust
use crate::core::db::{Bm25Result, PostgresDb};
|
|
use reqwest::Client;
|
|
use serde::{Deserialize, Serialize};
|
|
use std::collections::HashSet;
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
pub struct SmartSearchRequest {
|
|
pub query: String,
|
|
pub uuid: Option<String>,
|
|
pub limit: Option<usize>,
|
|
}
|
|
|
|
#[derive(Debug, Serialize)]
|
|
pub struct SmartSearchResponse {
|
|
pub query: String,
|
|
pub parsed_dimensions: serde_json::Value,
|
|
pub hits: Vec<serde_json::Value>,
|
|
pub total: usize,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize, Serialize)]
|
|
struct LlmDimensionResponse {
|
|
pub who: Option<String>,
|
|
pub what: Option<String>,
|
|
pub when: Option<String>,
|
|
pub r#where: Option<String>,
|
|
pub why: Option<String>,
|
|
#[serde(default)]
|
|
pub keywords: Vec<String>,
|
|
}
|
|
|
|
/// POST /api/v1/n8n/search/smart
|
|
pub async fn n8n_search_smart(
|
|
db: &PostgresDb,
|
|
req: SmartSearchRequest,
|
|
) -> Result<SmartSearchResponse, Box<dyn std::error::Error + Send + Sync>> {
|
|
let limit = req.limit.unwrap_or(10);
|
|
let video_uuid = req.uuid.clone();
|
|
|
|
// 1. Call LLM to extract 5W1H (Fallback to keywords if LLM fails)
|
|
let dimensions = match parse_query_with_llm(&req.query).await {
|
|
Some(dims) => dims,
|
|
None => LlmDimensionResponse {
|
|
who: None,
|
|
what: None,
|
|
when: None,
|
|
r#where: None,
|
|
why: None,
|
|
keywords: extract_keywords(&req.query),
|
|
},
|
|
};
|
|
|
|
// Prepare search terms based on dimensions
|
|
let keywords = dimensions.keywords.join(" ");
|
|
|
|
let semantic_query = [
|
|
dimensions.who.clone(),
|
|
dimensions.what.clone(),
|
|
dimensions.r#where.clone(),
|
|
dimensions.why.clone(),
|
|
dimensions.when.clone(),
|
|
]
|
|
.into_iter()
|
|
.flatten()
|
|
.collect::<Vec<_>>()
|
|
.join(" ");
|
|
|
|
// 2. Multi-dimensional Search
|
|
let mut hits: Vec<serde_json::Value> = Vec::new();
|
|
let mut seen_chunk_ids: HashSet<String> = HashSet::new();
|
|
|
|
// Helper function
|
|
fn add_hit(
|
|
hits: &mut Vec<serde_json::Value>,
|
|
seen_chunk_ids: &mut HashSet<String>,
|
|
sr: Bm25Result,
|
|
boost: f32,
|
|
) {
|
|
if seen_chunk_ids.insert(sr.chunk_id.clone()) {
|
|
let score = sr.bm25_score * boost;
|
|
let val = serde_json::json!({
|
|
"id": sr.chunk_id,
|
|
"vid": sr.uuid,
|
|
"start": sr.start_time,
|
|
"end": sr.end_time,
|
|
"text": sr.text,
|
|
"score": score,
|
|
"chunk_type": sr.chunk_type
|
|
});
|
|
hits.push(val);
|
|
}
|
|
}
|
|
|
|
// A. Keyword Search (BM25)
|
|
if !keywords.is_empty() {
|
|
if let Ok(results) = db
|
|
.search_bm25(&keywords, video_uuid.as_deref(), limit)
|
|
.await
|
|
{
|
|
for sr in results {
|
|
add_hit(&mut hits, &mut seen_chunk_ids, sr, 1.0);
|
|
}
|
|
}
|
|
}
|
|
|
|
// B. Who Search (Person Matching)
|
|
if let Some(who_query) = &dimensions.who {
|
|
// 1. Search Person
|
|
if let Ok(persons) = db.search_person_candidates(who_query, &video_uuid, 5).await {
|
|
if !persons.is_empty() {
|
|
let person_id = persons[0]
|
|
.get("candidate_id")
|
|
.and_then(|v| v.as_str())
|
|
.map(String::from);
|
|
|
|
if let Some(pid) = person_id {
|
|
// Heuristic: Search BM25 for the person's ID or Name
|
|
let person_name = persons[0]
|
|
.get("display_name")
|
|
.and_then(|v| v.as_str())
|
|
.unwrap_or(who_query);
|
|
|
|
// Re-run BM25 with person name to find specific chunks and boost them
|
|
if let Ok(results) = db
|
|
.search_bm25(person_name, video_uuid.as_deref(), limit)
|
|
.await
|
|
{
|
|
for sr in results {
|
|
let id = sr.chunk_id.clone();
|
|
if seen_chunk_ids.insert(id) {
|
|
let score = sr.bm25_score * 1.5;
|
|
let val = serde_json::json!({
|
|
"id": sr.chunk_id,
|
|
"vid": sr.uuid,
|
|
"start": sr.start_time,
|
|
"end": sr.end_time,
|
|
"text": sr.text,
|
|
"score": score,
|
|
"matched_person": pid
|
|
});
|
|
hits.push(val);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Sort by score
|
|
hits.sort_by(|a, b| {
|
|
let score_a = a.get("score").and_then(|v| v.as_f64()).unwrap_or(0.0);
|
|
let score_b = b.get("score").and_then(|v| v.as_f64()).unwrap_or(0.0);
|
|
score_b.partial_cmp(&score_a).unwrap()
|
|
});
|
|
|
|
// Limit
|
|
hits.truncate(limit);
|
|
let total = hits.len();
|
|
|
|
Ok(SmartSearchResponse {
|
|
query: req.query,
|
|
parsed_dimensions: serde_json::json!(dimensions),
|
|
hits,
|
|
total,
|
|
})
|
|
}
|
|
|
|
fn extract_keywords(query: &str) -> Vec<String> {
|
|
// Simple keyword extraction: remove common stop words and punctuation
|
|
let stop_words = [
|
|
"who", "what", "where", "when", "why", "how", "is", "the", "a", "an", "and", "or", "of",
|
|
"in", "to", "for", "with", "by", "on", "at", "from", "up", "about", "into", "over",
|
|
"after",
|
|
];
|
|
|
|
query
|
|
.to_lowercase()
|
|
.chars()
|
|
.map(|c| if c.is_alphanumeric() { c } else { ' ' })
|
|
.collect::<String>()
|
|
.split_whitespace()
|
|
.filter(|w| !stop_words.contains(w))
|
|
.map(String::from)
|
|
.collect()
|
|
}
|
|
|
|
async fn parse_query_with_llm(query: &str) -> Option<LlmDimensionResponse> {
|
|
let client = reqwest::Client::new();
|
|
|
|
// Test connectivity first
|
|
if let Ok(resp) = client.get("http://127.0.0.1:8081/health").send().await {
|
|
tracing::info!("LLM Health Check: {}", resp.status());
|
|
} else {
|
|
tracing::error!("LLM Server is unreachable at 127.0.0.1:8081");
|
|
}
|
|
|
|
// We use the OpenAI-compatible endpoint provided by llama.cpp server (default port 8081)
|
|
let prompt = format!(
|
|
r#"Analyze the user query and extract the following dimensions into a JSON object.
|
|
If a dimension is not present, use null.
|
|
Dimensions: "who" (person/subject), "what" (action/event), "where" (location), "when" (time), "why" (reason/intent), "keywords" (array of specific keywords).
|
|
|
|
User Query: "{}"
|
|
|
|
Output ONLY the JSON object.
|
|
"#,
|
|
query
|
|
);
|
|
|
|
let payload = serde_json::json!({
|
|
"model": "gemma4",
|
|
"messages": [
|
|
{
|
|
"role": "user",
|
|
"content": prompt
|
|
}
|
|
],
|
|
"temperature": 0.1,
|
|
"stream": false
|
|
});
|
|
|
|
if let Ok(response) = client
|
|
.post("http://127.0.0.1:8081/v1/chat/completions")
|
|
.json(&payload)
|
|
.timeout(std::time::Duration::from_secs(60))
|
|
.send()
|
|
.await
|
|
{
|
|
tracing::info!("LLM Response Status: {}", response.status());
|
|
if let Ok(json_resp) = response.json::<serde_json::Value>().await {
|
|
tracing::info!("LLM Response Body: {}", json_resp);
|
|
if let Some(choices) = json_resp.get("choices").and_then(|v| v.as_array()) {
|
|
if let Some(choice) = choices.get(0) {
|
|
if let Some(content) = choice
|
|
.get("message")
|
|
.and_then(|m| m.get("content"))
|
|
.and_then(|c| c.as_str())
|
|
{
|
|
if let Some(start) = content.find("{") {
|
|
if let Some(end) = content.rfind("}") {
|
|
let json_str = &content[start..=end];
|
|
if let Ok(dims) =
|
|
serde_json::from_str::<LlmDimensionResponse>(json_str)
|
|
{
|
|
tracing::info!("Parsed LLM Dimensions: {:?}", dims);
|
|
return Some(dims);
|
|
} else {
|
|
tracing::warn!("Failed to parse LLM JSON: {}", json_str);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
} else {
|
|
tracing::warn!("Failed to parse LLM response JSON");
|
|
}
|
|
} else {
|
|
tracing::warn!("LLM request failed or timed out");
|
|
}
|
|
|
|
None
|
|
}
|