#!/usr/bin/env python3 """ CLIP Zero-Shot Classifier Uses OpenAI CLIP for reliable scene and object classification. Advantages over LLaVA Vision: - Zero-shot classification (no prompt induction) - Reliable confidence scores - Fast inference - No hallucinations """ import argparse import json import sys from pathlib import Path from typing import Dict, List, Optional, Tuple try: import torch from PIL import Image from transformers import CLIPProcessor, CLIPModel HAS_CLIP = True except ImportError as e: print(f"[ERROR] Required packages not found: {e}", file=sys.stderr) print("[ERROR] Install with: pip install transformers torch pillow", file=sys.stderr) HAS_CLIP = False sys.exit(1) class CLIPClassifier: def __init__(self, model_name: str = "openai/clip-vit-base-patch32"): """ Initialize CLIP model. Args: model_name: HuggingFace model name (default: openai/clip-vit-base-patch32) """ print(f"[CLIP] Loading model: {model_name}") self.model = CLIPModel.from_pretrained(model_name) self.processor = CLIPProcessor.from_pretrained(model_name) self.device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") self.model.to(self.device) print(f"[CLIP] Model loaded on device: {self.device}") def classify_image( self, image_path: str, labels: List[str], top_k: int = 5 ) -> List[Dict[str, float]]: """ Classify a single image with given labels. Args: image_path: Path to image file labels: List of candidate labels (e.g., ["person in room", "outdoor scene", "snow landscape"]) top_k: Number of top predictions to return Returns: List of {"label": str, "confidence": float} sorted by confidence """ try: image = Image.open(image_path).convert("RGB") except Exception as e: print(f"[ERROR] Failed to load image {image_path}: {e}", file=sys.stderr) return [] # Prepare inputs inputs = self.processor( text=labels, images=image, return_tensors="pt", padding=True ).to(self.device) # Get predictions with torch.no_grad(): outputs = self.model(**inputs) logits_per_image = outputs.logits_per_image probs = logits_per_image.softmax(dim=1).cpu().numpy()[0] # Sort by confidence results = [ {"label": label, "confidence": float(prob)} for label, prob in zip(labels, probs) ] results.sort(key=lambda x: x["confidence"], reverse=True) return results[:top_k] def classify_images( self, image_paths: List[str], labels: List[str], top_k: int = 5 ) -> Dict[str, List[Dict[str, float]]]: """ Classify multiple images with given labels. Args: image_paths: List of image paths labels: List of candidate labels top_k: Number of top predictions per image Returns: Dict mapping image_path -> predictions """ results = {} for img_path in image_paths: results[img_path] = self.classify_image(img_path, labels, top_k) return results def detect_objects( self, image_path: str, objects: List[str], threshold: float = 0.15 ) -> List[Dict[str, float]]: """ Detect if specific objects are present in image. Args: image_path: Path to image file objects: List of objects to detect (e.g., ["gun", "knife", "weapon"]) threshold: Confidence threshold (default: 0.15) Returns: List of detected objects with confidence >= threshold """ predictions = self.classify_image(image_path, objects, top_k=len(objects)) detected = [p for p in predictions if p["confidence"] >= threshold] return detected def batch_detect_objects( self, image_paths: List[str], objects: List[str], threshold: float = 0.15 ) -> Dict[str, List[Dict[str, float]]]: """ Detect objects across multiple images. Args: image_paths: List of image paths objects: List of objects to detect threshold: Confidence threshold Returns: Dict mapping image_path -> detected objects """ results = {} for img_path in image_paths: detected = self.detect_objects(img_path, objects, threshold) if detected: results[img_path] = detected return results def main(): parser = argparse.ArgumentParser( description="CLIP Zero-Shot Classifier", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: # Scene classification python clip_classifier.py image.jpg --labels "indoor room,outdoor scene,person in room" --top-k 3 # Object detection python clip_classifier.py image.jpg --detect "gun,weapon,knife" --threshold 0.2 # Batch processing python clip_classifier.py images.txt --batch --labels "indoor,outdoor" """ ) parser.add_argument("input", help="Image path or text file with image paths (for batch)") parser.add_argument("--labels", help="Comma-separated labels for classification") parser.add_argument("--detect", help="Comma-separated objects to detect") parser.add_argument("--threshold", type=float, default=0.15, help="Detection threshold (default: 0.15)") parser.add_argument("--top-k", type=int, default=5, help="Top-k predictions (default: 5)") parser.add_argument("--batch", action="store_true", help="Batch mode (input is text file)") parser.add_argument("--output", help="Output JSON file (default: stdout)") parser.add_argument("--model", default="openai/clip-vit-base-patch32", help="CLIP model name") args = parser.parse_args() if not HAS_CLIP: sys.exit(1) # Initialize classifier classifier = CLIPClassifier(args.model) # Prepare image paths if args.batch: with open(args.input, "r") as f: image_paths = [line.strip() for line in f if line.strip()] else: image_paths = [args.input] # Run classification results = {} if args.detect: # Object detection mode objects = [obj.strip() for obj in args.detect.split(",")] print(f"[CLIP] Detecting objects: {objects}") results = classifier.batch_detect_objects(image_paths, objects, args.threshold) elif args.labels: # Scene classification mode labels = [label.strip() for label in args.labels.split(",")] print(f"[CLIP] Classifying with {len(labels)} labels") results = classifier.classify_images(image_paths, labels, args.top_k) else: print("[ERROR] Must specify --labels or --detect", file=sys.stderr) sys.exit(1) # Output results output_json = json.dumps(results, indent=2, ensure_ascii=False) if args.output: with open(args.output, "w", encoding="utf-8") as f: f.write(output_json) print(f"[CLIP] Results saved to {args.output}") else: print(output_json) if __name__ == "__main__": main()