//! Visual chunk search functionality. //! //! This module provides search capabilities for visual chunks based on: //! - Object classes (e.g., "person", "car", "envelope") //! - Confidence thresholds //! - Object counts //! - Spatial density //! - Object relationships use crate::core::chunk::types::{Chunk, ChunkRule, ChunkType}; use crate::core::db::PostgresDb; use anyhow::Result; use serde_json::Value; use std::collections::HashMap; /// Criteria for searching visual chunks #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] pub struct VisualChunkSearchCriteria { /// Minimum average confidence across frames pub min_avg_confidence: Option, /// Minimum number of frames with objects pub min_frames_with_objects: Option, /// Minimum number of unique object classes pub min_unique_classes: Option, /// Specific object classes to include (empty means all) pub required_classes: Vec, /// Object class counts to filter by pub class_counts: HashMap, /// Time range (optional) pub time_range: Option<(f64, f64)>, } impl Default for VisualChunkSearchCriteria { fn default() -> Self { Self { min_avg_confidence: None, min_frames_with_objects: None, min_unique_classes: None, required_classes: Vec::new(), class_counts: HashMap::new(), time_range: None, } } } /// Search visual chunks based on criteria pub async fn search_visual_chunks( db: &PostgresDb, uuid: &str, criteria: &VisualChunkSearchCriteria, ) -> Result> { // First, get all visual chunks for this video let all_chunks = get_visual_chunks_by_uuid(db, uuid).await?; // Apply filters let filtered_chunks: Vec = all_chunks .into_iter() .filter(|chunk| { // Check min avg confidence if let Some(min_avg_confidence) = criteria.min_avg_confidence { if let Some(content) = &chunk.content.as_object() { if let Some(metadata) = content.get("metadata") { if let Some(avg_confidence) = metadata.get("avg_confidence") { if let Some(conf) = avg_confidence.as_f64() { if conf < min_avg_confidence as f64 { return false; } } } } } } // Check min frames with objects if let Some(min_frames) = criteria.min_frames_with_objects { if let Some(stats) = &chunk.visual_stats { if let Some(frames_with_objects) = stats.get("frames_with_objects") { if let Some(count) = frames_with_objects.as_u64() { if count < min_frames as u64 { return false; } } } } } // Check min unique classes if let Some(min_unique_classes) = criteria.min_unique_classes { if let Some(content) = &chunk.content.as_object() { if let Some(metadata) = content.get("metadata") { if let Some(unique_classes) = metadata.get("unique_classes") { if let Some(classes) = unique_classes.as_array() { if (classes.len() as u32) < min_unique_classes { return false; } } } } } } // Check required classes if !criteria.required_classes.is_empty() { if let Some(content) = &chunk.content.as_object() { if let Some(keyframe_objects) = content.get("keyframe_objects") { if let Some(objects) = keyframe_objects.as_array() { let mut found_all = true; for required_class in &criteria.required_classes { let mut found = false; for obj in objects { if let Some(class_name) = obj.get("class_name") { if let Some(class_str) = class_name.as_str() { if class_str == required_class { found = true; break; } } } } if !found { found_all = false; break; } } if !found_all { return false; } } } } } // Check class counts if !criteria.class_counts.is_empty() { if let Some(content) = &chunk.content.as_object() { if let Some(metadata) = content.get("metadata") { if let Some(object_counts) = metadata.get("object_counts") { for (class, (min, max)) in &criteria.class_counts { if let Some(count_value) = object_counts.get(class) { if let Some(count) = count_value.as_u64() { if *min > 0 && count < *min as u64 { return false; } if *max < u32::MAX && count > *max as u64 { return false; } } } else if *min > 0 { return false; } } } else if criteria.class_counts.values().any(|(min, _)| *min > 0) { return false; } } } } // Check time range if let Some((start_time, end_time)) = criteria.time_range { // Calculate chunk time from frames let chunk_start_time = chunk.start_frame as f64 / chunk.fps; let chunk_end_time = chunk.end_frame as f64 / chunk.fps; if chunk_start_time < start_time || chunk_end_time > end_time { return false; } } true }) .collect(); Ok(filtered_chunks) } /// Get all visual chunks for a video UUID async fn get_visual_chunks_by_uuid(db: &PostgresDb, uuid: &str) -> Result> { let sql = format!( "SELECT file_id, uuid, chunk_id, chunk_index, chunk_type, fps, start_frame, end_frame, text_content, content, metadata, vector_id, visual_stats FROM chunks WHERE uuid = '{}' AND chunk_type = 'visual' ORDER BY start_frame ASC", uuid.replace('\'', "''") ); let rows: Vec<( i32, // file_id String, // uuid String, // chunk_id i32, // chunk_index String, // chunk_type f64, // fps i64, // start_frame i64, // end_frame Option, // text_content Value, // content Option, // metadata Option, // vector_id Option, // visual_stats )> = sqlx::query_as(&sql).fetch_all(db.pool()).await?; let mut chunks = Vec::new(); for row in rows { let chunk_type = match row.4.as_str() { "visual" => ChunkType::Visual, "sentence" => ChunkType::Sentence, "time_based" => ChunkType::TimeBased, "cut" => ChunkType::Cut, "trace" => ChunkType::Trace, "story" => ChunkType::Story, _ => ChunkType::TimeBased, }; // Calculate frame_count let frame_count = (row.7 - row.6) as i32; chunks.push(Chunk { file_id: row.0, uuid: row.1, chunk_id: row.2, chunk_index: row.3 as u32, chunk_type, rule: ChunkRule::Rule2, // Visual chunks use Rule2 fps: row.5, start_frame: row.6, end_frame: row.7, text_content: row.8, content: row.9, metadata: row.10, vector_id: row.11, frame_count, pre_chunk_ids: Vec::new(), parent_chunk_id: None, child_chunk_ids: Vec::new(), visual_stats: row.12, }); } Ok(chunks) } /// Search visual chunks by object class pub async fn search_visual_chunks_by_class( db: &PostgresDb, uuid: &str, object_class: &str, min_count: Option, max_count: Option, ) -> Result> { let all_chunks = get_visual_chunks_by_uuid(db, uuid).await?; let filtered_chunks: Vec = all_chunks .into_iter() .filter(|chunk| { // Check if chunk contains the object class let mut contains_class = false; if let Some(content) = &chunk.content.as_object() { if let Some(keyframe_objects) = content.get("keyframe_objects") { if let Some(objects) = keyframe_objects.as_array() { for obj in objects { if let Some(class_name) = obj.get("class_name") { if let Some(class_str) = class_name.as_str() { if class_str == object_class { contains_class = true; break; } } } } } } } if !contains_class { return false; } // Check count in visual_stats if let Some(stats) = &chunk.visual_stats { if let Some(count) = stats.get(object_class) { if let Some(c) = count.as_u64() { if let Some(min) = min_count { if c < min as u64 { return false; } } if let Some(max) = max_count { if c > max as u64 { return false; } } } } } true }) .collect(); Ok(filtered_chunks) } /// Search visual chunks by spatial density pub async fn search_visual_chunks_by_density( db: &PostgresDb, uuid: &str, min_density: f32, max_density: Option, ) -> Result> { let all_chunks = get_visual_chunks_by_uuid(db, uuid).await?; let filtered_chunks: Vec = all_chunks .into_iter() .filter(|chunk| { if let Some(content) = &chunk.content.as_object() { if let Some(metadata) = content.get("metadata") { if let Some(density_value) = metadata.get("spatial_density") { if let Some(density) = density_value.as_f64() { if density < min_density as f64 { return false; } if let Some(max_dens) = max_density { if density > max_dens as f64 { return false; } } return true; } } } } false }) .collect(); Ok(filtered_chunks) } /// Find chunks containing specific object combinations pub async fn search_visual_chunks_by_combination( db: &PostgresDb, uuid: &str, combination: &[(&str, u32)], ) -> Result> { let all_chunks = get_visual_chunks_by_uuid(db, uuid).await?; let filtered_chunks: Vec = all_chunks .into_iter() .filter(|chunk| { // Check if all required combinations are present for (object_class, min_count) in combination { let mut found = false; if let Some(stats) = &chunk.visual_stats { if let Some(object_counts) = stats.get("object_counts") { if let Some(count_value) = object_counts.get(*object_class) { if let Some(count) = count_value.as_u64() { if count >= *min_count as u64 { found = true; } } } } } if !found { return false; } } true }) .collect(); Ok(filtered_chunks) } /// Get visual chunk statistics pub async fn get_visual_chunk_statistics( db: &PostgresDb, uuid: &str, ) -> Result> { let sql = format!( "SELECT COUNT(*) as total_chunks, AVG((content->'metadata'->>'avg_confidence')::float) as avg_confidence, MIN((content->'metadata'->>'avg_confidence')::float) as min_confidence, MAX((content->'metadata'->>'avg_confidence')::float) as max_confidence, SUM((content->'metadata'->>'object_count')::int) as total_objects, AVG((content->'metadata'->>'spatial_density')::float) as avg_density FROM chunks WHERE uuid = '{}' AND chunk_type = 'visual'", uuid.replace('\'', "''") ); let row: (i64, Option, Option, Option, i64, Option) = sqlx::query_as(&sql).fetch_one(db.pool()).await?; let mut stats = HashMap::new(); stats.insert("total_chunks".to_string(), Value::from(row.0)); stats.insert( "avg_confidence".to_string(), Value::from(row.1.unwrap_or(0.0)), ); stats.insert( "min_confidence".to_string(), Value::from(row.2.unwrap_or(0.0)), ); stats.insert( "max_confidence".to_string(), Value::from(row.3.unwrap_or(0.0)), ); stats.insert("total_objects".to_string(), Value::from(row.4)); stats.insert("avg_density".to_string(), Value::from(row.5.unwrap_or(0.0))); Ok(stats) } #[cfg(test)] mod tests { use super::*; #[test] fn test_visual_chunk_search_criteria_default() { let criteria = VisualChunkSearchCriteria::default(); assert_eq!(criteria.min_avg_confidence, None); assert_eq!(criteria.min_frames_with_objects, None); assert_eq!(criteria.min_unique_classes, None); assert!(criteria.required_classes.is_empty()); assert!(criteria.class_counts.is_empty()); assert_eq!(criteria.time_range, None); } #[test] fn test_visual_chunk_search_criteria_with_values() { let mut criteria = VisualChunkSearchCriteria::default(); criteria.min_avg_confidence = Some(0.8); criteria.min_frames_with_objects = Some(10); criteria.min_unique_classes = Some(3); criteria.required_classes = vec!["person".to_string(), "car".to_string()]; criteria.time_range = Some((0.0, 60.0)); assert_eq!(criteria.min_avg_confidence, Some(0.8)); assert_eq!(criteria.min_frames_with_objects, Some(10)); assert_eq!(criteria.min_unique_classes, Some(3)); assert_eq!(criteria.required_classes.len(), 2); assert_eq!(criteria.time_range, Some((0.0, 60.0))); } #[test] fn test_visual_chunk_search_criteria_serialization() { let criteria = VisualChunkSearchCriteria { min_avg_confidence: Some(0.85), min_frames_with_objects: Some(5), min_unique_classes: Some(2), required_classes: vec!["person".to_string()], class_counts: HashMap::new(), time_range: Some((10.0, 30.0)), }; let json = serde_json::to_string(&criteria).unwrap(); assert!(json.contains("min_avg_confidence")); assert!(json.contains("required_classes")); let deserialized: VisualChunkSearchCriteria = serde_json::from_str(&json).unwrap(); assert_eq!(deserialized.min_avg_confidence, Some(0.85)); assert_eq!(deserialized.required_classes.len(), 1); } #[test] fn test_visual_chunk_search_criteria_with_class_counts() { let mut criteria = VisualChunkSearchCriteria::default(); criteria.class_counts.insert("person".to_string(), (5, 20)); criteria.class_counts.insert("car".to_string(), (1, 10)); assert_eq!(criteria.class_counts.len(), 2); assert_eq!(criteria.class_counts.get("person"), Some(&(5, 20))); assert_eq!(criteria.class_counts.get("car"), Some(&(1, 10))); } #[test] fn test_chunk_type_conversion() { // Test chunk type string to enum conversion logic let test_cases = vec![ ("visual", ChunkType::Visual), ("sentence", ChunkType::Sentence), ("time_based", ChunkType::TimeBased), ("cut", ChunkType::Cut), ("trace", ChunkType::Trace), ("story", ChunkType::Story), ("unknown", ChunkType::TimeBased), // Default fallback ]; for (input, expected) in test_cases { let chunk_type = match input { "visual" => ChunkType::Visual, "sentence" => ChunkType::Sentence, "time_based" => ChunkType::TimeBased, "cut" => ChunkType::Cut, "trace" => ChunkType::Trace, "story" => ChunkType::Story, _ => ChunkType::TimeBased, }; assert_eq!(chunk_type, expected); } } }