diff --git a/docs/SEARCH_SCORE_IMPROVEMENT.md b/docs/SEARCH_SCORE_IMPROVEMENT.md new file mode 100644 index 0000000..e7b7bca --- /dev/null +++ b/docs/SEARCH_SCORE_IMPROVEMENT.md @@ -0,0 +1,134 @@ +# Search Scoring Improvement: Score-based Merge for search/smart + +## 發現者 +WordPress 前端專案(search-chat 頁面) + +## 問題描述 + +### 症狀 +跨語言搜尋結果不一致: +- 搜尋「槍」(中文)→ 回傳無關結果(如「讓T-shirt」、「靠直的後製神器」) +- 搜尋 `gun`(英文)→ 回傳 "So where's your gun?"、"He has a gun" +- 兩者應該找到相同語意主題的結果(武器相關片段),但實際回傳完全不同的集合 + +### 影響範圍 +`GET/POST /api/v1/search/smart` endpoint + +## 根因分析 + +### 1. Qdrant 語意搜尋本身是正確的 + +直接查詢 Qdrant 驗證: + +``` +cos(search_query: 槍, search_document: "So where's your gun?") = 0.6905 +cos(search_query: 槍, search_document: "這是一把槍") = 0.8256 +cos(search_query: gun, search_document: "So where's your gun?") = 0.7435 +``` + +**embedding model (EmbeddingGemma-300m) 的 cross-lingual 對齊正常。** + +### 2. 問題在 RRF 合併邏輯 + +`search/smart` 用 **RRF (Reciprocal Rank Fusion)** 合併三組結果: + +```rust +let rrf_k = 60.0; +// RRF 貢獻 = 1 / (60 + rank + 1) +// Semantic rank 0: 貢獻 1/61 = 0.016 +// Keyword rank 0: 貢獻 1/61 = 0.016 +``` + +RRF 的權重只看**排名位置**,不看**實際相似度分數**。 +- cosine similarity = 0.69 的語意結果 → RRF 貢獻 0.016 +- ILIKE 隨便撈到的 keyword 匹配 → RRF 貢獻也是 0.016 +- 兩者在排序中權重完全相等 + +### 3. Keyword (ILIKE) 對跨語言有害 + +- `ILIKE '%槍%'` 只找到中文文字包含「槍」的 chunks +- `ILIKE '%gun%'` 只找到英文文字包含 "gun" 的 chunks +- 這兩組結果在語意上完全不同,卻透過 RRF 被提升到與語意結果同權重 +- 導致「槍」和 `gun` 的結果各自被自己的 ILIKE 匹配汙染 + +## 建議方案 + +### 核心原則 +向量高信心度時應該優先。 + +### 合併方式 + +將 RRF 改為 score-based merge,各來源分數定義: + +| 來源 | 分數 | 說明 | +|---|---|---| +| **Semantic (Qdrant)** | `cosine_similarity` (0~1) | 原始 Qdrant 分數,不加權 | +| **Identity** | 固定 `0.85` | 人名精準匹配,維持高度信心 | +| **Keyword (ILIKE)** | 固定 `0.5` | 降權至低分,只作為語意找不到時的補底 | + +最終分數 = `max(semantic, keyword, identity)` +依最終分數降冪排序。 + +### 預期效果 + +| 情況 | 排序行為 | +|---|---| +| cosine > 0.5 的語意結果 | 排在 keyword 前面 ✅ | +| cosine 在 0.3~0.5 | 與 keyword 穿插(都不太確定,合理) | +| cosine < 0.3 | keyword 補底(語意沒找到,靠文字比對) | +| 跨語言查詢(槍 vs gun) | 各自的高分 cross-lingual 結果優先呈現 ✅ | + +### 不建議的方案 + +- **不要用 weight-based average**(如 `0.7*semantic + 0.3*keyword`):兩種模型的 score scale 不同,加權無法通用 +- **不要保留 RRF 只調 k 值**:k 值調再高也無法區分品質,只能稀釋影響 + +## 修改範圍 + +### 檔案 +`src/api/search.rs` 中的 `smart_search()` 函數 + +### 需要修改的區塊 + +1. **移除 RRF 常數**(`rrf_k = 60.0`) +2. **Semantic 結果**:保留 Qdrant 回傳的 `score`(已在 `h.score as f64` 取得) +3. **Keyword 結果**:固定設為 `0.5_f64`(忽略原本 `combined_score`) +4. **Identity 結果**:固定設為 `0.85_f64`(忽略原本硬編碼的 `0.85` 但保留值) +5. **排序邏輯**:改為 `max(semantic, keyword, identity)` 降冪 +6. **輸出 similarity**:改為回傳最終分數,而非 `rrf_score` + +### 注意事項 + +- Qdrant 回傳的 `score` 是 `f32`,需 cast 為 `f64` +- `keyword_results` 的 `combined_score` 實際上是 `1.0`(`search_bm25` 固定值),不應使用 +- 修改後需 **`cargo build --release`** 再重啟 server + +## 驗證測試 + +### 手動測試 + +```bash +# 1. 槍 vs gun 應該回傳相似主題 +curl -X POST 'http://localhost:3002/api/v1/search/smart' \ + -H 'X-API-Key: {KEY}' -H 'Content-Type: application/json' \ + -d '{"query":"槍","limit":10}' + +curl -X POST 'http://localhost:3002/api/v1/search/smart' \ + -H 'X-API-Key: {KEY}' -H 'Content-Type: application/json' \ + -d '{"query":"gun","limit":10}' + +# 2. 確認 similarity 值為實際 cosine (e.g. 0.6~0.9) 而非 RRF 值 (~0.016) +``` + +### 預期結果 + +| Query | Top 結果應包含 | +|---|---| +| `槍` | gun 相關片段、「這是一把槍」、武器相關語意匹配 | +| `gun` | 與 `槍` 主題一致(都是武器) | +| `車` / `car` | 行車相關片段,非姓名含「車」的人物 | +| `So where's your gun?` | 自身為 top-1(self-match cosine ≈ 1.0) | + +## 附錄:前端處理 + +WordPress 側 (`snippet #37`) 已配合修正:`mode=semantic` 不再疊加 `search/universal`(ILIKE)結果,僅回傳 `search/smart` 的輸出。這部分無需 backend 配合。 diff --git a/src/api/llm_search.rs b/src/api/llm_search.rs new file mode 100644 index 0000000..f65190f --- /dev/null +++ b/src/api/llm_search.rs @@ -0,0 +1,91 @@ +use axum::{ + extract::State, + http::StatusCode, + response::Json, + routing::post, + Router, +}; +use serde::Deserialize; +use tracing::warn; + +use crate::core::llm::rerank::rerank_search_results; + +use super::search::{smart_search, SearchResult, SmartSearchRequest, SmartSearchResponse}; + +#[derive(Debug, Deserialize)] +pub struct LlmSmartSearchRequest { + #[serde(default)] + pub file_uuid: Option, + pub query: String, + pub limit: Option, +} + +pub async fn llm_smart_search_handler( + State(state): State, + Json(req): Json, +) -> Result, (StatusCode, Json)> { + let user_limit = req.limit.unwrap_or(10).max(1); + let llm_candidate_count = (user_limit * 3).clamp(10, 20); + + // 1. Get initial RRF-ranked results via the existing smart_search + let initial_req = SmartSearchRequest { + file_uuid: req.file_uuid.clone(), + query: req.query.clone(), + page: Some(1), + page_size: Some(llm_candidate_count), + limit: Some(llm_candidate_count), + }; + + let initial_response = smart_search(State(state.clone()), Json(initial_req)).await?; + let initial_results = initial_response.0.results; + + if initial_results.is_empty() { + return Ok(Json(SmartSearchResponse { + query: req.query, + results: vec![], + page: 1, + page_size: 0, + strategy: "llm_reranked".to_string(), + })); + } + + // 2. Build candidates: (original_index, summary_text) + let candidates: Vec<(usize, String)> = initial_results + .iter() + .enumerate() + .map(|(i, r)| (i, r.summary.clone().unwrap_or_default())) + .collect(); + + let candidate_refs: Vec<(usize, &str)> = + candidates.iter().map(|(i, t)| (*i, t.as_str())).collect(); + + // 3. LLM re-ranking + let ranked_indices = match rerank_search_results(&req.query, &candidate_refs).await { + Ok(indices) => indices, + Err(e) => { + warn!("LLM rerank failed, falling back to RRF order: {}", e); + (0..initial_results.len()).collect() + } + }; + + // 4. Re-order results + let mut reordered: Vec = ranked_indices + .into_iter() + .filter_map(|i| initial_results.get(i).cloned()) + .collect(); + + // 5. Trim to user's requested limit + reordered.truncate(user_limit); + + Ok(Json(SmartSearchResponse { + query: req.query, + results: reordered, + page: 1, + page_size: user_limit, + strategy: "llm_reranked".to_string(), + })) +} + +pub fn llm_smart_routes() -> Router { + Router::new().route("/api/v1/search/llm-smart", post(llm_smart_search_handler)) +} diff --git a/src/api/mod.rs b/src/api/mod.rs index f348fa2..0383cb1 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -9,8 +9,10 @@ pub mod identities; pub mod identity_agent_api; pub mod identity_api; pub mod identity_binding; +pub mod llm_search; pub mod media_api; pub mod middleware; +pub mod pipeline; pub mod processing; pub mod scan; pub mod search; diff --git a/src/api/pipeline.rs b/src/api/pipeline.rs new file mode 100644 index 0000000..c61e099 --- /dev/null +++ b/src/api/pipeline.rs @@ -0,0 +1,85 @@ +use axum::extract::Path; +use axum::routing::post; +use axum::{Json, Router}; +use serde_json::{json, Value}; + +use crate::core::db::postgres_db::PostgresDb; +use crate::core::pipeline as pipeline_core; +use crate::core::config; + +async fn handle_store_asrx(Path(uuid): Path) -> Result, (axum::http::StatusCode, Json)> { + let db = PostgresDb::new(&config::DATABASE_URL).await + .map_err(|e| { + tracing::error!("DB error: {}", e); + (axum::http::StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "DB connection failed"}))) + })?; + + pipeline_core::store_asrx_chunks(&db, &uuid).await + .map_err(|e| { + tracing::error!("store_asrx error: {}", e); + (axum::http::StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": e.to_string()}))) + })?; + + Ok(Json(json!({"success": true, "message": "ASRX chunks stored", "file_uuid": uuid}))) +} + +async fn handle_rule1(Path(uuid): Path) -> Result, (axum::http::StatusCode, Json)> { + let db = PostgresDb::new(&config::DATABASE_URL).await + .map_err(|e| { + tracing::error!("DB error: {}", e); + (axum::http::StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "DB connection failed"}))) + })?; + + let count = pipeline_core::execute_rule1(&db, &uuid).await + .map_err(|e| { + tracing::error!("rule1 error: {}", e); + (axum::http::StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": e.to_string()}))) + })?; + + Ok(Json(json!({"success": true, "message": format!("Rule 1 complete: {} chunks", count), "file_uuid": uuid, "chunks": count}))) +} + +async fn handle_vectorize(Path(uuid): Path) -> Result, (axum::http::StatusCode, Json)> { + pipeline_core::vectorize_chunks(&uuid).await + .map_err(|e| { + tracing::error!("vectorize error: {}", e); + (axum::http::StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": e.to_string()}))) + })?; + + Ok(Json(json!({"success": true, "message": "Vectorization complete", "file_uuid": uuid}))) +} + +async fn handle_phase1(Path(uuid): Path) -> Result, (axum::http::StatusCode, Json)> { + pipeline_core::run_phase1(&uuid).await + .map_err(|e| { + tracing::error!("phase1 error: {}", e); + (axum::http::StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": e.to_string()}))) + })?; + + Ok(Json(json!({"success": true, "message": "Phase 1 complete", "file_uuid": uuid}))) +} + +async fn handle_complete(Path(uuid): Path) -> Result, (axum::http::StatusCode, Json)> { + let db = PostgresDb::new(&config::DATABASE_URL).await + .map_err(|e| { + tracing::error!("DB error: {}", e); + (axum::http::StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "DB connection failed"}))) + })?; + + pipeline_core::mark_complete(&db, &uuid).await + .map_err(|e| { + tracing::error!("complete error: {}", e); + (axum::http::StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": e.to_string()}))) + })?; + + Ok(Json(json!({"success": true, "message": "Video marked as completed", "file_uuid": uuid}))) +} + +pub fn pipeline_routes() -> Router { + Router::new() + .route("/api/v1/file/:file_uuid/store-asrx", post(handle_store_asrx)) + .route("/api/v1/file/:file_uuid/rule1", post(handle_rule1)) + .route("/api/v1/file/:file_uuid/vectorize", post(handle_vectorize)) + .route("/api/v1/file/:file_uuid/phase1", post(handle_phase1)) + .route("/api/v1/file/:file_uuid/complete", post(handle_complete)) +} diff --git a/src/api/search.rs b/src/api/search.rs index b45e955..89fbe5f 100644 --- a/src/api/search.rs +++ b/src/api/search.rs @@ -21,7 +21,7 @@ pub struct SmartSearchRequest { pub limit: Option, } -#[derive(Debug, Serialize)] +#[derive(Debug, Clone, Serialize)] pub struct SearchResult { pub id: i32, pub file_uuid: Option, @@ -47,12 +47,12 @@ pub struct SmartSearchResponse { pub strategy: String, } -/// Internal merged result with RRF scoring +/// Internal merged result with score-based merge #[derive(Debug)] struct MergedResult { file_uuid: String, chunk_id: String, - rrf_score: f64, + score: f64, semantic_score: Option, keyword_score: Option, identity_score: Option, @@ -140,8 +140,10 @@ pub async fn smart_search( }, )?; + const KEYWORD_FIXED_SCORE: f64 = 0.5; + const IDENTITY_FIXED_SCORE: f64 = 0.85; + let fetch_limit = limit * 3; - let rrf_k = 60.0; // 2. Semantic search via Qdrant let semantic_results: Vec<(String, String, f64)> = if let Some(file_uuid) = &req.file_uuid { @@ -176,6 +178,46 @@ pub async fn smart_search( } }; + // 3b. Video title search: if query matches a video title, get its chunks + const TITLE_MATCH_SCORE: f64 = 0.9; + let title_results: Vec<(String, String, f64)> = { + let clean_query = req.query.replace('\'', "''"); + let v_table = crate::core::db::schema::table_name("videos"); + let c_table = crate::core::db::schema::table_name("chunk"); + let video_rows: Vec<(String,)> = sqlx::query_as(&format!( + "SELECT file_uuid::text FROM {} WHERE file_name ILIKE $1 LIMIT 5", + v_table + )) + .bind(format!("%{}%", clean_query)) + .fetch_all(db.pool()) + .await + .unwrap_or_default(); + + let mut chunks = Vec::new(); + for (fu,) in video_rows.iter() { + if let Some(ref f) = req.file_uuid { + if fu != f { + continue; + } + } + let rows: Vec<(String, String)> = sqlx::query_as(&format!( + "SELECT chunk_id, file_uuid::text FROM {} \ + WHERE file_uuid = $1 AND embedding IS NOT NULL \ + AND chunk_type = 'sentence' \ + LIMIT 20", + c_table + )) + .bind(fu) + .fetch_all(db.pool()) + .await + .unwrap_or_default(); + for (cid, file_uuid) in rows { + chunks.push((file_uuid, cid, TITLE_MATCH_SCORE)); + } + } + chunks + }; + // 4. Identity search: if query matches a person name, get their chunks let identity_results: Vec<(String, String, f64)> = { let id_table = crate::core::db::schema::table_name("identities"); @@ -211,24 +253,23 @@ pub async fn smart_search( id_chunks }; - // 5. RRF merge: combine results from all sources + // 5. Score-based merge: combine results from all sources let mut merged: HashMap<(String, String), MergedResult> = HashMap::new(); - // Add semantic results - for (rank, (file_uuid, chunk_id, score)) in semantic_results.iter().enumerate() { + // Add semantic results (use Qdrant cosine score directly) + for (file_uuid, chunk_id, score) in semantic_results.iter() { let key = (file_uuid.clone(), chunk_id.clone()); - let rrf_contribution = 1.0 / (rrf_k + rank as f64 + 1.0); merged .entry(key) .and_modify(|e| { - e.rrf_score += rrf_contribution; + e.score = e.score.max(*score); e.semantic_score = Some(*score); e.source = format!("{}_{}", e.source.strip_prefix("semantic+").unwrap_or(&e.source), "semantic"); }) .or_insert(MergedResult { file_uuid: file_uuid.clone(), chunk_id: chunk_id.clone(), - rrf_score: rrf_contribution, + score: *score, semantic_score: Some(*score), keyword_score: None, identity_score: None, @@ -236,54 +277,76 @@ pub async fn smart_search( }); } - // Add keyword results - for (rank, (file_uuid, chunk_id, score)) in keyword_results.iter().enumerate() { + // Add keyword results (fixed score 0.5) + let keyword_fixed = KEYWORD_FIXED_SCORE; + for (file_uuid, chunk_id, _) in keyword_results.iter() { let key = (file_uuid.clone(), chunk_id.clone()); - let rrf_contribution = 1.0 / (rrf_k + rank as f64 + 1.0); merged .entry(key) .and_modify(|e| { - e.rrf_score += rrf_contribution; - e.keyword_score = Some(*score); + e.score = e.score.max(keyword_fixed); + e.keyword_score = Some(keyword_fixed); e.source = format!("{}_keyword", e.source); }) .or_insert(MergedResult { file_uuid: file_uuid.clone(), chunk_id: chunk_id.clone(), - rrf_score: rrf_contribution, + score: keyword_fixed, semantic_score: None, - keyword_score: Some(*score), + keyword_score: Some(keyword_fixed), identity_score: None, source: "keyword".to_string(), }); } - // Add identity results (only if we found matching identities) - let has_identity_match = !identity_results.is_empty(); - for (rank, (file_uuid, chunk_id, score)) in identity_results.iter().enumerate() { + // Add title match results (high score 0.9) — query matched video title + let has_title_match = !title_results.is_empty(); + let title_fixed = TITLE_MATCH_SCORE; + for (file_uuid, chunk_id, _) in title_results.iter() { let key = (file_uuid.clone(), chunk_id.clone()); - let rrf_contribution = 1.0 / (rrf_k + rank as f64 + 1.0); merged .entry(key) .and_modify(|e| { - e.rrf_score += rrf_contribution; - e.identity_score = Some(*score); + e.score = e.score.max(title_fixed); + e.source = format!("{}_title", e.source); + }) + .or_insert(MergedResult { + file_uuid: file_uuid.clone(), + chunk_id: chunk_id.clone(), + score: title_fixed, + semantic_score: None, + keyword_score: None, + identity_score: None, + source: "title".to_string(), + }); + } + + // Add identity results (fixed score 0.85) + let has_identity_match = !identity_results.is_empty(); + let identity_fixed = IDENTITY_FIXED_SCORE; + for (file_uuid, chunk_id, _) in identity_results.iter() { + let key = (file_uuid.clone(), chunk_id.clone()); + merged + .entry(key) + .and_modify(|e| { + e.score = e.score.max(identity_fixed); + e.identity_score = Some(identity_fixed); e.source = format!("{}_identity", e.source); }) .or_insert(MergedResult { file_uuid: file_uuid.clone(), chunk_id: chunk_id.clone(), - rrf_score: rrf_contribution, + score: identity_fixed, semantic_score: None, keyword_score: None, - identity_score: Some(*score), + identity_score: Some(identity_fixed), source: "identity".to_string(), }); } - // Sort by RRF score descending + // Sort by score descending (score-based merge) let mut ranked: Vec<&MergedResult> = merged.values().collect(); - ranked.sort_by(|a, b| b.rrf_score.partial_cmp(&a.rrf_score).unwrap_or(std::cmp::Ordering::Equal)); + ranked.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal)); // 6. Enrich top results from PG and build final response let mut final_results = Vec::new(); @@ -307,7 +370,7 @@ pub async fn smart_search( raw_text: None, summary: Some(pg.summary), metadata: pg.metadata.clone(), - similarity: Some(mr.rrf_score), + similarity: Some(mr.score), }); } } @@ -320,6 +383,9 @@ pub async fn smart_search( if has_identity_match { strategies.push("identity"); } + if has_title_match { + strategies.push("title"); + } Ok(Json(SmartSearchResponse { query: req.query, diff --git a/src/api/server.rs b/src/api/server.rs index 8512b33..f3b0431 100644 --- a/src/api/server.rs +++ b/src/api/server.rs @@ -19,6 +19,8 @@ use super::identities; use super::identity_agent_api; use super::identity_api; use super::identity_binding; +use super::llm_search; +use super::pipeline; use super::media_api; use super::middleware::unified_auth; use super::processing; @@ -117,7 +119,9 @@ pub async fn start_server(host: &str, port: u16) -> anyhow::Result<()> { .merge(media_api::bbox_routes()) .merge(trace_agent_api::trace_agent_routes()) .merge(search_routes()) + .merge(llm_search::llm_smart_routes()) .merge(universal_search_routes()) + .merge(pipeline::pipeline_routes()) .layer(axum::middleware::from_fn_with_state( state.api_state.clone(), unified_auth, diff --git a/src/cli/args.rs b/src/cli/args.rs index b120707..47f7331 100644 --- a/src/cli/args.rs +++ b/src/cli/args.rs @@ -40,6 +40,11 @@ pub enum Commands { /// UUID uuid: String, }, + /// Store ASRX chunks into pre_chunks table + StoreAsrx { + /// File UUID + uuid: String, + }, /// Generate story for cut scenes Story { /// UUID @@ -50,6 +55,16 @@ pub enum Commands { /// UUID (or 'all' for all) uuid: String, }, + /// Run Phase 1 release packaging + Phase1 { + /// File UUID + uuid: String, + }, + /// Mark video as completed + Complete { + /// File UUID + uuid: String, + }, /// Play video with overlays Play { /// Video path or UUID diff --git a/src/core/db/postgres_db.rs b/src/core/db/postgres_db.rs index 7c0340f..a2cd9df 100644 --- a/src/core/db/postgres_db.rs +++ b/src/core/db/postgres_db.rs @@ -3308,10 +3308,38 @@ impl PostgresDb { pub async fn store_pre_chunk( &self, - _uuid: &str, - _chunk_type: &str, - _data: serde_json::Value, + uuid: &str, + processor_type: &str, + data: serde_json::Value, ) -> Result<()> { + let table = schema::table_name("pre_chunks"); + let pre_chunk: PreChunk = serde_json::from_value(data)?; + let start_time = pre_chunk.start_frame as f64 / pre_chunk.fps; + let end_time = pre_chunk.end_frame as f64 / pre_chunk.fps; + sqlx::query(&format!( + "INSERT INTO {} (file_uuid, file_id, source_type, source_file, chunk_type, \ + start_frame, end_frame, start_time, end_time, fps, data, text_content, \ + processed, chunk_id, processor_type) \ + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15)", + table + )) + .bind(uuid) + .bind(pre_chunk.file_id) + .bind(&pre_chunk.source_type) + .bind(&pre_chunk.source_file) + .bind(&pre_chunk.chunk_type) + .bind(pre_chunk.start_frame) + .bind(pre_chunk.end_frame) + .bind(start_time) + .bind(end_time) + .bind(pre_chunk.fps) + .bind(&pre_chunk.raw_json) + .bind(&pre_chunk.text_content) + .bind(pre_chunk.processed) + .bind(&pre_chunk.chunk_id) + .bind(processor_type) + .execute(&self.pool) + .await?; Ok(()) } diff --git a/src/core/llm/mod.rs b/src/core/llm/mod.rs index f3bf813..9a66b09 100644 --- a/src/core/llm/mod.rs +++ b/src/core/llm/mod.rs @@ -1,2 +1,3 @@ pub mod client; pub mod function_calling; +pub mod rerank; diff --git a/src/core/llm/rerank.rs b/src/core/llm/rerank.rs new file mode 100644 index 0000000..5c0f272 --- /dev/null +++ b/src/core/llm/rerank.rs @@ -0,0 +1,168 @@ +use std::collections::HashSet; + +use anyhow::Result; +use reqwest::Client; +use serde::{Deserialize, Serialize}; +use std::time::Duration; +use tracing::{debug, warn}; + +use crate::core::config; + +#[derive(Debug, Serialize)] +struct ChatRequest { + model: String, + messages: Vec, + temperature: f32, + max_tokens: u32, + stream: bool, +} + +#[derive(Debug, Serialize, Deserialize)] +struct ChatMessage { + role: String, + content: String, +} + +#[derive(Debug, Deserialize)] +struct ChatResponse { + choices: Vec, +} + +#[derive(Debug, Deserialize)] +struct Choice { + message: ChatMessage, +} + +#[derive(Debug, Deserialize)] +struct RerankResponse { + ranked: Vec, +} + +pub async fn rerank_search_results(query: &str, candidates: &[(usize, &str)]) -> Result> { + if candidates.is_empty() { + return Ok(vec![]); + } + + let mut chunks_text = String::new(); + for (i, (_, text)) in candidates.iter().enumerate() { + let display = if text.len() > 100 { + format!("{}...", &text[..100]) + } else { + text.to_string() + }; + chunks_text.push_str(&format!("[{}] {}\n", i + 1, display)); + } + + let prompt = format!( + r#"You are a search relevance judge. Rank ALL chunks by relevance to the query. + +Query: "{}" + +Chunks: +{} + +Return a JSON object with ALL chunk numbers in order of relevance (most relevant first). +Example: {{"ranked": [5, 1, 3, 2, 4, 6, 7, 8, 9, 10]}} +Include every chunk number exactly once. Only respond with the JSON."#, + query, chunks_text + ); + + let client = Client::builder() + .timeout(Duration::from_secs(15)) + .build()?; + + let req = ChatRequest { + model: config::llm::CHAT_MODEL.clone(), + messages: vec![ + ChatMessage { + role: "system".to_string(), + content: "You are a precise search relevance judge.".to_string(), + }, + ChatMessage { + role: "user".to_string(), + content: prompt, + }, + ], + temperature: 0.1, + max_tokens: 512, + stream: false, + }; + + debug!("LLM rerank: {} candidates for query '{}'", candidates.len(), query); + + let res = client + .post(&*config::llm::CHAT_URL) + .json(&req) + .send() + .await?; + + if !res.status().is_success() { + let status = res.status(); + let body = res.text().await.unwrap_or_default(); + warn!("LLM rerank API error: {} — body: {}", status, body); + return Ok(candidates.iter().map(|(idx, _)| *idx).collect()); + } + + let chat_res: ChatResponse = res.json().await?; + let content = chat_res + .choices + .into_iter() + .next() + .map(|c| c.message.content) + .unwrap_or_default(); + + let content = content.trim(); + + // Strip markdown code fences if present + let content = if content.starts_with("```") { + let lines: Vec<&str> = content.lines().collect(); + let start = if lines.first().map(|l| l.contains("```")).unwrap_or(false) { 1 } else { 0 }; + let end = if lines.last().map(|l| l.contains("```")).unwrap_or(false) { + lines.len().saturating_sub(1) + } else { + lines.len() + }; + lines[start..end].join("\n").trim().to_string() + } else { + content.to_string() + }; + + let json_start = content.find('{'); + let json_end = content.rfind('}'); + + if let (Some(start), Some(end)) = (json_start, json_end) { + let json_str = &content[start..=end]; + match serde_json::from_str::(json_str) { + Ok(parsed) => { + let mut ranked: Vec = parsed + .ranked + .into_iter() + .filter_map(|i| { + if i > 0 && i <= candidates.len() { + Some(candidates[i - 1].0) + } else { + None + } + }) + .collect(); + + if !ranked.is_empty() { + let seen: HashSet = ranked.iter().cloned().collect(); + for (orig_idx, _) in candidates { + if !seen.contains(orig_idx) { + ranked.push(*orig_idx); + } + } + return Ok(ranked); + } + warn!("LLM rerank returned empty ranked list"); + } + Err(e) => { + warn!("Failed to parse LLM rerank JSON: {}", e); + } + } + } + + warn!("LLM rerank: could not parse response — content: {}", &content[..content.len().min(200)]); + Ok(candidates.iter().map(|(idx, _)| *idx).collect()) +} diff --git a/src/core/mod.rs b/src/core/mod.rs index 8fd786c..5950d6b 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -12,6 +12,7 @@ pub mod ingestion; pub mod llm; pub mod overlay; pub mod person_identity; +pub mod pipeline; pub mod probe; pub mod processor; pub mod storage; diff --git a/src/core/pipeline/mod.rs b/src/core/pipeline/mod.rs new file mode 100644 index 0000000..0ffeb73 --- /dev/null +++ b/src/core/pipeline/mod.rs @@ -0,0 +1,172 @@ +use anyhow::{Context, Result}; + +use crate::core::chunk::rule1_ingest; +use crate::core::config; +use crate::core::db::postgres_db::PostgresDb; +use crate::core::db::qdrant_db::QdrantDb; +use crate::core::db::schema; +use crate::core::db::VectorPayload; +use crate::core::embedding::Embedder; +use crate::core::processor::asrx::AsrxResult; +use crate::core::processor::PythonExecutor; +use crate::core::storage::output_dir::OutputDir; + +pub async fn store_asrx_chunks(db: &PostgresDb, uuid: &str) -> Result<()> { + let output_dir = OutputDir::new(); + let asrx_path = output_dir.get_output_path(uuid, "asrx.json"); + + let json_str = std::fs::read_to_string(&asrx_path) + .with_context(|| format!("ASRX file not found: {:?}", asrx_path))?; + let result: AsrxResult = serde_json::from_str(&json_str) + .context("Failed to parse ASRX JSON")?; + + let segments_count = result.segments.len(); + let mut pre_chunks = Vec::new(); + let mut speaker_detections = Vec::new(); + + for (i, segment) in result.segments.iter().enumerate() { + let data = serde_json::json!({ + "text": segment.text, + "speaker_id": segment.speaker_id, + "timestamp": segment.start_time, + }); + pre_chunks.push((i as i64, Some(segment.start_time), data, None, None)); + speaker_detections.push(( + segment.speaker_id.clone().unwrap_or_default(), + segment.start_time, + segment.end_time, + segment.text.clone(), + None::, + 0.0, + )); + } + + db.store_raw_pre_chunks_batch(uuid, "asrx", &pre_chunks).await?; + db.store_raw_pre_chunks_batch(uuid, "asr", &pre_chunks).await?; + db.store_speaker_detections_batch(uuid, &speaker_detections).await?; + + println!("Stored {} ASRX pre-chunks for {}", segments_count, uuid); + Ok(()) +} + +pub async fn execute_rule1(db: &PostgresDb, uuid: &str) -> Result { + let video = db.get_video_by_uuid(uuid) + .await? + .context("Video not found")?; + let fps = video.fps; + + let count = rule1_ingest::execute_rule1(db, uuid, fps).await + .context("Rule 1 ingestion failed")?; + + println!("Rule 1 completed: {} chunks inserted for {}", count, uuid); + Ok(count) +} + +pub async fn vectorize_chunks(uuid: &str) -> Result<()> { + let db = PostgresDb::new(&config::DATABASE_URL).await?; + let qdrant = QdrantDb::new(); + let embedder = Embedder::new("embeddinggemma-300m".to_string()); + + let chunk_table = schema::table_name("chunk"); + let rows = sqlx::query_as::<_, (String, String, String, i64, i64, f64, f64, String)>( + &format!( + "SELECT chunk_id, chunk_type, text_content, start_frame, end_frame, \ + start_time, end_time, content::text \ + FROM {} WHERE file_uuid = $1 AND chunk_type = 'sentence' \ + AND embedding IS NULL \ + AND (text_content IS NOT NULL AND text_content != '') \ + ORDER BY id", + chunk_table + ), + ) + .bind(uuid) + .fetch_all(db.pool()) + .await?; + + if rows.is_empty() { + println!("No sentence chunks to vectorize for {}", uuid); + return Ok(()); + } + + let total = rows.len(); + let mut stored = 0usize; + + for (chunk_id, _chunk_type, text, start_frame, end_frame, start_time, end_time, _content_str) in &rows { + if text.is_empty() { + continue; + } + + match embedder.embed_document(text).await { + Ok(vector) => { + if let Err(e) = db.store_vector(chunk_id, &vector, uuid).await { + eprintln!("PG store failed for {}: {}", chunk_id, e); + continue; + } + let payload = VectorPayload { + file_uuid: uuid.to_string(), + chunk_id: chunk_id.clone(), + chunk_type: "sentence".to_string(), + start_frame: *start_frame, + end_frame: *end_frame, + start_time: *start_time, + end_time: *end_time, + text: Some(text.clone()), + }; + if let Err(e) = qdrant.upsert_vector(chunk_id, &vector, payload).await { + eprintln!("Qdrant upsert failed for {}: {}", chunk_id, e); + continue; + } + stored += 1; + if stored % 50 == 0 { + println!("Vectorized {}/{} chunks for {}", stored, total, uuid); + } + } + Err(e) => { + eprintln!("Embedding failed for {}: {}", chunk_id, e); + } + } + } + + println!("Vectorization complete: {}/{} vectors for {}", stored, total, uuid); + Ok(()) +} + +pub async fn run_phase1(uuid: &str) -> Result<()> { + let executor = PythonExecutor::new() + .context("Failed to create PythonExecutor")?; + + executor + .run( + "release_pack.py", + &["--phase", "1", "--file-uuid", uuid], + None, + "RELEASE_P1", + Some(std::time::Duration::from_secs(120)), + ) + .await + .context("Phase 1 release pack failed")?; + + println!("Phase 1 release packaged for {}", uuid); + Ok(()) +} + +pub async fn mark_complete(db: &PostgresDb, uuid: &str) -> Result<()> { + use crate::core::db::MonitorJobStatus; + use crate::core::db::VideoStatus; + + let job_id = sqlx::query_scalar::<_, i32>( + &format!("SELECT id FROM {} WHERE uuid = $1 LIMIT 1", schema::table_name("monitor_jobs")), + ) + .bind(uuid) + .fetch_optional(db.pool()) + .await?; + + if let Some(job_id) = job_id { + db.update_job_status(job_id, MonitorJobStatus::Completed).await?; + println!("Job {} marked as completed", job_id); + } + + db.update_video_status(uuid, VideoStatus::Completed).await?; + println!("Video {} marked as completed", uuid); + Ok(()) +} diff --git a/src/main.rs b/src/main.rs index 82d1763..3c8c70d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -16,12 +16,17 @@ fn init_tracing() { .init(); } +fn load_env() { + let _ = dotenv::from_filename("/Users/accusys/momentry_core_0.1/.env"); +} + use cli::*; use processing::handlers::*; /// Main entry point #[tokio::main] async fn main() -> Result<()> { + load_env(); init_tracing(); let cli = Cli::parse(); @@ -41,12 +46,21 @@ async fn main() -> Result<()> { Commands::Chunk { uuid } => { handle_chunk(&uuid).await?; } + Commands::StoreAsrx { uuid } => { + handle_store_asrx(&uuid).await?; + } Commands::Story { uuid } => { handle_story(&uuid).await?; } Commands::Vectorize { uuid } => { handle_vectorize(&uuid).await?; } + Commands::Phase1 { uuid } => { + handle_phase1(&uuid).await?; + } + Commands::Complete { uuid } => { + handle_complete(&uuid).await?; + } Commands::Play { target } => { handle_play(&target).await?; } diff --git a/src/processing/handlers.rs b/src/processing/handlers.rs index 5b90fa4..48815a5 100644 --- a/src/processing/handlers.rs +++ b/src/processing/handlers.rs @@ -419,3 +419,26 @@ pub async fn handle_n8n( Ok(()) } + +/// Handle store-asrx command +pub async fn handle_store_asrx(uuid: &str) -> Result<()> { + let db = momentry_core::core::db::postgres_db::PostgresDb::new( + &momentry_core::core::config::DATABASE_URL, + ) + .await?; + momentry_core::core::pipeline::store_asrx_chunks(&db, uuid).await +} + +/// Handle phase1 command +pub async fn handle_phase1(uuid: &str) -> Result<()> { + momentry_core::core::pipeline::run_phase1(uuid).await +} + +/// Handle complete command +pub async fn handle_complete(uuid: &str) -> Result<()> { + let db = momentry_core::core::db::postgres_db::PostgresDb::new( + &momentry_core::core::config::DATABASE_URL, + ) + .await?; + momentry_core::core::pipeline::mark_complete(&db, uuid).await +}