ollama source for Momentry Core verification
This commit is contained in:
287
x/imagegen/cmd/engine/main.go
Normal file
287
x/imagegen/cmd/engine/main.go
Normal file
@@ -0,0 +1,287 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"image"
|
||||
_ "image/jpeg"
|
||||
_ "image/png"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime/pprof"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/models/flux2"
|
||||
"github.com/ollama/ollama/x/imagegen/models/zimage"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// stringSlice is a flag type that accumulates multiple values
|
||||
type stringSlice []string
|
||||
|
||||
func (s *stringSlice) String() string {
|
||||
return fmt.Sprintf("%v", *s)
|
||||
}
|
||||
|
||||
func (s *stringSlice) Set(value string) error {
|
||||
*s = append(*s, value)
|
||||
return nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
modelPath := flag.String("model", "", "Model directory")
|
||||
prompt := flag.String("prompt", "Hello", "Prompt")
|
||||
|
||||
// Text generation params
|
||||
maxTokens := flag.Int("max-tokens", 100, "Max tokens")
|
||||
temperature := flag.Float64("temperature", 0.7, "Temperature")
|
||||
topP := flag.Float64("top-p", 0.9, "Top-p sampling")
|
||||
topK := flag.Int("top-k", 40, "Top-k sampling")
|
||||
imagePath := flag.String("image", "", "Image path for multimodal models")
|
||||
|
||||
// Image generation params
|
||||
width := flag.Int("width", 0, "Image width (0 = auto from input or 1024)")
|
||||
height := flag.Int("height", 0, "Image height (0 = auto from input or 1024)")
|
||||
steps := flag.Int("steps", 0, "Denoising steps (0 = model default)")
|
||||
seed := flag.Int64("seed", 42, "Random seed")
|
||||
out := flag.String("output", "output.png", "Output path")
|
||||
|
||||
// Utility flags
|
||||
listTensors := flag.Bool("list", false, "List tensors only")
|
||||
cpuProfile := flag.String("cpuprofile", "", "Write CPU profile to file")
|
||||
gpuCapture := flag.String("gpu-capture", "", "Capture GPU trace to .gputrace file (run with MTL_CAPTURE_ENABLED=1)")
|
||||
wiredLimitGB := flag.Int("wired-limit", 32, "Metal wired memory limit in GB")
|
||||
|
||||
// Legacy mode flags
|
||||
zimageFlag := flag.Bool("zimage", false, "Z-Image generation")
|
||||
flux2Flag := flag.Bool("flux2", false, "FLUX.2 Klein generation")
|
||||
var inputImages stringSlice
|
||||
flag.Var(&inputImages, "input-image", "Input image for image editing (can be specified multiple times)")
|
||||
negativePrompt := flag.String("negative-prompt", "", "Negative prompt for CFG (empty = no CFG, matching Python)")
|
||||
cfgScale := flag.Float64("cfg-scale", 4.0, "CFG scale for image editing")
|
||||
teaCache := flag.Bool("teacache", false, "Enable TeaCache for faster inference")
|
||||
teaCacheThreshold := flag.Float64("teacache-threshold", 0.1, "TeaCache threshold (lower = more aggressive caching)")
|
||||
fusedQKV := flag.Bool("fused-qkv", false, "Enable fused QKV projection for faster attention")
|
||||
|
||||
flag.Parse()
|
||||
|
||||
if *modelPath == "" {
|
||||
flag.Usage()
|
||||
return
|
||||
}
|
||||
|
||||
// Check if MLX initialized successfully
|
||||
if !mlx.IsMLXAvailable() {
|
||||
log.Fatalf("MLX initialization failed: %v", mlx.GetMLXInitError())
|
||||
}
|
||||
|
||||
// Restore strict error handling now that we know MLX is working.
|
||||
// During init(), a safe handler prevented exit(-1) on GPU errors.
|
||||
mlx.RestoreDefaultErrorHandler()
|
||||
|
||||
// CPU profiling
|
||||
if *cpuProfile != "" {
|
||||
f, err := os.Create(*cpuProfile)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
if err := pprof.StartCPUProfile(f); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer pprof.StopCPUProfile()
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
// Handle legacy mode flags that aren't unified yet
|
||||
switch {
|
||||
case *zimageFlag:
|
||||
m := &zimage.Model{}
|
||||
if loadErr := m.Load(*modelPath); loadErr != nil {
|
||||
log.Fatal(loadErr)
|
||||
}
|
||||
var img *mlx.Array
|
||||
img, err = m.GenerateFromConfig(context.Background(), &zimage.GenerateConfig{
|
||||
Prompt: *prompt,
|
||||
NegativePrompt: *negativePrompt,
|
||||
CFGScale: float32(*cfgScale),
|
||||
Width: int32(*width),
|
||||
Height: int32(*height),
|
||||
Steps: *steps,
|
||||
Seed: *seed,
|
||||
CapturePath: *gpuCapture,
|
||||
TeaCache: *teaCache,
|
||||
TeaCacheThreshold: float32(*teaCacheThreshold),
|
||||
FusedQKV: *fusedQKV,
|
||||
})
|
||||
if err == nil {
|
||||
err = saveImageArray(img, *out)
|
||||
}
|
||||
case *flux2Flag:
|
||||
m := &flux2.Model{}
|
||||
if loadErr := m.Load(*modelPath); loadErr != nil {
|
||||
log.Fatal(loadErr)
|
||||
}
|
||||
// Load input images with EXIF orientation correction
|
||||
var loadedImages []image.Image
|
||||
for _, path := range inputImages {
|
||||
img, loadErr := loadImageWithEXIF(path)
|
||||
if loadErr != nil {
|
||||
log.Fatalf("Failed to load image %s: %v", path, loadErr)
|
||||
}
|
||||
loadedImages = append(loadedImages, img)
|
||||
}
|
||||
// When input images provided and user didn't override dimensions, use 0 to match input
|
||||
fluxWidth := int32(*width)
|
||||
fluxHeight := int32(*height)
|
||||
if len(loadedImages) > 0 && *width == 0 && *height == 0 {
|
||||
// Both unset, will auto-detect from input
|
||||
} else if len(loadedImages) > 0 && *width == 0 {
|
||||
fluxWidth = 0 // Compute from height + aspect ratio
|
||||
} else if len(loadedImages) > 0 && *height == 0 {
|
||||
fluxHeight = 0 // Compute from width + aspect ratio
|
||||
}
|
||||
var img *mlx.Array
|
||||
img, err = m.GenerateFromConfig(context.Background(), &flux2.GenerateConfig{
|
||||
Prompt: *prompt,
|
||||
Width: fluxWidth,
|
||||
Height: fluxHeight,
|
||||
Steps: *steps,
|
||||
GuidanceScale: float32(*cfgScale),
|
||||
Seed: *seed,
|
||||
CapturePath: *gpuCapture,
|
||||
InputImages: loadedImages,
|
||||
})
|
||||
if err == nil {
|
||||
err = saveImageArray(img, *out)
|
||||
}
|
||||
case *listTensors:
|
||||
err = listModelTensors(*modelPath)
|
||||
default:
|
||||
// llm path
|
||||
m, err := load(*modelPath)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Load image if provided and model supports it.
|
||||
var image *mlx.Array
|
||||
if *imagePath != "" {
|
||||
if mm, ok := m.(interface{ ImageSize() int32 }); ok {
|
||||
image, err = imagegen.ProcessImage(*imagePath, mm.ImageSize())
|
||||
if err != nil {
|
||||
log.Fatal("load image:", err)
|
||||
}
|
||||
} else {
|
||||
log.Fatal("model does not support image input")
|
||||
}
|
||||
}
|
||||
|
||||
err = generate(context.Background(), m, input{
|
||||
Prompt: *prompt,
|
||||
Image: image,
|
||||
MaxTokens: *maxTokens,
|
||||
Temperature: float32(*temperature),
|
||||
TopP: float32(*topP),
|
||||
TopK: *topK,
|
||||
WiredLimitGB: *wiredLimitGB,
|
||||
}, func(out output) {
|
||||
if out.Text != "" {
|
||||
fmt.Print(out.Text)
|
||||
}
|
||||
if out.Done {
|
||||
fmt.Printf("\n\n[prefill: %.1f tok/s, gen: %.1f tok/s]\n", out.PrefillTokSec, out.GenTokSec)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func listModelTensors(modelPath string) error {
|
||||
weights, err := safetensors.LoadModelWeights(modelPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, name := range weights.ListTensors() {
|
||||
info, _ := weights.GetTensorInfo(name)
|
||||
fmt.Printf("%s: %v (%s)\n", name, info.Shape, info.Dtype)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadModel builds and evaluates a model using the common load pattern.
|
||||
// Release safetensors BEFORE eval - lazy arrays have captured their data,
|
||||
// and this reduces peak memory by ~6GB (matches mlx-lm behavior).
|
||||
func loadModel[T Model](build func() T, cleanup func()) T {
|
||||
m := build()
|
||||
weights := mlx.Collect(m)
|
||||
cleanup()
|
||||
mlx.Eval(weights...)
|
||||
return m
|
||||
}
|
||||
|
||||
func load(modelPath string) (Model, error) {
|
||||
kind, err := detectModelKind(modelPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("detect model kind: %w", err)
|
||||
}
|
||||
|
||||
switch kind {
|
||||
default:
|
||||
return nil, fmt.Errorf("model type %q is not supported by x/imagegen/cmd/engine", kind)
|
||||
}
|
||||
}
|
||||
|
||||
func detectModelKind(modelPath string) (string, error) {
|
||||
indexPath := filepath.Join(modelPath, "model_index.json")
|
||||
if _, err := os.Stat(indexPath); err == nil {
|
||||
data, err := os.ReadFile(indexPath)
|
||||
if err != nil {
|
||||
return "zimage", nil
|
||||
}
|
||||
var index struct {
|
||||
ClassName string `json:"_class_name"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &index); err == nil {
|
||||
switch index.ClassName {
|
||||
case "FluxPipeline", "ZImagePipeline":
|
||||
return "zimage", nil
|
||||
case "Flux2KleinPipeline":
|
||||
return "flux2", nil
|
||||
}
|
||||
}
|
||||
return "zimage", nil
|
||||
}
|
||||
|
||||
configPath := filepath.Join(modelPath, "config.json")
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("no config.json or model_index.json found: %w", err)
|
||||
}
|
||||
|
||||
var cfg struct {
|
||||
ModelType string `json:"model_type"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return "", fmt.Errorf("parse config.json: %w", err)
|
||||
}
|
||||
|
||||
return cfg.ModelType, nil
|
||||
}
|
||||
|
||||
// loadImageWithEXIF loads an image from a file path with EXIF orientation correction.
|
||||
func loadImageWithEXIF(path string) (image.Image, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read file: %w", err)
|
||||
}
|
||||
return imagegen.DecodeImage(data)
|
||||
}
|
||||
Reference in New Issue
Block a user