ollama source for Momentry Core verification
This commit is contained in:
25
x/imagegen/cmd/engine/README.md
Normal file
25
x/imagegen/cmd/engine/README.md
Normal file
@@ -0,0 +1,25 @@
|
||||
# MLX Engine
|
||||
|
||||
Experimental MLX backend for running models on Apple Silicon and CUDA.
|
||||
|
||||
## Build
|
||||
|
||||
```bash
|
||||
go build -o engine ./x/imagegen/cmd/engine
|
||||
```
|
||||
|
||||
## Text Generation
|
||||
|
||||
Text generation models are no longer supported by this engine.
|
||||
|
||||
## Image Generation
|
||||
|
||||
```bash
|
||||
./engine -zimage -model /path/to/z-image -prompt "a cat" -output cat.png
|
||||
```
|
||||
|
||||
Options:
|
||||
|
||||
- `-width`, `-height` - image dimensions (default 1024x1024)
|
||||
- `-steps` - denoising steps (default 9)
|
||||
- `-seed` - random seed (default 42)
|
||||
357
x/imagegen/cmd/engine/generate.go
Normal file
357
x/imagegen/cmd/engine/generate.go
Normal file
@@ -0,0 +1,357 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/cache"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/tokenizer"
|
||||
)
|
||||
|
||||
// Dedicated stream for generation (like mlx-lm's generation_stream)
|
||||
var generationStream *mlx.Stream
|
||||
|
||||
// utf8Streamer buffers decoded text and emits only complete UTF-8 characters.
|
||||
// This handles cases where tokenizers output partial multi-byte sequences.
|
||||
type utf8Streamer struct {
|
||||
buffer []byte
|
||||
}
|
||||
|
||||
// Write adds decoded text to the buffer and returns complete UTF-8 characters.
|
||||
func (s *utf8Streamer) Write(text string) string {
|
||||
s.buffer = append(s.buffer, text...)
|
||||
|
||||
// Find the last position that ends with a complete UTF-8 character
|
||||
validLen := 0
|
||||
for i := 0; i < len(s.buffer); {
|
||||
r, size := utf8.DecodeRune(s.buffer[i:])
|
||||
if r == utf8.RuneError && size == 1 {
|
||||
// Invalid or incomplete UTF-8 sequence at this position
|
||||
// Check if it could be a valid start of a multi-byte sequence
|
||||
if len(s.buffer)-i < 4 {
|
||||
// Might be incomplete, keep it in buffer
|
||||
break
|
||||
}
|
||||
// Definitely invalid, skip this byte
|
||||
i++
|
||||
validLen = i
|
||||
} else {
|
||||
i += size
|
||||
validLen = i
|
||||
}
|
||||
}
|
||||
|
||||
if validLen == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
result := string(s.buffer[:validLen])
|
||||
s.buffer = s.buffer[validLen:]
|
||||
return result
|
||||
}
|
||||
|
||||
// Flush returns any remaining buffered bytes (may be incomplete UTF-8).
|
||||
func (s *utf8Streamer) Flush() string {
|
||||
if len(s.buffer) == 0 {
|
||||
return ""
|
||||
}
|
||||
result := string(s.buffer)
|
||||
s.buffer = nil
|
||||
return result
|
||||
}
|
||||
|
||||
// withStream runs fn with the generation stream as default
|
||||
func withStream(fn func()) {
|
||||
// Lazy initialization of generationStream
|
||||
if generationStream == nil {
|
||||
generationStream = mlx.NewStream()
|
||||
}
|
||||
orig := mlx.GetDefaultStream()
|
||||
mlx.SetDefaultStream(generationStream)
|
||||
fn()
|
||||
mlx.SetDefaultStream(orig)
|
||||
}
|
||||
|
||||
type Model interface {
|
||||
Tokenizer() *tokenizer.Tokenizer
|
||||
VocabSize() int32
|
||||
NewCache(maxSeqLen int32) []cache.Cache
|
||||
Forward(input *mlx.Array, caches []cache.Cache) *mlx.Array
|
||||
}
|
||||
|
||||
// ChatModel is an optional interface for models that support chat formatting
|
||||
type ChatModel interface {
|
||||
FormatPrompt(prompt string) string
|
||||
}
|
||||
|
||||
// MultimodalModel is for models that support image input
|
||||
type MultimodalModel interface {
|
||||
Model
|
||||
FormatPromptWithImage(prompt string) string
|
||||
ExpandImageTokens(tokens []int32) []int32
|
||||
ForwardWithImage(tokens *mlx.Array, image *mlx.Array, caches []cache.Cache) *mlx.Array
|
||||
ImageSize() int32 // Returns expected image size for preprocessing
|
||||
}
|
||||
|
||||
// ImageLoader loads and preprocesses an image for multimodal models
|
||||
// Returns nil if path is empty
|
||||
type ImageLoader func(path string, imageSize int32) (*mlx.Array, error)
|
||||
|
||||
type input struct {
|
||||
Prompt string
|
||||
Image *mlx.Array // Optional preprocessed image for multimodal models
|
||||
MaxTokens int
|
||||
Temperature float32
|
||||
TopP float32
|
||||
TopK int
|
||||
WiredLimitGB int // Metal wired memory limit in GB (default 32)
|
||||
}
|
||||
|
||||
type output struct {
|
||||
Text string
|
||||
Done bool
|
||||
PrefillTokSec float64
|
||||
GenTokSec float64
|
||||
}
|
||||
|
||||
// Decoder wraps model + cache for autoregressive generation.
|
||||
type Decoder struct {
|
||||
model Model
|
||||
caches []cache.Cache
|
||||
vocabSize int32
|
||||
temp float32
|
||||
topK int
|
||||
topP float32
|
||||
token *mlx.Array // Current token (kept across pools)
|
||||
oldCacheState []*mlx.Array // Preallocated slice for old cache state
|
||||
image *mlx.Array // Optional image for multimodal prefill
|
||||
}
|
||||
|
||||
func NewDecoder(m Model, temp float32, topK int, topP float32) *Decoder {
|
||||
caches := m.NewCache(0)
|
||||
return &Decoder{
|
||||
model: m,
|
||||
caches: caches,
|
||||
vocabSize: m.VocabSize(),
|
||||
temp: temp,
|
||||
topK: topK,
|
||||
topP: topP,
|
||||
oldCacheState: make([]*mlx.Array, 0, len(caches)*2),
|
||||
}
|
||||
}
|
||||
|
||||
// SetImage sets the image for multimodal prefill (call before prefill)
|
||||
func (d *Decoder) SetImage(img *mlx.Array) {
|
||||
d.image = img
|
||||
}
|
||||
|
||||
func (d *Decoder) prefill(inputIDs []int32) int {
|
||||
processed := 0
|
||||
|
||||
// Track old cache state to free after each chunk
|
||||
var oldCacheState []*mlx.Array
|
||||
|
||||
// For multimodal models with an image, we need to process all tokens together
|
||||
// in the first forward pass so the image embeddings can be inserted properly.
|
||||
// Skip chunking for multimodal prefill.
|
||||
isMultimodal := d.image != nil
|
||||
|
||||
// Process all-but-1 tokens in chunks, eval cache state for memory management
|
||||
// Skip chunking for multimodal - process everything in the final step
|
||||
if !isMultimodal {
|
||||
for len(inputIDs) > 1 {
|
||||
chunkSize := min(2048, len(inputIDs)-1)
|
||||
if chunkSize <= 0 {
|
||||
break
|
||||
}
|
||||
chunk := inputIDs[:chunkSize]
|
||||
|
||||
// Save old cache state before forward
|
||||
oldCacheState = oldCacheState[:0]
|
||||
for _, c := range d.caches {
|
||||
oldCacheState = append(oldCacheState, c.State()...)
|
||||
}
|
||||
|
||||
var cacheState []*mlx.Array
|
||||
withStream(func() {
|
||||
x := mlx.NewArrayInt32(chunk, []int32{1, int32(len(chunk))})
|
||||
d.model.Forward(x, d.caches)
|
||||
for _, c := range d.caches {
|
||||
cacheState = append(cacheState, c.State()...)
|
||||
}
|
||||
})
|
||||
mlx.Eval(cacheState...)
|
||||
|
||||
// Free old cache state
|
||||
for _, arr := range oldCacheState {
|
||||
if arr != nil {
|
||||
arr.Free()
|
||||
}
|
||||
}
|
||||
|
||||
inputIDs = inputIDs[chunkSize:]
|
||||
processed += chunkSize
|
||||
}
|
||||
}
|
||||
|
||||
// Save old cache state before final step
|
||||
oldCacheState = oldCacheState[:0]
|
||||
for _, c := range d.caches {
|
||||
oldCacheState = append(oldCacheState, c.State()...)
|
||||
}
|
||||
|
||||
// Final token + sampling (or all tokens for multimodal)
|
||||
withStream(func() {
|
||||
x := mlx.NewArrayInt32(inputIDs, []int32{1, int32(len(inputIDs))})
|
||||
mlx.Eval(x) // Materialize before any other evals
|
||||
|
||||
var logits *mlx.Array
|
||||
// Use ForwardWithImage if we have an image and model supports it
|
||||
if d.image != nil {
|
||||
if mm, ok := d.model.(MultimodalModel); ok {
|
||||
logits = mm.ForwardWithImage(x, d.image, d.caches)
|
||||
d.image = nil // Only use image for first forward
|
||||
} else {
|
||||
logits = d.model.Forward(x, d.caches)
|
||||
}
|
||||
} else {
|
||||
logits = d.model.Forward(x, d.caches)
|
||||
}
|
||||
d.token = sample(logits, d.temp, d.topK, d.topP, d.vocabSize)
|
||||
})
|
||||
// Keep cache state (token auto-kept by AsyncEval)
|
||||
for _, c := range d.caches {
|
||||
mlx.Keep(c.State()...)
|
||||
}
|
||||
mlx.AsyncEval(d.token)
|
||||
|
||||
// Free old cache state from before final step
|
||||
for _, arr := range oldCacheState {
|
||||
if arr != nil {
|
||||
arr.Free()
|
||||
}
|
||||
}
|
||||
|
||||
mlx.ClearCache()
|
||||
|
||||
return processed + len(inputIDs)
|
||||
}
|
||||
|
||||
func (d *Decoder) step() int32 {
|
||||
prevToken := d.token
|
||||
|
||||
// Save old cache state (reuse preallocated slice)
|
||||
d.oldCacheState = d.oldCacheState[:0]
|
||||
for _, c := range d.caches {
|
||||
d.oldCacheState = append(d.oldCacheState, c.State()...)
|
||||
}
|
||||
|
||||
withStream(func() {
|
||||
logits := d.model.Forward(mlx.Reshape(prevToken, 1, 1), d.caches)
|
||||
d.token = sample(logits, d.temp, d.topK, d.topP, d.vocabSize)
|
||||
})
|
||||
// Keep token and new cache state so they survive cleanup
|
||||
mlx.Keep(d.token)
|
||||
for _, c := range d.caches {
|
||||
mlx.Keep(c.State()...)
|
||||
}
|
||||
mlx.AsyncEval(d.token)
|
||||
|
||||
// Sync on previous token (GPU already working on next step)
|
||||
val := prevToken.ItemInt32()
|
||||
|
||||
// Free old token and old cache state
|
||||
prevToken.Free()
|
||||
for _, arr := range d.oldCacheState {
|
||||
arr.Free()
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
||||
func generate(ctx context.Context, m Model, in input, cb func(output)) error {
|
||||
mlx.EnableCompile()
|
||||
wiredLimit := in.WiredLimitGB
|
||||
if wiredLimit <= 0 {
|
||||
wiredLimit = 32 // default 32GB
|
||||
}
|
||||
mlx.MetalSetWiredLimit(uint64(wiredLimit) << 30)
|
||||
|
||||
temp := in.Temperature
|
||||
if temp < 0 {
|
||||
temp = 0.7
|
||||
}
|
||||
|
||||
tok := m.Tokenizer()
|
||||
dec := NewDecoder(m, temp, in.TopK, in.TopP)
|
||||
|
||||
// Apply chat template - use image template if we have an image
|
||||
prompt := in.Prompt
|
||||
var tokens []int32
|
||||
if mm, ok := m.(MultimodalModel); ok && in.Image != nil {
|
||||
prompt = mm.FormatPromptWithImage(prompt)
|
||||
tokens = tok.Encode(prompt, true)
|
||||
tokens = mm.ExpandImageTokens(tokens) // Expand <start_of_image> to 256 image tokens
|
||||
dec.SetImage(in.Image)
|
||||
} else if cm, ok := m.(ChatModel); ok {
|
||||
prompt = cm.FormatPrompt(prompt)
|
||||
tokens = tok.Encode(prompt, true)
|
||||
} else {
|
||||
tokens = tok.Encode(prompt, true)
|
||||
}
|
||||
|
||||
prefillStart := time.Now()
|
||||
prefillTokens := dec.prefill(tokens)
|
||||
// Prefill measurement should include time to first token (like mlx-lm)
|
||||
// Step() waits for prefill to complete and returns first token
|
||||
firstToken := dec.step()
|
||||
prefillTokSec := float64(prefillTokens) / time.Since(prefillStart).Seconds()
|
||||
|
||||
genStart := time.Now()
|
||||
maxTokens := max(in.MaxTokens, 100)
|
||||
var genTokens int
|
||||
|
||||
// UTF-8 streamer to handle partial multi-byte characters
|
||||
streamer := &utf8Streamer{}
|
||||
|
||||
// Handle first token
|
||||
genTokens++
|
||||
if tok.IsEOS(firstToken) {
|
||||
cb(output{Done: true, PrefillTokSec: prefillTokSec, GenTokSec: 0})
|
||||
return nil
|
||||
}
|
||||
if text := streamer.Write(tok.Decode([]int32{firstToken})); text != "" {
|
||||
cb(output{Text: text})
|
||||
}
|
||||
|
||||
for n := 1; n < maxTokens; n++ {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
token := dec.step()
|
||||
genTokens++
|
||||
|
||||
if tok.IsEOS(token) {
|
||||
break
|
||||
}
|
||||
if text := streamer.Write(tok.Decode([]int32{token})); text != "" {
|
||||
cb(output{Text: text})
|
||||
}
|
||||
|
||||
if n%256 == 0 {
|
||||
mlx.ClearCache()
|
||||
}
|
||||
}
|
||||
|
||||
// Flush any remaining buffered bytes
|
||||
if text := streamer.Flush(); text != "" {
|
||||
cb(output{Text: text})
|
||||
}
|
||||
|
||||
fmt.Printf("\nPeak memory: %.2fGB\n", float64(mlx.MetalGetPeakMemory())/(1<<30))
|
||||
cb(output{Done: true, PrefillTokSec: prefillTokSec,
|
||||
GenTokSec: float64(genTokens) / time.Since(genStart).Seconds()})
|
||||
return nil
|
||||
}
|
||||
87
x/imagegen/cmd/engine/image.go
Normal file
87
x/imagegen/cmd/engine/image.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"image"
|
||||
"image/png"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
)
|
||||
|
||||
// saveImageArray saves an MLX array as a PNG image.
|
||||
// Expected format: [B, C, H, W] with values in [0, 1] range and C=3 (RGB).
|
||||
func saveImageArray(arr *mlx.Array, path string) error {
|
||||
img, err := arrayToImage(arr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return savePNG(img, path)
|
||||
}
|
||||
|
||||
func savePNG(img *image.RGBA, path string) error {
|
||||
if filepath.Ext(path) != ".png" {
|
||||
path = path + ".png"
|
||||
}
|
||||
f, err := os.Create(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
return png.Encode(f, img)
|
||||
}
|
||||
|
||||
func arrayToImage(arr *mlx.Array) (*image.RGBA, error) {
|
||||
shape := arr.Shape()
|
||||
if len(shape) != 4 {
|
||||
return nil, fmt.Errorf("expected 4D array [B, C, H, W], got %v", shape)
|
||||
}
|
||||
|
||||
// Transform to [H, W, C] for image conversion
|
||||
img := mlx.Squeeze(arr, 0)
|
||||
arr.Free()
|
||||
img = mlx.Transpose(img, 1, 2, 0)
|
||||
img = mlx.Contiguous(img)
|
||||
mlx.Eval(img)
|
||||
|
||||
imgShape := img.Shape()
|
||||
H := int(imgShape[0])
|
||||
W := int(imgShape[1])
|
||||
C := int(imgShape[2])
|
||||
|
||||
if C != 3 {
|
||||
img.Free()
|
||||
return nil, fmt.Errorf("expected 3 channels (RGB), got %d", C)
|
||||
}
|
||||
|
||||
// Copy to CPU and free GPU memory
|
||||
data := img.Data()
|
||||
img.Free()
|
||||
|
||||
// Write directly to Pix slice (faster than SetRGBA)
|
||||
goImg := image.NewRGBA(image.Rect(0, 0, W, H))
|
||||
pix := goImg.Pix
|
||||
for y := 0; y < H; y++ {
|
||||
for x := 0; x < W; x++ {
|
||||
srcIdx := (y*W + x) * C
|
||||
dstIdx := (y*W + x) * 4
|
||||
pix[dstIdx+0] = uint8(clampF(data[srcIdx+0]*255+0.5, 0, 255))
|
||||
pix[dstIdx+1] = uint8(clampF(data[srcIdx+1]*255+0.5, 0, 255))
|
||||
pix[dstIdx+2] = uint8(clampF(data[srcIdx+2]*255+0.5, 0, 255))
|
||||
pix[dstIdx+3] = 255
|
||||
}
|
||||
}
|
||||
|
||||
return goImg, nil
|
||||
}
|
||||
|
||||
func clampF(v, min, max float32) float32 {
|
||||
if v < min {
|
||||
return min
|
||||
}
|
||||
if v > max {
|
||||
return max
|
||||
}
|
||||
return v
|
||||
}
|
||||
287
x/imagegen/cmd/engine/main.go
Normal file
287
x/imagegen/cmd/engine/main.go
Normal file
@@ -0,0 +1,287 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"image"
|
||||
_ "image/jpeg"
|
||||
_ "image/png"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime/pprof"
|
||||
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/mlx"
|
||||
"github.com/ollama/ollama/x/imagegen/models/flux2"
|
||||
"github.com/ollama/ollama/x/imagegen/models/zimage"
|
||||
"github.com/ollama/ollama/x/imagegen/safetensors"
|
||||
)
|
||||
|
||||
// stringSlice is a flag type that accumulates multiple values
|
||||
type stringSlice []string
|
||||
|
||||
func (s *stringSlice) String() string {
|
||||
return fmt.Sprintf("%v", *s)
|
||||
}
|
||||
|
||||
func (s *stringSlice) Set(value string) error {
|
||||
*s = append(*s, value)
|
||||
return nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
modelPath := flag.String("model", "", "Model directory")
|
||||
prompt := flag.String("prompt", "Hello", "Prompt")
|
||||
|
||||
// Text generation params
|
||||
maxTokens := flag.Int("max-tokens", 100, "Max tokens")
|
||||
temperature := flag.Float64("temperature", 0.7, "Temperature")
|
||||
topP := flag.Float64("top-p", 0.9, "Top-p sampling")
|
||||
topK := flag.Int("top-k", 40, "Top-k sampling")
|
||||
imagePath := flag.String("image", "", "Image path for multimodal models")
|
||||
|
||||
// Image generation params
|
||||
width := flag.Int("width", 0, "Image width (0 = auto from input or 1024)")
|
||||
height := flag.Int("height", 0, "Image height (0 = auto from input or 1024)")
|
||||
steps := flag.Int("steps", 0, "Denoising steps (0 = model default)")
|
||||
seed := flag.Int64("seed", 42, "Random seed")
|
||||
out := flag.String("output", "output.png", "Output path")
|
||||
|
||||
// Utility flags
|
||||
listTensors := flag.Bool("list", false, "List tensors only")
|
||||
cpuProfile := flag.String("cpuprofile", "", "Write CPU profile to file")
|
||||
gpuCapture := flag.String("gpu-capture", "", "Capture GPU trace to .gputrace file (run with MTL_CAPTURE_ENABLED=1)")
|
||||
wiredLimitGB := flag.Int("wired-limit", 32, "Metal wired memory limit in GB")
|
||||
|
||||
// Legacy mode flags
|
||||
zimageFlag := flag.Bool("zimage", false, "Z-Image generation")
|
||||
flux2Flag := flag.Bool("flux2", false, "FLUX.2 Klein generation")
|
||||
var inputImages stringSlice
|
||||
flag.Var(&inputImages, "input-image", "Input image for image editing (can be specified multiple times)")
|
||||
negativePrompt := flag.String("negative-prompt", "", "Negative prompt for CFG (empty = no CFG, matching Python)")
|
||||
cfgScale := flag.Float64("cfg-scale", 4.0, "CFG scale for image editing")
|
||||
teaCache := flag.Bool("teacache", false, "Enable TeaCache for faster inference")
|
||||
teaCacheThreshold := flag.Float64("teacache-threshold", 0.1, "TeaCache threshold (lower = more aggressive caching)")
|
||||
fusedQKV := flag.Bool("fused-qkv", false, "Enable fused QKV projection for faster attention")
|
||||
|
||||
flag.Parse()
|
||||
|
||||
if *modelPath == "" {
|
||||
flag.Usage()
|
||||
return
|
||||
}
|
||||
|
||||
// Check if MLX initialized successfully
|
||||
if !mlx.IsMLXAvailable() {
|
||||
log.Fatalf("MLX initialization failed: %v", mlx.GetMLXInitError())
|
||||
}
|
||||
|
||||
// Restore strict error handling now that we know MLX is working.
|
||||
// During init(), a safe handler prevented exit(-1) on GPU errors.
|
||||
mlx.RestoreDefaultErrorHandler()
|
||||
|
||||
// CPU profiling
|
||||
if *cpuProfile != "" {
|
||||
f, err := os.Create(*cpuProfile)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
if err := pprof.StartCPUProfile(f); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer pprof.StopCPUProfile()
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
// Handle legacy mode flags that aren't unified yet
|
||||
switch {
|
||||
case *zimageFlag:
|
||||
m := &zimage.Model{}
|
||||
if loadErr := m.Load(*modelPath); loadErr != nil {
|
||||
log.Fatal(loadErr)
|
||||
}
|
||||
var img *mlx.Array
|
||||
img, err = m.GenerateFromConfig(context.Background(), &zimage.GenerateConfig{
|
||||
Prompt: *prompt,
|
||||
NegativePrompt: *negativePrompt,
|
||||
CFGScale: float32(*cfgScale),
|
||||
Width: int32(*width),
|
||||
Height: int32(*height),
|
||||
Steps: *steps,
|
||||
Seed: *seed,
|
||||
CapturePath: *gpuCapture,
|
||||
TeaCache: *teaCache,
|
||||
TeaCacheThreshold: float32(*teaCacheThreshold),
|
||||
FusedQKV: *fusedQKV,
|
||||
})
|
||||
if err == nil {
|
||||
err = saveImageArray(img, *out)
|
||||
}
|
||||
case *flux2Flag:
|
||||
m := &flux2.Model{}
|
||||
if loadErr := m.Load(*modelPath); loadErr != nil {
|
||||
log.Fatal(loadErr)
|
||||
}
|
||||
// Load input images with EXIF orientation correction
|
||||
var loadedImages []image.Image
|
||||
for _, path := range inputImages {
|
||||
img, loadErr := loadImageWithEXIF(path)
|
||||
if loadErr != nil {
|
||||
log.Fatalf("Failed to load image %s: %v", path, loadErr)
|
||||
}
|
||||
loadedImages = append(loadedImages, img)
|
||||
}
|
||||
// When input images provided and user didn't override dimensions, use 0 to match input
|
||||
fluxWidth := int32(*width)
|
||||
fluxHeight := int32(*height)
|
||||
if len(loadedImages) > 0 && *width == 0 && *height == 0 {
|
||||
// Both unset, will auto-detect from input
|
||||
} else if len(loadedImages) > 0 && *width == 0 {
|
||||
fluxWidth = 0 // Compute from height + aspect ratio
|
||||
} else if len(loadedImages) > 0 && *height == 0 {
|
||||
fluxHeight = 0 // Compute from width + aspect ratio
|
||||
}
|
||||
var img *mlx.Array
|
||||
img, err = m.GenerateFromConfig(context.Background(), &flux2.GenerateConfig{
|
||||
Prompt: *prompt,
|
||||
Width: fluxWidth,
|
||||
Height: fluxHeight,
|
||||
Steps: *steps,
|
||||
GuidanceScale: float32(*cfgScale),
|
||||
Seed: *seed,
|
||||
CapturePath: *gpuCapture,
|
||||
InputImages: loadedImages,
|
||||
})
|
||||
if err == nil {
|
||||
err = saveImageArray(img, *out)
|
||||
}
|
||||
case *listTensors:
|
||||
err = listModelTensors(*modelPath)
|
||||
default:
|
||||
// llm path
|
||||
m, err := load(*modelPath)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Load image if provided and model supports it.
|
||||
var image *mlx.Array
|
||||
if *imagePath != "" {
|
||||
if mm, ok := m.(interface{ ImageSize() int32 }); ok {
|
||||
image, err = imagegen.ProcessImage(*imagePath, mm.ImageSize())
|
||||
if err != nil {
|
||||
log.Fatal("load image:", err)
|
||||
}
|
||||
} else {
|
||||
log.Fatal("model does not support image input")
|
||||
}
|
||||
}
|
||||
|
||||
err = generate(context.Background(), m, input{
|
||||
Prompt: *prompt,
|
||||
Image: image,
|
||||
MaxTokens: *maxTokens,
|
||||
Temperature: float32(*temperature),
|
||||
TopP: float32(*topP),
|
||||
TopK: *topK,
|
||||
WiredLimitGB: *wiredLimitGB,
|
||||
}, func(out output) {
|
||||
if out.Text != "" {
|
||||
fmt.Print(out.Text)
|
||||
}
|
||||
if out.Done {
|
||||
fmt.Printf("\n\n[prefill: %.1f tok/s, gen: %.1f tok/s]\n", out.PrefillTokSec, out.GenTokSec)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func listModelTensors(modelPath string) error {
|
||||
weights, err := safetensors.LoadModelWeights(modelPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, name := range weights.ListTensors() {
|
||||
info, _ := weights.GetTensorInfo(name)
|
||||
fmt.Printf("%s: %v (%s)\n", name, info.Shape, info.Dtype)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadModel builds and evaluates a model using the common load pattern.
|
||||
// Release safetensors BEFORE eval - lazy arrays have captured their data,
|
||||
// and this reduces peak memory by ~6GB (matches mlx-lm behavior).
|
||||
func loadModel[T Model](build func() T, cleanup func()) T {
|
||||
m := build()
|
||||
weights := mlx.Collect(m)
|
||||
cleanup()
|
||||
mlx.Eval(weights...)
|
||||
return m
|
||||
}
|
||||
|
||||
func load(modelPath string) (Model, error) {
|
||||
kind, err := detectModelKind(modelPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("detect model kind: %w", err)
|
||||
}
|
||||
|
||||
switch kind {
|
||||
default:
|
||||
return nil, fmt.Errorf("model type %q is not supported by x/imagegen/cmd/engine", kind)
|
||||
}
|
||||
}
|
||||
|
||||
func detectModelKind(modelPath string) (string, error) {
|
||||
indexPath := filepath.Join(modelPath, "model_index.json")
|
||||
if _, err := os.Stat(indexPath); err == nil {
|
||||
data, err := os.ReadFile(indexPath)
|
||||
if err != nil {
|
||||
return "zimage", nil
|
||||
}
|
||||
var index struct {
|
||||
ClassName string `json:"_class_name"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &index); err == nil {
|
||||
switch index.ClassName {
|
||||
case "FluxPipeline", "ZImagePipeline":
|
||||
return "zimage", nil
|
||||
case "Flux2KleinPipeline":
|
||||
return "flux2", nil
|
||||
}
|
||||
}
|
||||
return "zimage", nil
|
||||
}
|
||||
|
||||
configPath := filepath.Join(modelPath, "config.json")
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("no config.json or model_index.json found: %w", err)
|
||||
}
|
||||
|
||||
var cfg struct {
|
||||
ModelType string `json:"model_type"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return "", fmt.Errorf("parse config.json: %w", err)
|
||||
}
|
||||
|
||||
return cfg.ModelType, nil
|
||||
}
|
||||
|
||||
// loadImageWithEXIF loads an image from a file path with EXIF orientation correction.
|
||||
func loadImageWithEXIF(path string) (image.Image, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read file: %w", err)
|
||||
}
|
||||
return imagegen.DecodeImage(data)
|
||||
}
|
||||
47
x/imagegen/cmd/engine/sample.go
Normal file
47
x/imagegen/cmd/engine/sample.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package main
|
||||
|
||||
import "github.com/ollama/ollama/x/imagegen/mlx"
|
||||
|
||||
// sampleTopK samples from top-k logits using global random state
|
||||
func sampleTopK(scaledLogits *mlx.Array, k int) *mlx.Array {
|
||||
neg := mlx.Neg(scaledLogits)
|
||||
indices := mlx.Argpartition(neg, k-1, -1)
|
||||
topKIdx := mlx.Slice(indices, []int32{0}, []int32{int32(k)})
|
||||
values := mlx.TakeAlongAxis(scaledLogits, topKIdx, -1)
|
||||
sampled := mlx.RandomCategorical(values, -1, 1)
|
||||
return mlx.Take(topKIdx, sampled, -1)
|
||||
}
|
||||
|
||||
// sampleTopP samples using nucleus sampling with global random state
|
||||
func sampleTopP(scaledLogits *mlx.Array, p float32, vocabSize int32) *mlx.Array {
|
||||
sorted := mlx.Argsort(mlx.Neg(scaledLogits), -1)
|
||||
sortedLogits := mlx.TakeAlongAxis(scaledLogits, sorted, -1)
|
||||
probs := mlx.Softmax(sortedLogits, -1)
|
||||
cumProbs := mlx.Cumsum(probs, -1)
|
||||
mask := mlx.LessScalar(cumProbs, p)
|
||||
negInf := mlx.FullDtype(float32(-1e9), scaledLogits.Dtype(), vocabSize)
|
||||
masked := mlx.Where(mask, sortedLogits, negInf)
|
||||
sampled := mlx.RandomCategorical(masked, -1, 1)
|
||||
return mlx.Take(sorted, sampled, -1)
|
||||
}
|
||||
|
||||
// sample samples from logits at the last position
|
||||
func sample(logits *mlx.Array, temp float32, topK int, topP float32, vocab int32) *mlx.Array {
|
||||
// Get last position logits: [1, L, vocab] -> [vocab]
|
||||
shape := logits.Shape()
|
||||
seqLen := shape[1]
|
||||
lastLogits := mlx.Slice(logits, []int32{0, seqLen - 1, 0}, []int32{1, seqLen, vocab})
|
||||
lastLogits = mlx.Reshape(lastLogits, vocab)
|
||||
|
||||
if temp == 0 {
|
||||
return mlx.Argmax(lastLogits, -1, false)
|
||||
}
|
||||
scaled := mlx.DivScalar(lastLogits, temp)
|
||||
if topK > 0 && topK < int(vocab) {
|
||||
return sampleTopK(scaled, topK)
|
||||
}
|
||||
if topP > 0 && topP < 1.0 {
|
||||
return sampleTopP(scaled, topP, vocab)
|
||||
}
|
||||
return mlx.RandomCategorical(scaled, -1, 1)
|
||||
}
|
||||
Reference in New Issue
Block a user