69 lines
2.4 KiB
Python
69 lines
2.4 KiB
Python
#!/usr/bin/env python3
|
|
"""EmbeddingGemma HTTP server - Metal GPU (MPS) accelerated, compatible with M4/M5."""
|
|
|
|
import argparse, json, time, torch
|
|
from flask import Flask, request, jsonify
|
|
from transformers import AutoModel, AutoTokenizer
|
|
import numpy as np
|
|
|
|
app = Flask(__name__)
|
|
|
|
MODEL = None
|
|
TOKENIZER = None
|
|
DEVICE = None
|
|
|
|
def load_model(model_path: str = "google/embeddinggemma-300m"):
|
|
global MODEL, TOKENIZER, DEVICE
|
|
if MODEL is not None:
|
|
return
|
|
DEVICE = "mps" if torch.backends.mps.is_available() else "cpu"
|
|
dtype = torch.float32
|
|
print(f"[EmbeddingGemma] Loading model on {DEVICE} (dtype={dtype})...")
|
|
t0 = time.time()
|
|
MODEL = AutoModel.from_pretrained(
|
|
model_path,
|
|
torch_dtype=dtype,
|
|
trust_remote_code=True,
|
|
).eval().to(DEVICE)
|
|
TOKENIZER = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
|
print(f"[EmbeddingGemma] Loaded in {time.time()-t0:.1f}s on {DEVICE}")
|
|
|
|
def embed(texts: list[str]) -> list[list[float]]:
|
|
inputs = TOKENIZER(texts, padding=True, truncation=True, return_tensors="pt").to(DEVICE)
|
|
with torch.no_grad():
|
|
outputs = MODEL(**inputs)
|
|
mask = inputs["attention_mask"].unsqueeze(-1).to(outputs.last_hidden_state.dtype)
|
|
pooled = (outputs.last_hidden_state * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-9)
|
|
pooled = torch.nn.functional.normalize(pooled, p=2, dim=1)
|
|
return pooled.cpu().numpy().tolist()
|
|
|
|
@app.route("/v1/embeddings", methods=["POST"])
|
|
def embeddings():
|
|
data = request.get_json()
|
|
texts = data.get("input", [])
|
|
if isinstance(texts, str):
|
|
texts = [texts]
|
|
if not texts:
|
|
return jsonify({"error": "empty input"}), 400
|
|
try:
|
|
emb = embed(texts)
|
|
result = {
|
|
"data": [{"embedding": e, "index": i} for i, e in enumerate(emb)],
|
|
"model": "embeddinggemma-300m",
|
|
}
|
|
return jsonify(result)
|
|
except Exception as e:
|
|
return jsonify({"error": str(e)}), 500
|
|
|
|
@app.route("/health", methods=["GET"])
|
|
def health():
|
|
return jsonify({"status": "ok", "device": str(DEVICE)})
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--port", type=int, default=11436)
|
|
parser.add_argument("--model", type=str, default="google/embeddinggemma-300m")
|
|
args = parser.parse_args()
|
|
load_model(args.model)
|
|
app.run(host="0.0.0.0", port=args.port, threaded=True)
|