#!/usr/bin/env python3 """EmbeddingGemma HTTP server - Metal GPU (MPS) accelerated, compatible with M4/M5.""" import argparse, json, time, torch from flask import Flask, request, jsonify from transformers import AutoModel, AutoTokenizer import numpy as np app = Flask(__name__) MODEL = None TOKENIZER = None DEVICE = None def load_model(model_path: str = "google/embeddinggemma-300m"): global MODEL, TOKENIZER, DEVICE if MODEL is not None: return DEVICE = "mps" if torch.backends.mps.is_available() else "cpu" dtype = torch.float32 print(f"[EmbeddingGemma] Loading model on {DEVICE} (dtype={dtype})...") t0 = time.time() MODEL = AutoModel.from_pretrained( model_path, torch_dtype=dtype, trust_remote_code=True, ).eval().to(DEVICE) TOKENIZER = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) print(f"[EmbeddingGemma] Loaded in {time.time()-t0:.1f}s on {DEVICE}") def embed(texts: list[str]) -> list[list[float]]: inputs = TOKENIZER(texts, padding=True, truncation=True, return_tensors="pt").to(DEVICE) with torch.no_grad(): outputs = MODEL(**inputs) mask = inputs["attention_mask"].unsqueeze(-1).to(outputs.last_hidden_state.dtype) pooled = (outputs.last_hidden_state * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-9) pooled = torch.nn.functional.normalize(pooled, p=2, dim=1) return pooled.cpu().numpy().tolist() @app.route("/v1/embeddings", methods=["POST"]) def embeddings(): data = request.get_json() texts = data.get("input", []) if isinstance(texts, str): texts = [texts] if not texts: return jsonify({"error": "empty input"}), 400 try: emb = embed(texts) result = { "data": [{"embedding": e, "index": i} for i, e in enumerate(emb)], "model": "embeddinggemma-300m", } return jsonify(result) except Exception as e: return jsonify({"error": str(e)}), 500 @app.route("/health", methods=["GET"]) def health(): return jsonify({"status": "ok", "device": str(DEVICE)}) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--port", type=int, default=11436) parser.add_argument("--model", type=str, default="google/embeddinggemma-300m") args = parser.parse_args() load_model(args.model) app.run(host="0.0.0.0", port=args.port, threaded=True)