From 17e4e15860e059aeecee352d7d1a635d557d27ad Mon Sep 17 00:00:00 2001 From: Accusys Date: Sat, 13 Jun 2026 16:25:52 +0800 Subject: [PATCH] feat: add Vision LLM integration (CLIP + Qwen3-VL cascade) - Add Qwen3-VL dynamic management (start/stop/status CLI) - Add CLIP + Qwen3-VL cascade detection strategy - Add Vision CLI commands (vision start/stop/status, detect) - Add cascade_vision processor module - Add clip processor module - Add qwen_vl_manager module Changes: - scripts/start_qwen3vl.sh, stop_qwen3vl.sh: Qwen3-VL management scripts - src/core/vision/: Qwen3-VL manager module - src/core/processor/cascade_vision.rs: CLIP + Qwen3-VL cascade logic - src/core/processor/clip.rs: CLIP classification and detection - src/api/clip_api.rs: CLIP API endpoints - src/cli/vision.rs: Vision CLI implementation - src/cli/args.rs: Add Vision and Detect commands - src/main.rs: Integrate Vision CLI - src/core/mod.rs: Add vision module - src/core/processor/mod.rs: Add cascade_vision module --- scripts/clip_classifier.py | 232 ++++++++++++++++++++ scripts/start_qwen3vl.sh | 35 +++ scripts/stop_qwen3vl.sh | 30 +++ src/api/agent_api.rs | 5 +- src/api/agent_search.rs | 124 ++++++++--- src/api/clip_api.rs | 194 +++++++++++++++++ src/api/five_w1h_agent_api.rs | 6 +- src/api/identity_agent_api.rs | 12 +- src/api/identity_api.rs | 3 +- src/api/identity_binding.rs | 10 +- src/api/llm_search.rs | 8 +- src/api/media_api.rs | 258 ++++++++++++++++++---- src/api/pipeline.rs | 131 ++++++++---- src/api/search.rs | 87 +++++++- src/api/server.rs | 3 +- src/api/trace_agent_api.rs | 31 ++- src/bin/sync_qdrant_from_pg.rs | 44 ++-- src/bin/vectorize_missing.rs | 23 +- src/cli/args.rs | 28 +++ src/cli/mod.rs | 1 + src/cli/vision.rs | 95 +++++++++ src/core/config.rs | 10 + src/core/db/postgres_db.rs | 2 +- src/core/llm/client.rs | 9 +- src/core/llm/function_calling.rs | 36 ++-- src/core/llm/rerank.rs | 31 ++- src/core/mod.rs | 1 + src/core/pipeline/mod.rs | 51 +++-- src/core/processor/asrx.rs | 5 +- src/core/processor/cascade_vision.rs | 308 +++++++++++++++++++++++++++ src/core/processor/clip.rs | 290 +++++++++++++++++++++++++ src/core/processor/mod.rs | 4 + src/core/vision/mod.rs | 1 + src/core/vision/qwen_vl_manager.rs | 218 +++++++++++++++++++ src/main.rs | 11 + src/worker/job_worker.rs | 138 ++++++------ src/worker/processor.rs | 4 +- 37 files changed, 2185 insertions(+), 294 deletions(-) create mode 100644 scripts/clip_classifier.py create mode 100755 scripts/start_qwen3vl.sh create mode 100755 scripts/stop_qwen3vl.sh create mode 100644 src/api/clip_api.rs create mode 100644 src/cli/vision.rs create mode 100644 src/core/processor/cascade_vision.rs create mode 100644 src/core/processor/clip.rs create mode 100644 src/core/vision/mod.rs create mode 100644 src/core/vision/qwen_vl_manager.rs diff --git a/scripts/clip_classifier.py b/scripts/clip_classifier.py new file mode 100644 index 0000000..e6aafc9 --- /dev/null +++ b/scripts/clip_classifier.py @@ -0,0 +1,232 @@ +#!/usr/bin/env python3 +""" +CLIP Zero-Shot Classifier +Uses OpenAI CLIP for reliable scene and object classification. + +Advantages over LLaVA Vision: +- Zero-shot classification (no prompt induction) +- Reliable confidence scores +- Fast inference +- No hallucinations +""" + +import argparse +import json +import sys +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +try: + import torch + from PIL import Image + from transformers import CLIPProcessor, CLIPModel + HAS_CLIP = True +except ImportError as e: + print(f"[ERROR] Required packages not found: {e}", file=sys.stderr) + print("[ERROR] Install with: pip install transformers torch pillow", file=sys.stderr) + HAS_CLIP = False + sys.exit(1) + + +class CLIPClassifier: + def __init__(self, model_name: str = "openai/clip-vit-base-patch32"): + """ + Initialize CLIP model. + + Args: + model_name: HuggingFace model name (default: openai/clip-vit-base-patch32) + """ + print(f"[CLIP] Loading model: {model_name}") + self.model = CLIPModel.from_pretrained(model_name) + self.processor = CLIPProcessor.from_pretrained(model_name) + self.device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") + self.model.to(self.device) + print(f"[CLIP] Model loaded on device: {self.device}") + + def classify_image( + self, + image_path: str, + labels: List[str], + top_k: int = 5 + ) -> List[Dict[str, float]]: + """ + Classify a single image with given labels. + + Args: + image_path: Path to image file + labels: List of candidate labels (e.g., ["person in room", "outdoor scene", "snow landscape"]) + top_k: Number of top predictions to return + + Returns: + List of {"label": str, "confidence": float} sorted by confidence + """ + try: + image = Image.open(image_path).convert("RGB") + except Exception as e: + print(f"[ERROR] Failed to load image {image_path}: {e}", file=sys.stderr) + return [] + + # Prepare inputs + inputs = self.processor( + text=labels, + images=image, + return_tensors="pt", + padding=True + ).to(self.device) + + # Get predictions + with torch.no_grad(): + outputs = self.model(**inputs) + logits_per_image = outputs.logits_per_image + probs = logits_per_image.softmax(dim=1).cpu().numpy()[0] + + # Sort by confidence + results = [ + {"label": label, "confidence": float(prob)} + for label, prob in zip(labels, probs) + ] + results.sort(key=lambda x: x["confidence"], reverse=True) + + return results[:top_k] + + def classify_images( + self, + image_paths: List[str], + labels: List[str], + top_k: int = 5 + ) -> Dict[str, List[Dict[str, float]]]: + """ + Classify multiple images with given labels. + + Args: + image_paths: List of image paths + labels: List of candidate labels + top_k: Number of top predictions per image + + Returns: + Dict mapping image_path -> predictions + """ + results = {} + for img_path in image_paths: + results[img_path] = self.classify_image(img_path, labels, top_k) + return results + + def detect_objects( + self, + image_path: str, + objects: List[str], + threshold: float = 0.15 + ) -> List[Dict[str, float]]: + """ + Detect if specific objects are present in image. + + Args: + image_path: Path to image file + objects: List of objects to detect (e.g., ["gun", "knife", "weapon"]) + threshold: Confidence threshold (default: 0.15) + + Returns: + List of detected objects with confidence >= threshold + """ + predictions = self.classify_image(image_path, objects, top_k=len(objects)) + detected = [p for p in predictions if p["confidence"] >= threshold] + return detected + + def batch_detect_objects( + self, + image_paths: List[str], + objects: List[str], + threshold: float = 0.15 + ) -> Dict[str, List[Dict[str, float]]]: + """ + Detect objects across multiple images. + + Args: + image_paths: List of image paths + objects: List of objects to detect + threshold: Confidence threshold + + Returns: + Dict mapping image_path -> detected objects + """ + results = {} + for img_path in image_paths: + detected = self.detect_objects(img_path, objects, threshold) + if detected: + results[img_path] = detected + return results + + +def main(): + parser = argparse.ArgumentParser( + description="CLIP Zero-Shot Classifier", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Scene classification + python clip_classifier.py image.jpg --labels "indoor room,outdoor scene,person in room" --top-k 3 + + # Object detection + python clip_classifier.py image.jpg --detect "gun,weapon,knife" --threshold 0.2 + + # Batch processing + python clip_classifier.py images.txt --batch --labels "indoor,outdoor" +""" + ) + + parser.add_argument("input", help="Image path or text file with image paths (for batch)") + parser.add_argument("--labels", help="Comma-separated labels for classification") + parser.add_argument("--detect", help="Comma-separated objects to detect") + parser.add_argument("--threshold", type=float, default=0.15, help="Detection threshold (default: 0.15)") + parser.add_argument("--top-k", type=int, default=5, help="Top-k predictions (default: 5)") + parser.add_argument("--batch", action="store_true", help="Batch mode (input is text file)") + parser.add_argument("--output", help="Output JSON file (default: stdout)") + parser.add_argument("--model", default="openai/clip-vit-base-patch32", help="CLIP model name") + + args = parser.parse_args() + + if not HAS_CLIP: + sys.exit(1) + + # Initialize classifier + classifier = CLIPClassifier(args.model) + + # Prepare image paths + if args.batch: + with open(args.input, "r") as f: + image_paths = [line.strip() for line in f if line.strip()] + else: + image_paths = [args.input] + + # Run classification + results = {} + + if args.detect: + # Object detection mode + objects = [obj.strip() for obj in args.detect.split(",")] + print(f"[CLIP] Detecting objects: {objects}") + results = classifier.batch_detect_objects(image_paths, objects, args.threshold) + + elif args.labels: + # Scene classification mode + labels = [label.strip() for label in args.labels.split(",")] + print(f"[CLIP] Classifying with {len(labels)} labels") + results = classifier.classify_images(image_paths, labels, args.top_k) + + else: + print("[ERROR] Must specify --labels or --detect", file=sys.stderr) + sys.exit(1) + + # Output results + output_json = json.dumps(results, indent=2, ensure_ascii=False) + + if args.output: + with open(args.output, "w", encoding="utf-8") as f: + f.write(output_json) + print(f"[CLIP] Results saved to {args.output}") + else: + print(output_json) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/start_qwen3vl.sh b/scripts/start_qwen3vl.sh new file mode 100755 index 0000000..de6fda1 --- /dev/null +++ b/scripts/start_qwen3vl.sh @@ -0,0 +1,35 @@ +#!/bin/bash +# Start Qwen3-VL server + +MODEL_PATH="/Users/accusys/models/Qwen3VL-8B-Instruct-Q8_0.gguf" +MMPROJ_PATH="/Users/accusys/models/mmproj-Qwen3VL-8B-Instruct-F16.gguf" +LOG_FILE="/Users/accusys/momentry_core/logs/qwen3vl_8086.log" +PID_FILE="/tmp/qwen3vl.pid" + +# Kill existing process if running +if [ -f "$PID_FILE" ]; then + OLD_PID=$(cat "$PID_FILE") + if ps -p "$OLD_PID" > /dev/null 2>&1; then + kill "$OLD_PID" + sleep 2 + fi + rm "$PID_FILE" +fi + +# Start server +nohup /opt/homebrew/bin/llama-server \ + --model "$MODEL_PATH" \ + --mmproj "$MMPROJ_PATH" \ + --host 127.0.0.1 \ + --port 8086 \ + --ctx-size 8192 \ + --n-gpu-layers 99 \ + --threads 8 \ + --batch-size 512 \ + --media-path /Users/accusys/momentry/output_dev \ + > "$LOG_FILE" 2>&1 & + +echo $! > "$PID_FILE" +echo "Qwen3-VL started with PID $(cat $PID_FILE)" +echo "Log file: $LOG_FILE" +echo "Health check: http://localhost:8086/health" \ No newline at end of file diff --git a/scripts/stop_qwen3vl.sh b/scripts/stop_qwen3vl.sh new file mode 100755 index 0000000..160488d --- /dev/null +++ b/scripts/stop_qwen3vl.sh @@ -0,0 +1,30 @@ +#!/bin/bash +# Stop Qwen3-VL server + +PID_FILE="/tmp/qwen3vl.pid" + +if [ -f "$PID_FILE" ]; then + PID=$(cat "$PID_FILE") + if ps -p "$PID" > /dev/null 2>&1; then + kill "$PID" + sleep 2 + if ps -p "$PID" > /dev/null 2>&1; then + kill -9 "$PID" + fi + echo "Qwen3-VL stopped (PID: $PID)" + else + echo "Process already stopped (PID: $PID)" + fi + rm "$PID_FILE" +else + echo "No PID file found at $PID_FILE" + echo "Searching for running process..." + RUNNING_PID=$(ps aux | grep "Qwen3VL-8B" | grep -v grep | awk '{print $2}') + if [ -n "$RUNNING_PID" ]; then + echo "Found running process (PID: $RUNNING_PID)" + kill "$RUNNING_PID" + echo "Process killed" + else + echo "No running process found" + fi +fi \ No newline at end of file diff --git a/src/api/agent_api.rs b/src/api/agent_api.rs index 5fad979..570ba42 100644 --- a/src/api/agent_api.rs +++ b/src/api/agent_api.rs @@ -1,8 +1,8 @@ use axum::{extract::State, http::StatusCode, response::Json, routing::post, Router}; -use reqwest::Client; use serde::{Deserialize, Serialize}; use crate::api::types::AppState; +use crate::core::llm::function_calling::LLM_CLIENT; pub fn agent_routes() -> Router { Router::new().route("/api/v1/agents/translate", post(translate_text)) @@ -42,7 +42,6 @@ async fn translate_text( ); // Call LLM via configurable endpoint - let client = Client::new(); let llm_url = crate::core::config::llm::CHAT_URL.as_str(); let model = crate::core::config::llm::CHAT_MODEL.as_str(); @@ -57,7 +56,7 @@ async fn translate_text( "temperature": 0.1 }); - let response = client.post(llm_url).json(&body).send().await.map_err(|e| { + let response = LLM_CLIENT.post(llm_url).json(&body).send().await.map_err(|e| { ( StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to call LLM: {}", e), diff --git a/src/api/agent_search.rs b/src/api/agent_search.rs index b0fbdd7..7830ffb 100644 --- a/src/api/agent_search.rs +++ b/src/api/agent_search.rs @@ -91,19 +91,63 @@ const SYSTEM_PROMPT: &str = r#"你是 Momentry 影片分析助手。回答用戶 6. 用文字反查人物使用 identity_text(輸入關鍵字→找出誰說/提到這段話) 7. 語意/內容問題使用 smart_search 或 universal_search 8. 畫面分析使用 analyze_frame — 可以分析影片中的任何畫面內容(場景、人物表情、動作、物件等) -9. 可以同時呼叫多個工具 +9. **可以同時呼叫多個工具,但需符合以下條件:** + - ✅ 查詢多部影片的相同資訊(如:3部影片的人物列表) + - ✅ 需要組合多個來源的資訊才能回答(如:file_info + tkg_query) + - ❌ 不要為了「嘗試所有可能」而盲目並行呼叫 + - ❌ 如果單一工具已返回足夠答案,不需要額外呼叫 -## 引導規則 -- 如果用戶沒說片名 → 用 find_file 搜尋,如果名稱不明確就反問 -- 反問時提供 suggestions,例如演員名、年代 -- **如果影片的 has_data 為 false,代表尚未完成處理,不要推薦用戶使用。引導用戶選擇 has_data=true 的影片** -- 不要輸出 JSON,用自然語言回答 -- 引用資料時附上具體數字(frame 編號、時間秒數) +## 引導規則(優化版) +- **搜尋優先原則**: + 1. **所有問題都先嘗試搜尋,不要過早判斷用戶是否說了片名** + 2. 根據搜尋結果和答案性質決定是否反問: + - **列举型問題**(找出所有、列出)→ ✅ 不反問,列出所有結果 + - **指定型問題**(这部、那个)→ ⚠️ 反問選擇具體哪個 + - **統計型問題**(多少、幾個)→ ✅ 不反問,統計所有結果 + - **分析型問題**(分析、描述)→ ⚠️ 視問題表述決定 -## 回答規則 -- 回答要簡潔但完整 -- 如果找到影片,附上 file_uuid(用戶之後可能需要) -- 對於人物問題,說出角色名和演員名"#; +- **反問條件(精確)**: + 1. **答案需要分辨才反問**,不是「找到多部影片就反問」 + 2. 判断标准: + - ✅ 如果問題要求「所有」「列出」→ 答案不需要分辨 → 不反問 + - ⚠️ 如果問題要求「这部」「那个」→ 答案需要分辨 → 反問 + - ⚠️ 如果問題不明確 → 根據常理判断是否需要分辨 + +- **反問優化**: + 1. 反問時提供智能 suggestions(依問題類型調整) + 2. 人物問題 → suggestions: ["演員名", "角色名", "年代"] + 3. 內容問題 → suggestions: ["片名", "年代", "主題關鍵字"] + 4. 畫面問題 → suggestions: ["片名", "時間範圍", "場景描述"] + +- **特殊情況**: + - 如果影片的 has_data 為 false → 不要推薦,引導選擇 has_data=true + - 如果搜尋結果直接包含答案 → 直接回答,不額外呼叫工具 + - 如果找不到影片 → 反問提供更多資訊(片名、演員、年份) + +- **回答格式**: + - 不要輸出 JSON,用自然語言回答 + - 引用資料時附上具體數字(frame 編號、時間秒數) + +## 回答規則(優化版) +- 回答長度依問題類型調整: + - 簡單查詢(如「列出影片」)→ 簡潔列表回答(1-2句) + - 分析問題(如「描述情節」)→ 詳細回答(3-5句) + - 計數問題(如「有幾個場景」)→ 直接回答數字 + 簡短說明 + +- 回答格式: + - ✅ 如果找到影片,附上 file_uuid(用戶之後可能需要) + - ✅ 對於人物問題,說出角色名和演員名(如果有) + - ✅ 引用資料時附上具體數字(frame 編號、時間秒數) + - ❌ 不要輸出 JSON 格式,用自然語言回答 + - ❌ 不要編造資料,如果找不到就明確說「找不到」 + +## 停止規則(重要) +- **如果已經找到足夠資訊回答用戶問題,立即停止呼叫工具,直接回答** +- **如果連續 2 轪呼叫工具都返回空結果或相同資訊,停止並告知用戶「找不到更多相關資訊」** +- **如果用戶問題不明確或範圍過大,停止並反問用戶(提供 suggestions)** +- **如果單一工具呼叫返回完整答案,不需要額外呼叫其他工具補充** +- **優化效率:避免重複呼叫相同工具或查詢相同內容** +- **成本控制:主動判斷是否需要繼續,不要盲目嘗試所有工具**"#; fn make_tools(pool: &sqlx::PgPool) -> Vec { vec![ @@ -825,8 +869,12 @@ async fn exec_analyze_frame( async fn execute_tool(pool: &sqlx::PgPool, tool_call: &ToolCall) -> (String, String, String) { let name = tool_call.function.name.clone(); + let tool_call_id = tool_call.id.clone().unwrap_or_default(); let args: serde_json::Value = - serde_json::from_str(&tool_call.function.arguments).unwrap_or_default(); + match serde_json::from_str(&tool_call.function.arguments) { + Ok(v) => v, + Err(e) => return (tool_call_id, name, serde_json::json!({"error": format!("Invalid arguments: {}", e)}).to_string()), + }; let result = match name.as_str() { "find_file" => exec_find_file(pool, &args).await, "list_files" => exec_list_files(pool, &args).await, @@ -844,31 +892,42 @@ async fn execute_tool(pool: &sqlx::PgPool, tool_call: &ToolCall) -> (String, Str Ok(s) => s, Err(e) => serde_json::json!({"error": e}).to_string(), }; - let tool_call_id = tool_call.id.clone().unwrap_or_default(); (tool_call_id, name, content) } // ── Tool Loop ───────────────────────────────────────────────────── -const MAX_ROUNDS: u32 = 5; +const MAX_ROUNDS: u32 = 15; async fn run_tool_loop( pool: &sqlx::PgPool, system_prompt: &str, user_query: &str, history: Vec, -) -> (String, Vec) { +) -> (String, Vec, Vec) { let mut messages = function_calling::build_conversation(system_prompt, user_query, history); let mut sources = Vec::new(); for round in 0..MAX_ROUNDS { - let tools = Some(make_tools(pool)); - match function_calling::call_llm(messages.clone(), tools, 2048, 120).await { + let tools = make_tools(pool); + tracing::info!( + "[AGENT] Round {} started, message_count: {}, tools_available: {}", + round + 1, + messages.len(), + tools.len() + ); + + match function_calling::call_llm(messages.clone(), Some(tools.clone()), 2048, 120).await { Ok(LlmResponse::Text(text)) => { - return (text, sources); + tracing::info!( + "[AGENT] Loop finished: rounds_used={}, total_tools_called={}, answer_length={} chars", + round + 1, + sources.len(), + text.len() + ); + return (text, messages, sources); } Ok(LlmResponse::ToolCalls(calls)) => { - // Push assistant message with tool_calls so Gemma4 remembers messages.push(ChatMessage { role: "assistant".to_string(), content: None, @@ -878,21 +937,32 @@ async fn run_tool_loop( }); for call in &calls { let (tool_call_id, name, content) = execute_tool(pool, call).await; + tracing::info!( + "[AGENT] Tool called: {}, result_size: {} chars, round: {}", + name, + content.len(), + round + 1 + ); sources.push(serde_json::json!({"tool": name, "result": content})); messages.push(function_calling::make_tool_result( - &tool_call_id, - &name, - &content, + &tool_call_id, &name, &content, )); } } Err(e) => { - return (format!("系統錯誤:{}", e), sources); + tracing::error!("[AGENT] LLM call failed: {}", e); + return (format!("系統錯誤:{}", e), messages, sources); } } } + tracing::warn!( + "[AGENT] Max rounds reached: rounds_used={}, total_tools_called={}", + MAX_ROUNDS, + sources.len() + ); ( "已達到最大查詢次數,請縮小問題範圍後重新詢問。".to_string(), + messages, sources, ) } @@ -905,12 +975,12 @@ async fn agent_search( ) -> Result, (StatusCode, Json)> { let (conv_id, history) = get_or_create_conv(req.conversation_id.as_deref()); - let (answer, sources) = + let (answer, messages, sources) = run_tool_loop(state.db.pool(), SYSTEM_PROMPT, &req.query, history).await; - // Save updated messages for conversation continuation - let new_msgs = function_calling::build_conversation(SYSTEM_PROMPT, &req.query, vec![]); - save_messages(&conv_id, &new_msgs); + // Save messages (skip system prompt — build_conversation re-adds it) + let history: Vec = messages.into_iter().skip(1).collect(); + save_messages(&conv_id, &history); let needs_input = answer.contains('?') || answer.contains('?'); let suggestions = if needs_input { diff --git a/src/api/clip_api.rs b/src/api/clip_api.rs new file mode 100644 index 0000000..b1a9314 --- /dev/null +++ b/src/api/clip_api.rs @@ -0,0 +1,194 @@ +use axum::{ + extract::{Query, State}, + http::StatusCode, + response::{IntoResponse, Response}, + routing::{get, post}, + Json, Router, +}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +use crate::core::processor::{classify_image, classify_images, detect_objects, ClipPrediction}; +use crate::api::types::AppState; + +#[derive(Debug, Deserialize)] +pub struct ClassifyRequest { + image_path: String, + labels: String, + #[serde(default = "default_top_k")] + top_k: usize, + #[serde(default)] + model: Option, +} + +fn default_top_k() -> usize { + 5 +} + +#[derive(Debug, Deserialize)] +pub struct DetectRequest { + image_path: String, + objects: String, + #[serde(default = "default_threshold")] + threshold: f32, + #[serde(default)] + model: Option, +} + +fn default_threshold() -> f32 { + 0.15 +} + +#[derive(Debug, Deserialize)] +pub struct BatchClassifyRequest { + image_paths: String, + labels: String, + #[serde(default = "default_top_k")] + top_k: usize, + #[serde(default)] + model: Option, +} + +#[derive(Debug, Serialize)] +pub struct ClassifyResponse { + success: bool, + predictions: Vec, +} + +#[derive(Debug, Serialize)] +pub struct DetectResponse { + success: bool, + detected: Vec, +} + +#[derive(Debug, Serialize)] +pub struct BatchClassifyResponse { + success: bool, + results: HashMap>, +} + +#[derive(Debug, Serialize)] +pub struct ErrorResponse { + success: bool, + error: String, +} + +pub fn clip_routes() -> Router { + Router::new() + .route("/api/v1/clip/classify", post(classify_image_endpoint)) + .route("/api/v1/clip/detect", post(detect_objects_endpoint)) + .route("/api/v1/clip/batch", post(batch_classify_endpoint)) +} + +async fn classify_image_endpoint( + State(_state): State, + Json(req): Json, +) -> Response { + let labels: Vec<&str> = req.labels.split(',').map(|s| s.trim()).collect(); + + let result = classify_image( + &req.image_path, + &labels, + Some(req.top_k), + req.model.as_deref(), + ).await; + + match result { + Ok(predictions) => { + tracing::info!( + "[CLIP_API] Classified {} -> top: {} ({:.3})", + req.image_path, + predictions.first().map(|p| p.label.as_str()).unwrap_or("none"), + predictions.first().map(|p| p.confidence).unwrap_or(0.0) + ); + Json(ClassifyResponse { + success: true, + predictions, + }).into_response() + } + Err(e) => { + tracing::error!("[CLIP_API] Classification failed: {}", e); + Json(ErrorResponse { + success: false, + error: e.to_string(), + }).into_response() + } + } +} + +async fn detect_objects_endpoint( + State(_state): State, + Json(req): Json, +) -> Response { + let objects: Vec<&str> = req.objects.split(',').map(|s| s.trim()).collect(); + + let result = detect_objects( + &req.image_path, + &objects, + Some(req.threshold), + req.model.as_deref(), + ).await; + + match result { + Ok(detected) => { + if !detected.is_empty() { + tracing::info!( + "[CLIP_API] Detected {} objects in {}: {}", + detected.len(), + req.image_path, + detected.iter().map(|p| p.label.as_str()).collect::>().join(", ") + ); + } else { + tracing::info!("[CLIP_API] No objects detected in {} (threshold: {:.2})", req.image_path, req.threshold); + } + Json(DetectResponse { + success: true, + detected, + }).into_response() + } + Err(e) => { + tracing::error!("[CLIP_API] Detection failed: {}", e); + Json(ErrorResponse { + success: false, + error: e.to_string(), + }).into_response() + } + } +} + +async fn batch_classify_endpoint( + State(_state): State, + Json(req): Json, +) -> Response { + let image_paths: Vec<&str> = req.image_paths.split(',').map(|s| s.trim()).collect(); + let labels: Vec<&str> = req.labels.split(',').map(|s| s.trim()).collect(); + + let result = classify_images( + &image_paths, + &labels, + Some(req.top_k), + req.model.as_deref(), + ).await; + + match result { + Ok(results_vec) => { + let results: HashMap> = results_vec + .into_iter() + .map(|r| (r.image_path, r.predictions)) + .collect(); + + tracing::info!("[CLIP_API] Batch classified {} images", results.len()); + Json(BatchClassifyResponse { + success: true, + results, + }).into_response() + } + Err(e) => { + tracing::error!("[CLIP_API] Batch classification failed: {}", e); + Json(ErrorResponse { + success: false, + error: e.to_string(), + }).into_response() + } + } +} \ No newline at end of file diff --git a/src/api/five_w1h_agent_api.rs b/src/api/five_w1h_agent_api.rs index 705b717..23d3aa7 100644 --- a/src/api/five_w1h_agent_api.rs +++ b/src/api/five_w1h_agent_api.rs @@ -5,8 +5,9 @@ use axum::{ routing::{get, post}, Router, }; -use reqwest::Client; use serde::{Deserialize, Serialize}; + +use crate::core::llm::function_calling::LLM_CLIENT; use sqlx::Row; use crate::api::types::AppState; @@ -381,8 +382,7 @@ Rules: "stream": false }); - let client = Client::new(); - let resp = client + let resp = LLM_CLIENT .post(llm_base_url()) .json(&body) .timeout(std::time::Duration::from_secs(180)) diff --git a/src/api/identity_agent_api.rs b/src/api/identity_agent_api.rs index 89787ca..30495f1 100644 --- a/src/api/identity_agent_api.rs +++ b/src/api/identity_agent_api.rs @@ -1002,15 +1002,17 @@ pub async fn bind_speakers(pool: &sqlx::PgPool, file_uuid: &str) -> anyhow::Resu // Also update speaker_detections with the identity_id let sd_table = schema::table_name("speaker_detections"); - let _ = sqlx::query( - &format!("UPDATE {} SET identity_id = $1, confidence = $2 \ - WHERE file_uuid = $3 AND speaker_id = $4 AND identity_id IS NULL", sd_table) - ) + let _ = sqlx::query(&format!( + "UPDATE {} SET identity_id = $1, confidence = $2 \ + WHERE file_uuid = $3 AND speaker_id = $4 AND identity_id IS NULL", + sd_table + )) .bind(identity_id) .bind(overlap_ratio) .bind(file_uuid) .bind(&best_speaker) - .execute(pool).await; + .execute(pool) + .await; bindings += 1; } diff --git a/src/api/identity_api.rs b/src/api/identity_api.rs index a768137..bb41fd9 100644 --- a/src/api/identity_api.rs +++ b/src/api/identity_api.rs @@ -1510,7 +1510,8 @@ async fn search_identities_by_text( let chunk_table = schema::table_name("chunk"); let like_q = format!("%{}%", params.q.replace('%', "%%")); let page = params.page.unwrap_or(1).max(1); - let page_size = params.page_size + let page_size = params + .page_size .or(params.limit) .unwrap_or(20) .min(100) diff --git a/src/api/identity_binding.rs b/src/api/identity_binding.rs index 09e7878..ed64be9 100644 --- a/src/api/identity_binding.rs +++ b/src/api/identity_binding.rs @@ -734,6 +734,8 @@ pub async fn bind_identity_trace( Json(req): Json, ) -> Result>, (StatusCode, Json)> { let fd_table = crate::core::db::schema::table_name("face_detections"); + let video_table = crate::core::db::schema::table_name("videos"); + let video_table = crate::core::db::schema::table_name("videos"); let id_table = crate::core::db::schema::table_name("identities"); let history_table = crate::core::db::schema::table_name("identity_history"); @@ -854,6 +856,7 @@ pub async fn get_identity_traces( ) -> Result, (StatusCode, String)> { let id_table = crate::core::db::schema::table_name("identities"); let fd_table = crate::core::db::schema::table_name("face_detections"); + let video_table = crate::core::db::schema::table_name("videos"); let page = params.page.unwrap_or(1); let page_size = params.page_size.unwrap_or(20); @@ -879,12 +882,13 @@ pub async fn get_identity_traces( COUNT(*)::bigint AS frame_count, MIN(fd.frame_number)::int AS first_frame, MAX(fd.frame_number)::int AS last_frame, - ROUND(MIN(fd.frame_number)::numeric / 25.0, 1)::float8 AS first_sec, - ROUND(MAX(fd.frame_number)::numeric / 25.0, 1)::float8 AS last_sec, + ROUND(MIN(fd.frame_number)::numeric / NULLIF(v.fps, 0)::numeric, 1)::float8 AS first_sec, + ROUND(MAX(fd.frame_number)::numeric / NULLIF(v.fps, 0)::numeric, 1)::float8 AS last_sec, ROUND(AVG(fd.confidence)::numeric, 4)::float8 AS avg_confidence FROM {} fd + LEFT JOIN dev.videos v ON fd.file_uuid = v.file_uuid WHERE fd.identity_id = $1 - GROUP BY fd.file_uuid, fd.trace_id + GROUP BY fd.file_uuid, fd.trace_id, v.fps ORDER BY fd.file_uuid, fd.trace_id LIMIT $2 OFFSET $3"#, fd_table diff --git a/src/api/llm_search.rs b/src/api/llm_search.rs index f65190f..ecc4ecc 100644 --- a/src/api/llm_search.rs +++ b/src/api/llm_search.rs @@ -1,10 +1,4 @@ -use axum::{ - extract::State, - http::StatusCode, - response::Json, - routing::post, - Router, -}; +use axum::{extract::State, http::StatusCode, response::Json, routing::post, Router}; use serde::Deserialize; use tracing::warn; diff --git a/src/api/media_api.rs b/src/api/media_api.rs index 407b4fa..9033662 100644 --- a/src/api/media_api.rs +++ b/src/api/media_api.rs @@ -63,6 +63,7 @@ pub fn bbox_routes() -> Router { ) .route("/api/v1/file/:file_uuid/video", get(stream_video)) .route("/api/v1/file/:file_uuid/thumbnail", get(face_thumbnail)) + .route("/api/v1/file/:file_uuid/chunk/:chunk_id/thumbnail", get(chunk_thumbnail)) .route("/api/v1/file/:file_uuid/clip", get(video_clip)) } @@ -745,13 +746,14 @@ async fn face_thumbnail( .join(format!("{}.jpg", frame)); if cached_path.exists() { - tracing::debug!("[thumbnail] Using cached face crop: {}", cached_path.display()); - let bytes = tokio::fs::read(&cached_path) - .await - .map_err(|e| { - tracing::warn!("[thumbnail] Failed to read cached file: {}", e); - StatusCode::INTERNAL_SERVER_ERROR - })?; + tracing::debug!( + "[thumbnail] Using cached face crop: {}", + cached_path.display() + ); + let bytes = tokio::fs::read(&cached_path).await.map_err(|e| { + tracing::warn!("[thumbnail] Failed to read cached file: {}", e); + StatusCode::INTERNAL_SERVER_ERROR + })?; // Validate cached JPEG crate::core::thumbnail::validator::validate_jpeg(&bytes).map_err(|e| { @@ -766,7 +768,7 @@ async fn face_thumbnail( .body(Body::from(bytes)) .unwrap()); } - + // Cached file not found, fallback to ffmpeg tracing::debug!("[thumbnail] Cached file not found, falling back to ffmpeg"); } @@ -841,6 +843,99 @@ async fn face_thumbnail( .unwrap()) } +async fn chunk_thumbnail( + State(state): State, + Path((file_uuid, chunk_id)): Path<(String, String)>, +) -> Result { + let videos_table = schema::table_name("videos"); + let chunk_table = schema::table_name("chunk"); + + let output_dir = crate::core::config::OUTPUT_DIR.as_str(); + let cached_path = std::path::PathBuf::from(output_dir) + .join(".chunk_thumbs") + .join(&file_uuid) + .join(format!("{}.jpg", chunk_id)); + + if cached_path.exists() { + let bytes = tokio::fs::read(&cached_path).await.map_err(|e| { + tracing::warn!("[chunk_thumbnail] Failed to read cache: {}", e); + StatusCode::INTERNAL_SERVER_ERROR + })?; + return Ok(Response::builder() + .status(StatusCode::OK) + .header(header::CONTENT_TYPE, "image/jpeg") + .header(header::CACHE_CONTROL, "public, max-age=86400") + .body(Body::from(bytes)) + .unwrap()); + } + + let row: (f64, f64, f64) = sqlx::query_as(&format!( + "SELECT start_time, end_time, fps FROM {} WHERE file_uuid = $1 AND chunk_id = $2 LIMIT 1", + chunk_table + )) + .bind(&file_uuid) + .bind(&chunk_id) + .fetch_optional(state.db.pool()) + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? + .ok_or(StatusCode::NOT_FOUND)?; + + let (start_time, end_time, fps) = row; + + let start_frame = (start_time * fps).round() as i64; + let end_frame = (end_time * fps).round() as i64; + let mid_frame = (start_frame + end_frame) / 2; + + let video: Option<(String, Option)> = sqlx::query_as(&format!( + "SELECT file_path, total_frames FROM {} WHERE file_uuid = $1", + videos_table + )) + .bind(&file_uuid) + .fetch_optional(state.db.pool()) + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + let (file_path, total_frames) = video.ok_or(StatusCode::NOT_FOUND)?; + + let frame = match total_frames { + Some(t) if t > 0 => mid_frame.min(t - 1).max(0), + _ => mid_frame.max(0), + }; + + let select = format!("select=eq(n\\,{})", frame); + let output = ffmpeg_cmd() + .args([ + "-i", &file_path, + "-vf", &select, + "-frames:v", "1", + "-f", "image2pipe", + "-vcodec", "mjpeg", + "-", + ]) + .output() + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + + if !output.status.success() { + return Err(StatusCode::INTERNAL_SERVER_ERROR); + } + + crate::core::thumbnail::validator::validate_jpeg(&output.stdout).map_err(|e| { + tracing::warn!("[chunk_thumbnail] JPEG validation failed: {}", e); + StatusCode::INTERNAL_SERVER_ERROR + })?; + + if let Some(parent) = cached_path.parent() { + let _ = tokio::fs::create_dir_all(parent).await; + } + let _ = tokio::fs::write(&cached_path, &output.stdout).await; + + Ok(Response::builder() + .status(StatusCode::OK) + .header(header::CONTENT_TYPE, "image/jpeg") + .header(header::CACHE_CONTROL, "public, max-age=86400") + .body(Body::from(output.stdout)) + .unwrap()) +} + #[derive(Debug, serde::Deserialize)] struct ClipQuery { start_frame: Option, @@ -945,13 +1040,17 @@ async fn stranger_video_inner( use axum::http::header; use uuid::Uuid; - tracing::info!("[stranger_video] Starting for file={}, stranger={}", file_uuid, stranger_id); + tracing::info!( + "[stranger_video] Starting for file={}, stranger={}", + file_uuid, + stranger_id + ); let (mode, audio) = parse_video_params(¶ms); let videos_table = schema::table_name("videos"); tracing::debug!("[stranger_video] videos_table: {}", videos_table); - + let row: Option<(String, f64, i32, i32)> = sqlx::query_as(&format!( "SELECT file_path, COALESCE(fps, 24.0), COALESCE(width, 0), COALESCE(height, 0) FROM {} WHERE file_uuid = $1", videos_table @@ -963,18 +1062,22 @@ async fn stranger_video_inner( tracing::error!("[stranger_video] Video query error: {}", e); StatusCode::INTERNAL_SERVER_ERROR })?; - + let (video_path, fps, _width, _height) = row.ok_or_else(|| { tracing::error!("[stranger_video] Video not found for uuid={}", file_uuid); StatusCode::NOT_FOUND })?; - - tracing::info!("[stranger_video] Found video: path={}, fps={}", video_path, fps); + + tracing::info!( + "[stranger_video] Found video: path={}, fps={}", + video_path, + fps + ); // Query face detections by stranger_id directly let face_table = schema::table_name("face_detections"); tracing::debug!("[stranger_video] face_table: {}", face_table); - + // frame_number is BIGINT (i64) in database let rows: Vec<(i64, i32, i32, i32, i32)> = sqlx::query_as(&format!( "SELECT frame_number, x, y, width, height FROM {} WHERE file_uuid = $1 AND stranger_id = $2 ORDER BY frame_number", @@ -982,15 +1085,18 @@ async fn stranger_video_inner( )) .bind(&file_uuid).bind(stranger_id) .fetch_all(state.db.pool()).await - .unwrap_or_else(|e| { - tracing::error!("[stranger_video] Face query error: {}", e); - vec![] + .unwrap_or_else(|e| { + tracing::error!("[stranger_video] Face query error: {}", e); + vec![] }); tracing::info!("[stranger_video] Found {} faces", rows.len()); if rows.is_empty() { - tracing::error!("[stranger_video] No faces found for stranger_id={}", stranger_id); + tracing::error!( + "[stranger_video] No faces found for stranger_id={}", + stranger_id + ); return Err(StatusCode::NOT_FOUND); } @@ -1004,8 +1110,13 @@ async fn stranger_video_inner( let duration = (last_frame - first_frame) as f64 / fps + padding * 2.0; let seek = (start_sec - padding).max(0.0); - tracing::info!("[stranger_video] Frame range: {} - {}, time: {:.2}s - {:.2}s", - first_frame, last_frame, seek, seek + duration); + tracing::info!( + "[stranger_video] Frame range: {} - {}, time: {:.2}s - {:.2}s", + first_frame, + last_frame, + seek, + seek + duration + ); // Only support normal mode for stranger video let tmp = std::env::temp_dir().join(format!("stranger_{}.mp4", Uuid::new_v4())); @@ -1017,37 +1128,98 @@ async fn stranger_video_inner( cmd_args.push("-an"); } cmd_args.extend_from_slice(&["-y", &tmp_str]); - + tracing::debug!("[stranger_video] ffmpeg args: {:?}", cmd_args); - - let result = ffmpeg_cmd() - .args(&cmd_args) - .output() - .map_err(|e| { - tracing::error!("[stranger_video] ffmpeg spawn error: {}", e); - StatusCode::INTERNAL_SERVER_ERROR - })?; - + + let result = ffmpeg_cmd().args(&cmd_args).output().map_err(|e| { + tracing::error!("[stranger_video] ffmpeg spawn error: {}", e); + StatusCode::INTERNAL_SERVER_ERROR + })?; + if !result.status.success() { - tracing::error!("[stranger_video] ffmpeg failed: {}", String::from_utf8_lossy(&result.stderr)); + tracing::error!( + "[stranger_video] ffmpeg failed: {}", + String::from_utf8_lossy(&result.stderr) + ); return Err(StatusCode::INTERNAL_SERVER_ERROR); } - - tracing::info!("[stranger_video] ffmpeg success, output size: {} bytes", result.stdout.len()); - - let data = tokio::fs::read(&tmp) - .await - .map_err(|e| { - tracing::error!("[stranger_video] Read output error: {}", e); - StatusCode::INTERNAL_SERVER_ERROR - })?; + + tracing::info!( + "[stranger_video] ffmpeg success, output size: {} bytes", + result.stdout.len() + ); + + let data = tokio::fs::read(&tmp).await.map_err(|e| { + tracing::error!("[stranger_video] Read output error: {}", e); + StatusCode::INTERNAL_SERVER_ERROR + })?; let _ = std::fs::remove_file(&tmp); - - tracing::info!("[stranger_video] Returning video, size: {} bytes", data.len()); - + + tracing::info!( + "[stranger_video] Returning video, size: {} bytes", + data.len() + ); + Ok(Response::builder() .header(header::CONTENT_TYPE, "video/mp4") .header(header::CONTENT_LENGTH, data.len()) .body(Body::from(data)) .unwrap()) } + +// ── Media Proxy: Unified endpoint for WordPress frontend ── +// Accepts the same query param format as the (inactive) WordPress snippet 61. +// Dispatches to the appropriate existing handler based on `type`. +// Caddy rewrites /wp-json/momentry/v1/media → /api/v1/media-proxy{?} + +/// Dispatch query params to the appropriate handler +async fn media_proxy_handler( + State(state): State, + Query(params): Query>, + request: axum::http::Request, +) -> Result { + let uuid = params + .get("uuid") + .or_else(|| params.get("file_uuid")) + .ok_or(StatusCode::BAD_REQUEST)?; + + let type_ = params + .get("type") + .map(String::as_str) + .ok_or(StatusCode::BAD_REQUEST)?; + + match type_ { + "thumbnail" => { + let thumb_query = ThumbQuery { + frame: params.get("frame").and_then(|v| v.parse().ok()), + x: params.get("x").and_then(|v| v.parse().ok()), + y: params.get("y").and_then(|v| v.parse().ok()), + w: params.get("w").and_then(|v| v.parse().ok()), + h: params.get("h").and_then(|v| v.parse().ok()), + trace_id: params.get("trace_id").and_then(|v| v.parse().ok()), + }; + face_thumbnail(State(state), Path(uuid.clone()), Query(thumb_query)) + .await + .map(IntoResponse::into_response) + } + "video" => stream_video(State(state), Path(uuid.clone()), Query(params), request) + .await + .map(IntoResponse::into_response), + "chunk_thumbnail" => { + let chunk_id = params + .get("chunk_id") + .ok_or(StatusCode::BAD_REQUEST)?; + chunk_thumbnail( + State(state), + Path((uuid.clone(), chunk_id.clone())), + ) + .await + .map(IntoResponse::into_response) + } + _ => Err(StatusCode::BAD_REQUEST), + } +} + +pub fn media_proxy_routes() -> Router { + Router::new().route("/api/v1/media-proxy", get(media_proxy_handler)) +} diff --git a/src/api/pipeline.rs b/src/api/pipeline.rs index c61e099..3932f24 100644 --- a/src/api/pipeline.rs +++ b/src/api/pipeline.rs @@ -3,81 +3,126 @@ use axum::routing::post; use axum::{Json, Router}; use serde_json::{json, Value}; +use crate::core::config; use crate::core::db::postgres_db::PostgresDb; use crate::core::pipeline as pipeline_core; -use crate::core::config; -async fn handle_store_asrx(Path(uuid): Path) -> Result, (axum::http::StatusCode, Json)> { - let db = PostgresDb::new(&config::DATABASE_URL).await - .map_err(|e| { - tracing::error!("DB error: {}", e); - (axum::http::StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "DB connection failed"}))) - })?; +async fn handle_store_asrx( + Path(uuid): Path, +) -> Result, (axum::http::StatusCode, Json)> { + let db = PostgresDb::new(&config::DATABASE_URL).await.map_err(|e| { + tracing::error!("DB error: {}", e); + ( + axum::http::StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "DB connection failed"})), + ) + })?; - pipeline_core::store_asrx_chunks(&db, &uuid).await + pipeline_core::store_asrx_chunks(&db, &uuid) + .await .map_err(|e| { tracing::error!("store_asrx error: {}", e); - (axum::http::StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": e.to_string()}))) + ( + axum::http::StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": e.to_string()})), + ) })?; - Ok(Json(json!({"success": true, "message": "ASRX chunks stored", "file_uuid": uuid}))) + Ok(Json( + json!({"success": true, "message": "ASRX chunks stored", "file_uuid": uuid}), + )) } -async fn handle_rule1(Path(uuid): Path) -> Result, (axum::http::StatusCode, Json)> { - let db = PostgresDb::new(&config::DATABASE_URL).await - .map_err(|e| { - tracing::error!("DB error: {}", e); - (axum::http::StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "DB connection failed"}))) - })?; +async fn handle_rule1( + Path(uuid): Path, +) -> Result, (axum::http::StatusCode, Json)> { + let db = PostgresDb::new(&config::DATABASE_URL).await.map_err(|e| { + tracing::error!("DB error: {}", e); + ( + axum::http::StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "DB connection failed"})), + ) + })?; - let count = pipeline_core::execute_rule1(&db, &uuid).await + let count = pipeline_core::execute_rule1(&db, &uuid) + .await .map_err(|e| { tracing::error!("rule1 error: {}", e); - (axum::http::StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": e.to_string()}))) + ( + axum::http::StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": e.to_string()})), + ) })?; - Ok(Json(json!({"success": true, "message": format!("Rule 1 complete: {} chunks", count), "file_uuid": uuid, "chunks": count}))) + Ok(Json( + json!({"success": true, "message": format!("Rule 1 complete: {} chunks", count), "file_uuid": uuid, "chunks": count}), + )) } -async fn handle_vectorize(Path(uuid): Path) -> Result, (axum::http::StatusCode, Json)> { - pipeline_core::vectorize_chunks(&uuid).await - .map_err(|e| { - tracing::error!("vectorize error: {}", e); - (axum::http::StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": e.to_string()}))) - })?; +async fn handle_vectorize( + Path(uuid): Path, +) -> Result, (axum::http::StatusCode, Json)> { + pipeline_core::vectorize_chunks(&uuid).await.map_err(|e| { + tracing::error!("vectorize error: {}", e); + ( + axum::http::StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": e.to_string()})), + ) + })?; - Ok(Json(json!({"success": true, "message": "Vectorization complete", "file_uuid": uuid}))) + Ok(Json( + json!({"success": true, "message": "Vectorization complete", "file_uuid": uuid}), + )) } -async fn handle_phase1(Path(uuid): Path) -> Result, (axum::http::StatusCode, Json)> { - pipeline_core::run_phase1(&uuid).await - .map_err(|e| { - tracing::error!("phase1 error: {}", e); - (axum::http::StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": e.to_string()}))) - })?; +async fn handle_phase1( + Path(uuid): Path, +) -> Result, (axum::http::StatusCode, Json)> { + pipeline_core::run_phase1(&uuid).await.map_err(|e| { + tracing::error!("phase1 error: {}", e); + ( + axum::http::StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": e.to_string()})), + ) + })?; - Ok(Json(json!({"success": true, "message": "Phase 1 complete", "file_uuid": uuid}))) + Ok(Json( + json!({"success": true, "message": "Phase 1 complete", "file_uuid": uuid}), + )) } -async fn handle_complete(Path(uuid): Path) -> Result, (axum::http::StatusCode, Json)> { - let db = PostgresDb::new(&config::DATABASE_URL).await - .map_err(|e| { - tracing::error!("DB error: {}", e); - (axum::http::StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": "DB connection failed"}))) - })?; +async fn handle_complete( + Path(uuid): Path, +) -> Result, (axum::http::StatusCode, Json)> { + let db = PostgresDb::new(&config::DATABASE_URL).await.map_err(|e| { + tracing::error!("DB error: {}", e); + ( + axum::http::StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": "DB connection failed"})), + ) + })?; - pipeline_core::mark_complete(&db, &uuid).await + pipeline_core::mark_complete(&db, &uuid) + .await .map_err(|e| { tracing::error!("complete error: {}", e); - (axum::http::StatusCode::INTERNAL_SERVER_ERROR, Json(json!({"error": e.to_string()}))) + ( + axum::http::StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({"error": e.to_string()})), + ) })?; - Ok(Json(json!({"success": true, "message": "Video marked as completed", "file_uuid": uuid}))) + Ok(Json( + json!({"success": true, "message": "Video marked as completed", "file_uuid": uuid}), + )) } pub fn pipeline_routes() -> Router { Router::new() - .route("/api/v1/file/:file_uuid/store-asrx", post(handle_store_asrx)) + .route( + "/api/v1/file/:file_uuid/store-asrx", + post(handle_store_asrx), + ) .route("/api/v1/file/:file_uuid/rule1", post(handle_rule1)) .route("/api/v1/file/:file_uuid/vectorize", post(handle_vectorize)) .route("/api/v1/file/:file_uuid/phase1", post(handle_phase1)) diff --git a/src/api/search.rs b/src/api/search.rs index 89fbe5f..348fefa 100644 --- a/src/api/search.rs +++ b/src/api/search.rs @@ -36,6 +36,9 @@ pub struct SearchResult { pub summary: Option, pub metadata: Option, pub similarity: Option, + pub file_name: Option, + pub serve_url: Option, + pub thumbnail_url: Option, } #[derive(Debug, Serialize)] @@ -81,6 +84,9 @@ async fn enrich_from_pg( summary: Some(p.summary), metadata: p.metadata.clone(), similarity: Some(qdrant_score as f64), + file_name: None, + serve_url: None, + thumbnail_url: None, }), Ok(None) => None, Err(e) => { @@ -105,6 +111,9 @@ fn pg_result_to_search(p: &SemanticSearchResult) -> SearchResult { summary: Some(p.summary.clone()), metadata: p.metadata.clone(), similarity: p.similarity, + file_name: None, + serve_url: None, + thumbnail_url: None, } } @@ -156,7 +165,10 @@ pub async fn smart_search( .map(|h| (h.uuid, h.chunk_id, h.score as f64)) .collect() } else { - let qdrant_hits = qdrant.search(&embedding, fetch_limit).await.unwrap_or_default(); + let qdrant_hits = qdrant + .search(&embedding, fetch_limit) + .await + .unwrap_or_default(); qdrant_hits .into_iter() .map(|h| (h.uuid, h.chunk_id, h.score as f64)) @@ -264,7 +276,11 @@ pub async fn smart_search( .and_modify(|e| { e.score = e.score.max(*score); e.semantic_score = Some(*score); - e.source = format!("{}_{}", e.source.strip_prefix("semantic+").unwrap_or(&e.source), "semantic"); + e.source = format!( + "{}_{}", + e.source.strip_prefix("semantic+").unwrap_or(&e.source), + "semantic" + ); }) .or_insert(MergedResult { file_uuid: file_uuid.clone(), @@ -346,17 +362,36 @@ pub async fn smart_search( // Sort by score descending (score-based merge) let mut ranked: Vec<&MergedResult> = merged.values().collect(); - ranked.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal)); + ranked.sort_by(|a, b| { + b.score + .partial_cmp(&a.score) + .unwrap_or(std::cmp::Ordering::Equal) + }); // 6. Enrich top results from PG and build final response + let query_lower = req.query.to_lowercase(); let mut final_results = Vec::new(); - for mr in ranked.iter().take(limit) { + for mr in ranked.iter().take(limit * 3) { // 取更多結果以便過濾 if let Some(pg) = db .get_chunk_by_file_and_chunk_id(&mr.file_uuid, &mr.chunk_id) .await .ok() .flatten() { + // 關鍵字過濾 + let summary_lower = pg.summary.to_lowercase(); + let query_words: Vec = query_lower.split_whitespace().map(|s| s.to_string()).collect(); + + // 檢查是否包含所有查詢詞(完整單詞) + let text_match = !pg.summary.is_empty() && { + let bordered = format!(" {} ", summary_lower); + query_words.iter().all(|w| bordered.contains(&format!(" {} ", w))) + }; + + if !text_match { + continue; + } + final_results.push(SearchResult { id: 0, file_uuid: pg.file_uuid.clone(), @@ -371,10 +406,52 @@ pub async fn smart_search( summary: Some(pg.summary), metadata: pg.metadata.clone(), similarity: Some(mr.score), + file_name: None, + serve_url: None, + thumbnail_url: pg.file_uuid.as_ref().map(|fu| format!( + "/wp-json/momentry/v1/media?type=chunk_thumbnail&file_uuid={}&chunk_id={}", + fu, mr.chunk_id + )), }); } } + // Trim to requested limit + final_results.truncate(limit); + + // 7. Enrich results with file_name and serve_url from videos table + if !final_results.is_empty() { + let v_table = crate::core::db::schema::table_name("videos"); + let file_uuids: Vec = final_results + .iter() + .filter_map(|r| r.file_uuid.clone()) + .collect(); + let file_rows: Vec<(String, String, String)> = sqlx::query_as(&format!( + "SELECT file_uuid::text, file_name, file_path FROM {} WHERE file_uuid = ANY($1)", + v_table + )) + .bind(&file_uuids) + .fetch_all(db.pool()) + .await + .unwrap_or_default(); + let file_map: std::collections::HashMap = file_rows + .into_iter() + .map(|(uuid, name, path)| (uuid, (name, path))) + .collect(); + let storage_root = crate::core::config::STORAGE_ROOT.as_str(); + let serve_base = crate::core::config::SERVE_BASE_URL.as_str(); + for r in &mut final_results { + if let Some(ref uuid) = r.file_uuid { + if let Some((name, path)) = file_map.get(uuid) { + r.file_name = Some(name.clone()); + if let Some(relative) = path.strip_prefix(storage_root) { + r.serve_url = Some(format!("{}{}", serve_base, relative)); + } + } + } + } + } + // Determine strategy string let mut strategies = vec!["semantic"]; if !keyword_results.is_empty() { @@ -400,4 +477,4 @@ pub async fn smart_search( pub fn search_routes() -> Router { Router::new().route("/api/v1/search/smart", post(smart_search)) -} \ No newline at end of file +} diff --git a/src/api/server.rs b/src/api/server.rs index f3b0431..a71266b 100644 --- a/src/api/server.rs +++ b/src/api/server.rs @@ -20,9 +20,9 @@ use super::identity_agent_api; use super::identity_api; use super::identity_binding; use super::llm_search; -use super::pipeline; use super::media_api; use super::middleware::unified_auth; +use super::pipeline; use super::processing; use super::scan; use super::search::search_routes; @@ -117,6 +117,7 @@ pub async fn start_server(host: &str, port: u16) -> anyhow::Result<()> { .merge(identity_agent_api::identity_agent_routes()) .merge(five_w1h_agent_api::five_w1h_agent_routes()) .merge(media_api::bbox_routes()) + .merge(media_api::media_proxy_routes()) .merge(trace_agent_api::trace_agent_routes()) .merge(search_routes()) .merge(llm_search::llm_smart_routes()) diff --git a/src/api/trace_agent_api.rs b/src/api/trace_agent_api.rs index ffff929..51ff3ff 100644 --- a/src/api/trace_agent_api.rs +++ b/src/api/trace_agent_api.rs @@ -593,7 +593,11 @@ async fn get_trace_thumbnail_inner( // For trace_id=0 (untracked/stranger), check unbound directory instead let output_dir = crate::core::config::OUTPUT_DIR.as_str(); let trace_id_str = trace_id.to_string(); - let trace_dir_name = if trace_id == 0 { "unbound" } else { &trace_id_str }; + let trace_dir_name = if trace_id == 0 { + "unbound" + } else { + &trace_id_str + }; let trace_dir = std::path::PathBuf::from(output_dir) .join(".faces") .join(&file_uuid) @@ -605,15 +609,16 @@ async fn get_trace_thumbnail_inner( while let Some(Ok(entry)) = entries.next() { let path = entry.path(); if path.extension().map_or(false, |e| e == "jpg") { - tracing::info!("[trace_thumbnail] Using cached face crop: {}", path.display()); - let bytes = tokio::fs::read(&path) - .await - .map_err(|e| { - ( - StatusCode::INTERNAL_SERVER_ERROR, - Json(serde_json::json!({"error": e.to_string()})), - ) - })?; + tracing::info!( + "[trace_thumbnail] Using cached face crop: {}", + path.display() + ); + let bytes = tokio::fs::read(&path).await.map_err(|e| { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(serde_json::json!({"error": e.to_string()})), + ) + })?; // Validate cached JPEG crate::core::thumbnail::validator::validate_jpeg(&bytes).map_err(|e| { @@ -647,7 +652,11 @@ async fn get_trace_thumbnail_inner( let seek = sel.frame as f64 / sel.fps; let tmp = std::env::temp_dir().join(format!("trace_{}_{}.jpg", file_uuid, trace_id)); - tracing::debug!("[trace_thumbnail] Fallback to ffmpeg for trace {} frame {}", trace_id, sel.frame); + tracing::debug!( + "[trace_thumbnail] Fallback to ffmpeg for trace {} frame {}", + trace_id, + sel.frame + ); let status = tokio::process::Command::new("ffmpeg") .args([ diff --git a/src/bin/sync_qdrant_from_pg.rs b/src/bin/sync_qdrant_from_pg.rs index 7790d7a..592bdcf 100644 --- a/src/bin/sync_qdrant_from_pg.rs +++ b/src/bin/sync_qdrant_from_pg.rs @@ -6,7 +6,9 @@ async fn main() -> Result<()> { dotenv::from_filename("/Users/accusys/momentry_core_0.1/.env.development").ok(); tracing_subscriber::fmt::init(); - let pg = PostgresDb::init().await.context("Failed to init PostgreSQL")?; + let pg = PostgresDb::init() + .await + .context("Failed to init PostgreSQL")?; let qdrant = QdrantDb::new(); let chunk_table = momentry_core::core::db::schema::table_name("chunk"); @@ -17,8 +19,8 @@ async fn main() -> Result<()> { ]; for uuid in &uuids { - let rows = sqlx::query_as::<_, (String, String, i64, i64, f64, f64, String, String)>( - &format!( + let rows = + sqlx::query_as::<_, (String, String, i64, i64, f64, f64, String, String)>(&format!( "SELECT chunk_id, text_content, start_frame, end_frame, \ start_time, end_time, embedding::text, content::text \ FROM {} \ @@ -28,14 +30,16 @@ async fn main() -> Result<()> { AND (text_content IS NOT NULL AND text_content != '') \ ORDER BY id", chunk_table - ), - ) - .bind(uuid) - .fetch_all(pg.pool()) - .await?; + )) + .bind(uuid) + .fetch_all(pg.pool()) + .await?; let total = rows.len(); - println!("[{}] Found {} sentence chunks with embeddings to sync to Qdrant", uuid, total); + println!( + "[{}] Found {} sentence chunks with embeddings to sync to Qdrant", + uuid, total + ); if total == 0 { continue; @@ -45,7 +49,17 @@ async fn main() -> Result<()> { let mut stored = 0usize; let mut errors = 0usize; - for (chunk_id, text, start_frame, end_frame, start_time, end_time, vector_text, _content_str) in &rows { + for ( + chunk_id, + text, + start_frame, + end_frame, + start_time, + end_time, + vector_text, + _content_str, + ) in &rows + { let vector: Vec = serde_json::from_str(vector_text) .map_err(|e| anyhow::anyhow!("Failed to parse vector for {}: {}", chunk_id, e))?; @@ -73,9 +87,11 @@ async fn main() -> Result<()> { println!( " [{}] {}/{} ({:.1}%) | {:.0} vec/s | {} errors", uuid.get(..8).unwrap_or(uuid), - stored, total, + stored, + total, 100.0 * stored as f64 / total as f64, - rate, errors, + rate, + errors, ); } } @@ -84,7 +100,9 @@ async fn main() -> Result<()> { println!( "[{}] Done! {}/{} vectors synced ({} errors) in {:.1}s ({:.0} vec/s avg)", uuid.get(..8).unwrap_or(uuid), - stored, total, errors, + stored, + total, + errors, elapsed.as_secs_f64(), stored as f64 / elapsed.as_secs_f64(), ); diff --git a/src/bin/vectorize_missing.rs b/src/bin/vectorize_missing.rs index 785265f..cdd1a37 100644 --- a/src/bin/vectorize_missing.rs +++ b/src/bin/vectorize_missing.rs @@ -1,7 +1,5 @@ use anyhow::{Context, Result}; -use momentry_core::{ - Database, Embedder, PostgresDb, QdrantDb, VectorPayload, -}; +use momentry_core::{Database, Embedder, PostgresDb, QdrantDb, VectorPayload}; use std::time::{Duration, Instant}; #[tokio::main] @@ -9,16 +7,17 @@ async fn main() -> Result<()> { dotenv::from_filename("/Users/accusys/momentry_core_0.1/.env.development").ok(); tracing_subscriber::fmt::init(); - let pg = PostgresDb::init().await.context("Failed to init PostgreSQL")?; + let pg = PostgresDb::init() + .await + .context("Failed to init PostgreSQL")?; let qdrant = QdrantDb::new(); let embedder = Embedder::new("embeddinggemma-300m".to_string()); let uuid = "63acd3bb02b5b9dfbb9d6db499fcc864"; let chunk_table = momentry_core::core::db::schema::table_name("chunk"); - let rows = sqlx::query_as::<_, (String, String, i64, i64, f64, f64, String)>( - &format!( - "SELECT chunk_id, text_content, start_frame, end_frame, \ + let rows = sqlx::query_as::<_, (String, String, i64, i64, f64, f64, String)>(&format!( + "SELECT chunk_id, text_content, start_frame, end_frame, \ start_time, end_time, content::text \ FROM {} \ WHERE file_uuid = $1 \ @@ -26,15 +25,17 @@ async fn main() -> Result<()> { AND embedding IS NULL \ AND (text_content IS NOT NULL AND text_content != '') \ ORDER BY id", - chunk_table - ), - ) + chunk_table + )) .bind(uuid) .fetch_all(pg.pool()) .await?; let total = rows.len(); - println!("Found {} sentence chunks without embedding for {}", total, uuid); + println!( + "Found {} sentence chunks without embedding for {}", + total, uuid + ); if total == 0 { println!("Nothing to vectorize. Exiting."); diff --git a/src/cli/args.rs b/src/cli/args.rs index 47f7331..f68d09b 100644 --- a/src/cli/args.rs +++ b/src/cli/args.rs @@ -50,6 +50,24 @@ pub enum Commands { /// UUID uuid: String, }, + /// Detect objects in an image using CLIP or Qwen3-VL + Detect { + /// Image path + #[arg(short, long)] + image: String, + /// Objects to detect (comma separated) + #[arg(short, long, value_delimiter = ',')] + objects: Vec, + /// Use cascade mode (CLIP first, then Qwen3-VL for high confidence) + #[arg(long, default_value = "false")] + cascade: bool, + /// CLIP confidence threshold for cascade (default: 0.7) + #[arg(long, default_value = "0.7")] + threshold: f32, + }, + /// Vision LLM management + #[command(subcommand)] + Vision(VisionCommands), /// Vectorize chunks Vectorize { /// UUID (or 'all' for all) @@ -215,6 +233,16 @@ pub enum N8nAction { Verify, } +#[derive(Subcommand)] +pub enum VisionCommands { + /// Start Qwen3-VL server + Start, + /// Stop Qwen3-VL server + Stop, + /// Check Qwen3-VL status + Status, +} + /// Parse key type from string pub fn parse_key_type(s: Option<&str>) -> momentry_core::core::api_key::ApiKeyType { use momentry_core::core::api_key::ApiKeyType; diff --git a/src/cli/mod.rs b/src/cli/mod.rs index a55956b..48ad50e 100644 --- a/src/cli/mod.rs +++ b/src/cli/mod.rs @@ -1,5 +1,6 @@ //! CLI command definitions and argument parsing pub mod args; +pub mod vision; pub use args::*; diff --git a/src/cli/vision.rs b/src/cli/vision.rs new file mode 100644 index 0000000..88c3545 --- /dev/null +++ b/src/cli/vision.rs @@ -0,0 +1,95 @@ +use anyhow::Result; +use std::path::PathBuf; + +use momentry_core::core::vision::qwen_vl_manager::QwenVLManager; +use momentry_core::core::processor::cascade_vision::CascadeVisionProcessor; + +pub async fn handle_vision_command(cmd: crate::cli::args::VisionCommands) -> Result<()> { + let manager = QwenVLManager::new(); + + match cmd { + crate::cli::args::VisionCommands::Start => { + println!("Starting Qwen3-VL server..."); + manager.ensure_running().await?; + println!("✅ Qwen3-VL server started successfully"); + println!("Health check: http://localhost:8086/health"); + } + crate::cli::args::VisionCommands::Stop => { + println!("Stopping Qwen3-VL server..."); + manager.stop_server().await?; + println!("✅ Qwen3-VL server stopped"); + } + crate::cli::args::VisionCommands::Status => { + println!("Checking Qwen3-VL status..."); + let status = manager.get_status().await?; + + println!("Status:"); + println!(" Running: {}", if status.running { "✅ Yes" } else { "❌ No" }); + println!(" Port: {}", status.port); + println!(" Model: {}", status.model_path); + println!(" Last request: {} seconds ago", status.last_request); + println!(" PID file: {}", status.pid_file); + println!(" Log file: {}", status.log_file); + } + } + + Ok(()) +} + +pub async fn handle_detect_command( + image: String, + objects: Vec, + cascade: bool, + threshold: f32, +) -> Result<()> { + let image_path = PathBuf::from(&image); + + if !image_path.exists() { + anyhow::bail!("Image file not found: {}", image); + } + + println!("Detecting objects in: {}", image); + println!("Objects: {}", objects.join(", ")); + println!("Mode: {}", if cascade { "Cascade (CLIP + Qwen3-VL)" } else { "CLIP only" }); + println!("Threshold: {:.2}", threshold); + println!(); + + if cascade { + let processor = CascadeVisionProcessor::with_threshold(threshold); + let result = processor.detect_objects(&image_path, &objects.iter().map(|s| s.as_str()).collect::>()).await?; + + println!("Detection Results:"); + println!(" Model used: {}", result.model_used); + println!(" CLIP confidence: {:.3}", result.clip_confidence); + println!(" Qwen3-VL used: {}", if result.qwenvl_used { "✅ Yes" } else { "❌ No" }); + println!(" Processing time: {} ms", result.processing_time_ms); + println!(" Detections:"); + + for detection in &result.detections { + println!(" - {}: {:.3}", detection.label, detection.confidence); + } + + if result.detections.is_empty() { + println!(" (No objects detected)"); + } + } else { + use momentry_core::core::processor::clip::detect_objects; + + let objects_str: Vec<&str> = objects.iter().map(|s| s.as_str()).collect(); + let predictions = detect_objects(&image, &objects_str, Some(threshold), None).await?; + + println!("Detection Results:"); + println!(" Model used: CLIP"); + println!(" Detections:"); + + for prediction in &predictions { + println!(" - {}: {:.3}", prediction.label, prediction.confidence); + } + + if predictions.is_empty() { + println!(" (No objects detected above threshold {:.2})", threshold); + } + } + + Ok(()) +} \ No newline at end of file diff --git a/src/core/config.rs b/src/core/config.rs index 3612bb3..a42fa68 100644 --- a/src/core/config.rs +++ b/src/core/config.rs @@ -92,6 +92,16 @@ pub static MEDIA_BASE_URL: Lazy = Lazy::new(|| { .unwrap_or_else(|_| "https://wp.momentry.ddns.net".to_string()) }); +pub static STORAGE_ROOT: Lazy = Lazy::new(|| { + env::var("MOMENTRY_STORAGE_ROOT") + .unwrap_or_else(|_| "/Users/accusys/momentry/var/sftpgo/data".to_string()) +}); + +pub static SERVE_BASE_URL: Lazy = Lazy::new(|| { + env::var("MOMENTRY_SERVE_BASE_URL") + .unwrap_or_else(|_| "https://m5wp.momentry.ddns.net/files".to_string()) +}); + pub static SERVER_PORT: Lazy = Lazy::new(|| { env::var("MOMENTRY_SERVER_PORT") .unwrap_or_else(|_| "3002".to_string()) diff --git a/src/core/db/postgres_db.rs b/src/core/db/postgres_db.rs index a2cd9df..cd18144 100644 --- a/src/core/db/postgres_db.rs +++ b/src/core/db/postgres_db.rs @@ -2862,7 +2862,7 @@ impl PostgresDb { let rows = if let Some(u) = file_uuid { sqlx::query(&format!( "SELECT chunk_id, file_uuid, chunk_type, text_content, start_time, end_time, 1.0::float8 as score \ - FROM {} WHERE file_uuid=$1 AND text_content ILIKE $2 LIMIT $3", table) + FROM {} WHERE file_uuid=$1 AND text_content ILIKE $2 AND text_content != '' LIMIT $3", table) ) .bind(u).bind(&like).bind(limit) .fetch_all(&self.pool).await? diff --git a/src/core/llm/client.rs b/src/core/llm/client.rs index d09bd66..ac2da84 100644 --- a/src/core/llm/client.rs +++ b/src/core/llm/client.rs @@ -1,10 +1,10 @@ use anyhow::Result; -use reqwest::Client; use serde::{Deserialize, Serialize}; use std::time::Duration; use tracing::{debug, error, warn}; use crate::core::config; +use crate::core::llm::function_calling::LLM_CLIENT; #[derive(Debug, Serialize)] struct ChatRequest { @@ -39,10 +39,6 @@ pub async fn generate_5w1h_summary(scene_text: &str) -> Result { return Ok("LLM Disabled".to_string()); } - let client = Client::builder() - .timeout(Duration::from_secs(*config::llm::SUMMARY_TIMEOUT_SECS)) - .build()?; - let prompt = format!( r#"Analyze the following video scene transcript and provide a concise 5W1H+ summary in JSON format. Focus on: Who, What, Where, When, Why, How, and Key Objects/Actions. @@ -82,9 +78,10 @@ pub async fn generate_5w1h_summary(scene_text: &str) -> Result { debug!("Calling LLM for summary: {}", *config::llm::SUMMARY_URL); - let res = client + let res = LLM_CLIENT .post(&*config::llm::SUMMARY_URL) .json(&req) + .timeout(Duration::from_secs(*config::llm::SUMMARY_TIMEOUT_SECS)) .send() .await?; diff --git a/src/core/llm/function_calling.rs b/src/core/llm/function_calling.rs index dd2c9ac..98ab15c 100644 --- a/src/core/llm/function_calling.rs +++ b/src/core/llm/function_calling.rs @@ -1,8 +1,18 @@ +use once_cell::sync::Lazy; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; use crate::core::config; +/// Shared HTTP client with connection pooling for all LLM calls +pub static LLM_CLIENT: Lazy = Lazy::new(|| { + reqwest::Client::builder() + .pool_max_idle_per_host(32) + .pool_idle_timeout(std::time::Duration::from_secs(300)) + .build() + .expect("Failed to create shared LLM HTTP client") +}); + /// A tool/function definition for Gemma4 function calling #[derive(Debug, Clone, Serialize)] pub struct ToolDef { @@ -126,11 +136,11 @@ pub async fn call_llm_vision( "stream": false, }); - let client = reqwest::Client::builder() + let res = LLM_CLIENT + .post(&llm_vision_url()) + .json(&req) .timeout(std::time::Duration::from_secs(timeout_secs)) - .build()?; - - let res = client.post(&llm_vision_url()).json(&req).send().await?; + .send().await?; if !res.status().is_success() { let text = res.text().await.unwrap_or_default(); anyhow::bail!("Vision LLM API error: {}", text); @@ -182,13 +192,11 @@ pub async fn call_llm( max_tokens: u32, timeout_secs: u64, ) -> anyhow::Result { - let client = reqwest::Client::builder() - .timeout(std::time::Duration::from_secs(if timeout_secs > 0 { - timeout_secs - } else { - *config::llm::CHAT_TIMEOUT_SECS - })) - .build()?; + let timeout = if timeout_secs > 0 { + timeout_secs + } else { + *config::llm::CHAT_TIMEOUT_SECS + }; let req = ChatRequest { model: llm_model(), @@ -199,7 +207,11 @@ pub async fn call_llm( tools, }; - let res = client.post(&llm_chat_url()).json(&req).send().await?; + let res = LLM_CLIENT + .post(&llm_chat_url()) + .json(&req) + .timeout(std::time::Duration::from_secs(timeout)) + .send().await?; if !res.status().is_success() { let text = res.text().await.unwrap_or_default(); diff --git a/src/core/llm/rerank.rs b/src/core/llm/rerank.rs index 5c0f272..1db2f3b 100644 --- a/src/core/llm/rerank.rs +++ b/src/core/llm/rerank.rs @@ -1,12 +1,12 @@ use std::collections::HashSet; use anyhow::Result; -use reqwest::Client; use serde::{Deserialize, Serialize}; use std::time::Duration; use tracing::{debug, warn}; use crate::core::config; +use crate::core::llm::function_calling::LLM_CLIENT; #[derive(Debug, Serialize)] struct ChatRequest { @@ -38,7 +38,10 @@ struct RerankResponse { ranked: Vec, } -pub async fn rerank_search_results(query: &str, candidates: &[(usize, &str)]) -> Result> { +pub async fn rerank_search_results( + query: &str, + candidates: &[(usize, &str)], +) -> Result> { if candidates.is_empty() { return Ok(vec![]); } @@ -67,10 +70,6 @@ Include every chunk number exactly once. Only respond with the JSON."#, query, chunks_text ); - let client = Client::builder() - .timeout(Duration::from_secs(15)) - .build()?; - let req = ChatRequest { model: config::llm::CHAT_MODEL.clone(), messages: vec![ @@ -88,11 +87,16 @@ Include every chunk number exactly once. Only respond with the JSON."#, stream: false, }; - debug!("LLM rerank: {} candidates for query '{}'", candidates.len(), query); + debug!( + "LLM rerank: {} candidates for query '{}'", + candidates.len(), + query + ); - let res = client + let res = LLM_CLIENT .post(&*config::llm::CHAT_URL) .json(&req) + .timeout(Duration::from_secs(15)) .send() .await?; @@ -116,7 +120,11 @@ Include every chunk number exactly once. Only respond with the JSON."#, // Strip markdown code fences if present let content = if content.starts_with("```") { let lines: Vec<&str> = content.lines().collect(); - let start = if lines.first().map(|l| l.contains("```")).unwrap_or(false) { 1 } else { 0 }; + let start = if lines.first().map(|l| l.contains("```")).unwrap_or(false) { + 1 + } else { + 0 + }; let end = if lines.last().map(|l| l.contains("```")).unwrap_or(false) { lines.len().saturating_sub(1) } else { @@ -163,6 +171,9 @@ Include every chunk number exactly once. Only respond with the JSON."#, } } - warn!("LLM rerank: could not parse response — content: {}", &content[..content.len().min(200)]); + warn!( + "LLM rerank: could not parse response — content: {}", + &content[..content.len().min(200)] + ); Ok(candidates.iter().map(|(idx, _)| *idx).collect()) } diff --git a/src/core/mod.rs b/src/core/mod.rs index 5950d6b..a787d07 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -20,3 +20,4 @@ pub mod text; pub mod thumbnail; pub mod time; pub mod tmdb; +pub mod vision; diff --git a/src/core/pipeline/mod.rs b/src/core/pipeline/mod.rs index 0ffeb73..38196e0 100644 --- a/src/core/pipeline/mod.rs +++ b/src/core/pipeline/mod.rs @@ -17,8 +17,8 @@ pub async fn store_asrx_chunks(db: &PostgresDb, uuid: &str) -> Result<()> { let json_str = std::fs::read_to_string(&asrx_path) .with_context(|| format!("ASRX file not found: {:?}", asrx_path))?; - let result: AsrxResult = serde_json::from_str(&json_str) - .context("Failed to parse ASRX JSON")?; + let result: AsrxResult = + serde_json::from_str(&json_str).context("Failed to parse ASRX JSON")?; let segments_count = result.segments.len(); let mut pre_chunks = Vec::new(); @@ -41,21 +41,26 @@ pub async fn store_asrx_chunks(db: &PostgresDb, uuid: &str) -> Result<()> { )); } - db.store_raw_pre_chunks_batch(uuid, "asrx", &pre_chunks).await?; - db.store_raw_pre_chunks_batch(uuid, "asr", &pre_chunks).await?; - db.store_speaker_detections_batch(uuid, &speaker_detections).await?; + db.store_raw_pre_chunks_batch(uuid, "asrx", &pre_chunks) + .await?; + db.store_raw_pre_chunks_batch(uuid, "asr", &pre_chunks) + .await?; + db.store_speaker_detections_batch(uuid, &speaker_detections) + .await?; println!("Stored {} ASRX pre-chunks for {}", segments_count, uuid); Ok(()) } pub async fn execute_rule1(db: &PostgresDb, uuid: &str) -> Result { - let video = db.get_video_by_uuid(uuid) + let video = db + .get_video_by_uuid(uuid) .await? .context("Video not found")?; let fps = video.fps; - let count = rule1_ingest::execute_rule1(db, uuid, fps).await + let count = rule1_ingest::execute_rule1(db, uuid, fps) + .await .context("Rule 1 ingestion failed")?; println!("Rule 1 completed: {} chunks inserted for {}", count, uuid); @@ -68,17 +73,15 @@ pub async fn vectorize_chunks(uuid: &str) -> Result<()> { let embedder = Embedder::new("embeddinggemma-300m".to_string()); let chunk_table = schema::table_name("chunk"); - let rows = sqlx::query_as::<_, (String, String, String, i64, i64, f64, f64, String)>( - &format!( - "SELECT chunk_id, chunk_type, text_content, start_frame, end_frame, \ + let rows = sqlx::query_as::<_, (String, String, String, i64, i64, f64, f64, String)>(&format!( + "SELECT chunk_id, chunk_type, text_content, start_frame, end_frame, \ start_time, end_time, content::text \ FROM {} WHERE file_uuid = $1 AND chunk_type = 'sentence' \ AND embedding IS NULL \ AND (text_content IS NOT NULL AND text_content != '') \ ORDER BY id", - chunk_table - ), - ) + chunk_table + )) .bind(uuid) .fetch_all(db.pool()) .await?; @@ -91,7 +94,9 @@ pub async fn vectorize_chunks(uuid: &str) -> Result<()> { let total = rows.len(); let mut stored = 0usize; - for (chunk_id, _chunk_type, text, start_frame, end_frame, start_time, end_time, _content_str) in &rows { + for (chunk_id, _chunk_type, text, start_frame, end_frame, start_time, end_time, _content_str) in + &rows + { if text.is_empty() { continue; } @@ -127,13 +132,15 @@ pub async fn vectorize_chunks(uuid: &str) -> Result<()> { } } - println!("Vectorization complete: {}/{} vectors for {}", stored, total, uuid); + println!( + "Vectorization complete: {}/{} vectors for {}", + stored, total, uuid + ); Ok(()) } pub async fn run_phase1(uuid: &str) -> Result<()> { - let executor = PythonExecutor::new() - .context("Failed to create PythonExecutor")?; + let executor = PythonExecutor::new().context("Failed to create PythonExecutor")?; executor .run( @@ -154,15 +161,17 @@ pub async fn mark_complete(db: &PostgresDb, uuid: &str) -> Result<()> { use crate::core::db::MonitorJobStatus; use crate::core::db::VideoStatus; - let job_id = sqlx::query_scalar::<_, i32>( - &format!("SELECT id FROM {} WHERE uuid = $1 LIMIT 1", schema::table_name("monitor_jobs")), - ) + let job_id = sqlx::query_scalar::<_, i32>(&format!( + "SELECT id FROM {} WHERE uuid = $1 LIMIT 1", + schema::table_name("monitor_jobs") + )) .bind(uuid) .fetch_optional(db.pool()) .await?; if let Some(job_id) = job_id { - db.update_job_status(job_id, MonitorJobStatus::Completed).await?; + db.update_job_status(job_id, MonitorJobStatus::Completed) + .await?; println!("Job {} marked as completed", job_id); } diff --git a/src/core/processor/asrx.rs b/src/core/processor/asrx.rs index 2ab85a2..ee031ed 100644 --- a/src/core/processor/asrx.rs +++ b/src/core/processor/asrx.rs @@ -44,10 +44,7 @@ pub async fn process_asrx( let executor = PythonExecutor::new()?; let script_path = executor.script_path("asrx_processor.py"); - tracing::info!( - "[ASRX] Starting hybrid speaker diarization: {}", - video_path - ); + tracing::info!("[ASRX] Starting hybrid speaker diarization: {}", video_path); if !script_path.exists() { tracing::error!("[ASRX] Script not found: {:?}", script_path); diff --git a/src/core/processor/cascade_vision.rs b/src/core/processor/cascade_vision.rs new file mode 100644 index 0000000..f324e06 --- /dev/null +++ b/src/core/processor/cascade_vision.rs @@ -0,0 +1,308 @@ +use anyhow::{Context, Result}; +use serde::{Deserialize, Serialize}; +use std::path::Path; +use std::time::Duration; +use tracing::{debug, info, warn}; + +use crate::core::processor::clip::{ClipPrediction, detect_objects}; +use crate::core::vision::qwen_vl_manager::QwenVLManager; + +const DEFAULT_CLIP_THRESHOLD: f32 = 0.7; +const QWENVL_TIMEOUT: Duration = Duration::from_secs(30); + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CascadeDetectionResult { + pub detections: Vec, + pub model_used: String, + pub clip_confidence: f32, + pub qwenvl_used: bool, + pub processing_time_ms: u64, +} + +pub struct CascadeVisionProcessor { + clip_threshold: f32, + qwen_vl_manager: QwenVLManager, +} + +impl CascadeVisionProcessor { + pub fn new() -> Self { + Self { + clip_threshold: DEFAULT_CLIP_THRESHOLD, + qwen_vl_manager: QwenVLManager::new(), + } + } + + pub fn with_threshold(threshold: f32) -> Self { + Self { + clip_threshold: threshold, + qwen_vl_manager: QwenVLManager::new(), + } + } + + pub async fn detect_objects(&self, image_path: &Path, objects: &[&str]) -> Result { + let start_time = std::time::Instant::now(); + + info!( + "[Cascade] Starting detection for {:?} with {} object classes (threshold: {:.2})", + image_path, + objects.len(), + self.clip_threshold + ); + + let clip_result = self.run_clip_detection(image_path, objects).await?; + + let max_clip_confidence = clip_result + .iter() + .map(|p| p.confidence) + .fold(0.0_f32, |max, val| if val > max { val } else { max }); + + debug!( + "[Cascade] CLIP max confidence: {:.3} (threshold: {:.2})", + max_clip_confidence, + self.clip_threshold + ); + + if max_clip_confidence > self.clip_threshold { + info!( + "[Cascade] High confidence ({:.3} > {:.2}) → triggering Qwen3-VL", + max_clip_confidence, + self.clip_threshold + ); + + let qwenvl_result = self.run_qwenvl_detection(image_path, objects).await?; + + let processing_time = start_time.elapsed().as_millis() as u64; + + return Ok(CascadeDetectionResult { + detections: qwenvl_result, + model_used: "qwen3vl".to_string(), + clip_confidence: max_clip_confidence, + qwenvl_used: true, + processing_time_ms: processing_time, + }); + } + + info!( + "[Cascade] Low confidence ({:.3} <= {:.2}) → using CLIP results only", + max_clip_confidence, + self.clip_threshold + ); + + let processing_time = start_time.elapsed().as_millis() as u64; + + Ok(CascadeDetectionResult { + detections: clip_result, + model_used: "clip".to_string(), + clip_confidence: max_clip_confidence, + qwenvl_used: false, + processing_time_ms: processing_time, + }) + } + + async fn run_clip_detection(&self, image_path: &Path, objects: &[&str]) -> Result> { + let image_path_str = image_path.display().to_string(); + + debug!("[Cascade] Running CLIP detection for {:?}", image_path); + + let predictions = detect_objects(&image_path_str, objects, None, None) + .await + .context("CLIP detection failed")?; + + debug!( + "[Cascade] CLIP detected {} objects", + predictions.len() + ); + + Ok(predictions) + } + + async fn run_qwenvl_detection(&self, image_path: &Path, objects: &[&str]) -> Result> { + let image_path_str = image_path.display().to_string(); + + debug!("[Cascade] Running Qwen3-VL detection for {:?}", image_path); + + self.qwen_vl_manager.ensure_running().await?; + + let prompt = self.build_detection_prompt(objects); + + let client = reqwest::Client::new(); + let url = format!("http://localhost:{}/v1/chat/completions", self.qwen_vl_manager.get_port()); + + let request_body = serde_json::json!({ + "model": "Qwen3VL-8B-Instruct-Q8_0", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": prompt + }, + { + "type": "image_url", + "image_url": { + "url": format!("file://{}", image_path_str) + } + } + ] + } + ], + "max_tokens": 500, + "temperature": 0.1 + }); + + let response = client + .post(&url) + .json(&request_body) + .timeout(QWENVL_TIMEOUT) + .send() + .await + .context("Qwen3-VL API request failed")?; + + if !response.status().is_success() { + warn!("[Cascade] Qwen3-VL API error: {}", response.status()); + anyhow::bail!("Qwen3-VL API returned error: {}", response.status()); + } + + let response_json: serde_json::Value = response + .json() + .await + .context("Failed to parse Qwen3-VL response")?; + + let content = response_json + .get("choices") + .and_then(|choices| choices.get(0)) + .and_then(|choice| choice.get("message")) + .and_then(|message| message.get("content")) + .and_then(|content| content.as_str()) + .unwrap_or(""); + + debug!("[Cascade] Qwen3-VL response: {}", content); + + let detections = self.parse_qwenvl_response(content, objects); + + self.qwen_vl_manager.update_last_request_time().await; + + info!( + "[Cascade] Qwen3-VL detected {} objects", + detections.len() + ); + + Ok(detections) + } + + fn build_detection_prompt(&self, objects: &[&str]) -> String { + let object_list = objects.join(", "); + + format!( + "Analyze this image and detect the following objects: {}.\n\ + For each detected object, provide:\n\ + 1. The object name\n\ + 2. A confidence score (0.0 to 1.0)\n\ + 3. A brief description of what you see\n\ + \n\ + Format your response as JSON:\n\ + {{\"detections\": [{{\"label\": \"object_name\", \"confidence\": 0.95, \"description\": \"brief description\"}}]}}\n\ + \n\ + If no objects are detected, return: {{\"detections\": []}}\n\ + \n\ + IMPORTANT: Only detect objects that are clearly visible and identifiable. Do not guess or hallucinate.", + object_list + ) + } + + fn parse_qwenvl_response(&self, content: &str, _objects: &[&str]) -> Vec { + let json_start = content.find('{'); + let json_end = content.rfind('}'); + + if json_start.is_none() || json_end.is_none() { + debug!("[Cascade] No JSON found in Qwen3-VL response"); + return Vec::new(); + } + + let json_str = &content[json_start.unwrap()..=json_end.unwrap()]; + + let parsed: serde_json::Value = serde_json::from_str(json_str) + .unwrap_or(serde_json::json!({"detections": []})); + + let detections = parsed + .get("detections") + .and_then(|d| d.as_array()) + .map(|arr| arr.clone()) + .unwrap_or_else(|| Vec::new()); + + detections + .iter() + .filter_map(|d| { + let label = d.get("label").and_then(|l| l.as_str()).unwrap_or(""); + let confidence = d.get("confidence").and_then(|c| c.as_f64()).unwrap_or(0.0) as f32; + + if !label.is_empty() && confidence > 0.0 { + Some(ClipPrediction { + label: label.to_string(), + confidence, + }) + } else { + None + } + }) + .collect() + } +} + +impl Default for CascadeVisionProcessor { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_build_detection_prompt() { + let processor = CascadeVisionProcessor::new(); + let objects = vec!["gun", "weapon", "person"]; + let prompt = processor.build_detection_prompt(&objects); + + assert!(prompt.contains("gun, weapon, person")); + assert!(prompt.contains("confidence score")); + assert!(prompt.contains("JSON")); + } + + #[test] + fn test_parse_qwenvl_response() { + let processor = CascadeVisionProcessor::new(); + let response = "{\"detections\": [{\"label\": \"gun\", \"confidence\": 0.95, \"description\": \"a handgun\"}]}"; + let objects = vec!["gun"]; + + let detections = processor.parse_qwenvl_response(response, &objects); + + assert_eq!(detections.len(), 1); + assert_eq!(detections[0].label, "gun"); + assert!((detections[0].confidence - 0.95).abs() < 0.001); + } + + #[test] + fn test_parse_empty_response() { + let processor = CascadeVisionProcessor::new(); + let response = "{\"detections\": []}"; + let objects = vec!["gun"]; + + let detections = processor.parse_qwenvl_response(response, &objects); + + assert_eq!(detections.len(), 0); + } + + #[test] + fn test_parse_invalid_json() { + let processor = CascadeVisionProcessor::new(); + let response = "This is not JSON"; + let objects = vec!["gun"]; + + let detections = processor.parse_qwenvl_response(response, &objects); + + assert_eq!(detections.len(), 0); + } +} \ No newline at end of file diff --git a/src/core/processor/clip.rs b/src/core/processor/clip.rs new file mode 100644 index 0000000..45f793b --- /dev/null +++ b/src/core/processor/clip.rs @@ -0,0 +1,290 @@ +use anyhow::{Context, Result}; +use serde::{Deserialize, Serialize}; +use std::time::Duration; + +use super::executor::PythonExecutor; + +const CLIP_TIMEOUT: Duration = Duration::from_secs(300); + +/// CLIP classification prediction +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ClipPrediction { + pub label: String, + pub confidence: f32, +} + +/// CLIP classification result for a single image +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ClipImageResult { + pub image_path: String, + pub predictions: Vec, +} + +/// CLIP object detection result +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ClipDetectionResult { + pub image_path: String, + pub detected_objects: Vec, +} + +/// Classify a single image with given labels +pub async fn classify_image( + image_path: &str, + labels: &[&str], + top_k: Option, + model_name: Option<&str>, +) -> Result> { + let executor = PythonExecutor::new()?; + let script_path = executor.script_path("clip_classifier.py"); + + if !script_path.exists() { + anyhow::bail!("clip_classifier.py not found at {:?}", script_path); + } + + let top_k = top_k.unwrap_or(5); + let model = model_name.unwrap_or("openai/clip-vit-base-patch32"); + + let mut args = vec![ + image_path.to_string(), + "--labels".to_string(), + labels.join(","), + "--top-k".to_string(), + top_k.to_string(), + "--model".to_string(), + model.to_string(), + ]; + + let output_path = format!("{}.clip.json", image_path); + args.push("--output".to_string()); + args.push(output_path.clone()); + + tracing::info!( + "[CLIP] Classifying image: {} with {} labels", + image_path, + labels.len() + ); + + executor + .run( + "clip_classifier.py", + &args.iter().map(|s| s.as_str()).collect::>(), + None, + "CLIP", + Some(CLIP_TIMEOUT), + ) + .await + .context("Failed to run CLIP classifier")?; + + let json_str = std::fs::read_to_string(&output_path) + .context("Failed to read CLIP output")?; + + let results: std::collections::HashMap> = + serde_json::from_str(&json_str) + .context("Failed to parse CLIP output")?; + + let predictions = results + .get(image_path) + .cloned() + .unwrap_or_default(); + + tracing::info!( + "[CLIP] Top prediction: {} ({:.3})", + predictions.first().map(|p| p.label.as_str()).unwrap_or("none"), + predictions.first().map(|p| p.confidence).unwrap_or(0.0) + ); + + Ok(predictions) +} + +/// Detect objects in an image +pub async fn detect_objects( + image_path: &str, + objects: &[&str], + threshold: Option, + model_name: Option<&str>, +) -> Result> { + let executor = PythonExecutor::new()?; + let script_path = executor.script_path("clip_classifier.py"); + + if !script_path.exists() { + anyhow::bail!("clip_classifier.py not found at {:?}", script_path); + } + + let threshold = threshold.unwrap_or(0.15); + let model = model_name.unwrap_or("openai/clip-vit-base-patch32"); + + let mut args = vec![ + image_path.to_string(), + "--detect".to_string(), + objects.join(","), + "--threshold".to_string(), + threshold.to_string(), + "--model".to_string(), + model.to_string(), + ]; + + let output_path = format!("{}.clip.json", image_path); + args.push("--output".to_string()); + args.push(output_path.clone()); + + tracing::info!( + "[CLIP] Detecting {} objects in: {} (threshold: {:.2})", + objects.len(), + image_path, + threshold + ); + + executor + .run( + "clip_classifier.py", + &args.iter().map(|s| s.as_str()).collect::>(), + None, + "CLIP", + Some(CLIP_TIMEOUT), + ) + .await + .context("Failed to run CLIP object detection")?; + + let json_str = std::fs::read_to_string(&output_path) + .context("Failed to read CLIP output")?; + + let results: std::collections::HashMap> = + serde_json::from_str(&json_str) + .context("Failed to parse CLIP output")?; + + let detected = results + .get(image_path) + .cloned() + .unwrap_or_default(); + + if !detected.is_empty() { + tracing::info!( + "[CLIP] Detected {} objects: {}", + detected.len(), + detected.iter().map(|p| p.label.as_str()).collect::>().join(", ") + ); + } else { + tracing::info!("[CLIP] No objects detected above threshold {:.2}", threshold); + } + + Ok(detected) +} + +/// Batch classify multiple images +pub async fn classify_images( + image_paths: &[&str], + labels: &[&str], + top_k: Option, + model_name: Option<&str>, +) -> Result> { + let executor = PythonExecutor::new()?; + let script_path = executor.script_path("clip_classifier.py"); + + if !script_path.exists() { + anyhow::bail!("clip_classifier.py not found at {:?}", script_path); + } + + let top_k = top_k.unwrap_or(5); + let model = model_name.unwrap_or("openai/clip-vit-base-patch32"); + + // Create temp file with image paths + let temp_file = format!("/tmp/clip_batch_{}.txt", uuid::Uuid::new_v4()); + std::fs::write(&temp_file, image_paths.join("\n")) + .context("Failed to write batch file")?; + + let mut args = vec![ + temp_file.clone(), + "--batch".to_string(), + "--labels".to_string(), + labels.join(","), + "--top-k".to_string(), + top_k.to_string(), + "--model".to_string(), + model.to_string(), + ]; + + let output_path = format!("/tmp/clip_batch_{}.json", uuid::Uuid::new_v4()); + args.push("--output".to_string()); + args.push(output_path.clone()); + + tracing::info!( + "[CLIP] Batch classifying {} images with {} labels", + image_paths.len(), + labels.len() + ); + + executor + .run( + "clip_classifier.py", + &args.iter().map(|s| s.as_str()).collect::>(), + None, + "CLIP", + Some(CLIP_TIMEOUT), + ) + .await + .context("Failed to run batch CLIP classification")?; + + let json_str = std::fs::read_to_string(&output_path) + .context("Failed to read CLIP batch output")?; + + let results_map: std::collections::HashMap> = + serde_json::from_str(&json_str) + .context("Failed to parse CLIP batch output")?; + + let results: Vec = image_paths + .iter() + .map(|path| ClipImageResult { + image_path: path.to_string(), + predictions: results_map.get(*path).cloned().unwrap_or_default(), + }) + .collect(); + + // Cleanup temp files + let _ = std::fs::remove_file(&temp_file); + let _ = std::fs::remove_file(&output_path); + + Ok(results) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_clip_prediction_serialization() { + let pred = ClipPrediction { + label: "person in room".to_string(), + confidence: 0.876, + }; + let json = serde_json::to_string(&pred).unwrap(); + assert!(json.contains("person in room")); + assert!(json.contains("0.876")); + } + + #[test] + fn test_clip_prediction_deserialization() { + let json = r#"{"label":"outdoor scene","confidence":0.945}"#; + let pred: ClipPrediction = serde_json::from_str(json).unwrap(); + assert_eq!(pred.label, "outdoor scene"); + assert!((pred.confidence - 0.945).abs() < 0.001); + } + + #[test] + fn test_clip_image_result() { + let result = ClipImageResult { + image_path: "/test/image.jpg".to_string(), + predictions: vec![ + ClipPrediction { + label: "indoor".to_string(), + confidence: 0.92, + }, + ClipPrediction { + label: "outdoor".to_string(), + confidence: 0.08, + }, + ], + }; + assert_eq!(result.predictions.len(), 2); + assert_eq!(result.predictions[0].label, "indoor"); + } +} \ No newline at end of file diff --git a/src/core/processor/mod.rs b/src/core/processor/mod.rs index 402958a..ef44fb7 100644 --- a/src/core/processor/mod.rs +++ b/src/core/processor/mod.rs @@ -1,6 +1,8 @@ pub mod asr; pub mod asrx; pub mod caption; +pub mod cascade_vision; +pub mod clip; pub mod cut; pub mod executor; pub mod face; @@ -16,6 +18,8 @@ pub mod yolo; pub use asr::{process_asr, AsrResult, AsrSegment}; pub use asrx::{process_asrx, AsrxResult, AsrxSegment}; pub use caption::{process_caption, CaptionResult, CaptionSummary, FrameCaption}; +pub use cascade_vision::{CascadeDetectionResult, CascadeVisionProcessor}; +pub use clip::{classify_image, classify_images, detect_objects, ClipDetectionResult, ClipImageResult, ClipPrediction}; pub use cut::{process_cut, CutResult, CutScene}; pub use executor::{validate_python_env, PythonExecutor, RetryConfig}; pub use face::{process_face, Face, FaceFrame, FaceResult}; diff --git a/src/core/vision/mod.rs b/src/core/vision/mod.rs new file mode 100644 index 0000000..a87b77f --- /dev/null +++ b/src/core/vision/mod.rs @@ -0,0 +1 @@ +pub mod qwen_vl_manager; \ No newline at end of file diff --git a/src/core/vision/qwen_vl_manager.rs b/src/core/vision/qwen_vl_manager.rs new file mode 100644 index 0000000..36bfbbb --- /dev/null +++ b/src/core/vision/qwen_vl_manager.rs @@ -0,0 +1,218 @@ +use anyhow::{Context, Result}; +use std::path::PathBuf; +use std::process::Command; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tokio::sync::Mutex; +use tracing::{debug, error, info, warn}; + +pub struct QwenVLManager { + port: u16, + model_path: PathBuf, + mmproj_path: PathBuf, + log_file: PathBuf, + pid_file: PathBuf, + start_script: PathBuf, + stop_script: PathBuf, + last_request_time: Arc>, + max_startup_time: Duration, +} + +impl QwenVLManager { + pub fn new() -> Self { + Self { + port: 8086, + model_path: PathBuf::from("/Users/accusys/models/Qwen3VL-8B-Instruct-Q8_0.gguf"), + mmproj_path: PathBuf::from("/Users/accusys/models/mmproj-Qwen3VL-8B-Instruct-F16.gguf"), + log_file: PathBuf::from("logs/qwen3vl_8086.log"), + pid_file: PathBuf::from("/tmp/qwen3vl.pid"), + start_script: PathBuf::from("scripts/start_qwen3vl.sh"), + stop_script: PathBuf::from("scripts/stop_qwen3vl.sh"), + last_request_time: Arc::new(Mutex::new(Instant::now())), + max_startup_time: Duration::from_secs(60), + } + } + + pub fn with_port(port: u16) -> Self { + let mut manager = Self::new(); + manager.port = port; + manager.pid_file = PathBuf::from(format!("/tmp/qwen3vl_{}.pid", port)); + manager.log_file = PathBuf::from(format!("logs/qwen3vl_{}.log", port)); + manager + } + + pub fn get_port(&self) -> u16 { + self.port + } + + pub async fn is_running(&self) -> Result { + let health_url = format!("http://localhost:{}/health", self.port); + + let client = reqwest::Client::new(); + let response = client + .get(&health_url) + .timeout(Duration::from_secs(5)) + .send() + .await; + + match response { + Ok(resp) => { + let status = resp.status(); + let body = resp.text().await?; + if status.is_success() && body.contains("\"status\":\"ok\"") { + debug!("Qwen3-VL is running on port {}", self.port); + return Ok(true); + } + debug!("Qwen3-VL health check failed: {}", status); + Ok(false) + } + Err(e) => { + debug!("Qwen3-VL not reachable: {}", e); + Ok(false) + } + } + } + + pub async fn ensure_running(&self) -> Result<()> { + if self.is_running().await? { + debug!("Qwen3-VL already running"); + self.update_last_request_time().await; + return Ok(()); + } + + info!("Starting Qwen3-VL server on port {}", self.port); + self.start_server().await?; + self.wait_for_ready().await?; + self.update_last_request_time().await; + + info!("Qwen3-VL server started successfully"); + Ok(()) + } + + pub async fn start_server(&self) -> Result<()> { + let script_path = self.start_script.canonicalize() + .context("Failed to resolve start script path")?; + + debug!("Running start script: {}", script_path.display()); + + let output = Command::new("bash") + .arg(&script_path) + .output() + .context("Failed to execute start script")?; + + if !output.status.success() { + error!("Start script failed: {}", String::from_utf8_lossy(&output.stderr)); + anyhow::bail!("Failed to start Qwen3-VL server"); + } + + debug!("Start script output: {}", String::from_utf8_lossy(&output.stdout)); + Ok(()) + } + + pub async fn stop_server(&self) -> Result<()> { + let script_path = self.stop_script.canonicalize() + .context("Failed to resolve stop script path")?; + + debug!("Running stop script: {}", script_path.display()); + + let output = Command::new("bash") + .arg(&script_path) + .output() + .context("Failed to execute stop script")?; + + if !output.status.success() { + warn!("Stop script returned error: {}", String::from_utf8_lossy(&output.stderr)); + } + + debug!("Stop script output: {}", String::from_utf8_lossy(&output.stdout)); + + tokio::time::sleep(Duration::from_secs(2)).await; + + if self.is_running().await? { + warn!("Qwen3-VL still running after stop script"); + } + + info!("Qwen3-VL server stopped"); + Ok(()) + } + + pub async fn wait_for_ready(&self) -> Result<()> { + let health_url = format!("http://localhost:{}/health", self.port); + let client = reqwest::Client::new(); + + let start_time = Instant::now(); + + while start_time.elapsed() < self.max_startup_time { + let response = client + .get(&health_url) + .timeout(Duration::from_secs(2)) + .send() + .await; + + match response { + Ok(resp) => { + if resp.status().is_success() { + let body = resp.text().await?; + if body.contains("\"status\":\"ok\"") { + debug!("Qwen3-VL ready after {} seconds", start_time.elapsed().as_secs()); + return Ok(()); + } + } + } + Err(_) => {} + } + + tokio::time::sleep(Duration::from_secs(2)).await; + } + + error!("Qwen3-VL failed to start within {} seconds", self.max_startup_time.as_secs()); + anyhow::bail!("Qwen3-VL startup timeout"); + } + + pub async fn update_last_request_time(&self) { + let mut last_request = self.last_request_time.lock().await; + *last_request = Instant::now(); + debug!("Updated last request time"); + } + + pub async fn get_status(&self) -> Result { + let is_running = self.is_running().await?; + let last_request = self.last_request_time.lock().await.clone(); + + Ok(QwenVLStatus { + running: is_running, + port: self.port, + model_path: self.model_path.display().to_string(), + last_request: last_request.elapsed().as_secs(), + pid_file: self.pid_file.display().to_string(), + log_file: self.log_file.display().to_string(), + }) + } + + pub async fn auto_stop_if_idle(&self, idle_timeout: Duration) -> Result<()> { + let last_request = self.last_request_time.lock().await.clone(); + + if last_request.elapsed() > idle_timeout && self.is_running().await? { + info!("Qwen3-VL idle for {} seconds, stopping server", last_request.elapsed().as_secs()); + self.stop_server().await?; + } + + Ok(()) + } +} + +#[derive(Debug, Clone, serde::Serialize)] +pub struct QwenVLStatus { + pub running: bool, + pub port: u16, + pub model_path: String, + pub last_request: u64, + pub pid_file: String, + pub log_file: String, +} + +impl Default for QwenVLManager { + fn default() -> Self { + Self::new() + } +} \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index 3c8c70d..00e8d7d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -124,6 +124,17 @@ async fn main() -> Result<()> { } => { handle_n8n(action, api_key, label, expires_in_days).await?; } + Commands::Detect { + image, + objects, + cascade, + threshold, + } => { + cli::vision::handle_detect_command(image, objects, cascade, threshold).await?; + } + Commands::Vision(cmd) => { + cli::vision::handle_vision_command(cmd).await?; + } } Ok(()) diff --git a/src/worker/job_worker.rs b/src/worker/job_worker.rs index 55b0023..a0092b4 100644 --- a/src/worker/job_worker.rs +++ b/src/worker/job_worker.rs @@ -471,12 +471,19 @@ impl JobWorker { ); continue; } - - debug!("Output file not found, checking result_map for {}", processor_type.as_str()); + + debug!( + "Output file not found, checking result_map for {}", + processor_type.as_str() + ); // Check if processor already in terminal state if let Some(result) = result_map.get(processor_type) { - debug!("Found existing result for {}: status={:?}", processor_type.as_str(), result.status); + debug!( + "Found existing result for {}: status={:?}", + processor_type.as_str(), + result.status + ); match result.status { ProcessorJobStatus::Completed => { info!( @@ -606,7 +613,10 @@ impl JobWorker { } } - debug!("Checking capacity before starting {}", processor_type.as_str()); + debug!( + "Checking capacity before starting {}", + processor_type.as_str() + ); // Check capacity before starting processor if !self.processor_pool.can_start().await { info!( @@ -679,7 +689,11 @@ impl JobWorker { .upsert_processor_result(job.id, *processor_type, &job.uuid, "pending") .await?; - info!("Upserted processor_result for {}: id={}", processor_type.as_str(), processor_result_id); + info!( + "Upserted processor_result for {}: id={}", + processor_type.as_str(), + processor_result_id + ); self.redis .update_worker_processor_status( @@ -737,12 +751,10 @@ impl JobWorker { let fu = uuid; // Only check conditions relevant to the job's processors - let has_asr_or_asrx = job_processors.is_empty() - || job_processors.iter().any(|p| p == "asrx" || p == "asr"); - let has_cut = job_processors.is_empty() - || job_processors.iter().any(|p| p == "cut"); - let has_face = job_processors.is_empty() - || job_processors.iter().any(|p| p == "face"); + let has_asr_or_asrx = + job_processors.is_empty() || job_processors.iter().any(|p| p == "asrx" || p == "asr"); + let has_cut = job_processors.is_empty() || job_processors.iter().any(|p| p == "cut"); + let has_face = job_processors.is_empty() || job_processors.iter().any(|p| p == "face"); let rule1 = !has_asr_or_asrx || check!(&format!( @@ -852,11 +864,9 @@ impl JobWorker { if has_asrx { // Guard: only spawn Rule 1 if sentence chunks don't exist yet let chunk_t = schema::table_name("chunk"); - let already_spawned: bool = sqlx::query_scalar::<_, i64>( - &format!( - "SELECT 1 FROM {chunk_t} WHERE file_uuid = $1 AND chunk_type = 'sentence' LIMIT 1" - ), - ) + let already_spawned: bool = sqlx::query_scalar::<_, i64>(&format!( + "SELECT 1 FROM {chunk_t} WHERE file_uuid = $1 AND chunk_type = 'sentence' LIMIT 1" + )) .bind(uuid) .fetch_optional(self.db.pool()) .await? @@ -864,66 +874,70 @@ impl JobWorker { > 0; if already_spawned { - info!( - "✅ Rule 1 already completed for {}, skipping spawn", - uuid - ); + info!("✅ Rule 1 already completed for {}, skipping spawn", uuid); } else { info!("📝 Prerequisites met for Rule 1 Chunking. Starting ingestion..."); let db_clone = self.db.clone(); let uuid_clone = uuid.to_string(); tokio::spawn(async move { - match db_clone.get_video_by_uuid(&uuid_clone).await { - Ok(Some(video)) => { - let fps = video.fps; - match rule1_ingest::execute_rule1(&db_clone, &uuid_clone, fps).await { - Ok(count) => { - info!("✅ Rule 1 Ingestion completed: {} chunks inserted.", count); - if count > 0 { + match db_clone.get_video_by_uuid(&uuid_clone).await { + Ok(Some(video)) => { + let fps = video.fps; + match rule1_ingest::execute_rule1(&db_clone, &uuid_clone, fps).await { + Ok(count) => { info!( - "📝 Starting automatic vectorize for {} chunks...", + "✅ Rule 1 Ingestion completed: {} chunks inserted.", count ); - if let Err(e) = - Self::vectorize_chunks(&db_clone, &uuid_clone).await - { - error!( - "❌ Auto-vectorize failed for {}: {}", - uuid_clone, e + if count > 0 { + info!( + "📝 Starting automatic vectorize for {} chunks...", + count ); + if let Err(e) = + Self::vectorize_chunks(&db_clone, &uuid_clone).await + { + error!( + "❌ Auto-vectorize failed for {}: {}", + uuid_clone, e + ); + } + } + info!("📦 Phase 1 release packaging..."); + let executor = + match crate::core::processor::PythonExecutor::new() { + Ok(ex) => ex, + Err(e) => { + error!( + "Failed PythonExecutor for release pack: {}", + e + ); + return; + } + }; + match executor + .run( + "release_pack.py", + &["--phase", "1", "--file-uuid", &uuid_clone], + None, + "RELEASE_P1", + Some(std::time::Duration::from_secs(120)), + ) + .await + { + Ok(()) => { + info!("✅ Phase 1 release packaged for {}", uuid_clone) + } + Err(e) => error!("❌ Phase 1 release pack failed: {}", e), } } - info!("📦 Phase 1 release packaging..."); - let executor = match crate::core::processor::PythonExecutor::new() { - Ok(ex) => ex, - Err(e) => { - error!("Failed PythonExecutor for release pack: {}", e); - return; - } - }; - match executor - .run( - "release_pack.py", - &["--phase", "1", "--file-uuid", &uuid_clone], - None, - "RELEASE_P1", - Some(std::time::Duration::from_secs(120)), - ) - .await - { - Ok(()) => { - info!("✅ Phase 1 release packaged for {}", uuid_clone) - } - Err(e) => error!("❌ Phase 1 release pack failed: {}", e), - } + Err(e) => error!("❌ Rule 1 Ingestion failed: {}", e), } - Err(e) => error!("❌ Rule 1 Ingestion failed: {}", e), } + Ok(None) => error!("Video not found for chunking: {}", uuid_clone), + Err(e) => error!("Failed to get video info for chunking: {}", e), } - Ok(None) => error!("Video not found for chunking: {}", uuid_clone), - Err(e) => error!("Failed to get video info for chunking: {}", e), - } - }); + }); } } diff --git a/src/worker/processor.rs b/src/worker/processor.rs index 9702404..d629d53 100644 --- a/src/worker/processor.rs +++ b/src/worker/processor.rs @@ -1089,8 +1089,8 @@ impl ProcessorPool { segment.start_time, segment.end_time, segment.text.clone(), - None::, // chunk_id: unknown yet, filled later - 0.0, // confidence: updated after binding + None::, // chunk_id: unknown yet, filled later + 0.0, // confidence: updated after binding )); }