182 lines
4.6 KiB
Go
182 lines
4.6 KiB
Go
// Package imagegen provides a unified MLX runner for both LLM and image generation models.
|
|
package imagegen
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"flag"
|
|
"fmt"
|
|
"log/slog"
|
|
"net/http"
|
|
"os"
|
|
"os/signal"
|
|
"syscall"
|
|
"time"
|
|
|
|
"github.com/ollama/ollama/envconfig"
|
|
"github.com/ollama/ollama/x/imagegen/mlx"
|
|
"github.com/ollama/ollama/x/internal/mlxthread"
|
|
)
|
|
|
|
// Execute is the entry point for the unified MLX runner subprocess.
|
|
func Execute(args []string) error {
|
|
// Set up logging with appropriate level from environment
|
|
slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: envconfig.LogLevel()})))
|
|
|
|
fs := flag.NewFlagSet("mlx-runner", flag.ExitOnError)
|
|
modelName := fs.String("model", "", "path to model")
|
|
port := fs.Int("port", 0, "port to listen on")
|
|
|
|
if err := fs.Parse(args); err != nil {
|
|
return err
|
|
}
|
|
|
|
if *modelName == "" {
|
|
return fmt.Errorf("--model is required")
|
|
}
|
|
if *port == 0 {
|
|
return fmt.Errorf("--port is required")
|
|
}
|
|
|
|
// Detect model type from capabilities
|
|
mode := detectModelMode(*modelName)
|
|
slog.Info("starting mlx runner", "model", *modelName, "port", *port, "mode", mode)
|
|
|
|
if mode != ModeImageGen {
|
|
return fmt.Errorf("imagegen runner only supports image generation models")
|
|
}
|
|
|
|
worker, err := mlxthread.Start("imagegen", func() error {
|
|
if err := mlx.InitMLX(); err != nil {
|
|
slog.Error("unable to initialize MLX", "error", err)
|
|
return err
|
|
}
|
|
slog.Info("MLX library initialized")
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Create and start server
|
|
var server *server
|
|
if err := worker.Do(context.Background(), func() error {
|
|
var err error
|
|
server, err = newServer(*modelName, *port)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create server: %w", err)
|
|
}
|
|
server.mlxThread = worker
|
|
return nil
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Set up HTTP handlers
|
|
mux := http.NewServeMux()
|
|
mux.HandleFunc("/health", server.healthHandler)
|
|
mux.HandleFunc("/completion", server.completionHandler)
|
|
|
|
httpServer := &http.Server{
|
|
Addr: fmt.Sprintf("127.0.0.1:%d", *port),
|
|
Handler: mux,
|
|
}
|
|
|
|
// Handle shutdown
|
|
done := make(chan struct{})
|
|
go func() {
|
|
sigCh := make(chan os.Signal, 1)
|
|
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
|
<-sigCh
|
|
slog.Info("shutting down mlx runner")
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
if err := httpServer.Shutdown(ctx); err != nil {
|
|
slog.Warn("graceful shutdown timed out", "error", err)
|
|
if err := httpServer.Close(); err != nil {
|
|
slog.Warn("failed to close http server", "error", err)
|
|
}
|
|
}
|
|
if err := worker.Stop(ctx, func() {
|
|
mlx.ClearCache()
|
|
}); err != nil {
|
|
slog.Warn("failed to stop mlx worker", "error", err)
|
|
}
|
|
close(done)
|
|
}()
|
|
|
|
slog.Info("mlx runner listening", "addr", httpServer.Addr)
|
|
if err := httpServer.ListenAndServe(); err != http.ErrServerClosed {
|
|
return err
|
|
}
|
|
|
|
<-done
|
|
return nil
|
|
}
|
|
|
|
// detectModelMode determines whether a model is an LLM or image generation model.
|
|
func detectModelMode(modelName string) ModelMode {
|
|
// Check for image generation model by looking at model_index.json
|
|
modelType := DetectModelType(modelName)
|
|
if modelType != "" {
|
|
// Known image generation model types
|
|
switch modelType {
|
|
case "ZImagePipeline", "FluxPipeline", "Flux2KleinPipeline":
|
|
return ModeImageGen
|
|
}
|
|
}
|
|
|
|
// Default to LLM mode for safetensors models without known image gen types
|
|
return ModeLLM
|
|
}
|
|
|
|
// server holds the model and handles HTTP requests.
|
|
type server struct {
|
|
modelName string
|
|
port int
|
|
mlxThread *mlxthread.Thread
|
|
|
|
// Image generation model.
|
|
imageModel ImageModel
|
|
}
|
|
|
|
// newServer creates a new server instance for image generation models.
|
|
func newServer(modelName string, port int) (*server, error) {
|
|
s := &server{
|
|
modelName: modelName,
|
|
port: port,
|
|
}
|
|
|
|
if err := s.loadImageModel(); err != nil {
|
|
return nil, fmt.Errorf("failed to load image model: %w", err)
|
|
}
|
|
|
|
return s, nil
|
|
}
|
|
|
|
func (s *server) healthHandler(w http.ResponseWriter, r *http.Request) {
|
|
resp := HealthResponse{Status: "ok"}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(resp)
|
|
}
|
|
|
|
func (s *server) completionHandler(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodPost {
|
|
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
|
|
var req Request
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
if err := s.mlxThread.Do(r.Context(), func() error {
|
|
s.handleImageCompletion(w, r, req)
|
|
return nil
|
|
}); err != nil && r.Context().Err() == nil {
|
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
}
|
|
}
|