Files
momentry_core/src/api/visual_chunk_search.rs

505 lines
18 KiB
Rust

//! Visual chunk search functionality.
//!
//! This module provides search capabilities for visual chunks based on:
//! - Object classes (e.g., "person", "car", "envelope")
//! - Confidence thresholds
//! - Object counts
//! - Spatial density
//! - Object relationships
use crate::core::chunk::types::{Chunk, ChunkRule, ChunkType};
use crate::core::db::PostgresDb;
use anyhow::Result;
use serde_json::Value;
use std::collections::HashMap;
/// Criteria for searching visual chunks
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct VisualChunkSearchCriteria {
/// Minimum average confidence across frames
pub min_avg_confidence: Option<f32>,
/// Minimum number of frames with objects
pub min_frames_with_objects: Option<u32>,
/// Minimum number of unique object classes
pub min_unique_classes: Option<u32>,
/// Specific object classes to include (empty means all)
pub required_classes: Vec<String>,
/// Object class counts to filter by
pub class_counts: HashMap<String, (u32, u32)>,
/// Time range (optional)
pub time_range: Option<(f64, f64)>,
}
impl Default for VisualChunkSearchCriteria {
fn default() -> Self {
Self {
min_avg_confidence: None,
min_frames_with_objects: None,
min_unique_classes: None,
required_classes: Vec::new(),
class_counts: HashMap::new(),
time_range: None,
}
}
}
/// Search visual chunks based on criteria
pub async fn search_visual_chunks(
db: &PostgresDb,
uuid: &str,
criteria: &VisualChunkSearchCriteria,
) -> Result<Vec<Chunk>> {
// First, get all visual chunks for this video
let all_chunks = get_visual_chunks_by_uuid(db, uuid).await?;
// Apply filters
let filtered_chunks: Vec<Chunk> = all_chunks
.into_iter()
.filter(|chunk| {
// Check min avg confidence
if let Some(min_avg_confidence) = criteria.min_avg_confidence {
if let Some(content) = &chunk.content.as_object() {
if let Some(metadata) = content.get("metadata") {
if let Some(avg_confidence) = metadata.get("avg_confidence") {
if let Some(conf) = avg_confidence.as_f64() {
if conf < min_avg_confidence as f64 {
return false;
}
}
}
}
}
}
// Check min frames with objects
if let Some(min_frames) = criteria.min_frames_with_objects {
if let Some(stats) = &chunk.visual_stats {
if let Some(frames_with_objects) = stats.get("frames_with_objects") {
if let Some(count) = frames_with_objects.as_u64() {
if count < min_frames as u64 {
return false;
}
}
}
}
}
// Check min unique classes
if let Some(min_unique_classes) = criteria.min_unique_classes {
if let Some(content) = &chunk.content.as_object() {
if let Some(metadata) = content.get("metadata") {
if let Some(unique_classes) = metadata.get("unique_classes") {
if let Some(classes) = unique_classes.as_array() {
if (classes.len() as u32) < min_unique_classes {
return false;
}
}
}
}
}
}
// Check required classes
if !criteria.required_classes.is_empty() {
if let Some(content) = &chunk.content.as_object() {
if let Some(keyframe_objects) = content.get("keyframe_objects") {
if let Some(objects) = keyframe_objects.as_array() {
let mut found_all = true;
for required_class in &criteria.required_classes {
let mut found = false;
for obj in objects {
if let Some(class_name) = obj.get("class_name") {
if let Some(class_str) = class_name.as_str() {
if class_str == required_class {
found = true;
break;
}
}
}
}
if !found {
found_all = false;
break;
}
}
if !found_all {
return false;
}
}
}
}
}
// Check class counts
if !criteria.class_counts.is_empty() {
if let Some(content) = &chunk.content.as_object() {
if let Some(metadata) = content.get("metadata") {
if let Some(object_counts) = metadata.get("object_counts") {
for (class, (min, max)) in &criteria.class_counts {
if let Some(count_value) = object_counts.get(class) {
if let Some(count) = count_value.as_u64() {
if *min > 0 && count < *min as u64 {
return false;
}
if *max < u32::MAX && count > *max as u64 {
return false;
}
}
} else if *min > 0 {
return false;
}
}
} else if criteria.class_counts.values().any(|(min, _)| *min > 0) {
return false;
}
}
}
}
// Check time range
if let Some((start_time, end_time)) = criteria.time_range {
// Calculate chunk time from frames
let chunk_start_time = chunk.start_frame as f64 / chunk.fps;
let chunk_end_time = chunk.end_frame as f64 / chunk.fps;
if chunk_start_time < start_time || chunk_end_time > end_time {
return false;
}
}
true
})
.collect();
Ok(filtered_chunks)
}
/// Get all visual chunks for a video UUID
async fn get_visual_chunks_by_uuid(db: &PostgresDb, uuid: &str) -> Result<Vec<Chunk>> {
let sql = format!(
"SELECT file_id, uuid, chunk_id, chunk_index, chunk_type, fps, start_frame, end_frame, text_content, content, metadata, vector_id, visual_stats FROM chunks WHERE uuid = '{}' AND chunk_type = 'visual' ORDER BY start_frame ASC",
uuid.replace('\'', "''")
);
let rows: Vec<(
i32, // file_id
String, // uuid
String, // chunk_id
i32, // chunk_index
String, // chunk_type
f64, // fps
i64, // start_frame
i64, // end_frame
Option<String>, // text_content
Value, // content
Option<Value>, // metadata
Option<String>, // vector_id
Option<Value>, // visual_stats
)> = sqlx::query_as(&sql).fetch_all(db.pool()).await?;
let mut chunks = Vec::new();
for row in rows {
let chunk_type = match row.4.as_str() {
"visual" => ChunkType::Visual,
"sentence" => ChunkType::Sentence,
"time_based" => ChunkType::TimeBased,
"cut" => ChunkType::Cut,
"trace" => ChunkType::Trace,
"story" => ChunkType::Story,
_ => ChunkType::TimeBased,
};
// Calculate frame_count
let frame_count = (row.7 - row.6) as i32;
chunks.push(Chunk {
file_id: row.0,
uuid: row.1,
chunk_id: row.2,
chunk_index: row.3 as u32,
chunk_type,
rule: ChunkRule::Rule2, // Visual chunks use Rule2
fps: row.5,
start_frame: row.6,
end_frame: row.7,
text_content: row.8,
content: row.9,
metadata: row.10,
vector_id: row.11,
frame_count,
pre_chunk_ids: Vec::new(),
parent_chunk_id: None,
child_chunk_ids: Vec::new(),
visual_stats: row.12,
});
}
Ok(chunks)
}
/// Search visual chunks by object class
pub async fn search_visual_chunks_by_class(
db: &PostgresDb,
uuid: &str,
object_class: &str,
min_count: Option<u32>,
max_count: Option<u32>,
) -> Result<Vec<Chunk>> {
let all_chunks = get_visual_chunks_by_uuid(db, uuid).await?;
let filtered_chunks: Vec<Chunk> = all_chunks
.into_iter()
.filter(|chunk| {
// Check if chunk contains the object class
let mut contains_class = false;
if let Some(content) = &chunk.content.as_object() {
if let Some(keyframe_objects) = content.get("keyframe_objects") {
if let Some(objects) = keyframe_objects.as_array() {
for obj in objects {
if let Some(class_name) = obj.get("class_name") {
if let Some(class_str) = class_name.as_str() {
if class_str == object_class {
contains_class = true;
break;
}
}
}
}
}
}
}
if !contains_class {
return false;
}
// Check count in visual_stats
if let Some(stats) = &chunk.visual_stats {
if let Some(count) = stats.get(object_class) {
if let Some(c) = count.as_u64() {
if let Some(min) = min_count {
if c < min as u64 {
return false;
}
}
if let Some(max) = max_count {
if c > max as u64 {
return false;
}
}
}
}
}
true
})
.collect();
Ok(filtered_chunks)
}
/// Search visual chunks by spatial density
pub async fn search_visual_chunks_by_density(
db: &PostgresDb,
uuid: &str,
min_density: f32,
max_density: Option<f32>,
) -> Result<Vec<Chunk>> {
let all_chunks = get_visual_chunks_by_uuid(db, uuid).await?;
let filtered_chunks: Vec<Chunk> = all_chunks
.into_iter()
.filter(|chunk| {
if let Some(content) = &chunk.content.as_object() {
if let Some(metadata) = content.get("metadata") {
if let Some(density_value) = metadata.get("spatial_density") {
if let Some(density) = density_value.as_f64() {
if density < min_density as f64 {
return false;
}
if let Some(max_dens) = max_density {
if density > max_dens as f64 {
return false;
}
}
return true;
}
}
}
}
false
})
.collect();
Ok(filtered_chunks)
}
/// Find chunks containing specific object combinations
pub async fn search_visual_chunks_by_combination(
db: &PostgresDb,
uuid: &str,
combination: &[(&str, u32)],
) -> Result<Vec<Chunk>> {
let all_chunks = get_visual_chunks_by_uuid(db, uuid).await?;
let filtered_chunks: Vec<Chunk> = all_chunks
.into_iter()
.filter(|chunk| {
// Check if all required combinations are present
for (object_class, min_count) in combination {
let mut found = false;
if let Some(stats) = &chunk.visual_stats {
if let Some(object_counts) = stats.get("object_counts") {
if let Some(count_value) = object_counts.get(*object_class) {
if let Some(count) = count_value.as_u64() {
if count >= *min_count as u64 {
found = true;
}
}
}
}
}
if !found {
return false;
}
}
true
})
.collect();
Ok(filtered_chunks)
}
/// Get visual chunk statistics
pub async fn get_visual_chunk_statistics(
db: &PostgresDb,
uuid: &str,
) -> Result<HashMap<String, Value>> {
let sql = format!(
"SELECT
COUNT(*) as total_chunks,
AVG((content->'metadata'->>'avg_confidence')::float) as avg_confidence,
MIN((content->'metadata'->>'avg_confidence')::float) as min_confidence,
MAX((content->'metadata'->>'avg_confidence')::float) as max_confidence,
SUM((content->'metadata'->>'object_count')::int) as total_objects,
AVG((content->'metadata'->>'spatial_density')::float) as avg_density
FROM chunks
WHERE uuid = '{}'
AND chunk_type = 'visual'",
uuid.replace('\'', "''")
);
let row: (i64, Option<f64>, Option<f64>, Option<f64>, i64, Option<f64>) =
sqlx::query_as(&sql).fetch_one(db.pool()).await?;
let mut stats = HashMap::new();
stats.insert("total_chunks".to_string(), Value::from(row.0));
stats.insert(
"avg_confidence".to_string(),
Value::from(row.1.unwrap_or(0.0)),
);
stats.insert(
"min_confidence".to_string(),
Value::from(row.2.unwrap_or(0.0)),
);
stats.insert(
"max_confidence".to_string(),
Value::from(row.3.unwrap_or(0.0)),
);
stats.insert("total_objects".to_string(), Value::from(row.4));
stats.insert("avg_density".to_string(), Value::from(row.5.unwrap_or(0.0)));
Ok(stats)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_visual_chunk_search_criteria_default() {
let criteria = VisualChunkSearchCriteria::default();
assert_eq!(criteria.min_avg_confidence, None);
assert_eq!(criteria.min_frames_with_objects, None);
assert_eq!(criteria.min_unique_classes, None);
assert!(criteria.required_classes.is_empty());
assert!(criteria.class_counts.is_empty());
assert_eq!(criteria.time_range, None);
}
#[test]
fn test_visual_chunk_search_criteria_with_values() {
let mut criteria = VisualChunkSearchCriteria::default();
criteria.min_avg_confidence = Some(0.8);
criteria.min_frames_with_objects = Some(10);
criteria.min_unique_classes = Some(3);
criteria.required_classes = vec!["person".to_string(), "car".to_string()];
criteria.time_range = Some((0.0, 60.0));
assert_eq!(criteria.min_avg_confidence, Some(0.8));
assert_eq!(criteria.min_frames_with_objects, Some(10));
assert_eq!(criteria.min_unique_classes, Some(3));
assert_eq!(criteria.required_classes.len(), 2);
assert_eq!(criteria.time_range, Some((0.0, 60.0)));
}
#[test]
fn test_visual_chunk_search_criteria_serialization() {
let criteria = VisualChunkSearchCriteria {
min_avg_confidence: Some(0.85),
min_frames_with_objects: Some(5),
min_unique_classes: Some(2),
required_classes: vec!["person".to_string()],
class_counts: HashMap::new(),
time_range: Some((10.0, 30.0)),
};
let json = serde_json::to_string(&criteria).unwrap();
assert!(json.contains("min_avg_confidence"));
assert!(json.contains("required_classes"));
let deserialized: VisualChunkSearchCriteria = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.min_avg_confidence, Some(0.85));
assert_eq!(deserialized.required_classes.len(), 1);
}
#[test]
fn test_visual_chunk_search_criteria_with_class_counts() {
let mut criteria = VisualChunkSearchCriteria::default();
criteria.class_counts.insert("person".to_string(), (5, 20));
criteria.class_counts.insert("car".to_string(), (1, 10));
assert_eq!(criteria.class_counts.len(), 2);
assert_eq!(criteria.class_counts.get("person"), Some(&(5, 20)));
assert_eq!(criteria.class_counts.get("car"), Some(&(1, 10)));
}
#[test]
fn test_chunk_type_conversion() {
// Test chunk type string to enum conversion logic
let test_cases = vec![
("visual", ChunkType::Visual),
("sentence", ChunkType::Sentence),
("time_based", ChunkType::TimeBased),
("cut", ChunkType::Cut),
("trace", ChunkType::Trace),
("story", ChunkType::Story),
("unknown", ChunkType::TimeBased), // Default fallback
];
for (input, expected) in test_cases {
let chunk_type = match input {
"visual" => ChunkType::Visual,
"sentence" => ChunkType::Sentence,
"time_based" => ChunkType::TimeBased,
"cut" => ChunkType::Cut,
"trace" => ChunkType::Trace,
"story" => ChunkType::Story,
_ => ChunkType::TimeBased,
};
assert_eq!(chunk_type, expected);
}
}
}