Files
momentry_core/scripts/caption_processor.py

306 lines
9.1 KiB
Python
Executable File

#!/opt/homebrew/bin/python3.11
"""
Caption Processor - Generate image captions
Uses AI vision models to analyze video frames and generate descriptions
"""
import sys
import json
import os
import argparse
import subprocess
from typing import Dict, List, Optional
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from redis_publisher import RedisPublisher
def extract_frames(video_path: str, max_frames: int = 30) -> List[Dict]:
"""Extract frames from video at regular intervals"""
# Get video duration
cmd = [
"ffprobe",
"-v",
"quiet",
"-print_format",
"json",
"-show_format",
video_path,
]
try:
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode == 0:
data = json.loads(result.stdout)
duration = float(data.get("format", {}).get("duration", 0))
else:
duration = 60 # Default fallback
except Exception:
duration = 60
if duration <= 0:
duration = 60
# Calculate frame interval
interval = max(duration / max_frames, 1.0)
frames = []
temp_dir = os.path.join(os.path.dirname(video_path), ".caption_frames")
os.makedirs(temp_dir, exist_ok=True)
for i in range(max_frames):
timestamp = i * interval
output_file = os.path.join(temp_dir, f"frame_{i:04d}.jpg")
cmd = [
"ffmpeg",
"-y",
"-ss",
str(timestamp),
"-i",
video_path,
"-vframes",
"1",
"-q:v",
"2",
output_file,
]
try:
subprocess.run(cmd, capture_output=True, check=False)
if os.path.exists(output_file):
frames.append({"index": i, "timestamp": timestamp, "path": output_file})
except Exception:
pass
return frames
def generate_caption_with_llava(
image_path: str, prompt: str = "Describe this image in detail."
) -> Optional[str]:
"""Generate caption using LLaVA model"""
try:
# Try to use transformers with LLaVA
from transformers import AutoProcessor, AutoModelForVision2Seq # noqa: F401
import torch # noqa: F401
from PIL import Image # noqa: F401
# Note: This requires llava-hf/llava-1.5-7b-hf or similar
# For now, return a placeholder
return f"[LLaVA caption for {os.path.basename(image_path)}]"
except ImportError:
return None
def generate_caption_with_gpt4v(image_path: str, api_key: str = None) -> Optional[str]:
"""Generate caption using GPT-4V via OpenAI API"""
import base64
if not api_key:
api_key = os.environ.get("OPENAI_API_KEY")
if not api_key:
return None
try:
from openai import OpenAI
client = OpenAI(api_key=api_key)
# Encode image
with open(image_path, "rb") as f:
img_data = base64.b64encode(f.read()).decode()
response = client.chat.completions.create(
model="gpt-4o", # or gpt-4-turbo for vision
messages=[
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{img_data}"},
},
{
"type": "text",
"text": "Describe what you see in this image in one sentence.",
},
],
}
],
max_tokens=100,
)
return response.choices[0].message.content
except Exception:
return None
def generate_caption_fallback(image_path: str, existing_data: Dict = None) -> str:
"""Generate a basic caption using available metadata"""
caption_parts = []
# Check YOLO data for objects
if existing_data and existing_data.get("objects"):
objects = list(set([o["class"] for o in existing_data["objects"]]))[:5]
if objects:
caption_parts.append(f"Contains: {', '.join(objects)}")
# Check OCR data for text
if existing_data and existing_data.get("texts"):
texts = [t["text"] for t in existing_data["texts"] if t.get("text")]
if texts:
caption_parts.append(f"On-screen text: {' '.join(texts[:3])}")
if caption_parts:
return " | ".join(caption_parts)
return "Video frame at timestamp"
def process_frame(
frame_info: Dict, yolo_data: List = None, ocr_data: List = None
) -> Dict:
"""Process a single frame and generate caption"""
frame_path = frame_info["path"]
timestamp = frame_info["timestamp"]
caption = None
source = "unknown"
# Try GPT-4V first
caption = generate_caption_with_gpt4v(frame_path)
if caption:
source = "gpt-4v"
else:
# Try LLaVA
caption = generate_caption_with_llava(frame_path)
if caption:
source = "llava"
else:
# Use fallback with YOLO/OCR data
combined_data = {"objects": [], "texts": []}
if yolo_data:
combined_data["objects"] = [
o for o in yolo_data if o.get("timestamp") == timestamp
]
if ocr_data:
combined_data["texts"] = [
t for t in ocr_data if t.get("timestamp") == timestamp
]
caption = generate_caption_fallback(frame_path, combined_data)
source = "metadata"
return {
"index": frame_info["index"],
"timestamp": timestamp,
"caption": caption,
"source": source,
}
def run_caption(
video_path: str, output_path: str, uuid: str = "", max_frames: int = 30
):
publisher = RedisPublisher(uuid) if uuid else None
if publisher:
publisher.info("caption", "CAPTION_START")
if publisher:
publisher.info("caption", "Extracting frames from video...")
# Extract frames
frames = extract_frames(video_path, max_frames)
if publisher:
publisher.info("caption", f"Extracted {len(frames)} frames")
# Load YOLO and OCR data for context
base_path = os.path.dirname(output_path)
uuid_name = os.path.basename(output_path).split(".")[0]
yolo_objects = []
ocr_texts = []
yolo_path = os.path.join(base_path, f"{uuid_name}.yolo.json")
if os.path.exists(yolo_path):
with open(yolo_path) as f:
yolo_data = json.load(f)
# Flatten objects from all frames
for frame in yolo_data.get("frames", []):
for obj in frame.get("objects", []):
obj["timestamp"] = frame.get("timestamp", 0)
yolo_objects.append(obj)
ocr_path = os.path.join(base_path, f"{uuid_name}.ocr.json")
if os.path.exists(ocr_path):
with open(ocr_path) as f:
ocr_data = json.load(f)
for frame in ocr_data.get("frames", []):
for text in frame.get("texts", []):
text["timestamp"] = frame.get("timestamp", 0)
ocr_texts.append(text)
# Process each frame
captions = []
for i, frame in enumerate(frames):
if publisher and i % 5 == 0:
publisher.progress(
"caption", i, len(frames), f"Frame {i + 1}/{len(frames)}"
)
caption_data = process_frame(frame, yolo_objects, ocr_texts)
captions.append(caption_data)
# Cleanup temp frame
try:
os.remove(frame["path"])
except Exception:
pass
# Cleanup temp directory
temp_dir = os.path.join(os.path.dirname(video_path), ".caption_frames")
try:
os.rmdir(temp_dir)
except Exception:
pass
result = {
"video_path": video_path,
"total_frames": len(frames),
"captions": captions,
"summary": {
"avg_caption_length": sum(len(c.get("caption", "")) for c in captions)
/ max(len(captions), 1),
"gpt4v_count": sum(1 for c in captions if c.get("source") == "gpt-4v"),
"llava_count": sum(1 for c in captions if c.get("source") == "llava"),
"metadata_count": sum(1 for c in captions if c.get("source") == "metadata"),
},
}
with open(output_path, "w") as f:
json.dump(result, f, indent=2, ensure_ascii=False)
if publisher:
publisher.complete("caption", f"{len(captions)} frames captioned")
return result
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Video Caption Generator")
parser.add_argument("video_path", help="Path to video file")
parser.add_argument("output_path", help="Output JSON path")
parser.add_argument("--uuid", help="UUID for progress tracking", default="")
parser.add_argument(
"--max-frames", type=int, default=30, help="Maximum frames to caption"
)
args = parser.parse_args()
result = run_caption(args.video_path, args.output_path, args.uuid, args.max_frames)
print(f"Caption generated: {result['total_frames']} frames")