- Add Qwen3-VL dynamic management (start/stop/status CLI) - Add CLIP + Qwen3-VL cascade detection strategy - Add Vision CLI commands (vision start/stop/status, detect) - Add cascade_vision processor module - Add clip processor module - Add qwen_vl_manager module Changes: - scripts/start_qwen3vl.sh, stop_qwen3vl.sh: Qwen3-VL management scripts - src/core/vision/: Qwen3-VL manager module - src/core/processor/cascade_vision.rs: CLIP + Qwen3-VL cascade logic - src/core/processor/clip.rs: CLIP classification and detection - src/api/clip_api.rs: CLIP API endpoints - src/cli/vision.rs: Vision CLI implementation - src/cli/args.rs: Add Vision and Detect commands - src/main.rs: Integrate Vision CLI - src/core/mod.rs: Add vision module - src/core/processor/mod.rs: Add cascade_vision module
232 lines
7.5 KiB
Python
232 lines
7.5 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
CLIP Zero-Shot Classifier
|
|
Uses OpenAI CLIP for reliable scene and object classification.
|
|
|
|
Advantages over LLaVA Vision:
|
|
- Zero-shot classification (no prompt induction)
|
|
- Reliable confidence scores
|
|
- Fast inference
|
|
- No hallucinations
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import sys
|
|
from pathlib import Path
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
try:
|
|
import torch
|
|
from PIL import Image
|
|
from transformers import CLIPProcessor, CLIPModel
|
|
HAS_CLIP = True
|
|
except ImportError as e:
|
|
print(f"[ERROR] Required packages not found: {e}", file=sys.stderr)
|
|
print("[ERROR] Install with: pip install transformers torch pillow", file=sys.stderr)
|
|
HAS_CLIP = False
|
|
sys.exit(1)
|
|
|
|
|
|
class CLIPClassifier:
|
|
def __init__(self, model_name: str = "openai/clip-vit-base-patch32"):
|
|
"""
|
|
Initialize CLIP model.
|
|
|
|
Args:
|
|
model_name: HuggingFace model name (default: openai/clip-vit-base-patch32)
|
|
"""
|
|
print(f"[CLIP] Loading model: {model_name}")
|
|
self.model = CLIPModel.from_pretrained(model_name)
|
|
self.processor = CLIPProcessor.from_pretrained(model_name)
|
|
self.device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
|
|
self.model.to(self.device)
|
|
print(f"[CLIP] Model loaded on device: {self.device}")
|
|
|
|
def classify_image(
|
|
self,
|
|
image_path: str,
|
|
labels: List[str],
|
|
top_k: int = 5
|
|
) -> List[Dict[str, float]]:
|
|
"""
|
|
Classify a single image with given labels.
|
|
|
|
Args:
|
|
image_path: Path to image file
|
|
labels: List of candidate labels (e.g., ["person in room", "outdoor scene", "snow landscape"])
|
|
top_k: Number of top predictions to return
|
|
|
|
Returns:
|
|
List of {"label": str, "confidence": float} sorted by confidence
|
|
"""
|
|
try:
|
|
image = Image.open(image_path).convert("RGB")
|
|
except Exception as e:
|
|
print(f"[ERROR] Failed to load image {image_path}: {e}", file=sys.stderr)
|
|
return []
|
|
|
|
# Prepare inputs
|
|
inputs = self.processor(
|
|
text=labels,
|
|
images=image,
|
|
return_tensors="pt",
|
|
padding=True
|
|
).to(self.device)
|
|
|
|
# Get predictions
|
|
with torch.no_grad():
|
|
outputs = self.model(**inputs)
|
|
logits_per_image = outputs.logits_per_image
|
|
probs = logits_per_image.softmax(dim=1).cpu().numpy()[0]
|
|
|
|
# Sort by confidence
|
|
results = [
|
|
{"label": label, "confidence": float(prob)}
|
|
for label, prob in zip(labels, probs)
|
|
]
|
|
results.sort(key=lambda x: x["confidence"], reverse=True)
|
|
|
|
return results[:top_k]
|
|
|
|
def classify_images(
|
|
self,
|
|
image_paths: List[str],
|
|
labels: List[str],
|
|
top_k: int = 5
|
|
) -> Dict[str, List[Dict[str, float]]]:
|
|
"""
|
|
Classify multiple images with given labels.
|
|
|
|
Args:
|
|
image_paths: List of image paths
|
|
labels: List of candidate labels
|
|
top_k: Number of top predictions per image
|
|
|
|
Returns:
|
|
Dict mapping image_path -> predictions
|
|
"""
|
|
results = {}
|
|
for img_path in image_paths:
|
|
results[img_path] = self.classify_image(img_path, labels, top_k)
|
|
return results
|
|
|
|
def detect_objects(
|
|
self,
|
|
image_path: str,
|
|
objects: List[str],
|
|
threshold: float = 0.15
|
|
) -> List[Dict[str, float]]:
|
|
"""
|
|
Detect if specific objects are present in image.
|
|
|
|
Args:
|
|
image_path: Path to image file
|
|
objects: List of objects to detect (e.g., ["gun", "knife", "weapon"])
|
|
threshold: Confidence threshold (default: 0.15)
|
|
|
|
Returns:
|
|
List of detected objects with confidence >= threshold
|
|
"""
|
|
predictions = self.classify_image(image_path, objects, top_k=len(objects))
|
|
detected = [p for p in predictions if p["confidence"] >= threshold]
|
|
return detected
|
|
|
|
def batch_detect_objects(
|
|
self,
|
|
image_paths: List[str],
|
|
objects: List[str],
|
|
threshold: float = 0.15
|
|
) -> Dict[str, List[Dict[str, float]]]:
|
|
"""
|
|
Detect objects across multiple images.
|
|
|
|
Args:
|
|
image_paths: List of image paths
|
|
objects: List of objects to detect
|
|
threshold: Confidence threshold
|
|
|
|
Returns:
|
|
Dict mapping image_path -> detected objects
|
|
"""
|
|
results = {}
|
|
for img_path in image_paths:
|
|
detected = self.detect_objects(img_path, objects, threshold)
|
|
if detected:
|
|
results[img_path] = detected
|
|
return results
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="CLIP Zero-Shot Classifier",
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
epilog="""
|
|
Examples:
|
|
# Scene classification
|
|
python clip_classifier.py image.jpg --labels "indoor room,outdoor scene,person in room" --top-k 3
|
|
|
|
# Object detection
|
|
python clip_classifier.py image.jpg --detect "gun,weapon,knife" --threshold 0.2
|
|
|
|
# Batch processing
|
|
python clip_classifier.py images.txt --batch --labels "indoor,outdoor"
|
|
"""
|
|
)
|
|
|
|
parser.add_argument("input", help="Image path or text file with image paths (for batch)")
|
|
parser.add_argument("--labels", help="Comma-separated labels for classification")
|
|
parser.add_argument("--detect", help="Comma-separated objects to detect")
|
|
parser.add_argument("--threshold", type=float, default=0.15, help="Detection threshold (default: 0.15)")
|
|
parser.add_argument("--top-k", type=int, default=5, help="Top-k predictions (default: 5)")
|
|
parser.add_argument("--batch", action="store_true", help="Batch mode (input is text file)")
|
|
parser.add_argument("--output", help="Output JSON file (default: stdout)")
|
|
parser.add_argument("--model", default="openai/clip-vit-base-patch32", help="CLIP model name")
|
|
|
|
args = parser.parse_args()
|
|
|
|
if not HAS_CLIP:
|
|
sys.exit(1)
|
|
|
|
# Initialize classifier
|
|
classifier = CLIPClassifier(args.model)
|
|
|
|
# Prepare image paths
|
|
if args.batch:
|
|
with open(args.input, "r") as f:
|
|
image_paths = [line.strip() for line in f if line.strip()]
|
|
else:
|
|
image_paths = [args.input]
|
|
|
|
# Run classification
|
|
results = {}
|
|
|
|
if args.detect:
|
|
# Object detection mode
|
|
objects = [obj.strip() for obj in args.detect.split(",")]
|
|
print(f"[CLIP] Detecting objects: {objects}")
|
|
results = classifier.batch_detect_objects(image_paths, objects, args.threshold)
|
|
|
|
elif args.labels:
|
|
# Scene classification mode
|
|
labels = [label.strip() for label in args.labels.split(",")]
|
|
print(f"[CLIP] Classifying with {len(labels)} labels")
|
|
results = classifier.classify_images(image_paths, labels, args.top_k)
|
|
|
|
else:
|
|
print("[ERROR] Must specify --labels or --detect", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
# Output results
|
|
output_json = json.dumps(results, indent=2, ensure_ascii=False)
|
|
|
|
if args.output:
|
|
with open(args.output, "w", encoding="utf-8") as f:
|
|
f.write(output_json)
|
|
print(f"[CLIP] Results saved to {args.output}")
|
|
else:
|
|
print(output_json)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |