52 lines
1.8 KiB
Python
52 lines
1.8 KiB
Python
"""PaliGemma judge: Vision-Language frame description"""
|
|
import torch
|
|
from PIL import Image
|
|
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
|
|
|
|
MODEL_ID = "google/paligemma2-3b-ft-docci-448"
|
|
PROMPT = "en Describe the location and setting of this scene in one sentence. Is it indoor or outdoor?"
|
|
|
|
_model = None
|
|
_processor = None
|
|
|
|
def load():
|
|
global _model, _processor
|
|
if _model is None:
|
|
_processor = AutoProcessor.from_pretrained(MODEL_ID)
|
|
_model = PaliGemmaForConditionalGeneration.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16).eval()
|
|
if torch.backends.mps.is_available():
|
|
_model = _model.to("mps")
|
|
|
|
def score(frames, prompt):
|
|
load()
|
|
descriptions = []
|
|
for img in frames:
|
|
inputs = _processor(text=PROMPT, images=img, return_tensors="pt")
|
|
if torch.backends.mps.is_available():
|
|
inputs = {k: v.to("mps") for k, v in inputs.items()}
|
|
with torch.no_grad():
|
|
generated = _model.generate(**inputs, max_new_tokens=80, do_sample=False)
|
|
desc = _processor.decode(generated[0], skip_special_tokens=True)
|
|
if desc.startswith(PROMPT):
|
|
desc = desc[len(PROMPT):].strip()
|
|
descriptions.append(desc)
|
|
|
|
combined = " | ".join(descriptions)
|
|
|
|
# Simple text-to-score: check if description mentions key terms from prompt
|
|
prompt_lower = prompt.lower()
|
|
desc_lower = combined.lower()
|
|
score = 50 # default
|
|
# Boost if prompt elements found in description
|
|
for word in prompt_lower.split():
|
|
if len(word) > 3 and word in desc_lower:
|
|
score += 10
|
|
score = min(100, score)
|
|
|
|
return {
|
|
"agent": "PaliGemma",
|
|
"score": score,
|
|
"reasoning": combined,
|
|
"details": {"descriptions": descriptions}
|
|
}
|