505 lines
18 KiB
Rust
505 lines
18 KiB
Rust
//! Visual chunk search functionality.
|
|
//!
|
|
//! This module provides search capabilities for visual chunks based on:
|
|
//! - Object classes (e.g., "person", "car", "envelope")
|
|
//! - Confidence thresholds
|
|
//! - Object counts
|
|
//! - Spatial density
|
|
//! - Object relationships
|
|
|
|
use crate::core::chunk::types::{Chunk, ChunkRule, ChunkType};
|
|
use crate::core::db::PostgresDb;
|
|
use anyhow::Result;
|
|
use serde_json::Value;
|
|
use std::collections::HashMap;
|
|
|
|
/// Criteria for searching visual chunks
|
|
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
|
pub struct VisualChunkSearchCriteria {
|
|
/// Minimum average confidence across frames
|
|
pub min_avg_confidence: Option<f32>,
|
|
/// Minimum number of frames with objects
|
|
pub min_frames_with_objects: Option<u32>,
|
|
/// Minimum number of unique object classes
|
|
pub min_unique_classes: Option<u32>,
|
|
/// Specific object classes to include (empty means all)
|
|
pub required_classes: Vec<String>,
|
|
/// Object class counts to filter by
|
|
pub class_counts: HashMap<String, (u32, u32)>,
|
|
/// Time range (optional)
|
|
pub time_range: Option<(f64, f64)>,
|
|
}
|
|
|
|
impl Default for VisualChunkSearchCriteria {
|
|
fn default() -> Self {
|
|
Self {
|
|
min_avg_confidence: None,
|
|
min_frames_with_objects: None,
|
|
min_unique_classes: None,
|
|
required_classes: Vec::new(),
|
|
class_counts: HashMap::new(),
|
|
time_range: None,
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Search visual chunks based on criteria
|
|
pub async fn search_visual_chunks(
|
|
db: &PostgresDb,
|
|
uuid: &str,
|
|
criteria: &VisualChunkSearchCriteria,
|
|
) -> Result<Vec<Chunk>> {
|
|
// First, get all visual chunks for this video
|
|
let all_chunks = get_visual_chunks_by_uuid(db, uuid).await?;
|
|
|
|
// Apply filters
|
|
let filtered_chunks: Vec<Chunk> = all_chunks
|
|
.into_iter()
|
|
.filter(|chunk| {
|
|
// Check min avg confidence
|
|
if let Some(min_avg_confidence) = criteria.min_avg_confidence {
|
|
if let Some(content) = &chunk.content.as_object() {
|
|
if let Some(metadata) = content.get("metadata") {
|
|
if let Some(avg_confidence) = metadata.get("avg_confidence") {
|
|
if let Some(conf) = avg_confidence.as_f64() {
|
|
if conf < min_avg_confidence as f64 {
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Check min frames with objects
|
|
if let Some(min_frames) = criteria.min_frames_with_objects {
|
|
if let Some(stats) = &chunk.visual_stats {
|
|
if let Some(frames_with_objects) = stats.get("frames_with_objects") {
|
|
if let Some(count) = frames_with_objects.as_u64() {
|
|
if count < min_frames as u64 {
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Check min unique classes
|
|
if let Some(min_unique_classes) = criteria.min_unique_classes {
|
|
if let Some(content) = &chunk.content.as_object() {
|
|
if let Some(metadata) = content.get("metadata") {
|
|
if let Some(unique_classes) = metadata.get("unique_classes") {
|
|
if let Some(classes) = unique_classes.as_array() {
|
|
if (classes.len() as u32) < min_unique_classes {
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Check required classes
|
|
if !criteria.required_classes.is_empty() {
|
|
if let Some(content) = &chunk.content.as_object() {
|
|
if let Some(keyframe_objects) = content.get("keyframe_objects") {
|
|
if let Some(objects) = keyframe_objects.as_array() {
|
|
let mut found_all = true;
|
|
for required_class in &criteria.required_classes {
|
|
let mut found = false;
|
|
for obj in objects {
|
|
if let Some(class_name) = obj.get("class_name") {
|
|
if let Some(class_str) = class_name.as_str() {
|
|
if class_str == required_class {
|
|
found = true;
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
if !found {
|
|
found_all = false;
|
|
break;
|
|
}
|
|
}
|
|
if !found_all {
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Check class counts
|
|
if !criteria.class_counts.is_empty() {
|
|
if let Some(content) = &chunk.content.as_object() {
|
|
if let Some(metadata) = content.get("metadata") {
|
|
if let Some(object_counts) = metadata.get("object_counts") {
|
|
for (class, (min, max)) in &criteria.class_counts {
|
|
if let Some(count_value) = object_counts.get(class) {
|
|
if let Some(count) = count_value.as_u64() {
|
|
if *min > 0 && count < *min as u64 {
|
|
return false;
|
|
}
|
|
if *max < u32::MAX && count > *max as u64 {
|
|
return false;
|
|
}
|
|
}
|
|
} else if *min > 0 {
|
|
return false;
|
|
}
|
|
}
|
|
} else if criteria.class_counts.values().any(|(min, _)| *min > 0) {
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Check time range
|
|
if let Some((start_time, end_time)) = criteria.time_range {
|
|
// Calculate chunk time from frames
|
|
let chunk_start_time = chunk.start_frame as f64 / chunk.fps;
|
|
let chunk_end_time = chunk.end_frame as f64 / chunk.fps;
|
|
|
|
if chunk_start_time < start_time || chunk_end_time > end_time {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
true
|
|
})
|
|
.collect();
|
|
|
|
Ok(filtered_chunks)
|
|
}
|
|
|
|
/// Get all visual chunks for a video UUID
|
|
async fn get_visual_chunks_by_uuid(db: &PostgresDb, uuid: &str) -> Result<Vec<Chunk>> {
|
|
let sql = format!(
|
|
"SELECT file_id, uuid, chunk_id, chunk_index, chunk_type, fps, start_frame, end_frame, text_content, content, metadata, vector_id, visual_stats FROM chunks WHERE uuid = '{}' AND chunk_type = 'visual' ORDER BY start_frame ASC",
|
|
uuid.replace('\'', "''")
|
|
);
|
|
|
|
let rows: Vec<(
|
|
i32, // file_id
|
|
String, // uuid
|
|
String, // chunk_id
|
|
i32, // chunk_index
|
|
String, // chunk_type
|
|
f64, // fps
|
|
i64, // start_frame
|
|
i64, // end_frame
|
|
Option<String>, // text_content
|
|
Value, // content
|
|
Option<Value>, // metadata
|
|
Option<String>, // vector_id
|
|
Option<Value>, // visual_stats
|
|
)> = sqlx::query_as(&sql).fetch_all(db.pool()).await?;
|
|
|
|
let mut chunks = Vec::new();
|
|
for row in rows {
|
|
let chunk_type = match row.4.as_str() {
|
|
"visual" => ChunkType::Visual,
|
|
"sentence" => ChunkType::Sentence,
|
|
"time_based" => ChunkType::TimeBased,
|
|
"cut" => ChunkType::Cut,
|
|
"trace" => ChunkType::Trace,
|
|
"story" => ChunkType::Story,
|
|
_ => ChunkType::TimeBased,
|
|
};
|
|
|
|
// Calculate frame_count
|
|
let frame_count = (row.7 - row.6) as i32;
|
|
|
|
chunks.push(Chunk {
|
|
file_id: row.0,
|
|
uuid: row.1,
|
|
chunk_id: row.2,
|
|
chunk_index: row.3 as u32,
|
|
chunk_type,
|
|
rule: ChunkRule::Rule2, // Visual chunks use Rule2
|
|
fps: row.5,
|
|
start_frame: row.6,
|
|
end_frame: row.7,
|
|
text_content: row.8,
|
|
content: row.9,
|
|
metadata: row.10,
|
|
vector_id: row.11,
|
|
frame_count,
|
|
pre_chunk_ids: Vec::new(),
|
|
parent_chunk_id: None,
|
|
child_chunk_ids: Vec::new(),
|
|
visual_stats: row.12,
|
|
});
|
|
}
|
|
|
|
Ok(chunks)
|
|
}
|
|
|
|
/// Search visual chunks by object class
|
|
pub async fn search_visual_chunks_by_class(
|
|
db: &PostgresDb,
|
|
uuid: &str,
|
|
object_class: &str,
|
|
min_count: Option<u32>,
|
|
max_count: Option<u32>,
|
|
) -> Result<Vec<Chunk>> {
|
|
let all_chunks = get_visual_chunks_by_uuid(db, uuid).await?;
|
|
|
|
let filtered_chunks: Vec<Chunk> = all_chunks
|
|
.into_iter()
|
|
.filter(|chunk| {
|
|
// Check if chunk contains the object class
|
|
let mut contains_class = false;
|
|
if let Some(content) = &chunk.content.as_object() {
|
|
if let Some(keyframe_objects) = content.get("keyframe_objects") {
|
|
if let Some(objects) = keyframe_objects.as_array() {
|
|
for obj in objects {
|
|
if let Some(class_name) = obj.get("class_name") {
|
|
if let Some(class_str) = class_name.as_str() {
|
|
if class_str == object_class {
|
|
contains_class = true;
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if !contains_class {
|
|
return false;
|
|
}
|
|
|
|
// Check count in visual_stats
|
|
if let Some(stats) = &chunk.visual_stats {
|
|
if let Some(count) = stats.get(object_class) {
|
|
if let Some(c) = count.as_u64() {
|
|
if let Some(min) = min_count {
|
|
if c < min as u64 {
|
|
return false;
|
|
}
|
|
}
|
|
if let Some(max) = max_count {
|
|
if c > max as u64 {
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
true
|
|
})
|
|
.collect();
|
|
|
|
Ok(filtered_chunks)
|
|
}
|
|
|
|
/// Search visual chunks by spatial density
|
|
pub async fn search_visual_chunks_by_density(
|
|
db: &PostgresDb,
|
|
uuid: &str,
|
|
min_density: f32,
|
|
max_density: Option<f32>,
|
|
) -> Result<Vec<Chunk>> {
|
|
let all_chunks = get_visual_chunks_by_uuid(db, uuid).await?;
|
|
|
|
let filtered_chunks: Vec<Chunk> = all_chunks
|
|
.into_iter()
|
|
.filter(|chunk| {
|
|
if let Some(content) = &chunk.content.as_object() {
|
|
if let Some(metadata) = content.get("metadata") {
|
|
if let Some(density_value) = metadata.get("spatial_density") {
|
|
if let Some(density) = density_value.as_f64() {
|
|
if density < min_density as f64 {
|
|
return false;
|
|
}
|
|
if let Some(max_dens) = max_density {
|
|
if density > max_dens as f64 {
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
false
|
|
})
|
|
.collect();
|
|
|
|
Ok(filtered_chunks)
|
|
}
|
|
|
|
/// Find chunks containing specific object combinations
|
|
pub async fn search_visual_chunks_by_combination(
|
|
db: &PostgresDb,
|
|
uuid: &str,
|
|
combination: &[(&str, u32)],
|
|
) -> Result<Vec<Chunk>> {
|
|
let all_chunks = get_visual_chunks_by_uuid(db, uuid).await?;
|
|
|
|
let filtered_chunks: Vec<Chunk> = all_chunks
|
|
.into_iter()
|
|
.filter(|chunk| {
|
|
// Check if all required combinations are present
|
|
for (object_class, min_count) in combination {
|
|
let mut found = false;
|
|
if let Some(stats) = &chunk.visual_stats {
|
|
if let Some(object_counts) = stats.get("object_counts") {
|
|
if let Some(count_value) = object_counts.get(*object_class) {
|
|
if let Some(count) = count_value.as_u64() {
|
|
if count >= *min_count as u64 {
|
|
found = true;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
if !found {
|
|
return false;
|
|
}
|
|
}
|
|
true
|
|
})
|
|
.collect();
|
|
|
|
Ok(filtered_chunks)
|
|
}
|
|
|
|
/// Get visual chunk statistics
|
|
pub async fn get_visual_chunk_statistics(
|
|
db: &PostgresDb,
|
|
uuid: &str,
|
|
) -> Result<HashMap<String, Value>> {
|
|
let sql = format!(
|
|
"SELECT
|
|
COUNT(*) as total_chunks,
|
|
AVG((content->'metadata'->>'avg_confidence')::float) as avg_confidence,
|
|
MIN((content->'metadata'->>'avg_confidence')::float) as min_confidence,
|
|
MAX((content->'metadata'->>'avg_confidence')::float) as max_confidence,
|
|
SUM((content->'metadata'->>'object_count')::int) as total_objects,
|
|
AVG((content->'metadata'->>'spatial_density')::float) as avg_density
|
|
FROM chunks
|
|
WHERE uuid = '{}'
|
|
AND chunk_type = 'visual'",
|
|
uuid.replace('\'', "''")
|
|
);
|
|
|
|
let row: (i64, Option<f64>, Option<f64>, Option<f64>, i64, Option<f64>) =
|
|
sqlx::query_as(&sql).fetch_one(db.pool()).await?;
|
|
|
|
let mut stats = HashMap::new();
|
|
stats.insert("total_chunks".to_string(), Value::from(row.0));
|
|
stats.insert(
|
|
"avg_confidence".to_string(),
|
|
Value::from(row.1.unwrap_or(0.0)),
|
|
);
|
|
stats.insert(
|
|
"min_confidence".to_string(),
|
|
Value::from(row.2.unwrap_or(0.0)),
|
|
);
|
|
stats.insert(
|
|
"max_confidence".to_string(),
|
|
Value::from(row.3.unwrap_or(0.0)),
|
|
);
|
|
stats.insert("total_objects".to_string(), Value::from(row.4));
|
|
stats.insert("avg_density".to_string(), Value::from(row.5.unwrap_or(0.0)));
|
|
|
|
Ok(stats)
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_visual_chunk_search_criteria_default() {
|
|
let criteria = VisualChunkSearchCriteria::default();
|
|
|
|
assert_eq!(criteria.min_avg_confidence, None);
|
|
assert_eq!(criteria.min_frames_with_objects, None);
|
|
assert_eq!(criteria.min_unique_classes, None);
|
|
assert!(criteria.required_classes.is_empty());
|
|
assert!(criteria.class_counts.is_empty());
|
|
assert_eq!(criteria.time_range, None);
|
|
}
|
|
|
|
#[test]
|
|
fn test_visual_chunk_search_criteria_with_values() {
|
|
let mut criteria = VisualChunkSearchCriteria::default();
|
|
criteria.min_avg_confidence = Some(0.8);
|
|
criteria.min_frames_with_objects = Some(10);
|
|
criteria.min_unique_classes = Some(3);
|
|
criteria.required_classes = vec!["person".to_string(), "car".to_string()];
|
|
criteria.time_range = Some((0.0, 60.0));
|
|
|
|
assert_eq!(criteria.min_avg_confidence, Some(0.8));
|
|
assert_eq!(criteria.min_frames_with_objects, Some(10));
|
|
assert_eq!(criteria.min_unique_classes, Some(3));
|
|
assert_eq!(criteria.required_classes.len(), 2);
|
|
assert_eq!(criteria.time_range, Some((0.0, 60.0)));
|
|
}
|
|
|
|
#[test]
|
|
fn test_visual_chunk_search_criteria_serialization() {
|
|
let criteria = VisualChunkSearchCriteria {
|
|
min_avg_confidence: Some(0.85),
|
|
min_frames_with_objects: Some(5),
|
|
min_unique_classes: Some(2),
|
|
required_classes: vec!["person".to_string()],
|
|
class_counts: HashMap::new(),
|
|
time_range: Some((10.0, 30.0)),
|
|
};
|
|
|
|
let json = serde_json::to_string(&criteria).unwrap();
|
|
assert!(json.contains("min_avg_confidence"));
|
|
assert!(json.contains("required_classes"));
|
|
|
|
let deserialized: VisualChunkSearchCriteria = serde_json::from_str(&json).unwrap();
|
|
assert_eq!(deserialized.min_avg_confidence, Some(0.85));
|
|
assert_eq!(deserialized.required_classes.len(), 1);
|
|
}
|
|
|
|
#[test]
|
|
fn test_visual_chunk_search_criteria_with_class_counts() {
|
|
let mut criteria = VisualChunkSearchCriteria::default();
|
|
criteria.class_counts.insert("person".to_string(), (5, 20));
|
|
criteria.class_counts.insert("car".to_string(), (1, 10));
|
|
|
|
assert_eq!(criteria.class_counts.len(), 2);
|
|
assert_eq!(criteria.class_counts.get("person"), Some(&(5, 20)));
|
|
assert_eq!(criteria.class_counts.get("car"), Some(&(1, 10)));
|
|
}
|
|
|
|
#[test]
|
|
fn test_chunk_type_conversion() {
|
|
// Test chunk type string to enum conversion logic
|
|
let test_cases = vec![
|
|
("visual", ChunkType::Visual),
|
|
("sentence", ChunkType::Sentence),
|
|
("time_based", ChunkType::TimeBased),
|
|
("cut", ChunkType::Cut),
|
|
("trace", ChunkType::Trace),
|
|
("story", ChunkType::Story),
|
|
("unknown", ChunkType::TimeBased), // Default fallback
|
|
];
|
|
|
|
for (input, expected) in test_cases {
|
|
let chunk_type = match input {
|
|
"visual" => ChunkType::Visual,
|
|
"sentence" => ChunkType::Sentence,
|
|
"time_based" => ChunkType::TimeBased,
|
|
"cut" => ChunkType::Cut,
|
|
"trace" => ChunkType::Trace,
|
|
"story" => ChunkType::Story,
|
|
_ => ChunkType::TimeBased,
|
|
};
|
|
assert_eq!(chunk_type, expected);
|
|
}
|
|
}
|
|
}
|