Files
momentry_core/scripts/clip_classifier.py
Accusys 17e4e15860 feat: add Vision LLM integration (CLIP + Qwen3-VL cascade)
- Add Qwen3-VL dynamic management (start/stop/status CLI)
- Add CLIP + Qwen3-VL cascade detection strategy
- Add Vision CLI commands (vision start/stop/status, detect)
- Add cascade_vision processor module
- Add clip processor module
- Add qwen_vl_manager module

Changes:
- scripts/start_qwen3vl.sh, stop_qwen3vl.sh: Qwen3-VL management scripts
- src/core/vision/: Qwen3-VL manager module
- src/core/processor/cascade_vision.rs: CLIP + Qwen3-VL cascade logic
- src/core/processor/clip.rs: CLIP classification and detection
- src/api/clip_api.rs: CLIP API endpoints
- src/cli/vision.rs: Vision CLI implementation
- src/cli/args.rs: Add Vision and Detect commands
- src/main.rs: Integrate Vision CLI
- src/core/mod.rs: Add vision module
- src/core/processor/mod.rs: Add cascade_vision module
2026-06-13 16:25:52 +08:00

232 lines
7.5 KiB
Python

#!/usr/bin/env python3
"""
CLIP Zero-Shot Classifier
Uses OpenAI CLIP for reliable scene and object classification.
Advantages over LLaVA Vision:
- Zero-shot classification (no prompt induction)
- Reliable confidence scores
- Fast inference
- No hallucinations
"""
import argparse
import json
import sys
from pathlib import Path
from typing import Dict, List, Optional, Tuple
try:
import torch
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
HAS_CLIP = True
except ImportError as e:
print(f"[ERROR] Required packages not found: {e}", file=sys.stderr)
print("[ERROR] Install with: pip install transformers torch pillow", file=sys.stderr)
HAS_CLIP = False
sys.exit(1)
class CLIPClassifier:
def __init__(self, model_name: str = "openai/clip-vit-base-patch32"):
"""
Initialize CLIP model.
Args:
model_name: HuggingFace model name (default: openai/clip-vit-base-patch32)
"""
print(f"[CLIP] Loading model: {model_name}")
self.model = CLIPModel.from_pretrained(model_name)
self.processor = CLIPProcessor.from_pretrained(model_name)
self.device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
self.model.to(self.device)
print(f"[CLIP] Model loaded on device: {self.device}")
def classify_image(
self,
image_path: str,
labels: List[str],
top_k: int = 5
) -> List[Dict[str, float]]:
"""
Classify a single image with given labels.
Args:
image_path: Path to image file
labels: List of candidate labels (e.g., ["person in room", "outdoor scene", "snow landscape"])
top_k: Number of top predictions to return
Returns:
List of {"label": str, "confidence": float} sorted by confidence
"""
try:
image = Image.open(image_path).convert("RGB")
except Exception as e:
print(f"[ERROR] Failed to load image {image_path}: {e}", file=sys.stderr)
return []
# Prepare inputs
inputs = self.processor(
text=labels,
images=image,
return_tensors="pt",
padding=True
).to(self.device)
# Get predictions
with torch.no_grad():
outputs = self.model(**inputs)
logits_per_image = outputs.logits_per_image
probs = logits_per_image.softmax(dim=1).cpu().numpy()[0]
# Sort by confidence
results = [
{"label": label, "confidence": float(prob)}
for label, prob in zip(labels, probs)
]
results.sort(key=lambda x: x["confidence"], reverse=True)
return results[:top_k]
def classify_images(
self,
image_paths: List[str],
labels: List[str],
top_k: int = 5
) -> Dict[str, List[Dict[str, float]]]:
"""
Classify multiple images with given labels.
Args:
image_paths: List of image paths
labels: List of candidate labels
top_k: Number of top predictions per image
Returns:
Dict mapping image_path -> predictions
"""
results = {}
for img_path in image_paths:
results[img_path] = self.classify_image(img_path, labels, top_k)
return results
def detect_objects(
self,
image_path: str,
objects: List[str],
threshold: float = 0.15
) -> List[Dict[str, float]]:
"""
Detect if specific objects are present in image.
Args:
image_path: Path to image file
objects: List of objects to detect (e.g., ["gun", "knife", "weapon"])
threshold: Confidence threshold (default: 0.15)
Returns:
List of detected objects with confidence >= threshold
"""
predictions = self.classify_image(image_path, objects, top_k=len(objects))
detected = [p for p in predictions if p["confidence"] >= threshold]
return detected
def batch_detect_objects(
self,
image_paths: List[str],
objects: List[str],
threshold: float = 0.15
) -> Dict[str, List[Dict[str, float]]]:
"""
Detect objects across multiple images.
Args:
image_paths: List of image paths
objects: List of objects to detect
threshold: Confidence threshold
Returns:
Dict mapping image_path -> detected objects
"""
results = {}
for img_path in image_paths:
detected = self.detect_objects(img_path, objects, threshold)
if detected:
results[img_path] = detected
return results
def main():
parser = argparse.ArgumentParser(
description="CLIP Zero-Shot Classifier",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Scene classification
python clip_classifier.py image.jpg --labels "indoor room,outdoor scene,person in room" --top-k 3
# Object detection
python clip_classifier.py image.jpg --detect "gun,weapon,knife" --threshold 0.2
# Batch processing
python clip_classifier.py images.txt --batch --labels "indoor,outdoor"
"""
)
parser.add_argument("input", help="Image path or text file with image paths (for batch)")
parser.add_argument("--labels", help="Comma-separated labels for classification")
parser.add_argument("--detect", help="Comma-separated objects to detect")
parser.add_argument("--threshold", type=float, default=0.15, help="Detection threshold (default: 0.15)")
parser.add_argument("--top-k", type=int, default=5, help="Top-k predictions (default: 5)")
parser.add_argument("--batch", action="store_true", help="Batch mode (input is text file)")
parser.add_argument("--output", help="Output JSON file (default: stdout)")
parser.add_argument("--model", default="openai/clip-vit-base-patch32", help="CLIP model name")
args = parser.parse_args()
if not HAS_CLIP:
sys.exit(1)
# Initialize classifier
classifier = CLIPClassifier(args.model)
# Prepare image paths
if args.batch:
with open(args.input, "r") as f:
image_paths = [line.strip() for line in f if line.strip()]
else:
image_paths = [args.input]
# Run classification
results = {}
if args.detect:
# Object detection mode
objects = [obj.strip() for obj in args.detect.split(",")]
print(f"[CLIP] Detecting objects: {objects}")
results = classifier.batch_detect_objects(image_paths, objects, args.threshold)
elif args.labels:
# Scene classification mode
labels = [label.strip() for label in args.labels.split(",")]
print(f"[CLIP] Classifying with {len(labels)} labels")
results = classifier.classify_images(image_paths, labels, args.top_k)
else:
print("[ERROR] Must specify --labels or --detect", file=sys.stderr)
sys.exit(1)
# Output results
output_json = json.dumps(results, indent=2, ensure_ascii=False)
if args.output:
with open(args.output, "w", encoding="utf-8") as f:
f.write(output_json)
print(f"[CLIP] Results saved to {args.output}")
else:
print(output_json)
if __name__ == "__main__":
main()