M4: trace API, portal embed client, EmbeddingGemma sync, release plan
This commit is contained in:
68
scripts/embeddinggemma_server.py
Normal file
68
scripts/embeddinggemma_server.py
Normal file
@@ -0,0 +1,68 @@
|
||||
#!/usr/bin/env python3
|
||||
"""EmbeddingGemma HTTP server - Metal GPU (MPS) accelerated, compatible with M4/M5."""
|
||||
|
||||
import argparse, json, time, torch
|
||||
from flask import Flask, request, jsonify
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
import numpy as np
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
MODEL = None
|
||||
TOKENIZER = None
|
||||
DEVICE = None
|
||||
|
||||
def load_model(model_path: str = "google/embeddinggemma-300m"):
|
||||
global MODEL, TOKENIZER, DEVICE
|
||||
if MODEL is not None:
|
||||
return
|
||||
DEVICE = "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
dtype = torch.float32
|
||||
print(f"[EmbeddingGemma] Loading model on {DEVICE} (dtype={dtype})...")
|
||||
t0 = time.time()
|
||||
MODEL = AutoModel.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=dtype,
|
||||
trust_remote_code=True,
|
||||
).eval().to(DEVICE)
|
||||
TOKENIZER = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
print(f"[EmbeddingGemma] Loaded in {time.time()-t0:.1f}s on {DEVICE}")
|
||||
|
||||
def embed(texts: list[str]) -> list[list[float]]:
|
||||
inputs = TOKENIZER(texts, padding=True, truncation=True, return_tensors="pt").to(DEVICE)
|
||||
with torch.no_grad():
|
||||
outputs = MODEL(**inputs)
|
||||
mask = inputs["attention_mask"].unsqueeze(-1).to(outputs.last_hidden_state.dtype)
|
||||
pooled = (outputs.last_hidden_state * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-9)
|
||||
pooled = torch.nn.functional.normalize(pooled, p=2, dim=1)
|
||||
return pooled.cpu().numpy().tolist()
|
||||
|
||||
@app.route("/v1/embeddings", methods=["POST"])
|
||||
def embeddings():
|
||||
data = request.get_json()
|
||||
texts = data.get("input", [])
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
if not texts:
|
||||
return jsonify({"error": "empty input"}), 400
|
||||
try:
|
||||
emb = embed(texts)
|
||||
result = {
|
||||
"data": [{"embedding": e, "index": i} for i, e in enumerate(emb)],
|
||||
"model": "embeddinggemma-300m",
|
||||
}
|
||||
return jsonify(result)
|
||||
except Exception as e:
|
||||
return jsonify({"error": str(e)}), 500
|
||||
|
||||
@app.route("/health", methods=["GET"])
|
||||
def health():
|
||||
return jsonify({"status": "ok", "device": str(DEVICE)})
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--port", type=int, default=11436)
|
||||
parser.add_argument("--model", type=str, default="google/embeddinggemma-300m")
|
||||
args = parser.parse_args()
|
||||
load_model(args.model)
|
||||
app.run(host="0.0.0.0", port=args.port, threaded=True)
|
||||
Reference in New Issue
Block a user