Files
momentry_core/scripts/embeddinggemma_server.py

69 lines
2.4 KiB
Python

#!/usr/bin/env python3
"""EmbeddingGemma HTTP server - Metal GPU (MPS) accelerated, compatible with M4/M5."""
import argparse, json, time, torch
from flask import Flask, request, jsonify
from transformers import AutoModel, AutoTokenizer
import numpy as np
app = Flask(__name__)
MODEL = None
TOKENIZER = None
DEVICE = None
def load_model(model_path: str = "google/embeddinggemma-300m"):
global MODEL, TOKENIZER, DEVICE
if MODEL is not None:
return
DEVICE = "mps" if torch.backends.mps.is_available() else "cpu"
dtype = torch.float32
print(f"[EmbeddingGemma] Loading model on {DEVICE} (dtype={dtype})...")
t0 = time.time()
MODEL = AutoModel.from_pretrained(
model_path,
torch_dtype=dtype,
trust_remote_code=True,
).eval().to(DEVICE)
TOKENIZER = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
print(f"[EmbeddingGemma] Loaded in {time.time()-t0:.1f}s on {DEVICE}")
def embed(texts: list[str]) -> list[list[float]]:
inputs = TOKENIZER(texts, padding=True, truncation=True, return_tensors="pt").to(DEVICE)
with torch.no_grad():
outputs = MODEL(**inputs)
mask = inputs["attention_mask"].unsqueeze(-1).to(outputs.last_hidden_state.dtype)
pooled = (outputs.last_hidden_state * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-9)
pooled = torch.nn.functional.normalize(pooled, p=2, dim=1)
return pooled.cpu().numpy().tolist()
@app.route("/v1/embeddings", methods=["POST"])
def embeddings():
data = request.get_json()
texts = data.get("input", [])
if isinstance(texts, str):
texts = [texts]
if not texts:
return jsonify({"error": "empty input"}), 400
try:
emb = embed(texts)
result = {
"data": [{"embedding": e, "index": i} for i, e in enumerate(emb)],
"model": "embeddinggemma-300m",
}
return jsonify(result)
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route("/health", methods=["GET"])
def health():
return jsonify({"status": "ok", "device": str(DEVICE)})
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--port", type=int, default=11436)
parser.add_argument("--model", type=str, default="google/embeddinggemma-300m")
args = parser.parse_args()
load_model(args.model)
app.run(host="0.0.0.0", port=args.port, threaded=True)