ollama source for Momentry Core verification

This commit is contained in:
Accusys
2026-05-22 17:19:10 +08:00
commit 0b31ff9135
2020 changed files with 1413145 additions and 0 deletions

View File

@@ -0,0 +1,551 @@
// Package flux2 implements the FLUX.2 Klein diffusion transformer model.
// Klein is a 4B parameter distilled model that supports sub-second inference.
package flux2
import (
"context"
"encoding/json"
"fmt"
"image"
"math"
"time"
"github.com/ollama/ollama/x/imagegen/manifest"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/models/qwen3"
"github.com/ollama/ollama/x/imagegen/tokenizer"
"golang.org/x/image/draw"
)
// GenerateConfig holds all options for image generation.
type GenerateConfig struct {
Prompt string
Width int32 // Image width (default: 1024)
Height int32 // Image height (default: 1024)
Steps int // Denoising steps (default: 4 for Klein)
GuidanceScale float32 // Guidance scale (default: 1.0, Klein doesn't need CFG)
Seed int64 // Random seed
Progress func(step, totalSteps int) // Optional progress callback
CapturePath string // GPU capture path (debug)
InputImages []image.Image // Reference images for image conditioning (already loaded)
}
// Model represents a FLUX.2 Klein model.
type Model struct {
ModelName string
Tokenizer *tokenizer.Tokenizer
TextEncoder *qwen3.TextEncoder
Transformer *Flux2Transformer2DModel
VAE *AutoencoderKLFlux2
SchedulerConfig *SchedulerConfig
}
// TextEncoderLayerIndices are the layers from which to extract text embeddings.
// Diffusers uses hidden_states[9, 18, 27]. In Python, hidden_states[0] is the embedding
// output before any layers, so hidden_states[9] = after layer 8 (0-indexed).
// Go's ForwardWithLayerOutputs captures after layer i runs, so we use [8, 17, 26].
var TextEncoderLayerIndices = []int{8, 17, 26}
// Load loads the FLUX.2 Klein model from ollama blob storage.
func (m *Model) Load(modelName string) error {
fmt.Printf("Loading FLUX.2 Klein model from manifest: %s...\n", modelName)
start := time.Now()
if mlx.GPUIsAvailable() {
mlx.SetDefaultDeviceGPU()
mlx.EnableCompile()
}
m.ModelName = modelName
// Load manifest
manifest, err := manifest.LoadManifest(modelName)
if err != nil {
return fmt.Errorf("load manifest: %w", err)
}
// Load tokenizer
fmt.Print(" Loading tokenizer... ")
tokData, err := manifest.ReadConfig("tokenizer/tokenizer.json")
if err != nil {
return fmt.Errorf("tokenizer: %w", err)
}
tokConfig := &tokenizer.TokenizerConfig{}
if data, err := manifest.ReadConfig("tokenizer/tokenizer_config.json"); err == nil {
tokConfig.TokenizerConfigJSON = data
}
if data, err := manifest.ReadConfig("tokenizer/generation_config.json"); err == nil {
tokConfig.GenerationConfigJSON = data
}
if data, err := manifest.ReadConfig("tokenizer/special_tokens_map.json"); err == nil {
tokConfig.SpecialTokensMapJSON = data
}
tok, err := tokenizer.LoadFromBytesWithConfig(tokData, tokConfig)
if err != nil {
return fmt.Errorf("tokenizer: %w", err)
}
m.Tokenizer = tok
fmt.Println("✓")
// Load text encoder
m.TextEncoder = &qwen3.TextEncoder{}
if err := m.TextEncoder.Load(manifest, "text_encoder/config.json"); err != nil {
return fmt.Errorf("text encoder: %w", err)
}
// Load transformer
m.Transformer = &Flux2Transformer2DModel{}
if err := m.Transformer.Load(manifest); err != nil {
return fmt.Errorf("transformer: %w", err)
}
// Load VAE
m.VAE = &AutoencoderKLFlux2{}
if err := m.VAE.Load(manifest); err != nil {
return fmt.Errorf("VAE: %w", err)
}
// Evaluate all weights in a single batch (reduces GPU sync overhead)
fmt.Print(" Evaluating weights... ")
allWeights := mlx.Collect(m.TextEncoder)
allWeights = append(allWeights, mlx.Collect(m.Transformer)...)
allWeights = append(allWeights, mlx.Collect(m.VAE)...)
mlx.Eval(allWeights...)
fmt.Println("✓")
// Load scheduler config
m.SchedulerConfig = DefaultSchedulerConfig()
if schedData, err := manifest.ReadConfig("scheduler/scheduler_config.json"); err == nil {
if err := json.Unmarshal(schedData, m.SchedulerConfig); err != nil {
fmt.Printf(" Warning: failed to parse scheduler config: %v\n", err)
}
}
mem := mlx.MetalGetActiveMemory()
fmt.Printf(" Loaded in %.2fs (%.1f GB VRAM)\n", time.Since(start).Seconds(), float64(mem)/(1024*1024*1024))
return nil
}
// Generate creates an image from a prompt.
func (m *Model) Generate(prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
return m.GenerateFromConfig(context.Background(), &GenerateConfig{
Prompt: prompt,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
})
}
// GenerateWithProgress creates an image with progress callback.
func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress func(step, totalSteps int)) (*mlx.Array, error) {
return m.GenerateFromConfig(context.Background(), &GenerateConfig{
Prompt: prompt,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
Progress: progress,
})
}
// GenerateFromConfig generates an image using the unified config struct.
func (m *Model) GenerateFromConfig(ctx context.Context, cfg *GenerateConfig) (*mlx.Array, error) {
start := time.Now()
result, err := m.generate(ctx, cfg)
if err != nil {
return nil, err
}
fmt.Printf("Generated in %.2fs (%d steps)\n", time.Since(start).Seconds(), cfg.Steps)
return result, nil
}
// GenerateImage implements runner.ImageModel interface.
func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64, progress func(step, total int)) (*mlx.Array, error) {
return m.GenerateFromConfig(ctx, &GenerateConfig{
Prompt: prompt,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
Progress: progress,
})
}
// GenerateImageWithInputs implements runner.ImageEditModel interface.
// It generates an image conditioned on the provided input images for image editing.
func (m *Model) GenerateImageWithInputs(ctx context.Context, prompt string, width, height int32, steps int, seed int64, inputImages []image.Image, progress func(step, total int)) (*mlx.Array, error) {
return m.GenerateFromConfig(ctx, &GenerateConfig{
Prompt: prompt,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
InputImages: inputImages,
Progress: progress,
})
}
// MaxOutputPixels is the maximum output resolution (4 megapixels, ~2048x2048)
const MaxOutputPixels = 2048 * 2048
// MaxRefPixels is the maximum resolution for reference images (smaller to reduce attention memory)
const MaxRefPixels = 728 * 728
// generate is the internal denoising pipeline.
func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array, error) {
// Enable MLX compilation for fused kernels
mlx.EnableCompile()
// Apply defaults
if cfg.Steps <= 0 {
cfg.Steps = 4 // Klein default: 4 steps for distilled model
}
if cfg.GuidanceScale <= 0 {
cfg.GuidanceScale = 1.0 // Klein doesn't need guidance
}
// Determine output dimensions
if len(cfg.InputImages) > 0 {
// With input images, compute missing dimension from aspect ratio
// Images are already EXIF-rotated by the caller
bounds := cfg.InputImages[0].Bounds()
imgW, imgH := bounds.Dx(), bounds.Dy()
aspectRatio := float64(imgH) / float64(imgW)
if cfg.Width > 0 && cfg.Height <= 0 {
// Width specified, compute height
cfg.Height = int32(math.Round(float64(cfg.Width)*aspectRatio/16) * 16)
} else if cfg.Height > 0 && cfg.Width <= 0 {
// Height specified, compute width
cfg.Width = int32(math.Round(float64(cfg.Height)/aspectRatio/16) * 16)
} else if cfg.Width <= 0 && cfg.Height <= 0 {
// Neither specified, use input dimensions
cfg.Width = int32(imgW)
cfg.Height = int32(imgH)
}
}
if cfg.Width <= 0 {
cfg.Width = 1024
}
if cfg.Height <= 0 {
cfg.Height = 1024
}
// Cap to max pixels, preserve aspect ratio, round to multiple of 16
pixels := int(cfg.Width) * int(cfg.Height)
if pixels > MaxOutputPixels {
scale := math.Sqrt(float64(MaxOutputPixels) / float64(pixels))
cfg.Width = int32(math.Round(float64(cfg.Width) * scale / 16) * 16)
cfg.Height = int32(math.Round(float64(cfg.Height) * scale / 16) * 16)
}
cfg.Height = int32((cfg.Height + 8) / 16 * 16) // round to nearest 16
cfg.Width = int32((cfg.Width + 8) / 16 * 16)
fmt.Printf(" Output: %dx%d\n", cfg.Width, cfg.Height)
tcfg := m.Transformer.TransformerConfig
patchSize := m.VAE.Config.PatchSize
// Latent dimensions: image / 8 (VAE downscale) / patch_size
latentH := cfg.Height / 8
latentW := cfg.Width / 8
patchH := latentH / patchSize[0]
patchW := latentW / patchSize[1]
imgSeqLen := patchH * patchW
// Text encoding with multi-layer extraction (no padding, use true sequence length)
fmt.Print(" Encoding prompt... ")
promptEmbeds, textLen := m.TextEncoder.EncodePromptWithLayers(m.Tokenizer, cfg.Prompt, 512, TextEncoderLayerIndices, false)
fmt.Println("✓")
// Encode reference images if provided
var refTokens *ImageCondTokens
var refHeights, refWidths []int32
if len(cfg.InputImages) > 0 {
fmt.Printf(" Encoding %d reference image(s):\n", len(cfg.InputImages))
var err error
refTokens, err = m.EncodeImageRefs(cfg.InputImages)
if err != nil {
return nil, fmt.Errorf("encode reference images: %w", err)
}
// Extract heights/widths for RoPE computation (same limits as EncodeImageRefs)
limitPixels := MaxRefPixels
if len(cfg.InputImages) > 1 {
limitPixels = MaxRefPixels / 2
}
for _, img := range cfg.InputImages {
_, w, h := PrepareImage(img, limitPixels)
refHeights = append(refHeights, int32(h/16))
refWidths = append(refWidths, int32(w/16))
}
}
// Scheduler
scheduler := NewFlowMatchScheduler(m.SchedulerConfig)
scheduler.SetTimestepsWithMu(cfg.Steps, CalculateShift(imgSeqLen, cfg.Steps))
// Init latents in packed form [B, C*4, H/2, W/2] like diffusers
// diffusers creates noise in [B, 128, 64, 64] and packs to [B, 4096, 128]
latentChannels := m.VAE.Config.LatentChannels
packedChannels := latentChannels * 4 // 32 * 4 = 128
latents := scheduler.InitNoise([]int32{1, packedChannels, patchH, patchW}, cfg.Seed)
// Pack latents (transpose): [B, C, H, W] -> [B, H*W, C]
// This matches diffusers' _pack_latents
patches := packLatents(latents)
noiseSeqLen := patches.Shape()[1]
// RoPE cache - includes reference images if present
rope := PrepareRoPECache(textLen, patchH, patchW, tcfg.AxesDimsRoPE, tcfg.RopeTheta, refHeights, refWidths, ImageRefScale)
// Cleanup setup arrays when done
defer func() {
rope.Cos.Free()
rope.Sin.Free()
promptEmbeds.Free()
if refTokens != nil {
refTokens.Tokens.Free()
}
}()
// Pre-compute all timesteps before the loop to avoid per-step tensor creation
timesteps := make([]*mlx.Array, cfg.Steps)
for i := 0; i < cfg.Steps; i++ {
tCurr := scheduler.Timesteps[i] / float32(m.SchedulerConfig.NumTrainTimesteps)
timesteps[i] = mlx.ToBFloat16(mlx.NewArray([]float32{tCurr}, []int32{1}))
}
// Evaluate setup arrays
fmt.Print(" Evaluating setup... ")
setupStart := time.Now()
toEval := []*mlx.Array{promptEmbeds, patches, rope.Cos, rope.Sin}
toEval = append(toEval, timesteps...)
if refTokens != nil {
toEval = append(toEval, refTokens.Tokens)
}
mlx.Eval(toEval...)
mlx.MetalResetPeakMemory() // Reset peak to measure generation separately
fmt.Printf("✓ (%.2fs, %.1f GB)\n", time.Since(setupStart).Seconds(),
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024))
if cfg.Progress != nil {
cfg.Progress(0, cfg.Steps)
}
loopStart := time.Now()
stepStart := time.Now()
// Denoising loop
for i := 0; i < cfg.Steps; i++ {
// Check for cancellation
if ctx != nil {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
}
// GPU capture on step 2 if requested
if cfg.CapturePath != "" && i == 1 {
mlx.MetalStartCapture(cfg.CapturePath)
}
timestep := timesteps[i]
// Prepare input - concatenate noise patches with reference tokens if present
imgInput := patches
if refTokens != nil {
imgInput = mlx.Concatenate([]*mlx.Array{patches, refTokens.Tokens}, 1)
}
// Transformer forward pass
output := m.Transformer.Forward(imgInput, promptEmbeds, timestep, rope)
// If we concatenated reference tokens, slice to only get noise portion
if refTokens != nil {
output = mlx.Slice(output, []int32{0, 0, 0}, []int32{1, noiseSeqLen, output.Shape()[2]})
}
// Scheduler step (keep reference to old patches for the computation graph)
newPatches := scheduler.Step(output, patches, i)
if cfg.CapturePath != "" && i == 1 {
mlx.MetalStopCapture()
}
mlx.Eval(newPatches)
patches = newPatches
elapsed := time.Since(stepStart).Seconds()
peakGB := float64(mlx.MetalGetPeakMemory()) / (1024 * 1024 * 1024)
if i == 0 {
fmt.Printf(" step %d: %.2fs (JIT warmup), peak %.1f GB\n", i+1, elapsed, peakGB)
} else {
fmt.Printf(" step %d: %.2fs, peak %.1f GB\n", i+1, elapsed, peakGB)
}
stepStart = time.Now()
if cfg.Progress != nil {
cfg.Progress(i+1, cfg.Steps)
}
}
loopTime := time.Since(loopStart).Seconds()
peakMem := float64(mlx.MetalGetPeakMemory()) / (1024 * 1024 * 1024)
fmt.Printf(" Denoised %d steps in %.2fs (%.2fs/step), peak %.1f GB\n",
cfg.Steps, loopTime, loopTime/float64(cfg.Steps), peakMem)
// Free timesteps now that denoising is done
for _, ts := range timesteps {
ts.Free()
}
// VAE decode with tiling for larger images
fmt.Print(" Decoding VAE... ")
vaeStart := time.Now()
// Enable tiling for images > 512x512 (latent > 64x64)
// VAE attention is O(n²) on latent pixels, tiling reduces memory significantly
if patchH*2 > 64 || patchW*2 > 64 {
m.VAE.Tiling = DefaultTilingConfig()
}
decoded := m.VAE.Decode(patches, patchH, patchW)
mlx.Eval(decoded)
// Free patches now that decode is done
patches.Free()
fmt.Printf("✓ (%.2fs, peak %.1f GB)\n", time.Since(vaeStart).Seconds(),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
return decoded, nil
}
// packLatents converts [B, C, H, W] to [B, H*W, C] (matches diffusers _pack_latents)
func packLatents(x *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
C := shape[1]
H := shape[2]
W := shape[3]
// [B, C, H, W] -> [B, C, H*W] -> [B, H*W, C]
x = mlx.Reshape(x, B, C, H*W)
return mlx.Transpose(x, 0, 2, 1)
}
// LoadPersistent loads the model and keeps it in memory for repeated use.
func LoadPersistent(modelName string) (*Model, error) {
m := &Model{}
if err := m.Load(modelName); err != nil {
return nil, err
}
return m, nil
}
// ImageRefScale is the time coordinate offset between reference images (matches diffusers scale=10)
const ImageRefScale = 10
// PrepareImage resizes and crops an image to be a multiple of 16, with optional pixel limit.
// Returns the processed image and its dimensions.
func PrepareImage(img image.Image, limitPixels int) (image.Image, int, int) {
bounds := img.Bounds()
w, h := bounds.Dx(), bounds.Dy()
// Cap pixels if needed (like diffusers cap_pixels)
if limitPixels > 0 && w*h > limitPixels {
scale := math.Sqrt(float64(limitPixels) / float64(w*h))
w = int(float64(w) * scale)
h = int(float64(h) * scale)
}
// Round down to multiple of 16
w = (w / 16) * 16
h = (h / 16) * 16
if w < 16 {
w = 16
}
if h < 16 {
h = 16
}
// Resize using high-quality bicubic interpolation (matches diffusers' default lanczos)
resized := image.NewRGBA(image.Rect(0, 0, w, h))
draw.CatmullRom.Scale(resized, resized.Bounds(), img, img.Bounds(), draw.Over, nil)
return resized, w, h
}
// ImageToTensor converts an image to a tensor in [-1, 1] range with shape [1, C, H, W].
func ImageToTensor(img image.Image) *mlx.Array {
bounds := img.Bounds()
w, h := bounds.Dx(), bounds.Dy()
// Convert to float32 array in NCHW format [1, 3, H, W] with values in [-1, 1]
data := make([]float32, 3*h*w)
for y := 0; y < h; y++ {
for x := 0; x < w; x++ {
r, g, b, _ := img.At(x+bounds.Min.X, y+bounds.Min.Y).RGBA()
// RGBA returns 16-bit values, convert to [-1, 1]
data[0*h*w+y*w+x] = float32(r>>8)/127.5 - 1.0
data[1*h*w+y*w+x] = float32(g>>8)/127.5 - 1.0
data[2*h*w+y*w+x] = float32(b>>8)/127.5 - 1.0
}
}
arr := mlx.NewArrayFloat32(data, []int32{1, 3, int32(h), int32(w)})
return arr
}
// ImageCondTokens holds encoded reference image tokens.
type ImageCondTokens struct {
Tokens *mlx.Array // [1, total_tokens, C] - concatenated reference tokens
}
// EncodeImageRefs encodes reference images using the VAE.
func (m *Model) EncodeImageRefs(images []image.Image) (*ImageCondTokens, error) {
if len(images) == 0 {
return nil, nil
}
// Limit reference images to reduce attention memory
limitPixels := MaxRefPixels
if len(images) > 1 {
limitPixels = MaxRefPixels / 2
}
var allTokens []*mlx.Array
for _, img := range images {
// Prepare image (resize, crop to multiple of 16)
prepared, prepW, prepH := PrepareImage(img, limitPixels)
fmt.Printf(" Encoding %dx%d image... ", prepW, prepH)
// Convert to tensor [-1, 1]
tensor := ImageToTensor(prepared)
// Encode with VAE - returns [1, L, 128]
encoded := m.VAE.EncodeImage(tensor)
squeezed := mlx.Squeeze(encoded, 0) // [L, C]
// Defer eval - will be done with other setup arrays
allTokens = append(allTokens, squeezed)
fmt.Println("✓")
}
// For single image, just add batch dimension directly
// For multiple images, concatenate first
var tokens *mlx.Array
if len(allTokens) == 1 {
tokens = mlx.ExpandDims(allTokens[0], 0) // [1, L, C]
} else {
tokens = mlx.Concatenate(allTokens, 0) // [total_L, C]
tokens = mlx.ExpandDims(tokens, 0) // [1, total_L, C]
}
return &ImageCondTokens{Tokens: tokens}, nil
}

View File

@@ -0,0 +1,222 @@
package flux2
import (
"math"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// RoPEConfig holds 4D RoPE configuration for Flux2
type RoPEConfig struct {
Theta int32 // 2000 for Klein
AxesDims []int32 // [32, 32, 32, 32] - dimensions for T, H, W, L axes
}
// RoPECache holds precomputed RoPE cos/sin values
type RoPECache struct {
Cos *mlx.Array // [1, TotalSeqLen, 1, head_dim/2]
Sin *mlx.Array // [1, TotalSeqLen, 1, head_dim/2]
TextLen int32 // Length of text sequence
ImageLen int32 // Length of image sequence
}
// PrepareTextIDs creates position IDs for text tokens.
// Text tokens use: T=0, H=0, W=0, L=0..seqLen-1
// Returns: [seqLen, 4]
func PrepareTextIDs(seqLen int32) *mlx.Array {
ids := make([]float32, seqLen*4)
for i := int32(0); i < seqLen; i++ {
idx := i * 4
ids[idx+0] = 0 // T = 0
ids[idx+1] = 0 // H = 0
ids[idx+2] = 0 // W = 0
ids[idx+3] = float32(i) // L = sequence position
}
return mlx.NewArray(ids, []int32{seqLen, 4})
}
// PrepareLatentIDs creates position IDs for image latent tokens.
// Latent tokens use: T=0, H=0..height-1, W=0..width-1, L=0
// The latents are in row-major order (H then W).
// Returns: [height*width, 4]
func PrepareLatentIDs(height, width int32) *mlx.Array {
seqLen := height * width
ids := make([]float32, seqLen*4)
idx := 0
for h := int32(0); h < height; h++ {
for w := int32(0); w < width; w++ {
ids[idx*4+0] = 0 // T = 0
ids[idx*4+1] = float32(h) // H = row
ids[idx*4+2] = float32(w) // W = column
ids[idx*4+3] = 0 // L = 0
idx++
}
}
return mlx.NewArray(ids, []int32{seqLen, 4})
}
// PrepareImageIDs creates position IDs for reference image tokens (used in editing).
// Reference images use: T=scale*(i+1), H=0..h-1, W=0..w-1, L=0
// where i is the image index (0, 1, 2, ...) and scale separates images in T dimension.
// Returns: [total_tokens, 4]
func PrepareImageIDs(imageHeights, imageWidths []int32, scale int32) *mlx.Array {
// Calculate total tokens
totalTokens := int32(0)
for i := range imageHeights {
totalTokens += imageHeights[i] * imageWidths[i]
}
ids := make([]float32, totalTokens*4)
idx := int32(0)
for imgIdx, h := range imageHeights {
w := imageWidths[imgIdx]
tValue := float32(scale * int32(imgIdx+1))
for hi := int32(0); hi < h; hi++ {
for wi := int32(0); wi < w; wi++ {
ids[idx*4+0] = tValue // T = scale * (imgIdx + 1)
ids[idx*4+1] = float32(hi) // H = row
ids[idx*4+2] = float32(wi) // W = column
ids[idx*4+3] = 0 // L = 0
idx++
}
}
}
return mlx.NewArray(ids, []int32{totalTokens, 4})
}
// ComputeRoPE computes cos and sin for 4D rotary position embeddings.
// ids: [L, 4] with (T, H, W, L) coordinates
// axesDims: [32, 32, 32, 32] - each axis has this many dimensions (total = head_dim = 128)
// theta: base frequency (2000 for Klein)
// Returns: cos, sin each [1, L, 1, head_dim] with repeat_interleave applied
func ComputeRoPE(ids *mlx.Array, axesDims []int32, theta int32) (*mlx.Array, *mlx.Array) {
shape := ids.Shape()
seqLen := shape[0]
// Compute total head dim (sum of all axes dims)
headDim := int32(0)
for _, d := range axesDims {
headDim += d
}
// Extract each coordinate dimension
// ids[:, 0] = T, ids[:, 1] = H, ids[:, 2] = W, ids[:, 3] = L
posT := mlx.Slice(ids, []int32{0, 0}, []int32{seqLen, 1}) // [L, 1]
posH := mlx.Slice(ids, []int32{0, 1}, []int32{seqLen, 2}) // [L, 1]
posW := mlx.Slice(ids, []int32{0, 2}, []int32{seqLen, 3}) // [L, 1]
posL := mlx.Slice(ids, []int32{0, 3}, []int32{seqLen, 4}) // [L, 1]
// Compute frequencies for each axis
logTheta := float32(math.Log(float64(theta)))
cosArrs := make([]*mlx.Array, 4)
sinArrs := make([]*mlx.Array, 4)
positions := []*mlx.Array{posT, posH, posW, posL}
for i, axisDim := range axesDims {
half := axisDim / 2
// Create frequency array for this axis: theta^(-2j/dim) for j=0..half-1
// This matches diffusers: 1.0 / (theta ** (torch.arange(0, dim, 2) / dim))
freqs := make([]float32, half)
for j := int32(0); j < half; j++ {
freqs[j] = float32(math.Exp(float64(-logTheta * float32(2*j) / float32(axisDim))))
}
freqArr := mlx.NewArray(freqs, []int32{1, half})
// Compute pos * freq -> [L, half]
posExpanded := positions[i] // [L, 1]
args := mlx.Mul(posExpanded, freqArr) // [L, half]
// Compute cos and sin for this axis
cosAxis := mlx.Cos(args) // [L, half]
sinAxis := mlx.Sin(args) // [L, half]
// repeat_interleave(2): [c0, c1, ...] -> [c0, c0, c1, c1, ...]
// Reshape [L, half] -> [L, half, 1], tile to [L, half, 2], reshape to [L, axisDim]
cosAxis = mlx.ExpandDims(cosAxis, 2) // [L, half, 1]
cosAxis = mlx.Tile(cosAxis, []int32{1, 1, 2}) // [L, half, 2]
cosAxis = mlx.Reshape(cosAxis, seqLen, axisDim) // [L, axisDim]
sinAxis = mlx.ExpandDims(sinAxis, 2)
sinAxis = mlx.Tile(sinAxis, []int32{1, 1, 2})
sinAxis = mlx.Reshape(sinAxis, seqLen, axisDim)
cosArrs[i] = cosAxis
sinArrs[i] = sinAxis
}
// Concatenate all axes: [L, headDim]
cos := mlx.Concatenate(cosArrs, 1)
sin := mlx.Concatenate(sinArrs, 1)
// Reshape to [1, L, 1, headDim] for broadcasting with attention
cos = mlx.Reshape(cos, 1, seqLen, 1, headDim)
sin = mlx.Reshape(sin, 1, seqLen, 1, headDim)
return cos, sin
}
// ApplyRoPE4D applies 4D rotary position embeddings to queries and keys.
// x: [B, L, nheads, head_dim]
// cos, sin: [1, L, 1, head_dim] (with repeat_interleave applied)
// Returns: x with RoPE applied
// Matches diffusers apply_rotary_emb with use_real=True, use_real_unbind_dim=-1
func ApplyRoPE4D(x *mlx.Array, cos, sin *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
nheads := shape[2]
headDim := shape[3]
half := headDim / 2
// Reshape x to [B, L, nheads, half, 2] and split into real/imag
xReshaped := mlx.Reshape(x, B, L, nheads, half, 2)
// Extract real (index 0) and imag (index 1) parts
xReal := mlx.Slice(xReshaped, []int32{0, 0, 0, 0, 0}, []int32{B, L, nheads, half, 1})
xImag := mlx.Slice(xReshaped, []int32{0, 0, 0, 0, 1}, []int32{B, L, nheads, half, 2})
xReal = mlx.Squeeze(xReal, 4) // [B, L, nheads, half]
xImag = mlx.Squeeze(xImag, 4) // [B, L, nheads, half]
// x_rotated = stack([-x_imag, x_real], dim=-1).flatten(-2)
// This creates [-x_imag[0], x_real[0], -x_imag[1], x_real[1], ...]
negXImag := mlx.Neg(xImag)
negXImag = mlx.ExpandDims(negXImag, 4) // [B, L, nheads, half, 1]
xReal = mlx.ExpandDims(xReal, 4) // [B, L, nheads, half, 1]
xRotated := mlx.Concatenate([]*mlx.Array{negXImag, xReal}, 4) // [B, L, nheads, half, 2]
xRotated = mlx.Reshape(xRotated, B, L, nheads, headDim) // [B, L, nheads, headDim]
// out = x * cos + x_rotated * sin
return mlx.Add(mlx.Mul(x, cos), mlx.Mul(xRotated, sin))
}
// PrepareRoPECache creates RoPE cache for text + noise, optionally with reference images.
// textLen: number of text tokens
// noiseH, noiseW: dimensions of the noise latent in patch tokens
// axesDims: [32, 32, 32, 32]
// theta: 2000
// refHeights, refWidths: optional reference image dimensions (pass nil/empty for no images)
// scale: time coordinate offset between reference images (e.g., 10)
func PrepareRoPECache(textLen, noiseH, noiseW int32, axesDims []int32, theta int32, refHeights, refWidths []int32, scale int32) *RoPECache {
textIDs := PrepareTextIDs(textLen)
noiseIDs := PrepareLatentIDs(noiseH, noiseW)
var allIDs *mlx.Array
imageLen := noiseH * noiseW
if len(refHeights) > 0 {
refIDs := PrepareImageIDs(refHeights, refWidths, scale)
allIDs = mlx.Concatenate([]*mlx.Array{textIDs, noiseIDs, refIDs}, 0)
for i := range refHeights {
imageLen += refHeights[i] * refWidths[i]
}
} else {
allIDs = mlx.Concatenate([]*mlx.Array{textIDs, noiseIDs}, 0)
}
cos, sin := ComputeRoPE(allIDs, axesDims, theta)
cos = mlx.ToBFloat16(cos)
sin = mlx.ToBFloat16(sin)
return &RoPECache{Cos: cos, Sin: sin, TextLen: textLen, ImageLen: imageLen}
}

View File

@@ -0,0 +1,147 @@
package flux2
import (
"math"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// SchedulerConfig holds Flow-Match scheduler configuration
type SchedulerConfig struct {
NumTrainTimesteps int32 `json:"num_train_timesteps"` // 1000
Shift float32 `json:"shift"` // 3.0 for Klein
UseDynamicShifting bool `json:"use_dynamic_shifting"` // true
TimeShiftType string `json:"time_shift_type"` // "exponential" or "linear"
}
// DefaultSchedulerConfig returns default config for Klein
func DefaultSchedulerConfig() *SchedulerConfig {
return &SchedulerConfig{
NumTrainTimesteps: 1000,
Shift: 3.0, // Klein uses 3.0
UseDynamicShifting: true,
TimeShiftType: "exponential",
}
}
// FlowMatchScheduler implements the Flow-Match Euler discrete scheduler
type FlowMatchScheduler struct {
Config *SchedulerConfig
Timesteps []float32 // Discretized timesteps (t from 1 to 0)
Sigmas []float32 // Noise levels at each timestep
NumSteps int // Number of inference steps
}
// NewFlowMatchScheduler creates a new scheduler
func NewFlowMatchScheduler(cfg *SchedulerConfig) *FlowMatchScheduler {
return &FlowMatchScheduler{
Config: cfg,
}
}
// SetTimesteps sets up the scheduler for the given number of inference steps
func (s *FlowMatchScheduler) SetTimesteps(numSteps int) {
s.SetTimestepsWithMu(numSteps, 0)
}
// SetTimestepsWithMu sets up scheduler matching diffusers set_timesteps(sigmas=..., mu=...)
func (s *FlowMatchScheduler) SetTimestepsWithMu(numSteps int, mu float32) {
s.NumSteps = numSteps
// diffusers: sigmas = linspace(1, 1/num_steps, num_steps)
// Then applies time shift, appends 0.0 at end
s.Sigmas = make([]float32, numSteps+1)
for i := 0; i < numSteps; i++ {
// linspace(1, 1/num_steps, num_steps)
var sigma float32
if numSteps == 1 {
sigma = 1.0
} else {
sigma = 1.0 - float32(i)/float32(numSteps-1)*(1.0-1.0/float32(numSteps))
}
// Apply time shift if using dynamic shifting
if s.Config.UseDynamicShifting && mu != 0 {
sigma = s.timeShift(mu, sigma)
} else {
// If not dynamic shifting, apply fixed shift scaling like diffusers
shift := s.Config.Shift
sigma = shift * sigma / (1 + (shift-1)*sigma)
}
s.Sigmas[i] = sigma
}
// Append terminal zero
s.Sigmas[numSteps] = 0.0
// Timesteps scaled to training range (matches diffusers: timesteps = sigmas * num_train_timesteps)
s.Timesteps = make([]float32, numSteps+1)
for i, v := range s.Sigmas {
s.Timesteps[i] = v * float32(s.Config.NumTrainTimesteps)
}
}
// timeShift applies the dynamic time shift
func (s *FlowMatchScheduler) timeShift(mu float32, t float32) float32 {
if t <= 0 {
return 0
}
if s.Config.TimeShiftType == "linear" {
return mu / (mu + (1.0/t-1.0))
}
// Default: exponential
expMu := float32(math.Exp(float64(mu)))
return expMu / (expMu + (1.0/t - 1.0))
}
// Step performs one denoising step
func (s *FlowMatchScheduler) Step(modelOutput, sample *mlx.Array, timestepIdx int) *mlx.Array {
sigma := s.Sigmas[timestepIdx]
sigmaNext := s.Sigmas[timestepIdx+1]
// Euler step: x_{t-dt} = x_t + (sigma_next - sigma) * v_t
dt := sigmaNext - sigma
// Upcast to float32 for precision (matches diffusers)
sampleF32 := mlx.AsType(sample, mlx.DtypeFloat32)
outputF32 := mlx.AsType(modelOutput, mlx.DtypeFloat32)
scaledOutput := mlx.MulScalar(outputF32, dt)
result := mlx.Add(sampleF32, scaledOutput)
// Cast back to bfloat16
return mlx.ToBFloat16(result)
}
// GetTimestep returns the timestep value at the given index
func (s *FlowMatchScheduler) GetTimestep(idx int) float32 {
if idx < len(s.Timesteps) {
return s.Timesteps[idx]
}
return 0.0
}
// InitNoise creates initial noise for sampling
func (s *FlowMatchScheduler) InitNoise(shape []int32, seed int64) *mlx.Array {
return mlx.RandomNormalWithDtype(shape, uint64(seed), mlx.DtypeBFloat16)
}
// CalculateShift computes the mu shift value for dynamic scheduling
// Matches diffusers compute_empirical_mu function
func CalculateShift(imgSeqLen int32, numSteps int) float32 {
a1, b1 := float32(8.73809524e-05), float32(1.89833333)
a2, b2 := float32(0.00016927), float32(0.45666666)
seqLen := float32(imgSeqLen)
if imgSeqLen > 4300 {
return a2*seqLen + b2
}
m200 := a2*seqLen + b2
m10 := a1*seqLen + b1
a := (m200 - m10) / 190.0
b := m200 - 200.0*a
return a*float32(numSteps) + b
}

View File

@@ -0,0 +1,560 @@
package flux2
import (
"fmt"
"math"
"github.com/ollama/ollama/x/imagegen/manifest"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// TransformerConfig holds Flux2 transformer configuration
type TransformerConfig struct {
AttentionHeadDim int32 `json:"attention_head_dim"` // 128
AxesDimsRoPE []int32 `json:"axes_dims_rope"` // [32, 32, 32, 32]
Eps float32 `json:"eps"` // 1e-6
GuidanceEmbeds bool `json:"guidance_embeds"` // false for Klein
InChannels int32 `json:"in_channels"` // 128
JointAttentionDim int32 `json:"joint_attention_dim"` // 7680
MLPRatio float32 `json:"mlp_ratio"` // 3.0
NumAttentionHeads int32 `json:"num_attention_heads"` // 24
NumLayers int32 `json:"num_layers"` // 5
NumSingleLayers int32 `json:"num_single_layers"` // 20
PatchSize int32 `json:"patch_size"` // 1
RopeTheta int32 `json:"rope_theta"` // 2000
TimestepGuidanceChannels int32 `json:"timestep_guidance_channels"` // 256
}
// Computed dimensions
func (c *TransformerConfig) InnerDim() int32 {
return c.NumAttentionHeads * c.AttentionHeadDim // 24 * 128 = 3072
}
func (c *TransformerConfig) MLPHiddenDim() int32 {
return int32(float32(c.InnerDim()) * c.MLPRatio) // 3072 * 3.0 = 9216
}
// TimestepEmbedder creates timestep embeddings
// Weight names: time_guidance_embed.timestep_embedder.linear_1.weight, linear_2.weight
type TimestepEmbedder struct {
Linear1 nn.LinearLayer `weight:"linear_1"`
Linear2 nn.LinearLayer `weight:"linear_2"`
EmbedDim int32 // 256
}
// Forward creates sinusoidal embeddings and projects them
func (t *TimestepEmbedder) Forward(timesteps *mlx.Array) *mlx.Array {
half := t.EmbedDim / 2
freqs := make([]float32, half)
for i := int32(0); i < half; i++ {
freqs[i] = float32(math.Exp(-math.Log(10000.0) * float64(i) / float64(half)))
}
freqsArr := mlx.NewArray(freqs, []int32{1, half})
// timesteps: [B] -> [B, 1]
tExpanded := mlx.ExpandDims(timesteps, 1)
// args: [B, half]
args := mlx.Mul(tExpanded, freqsArr)
// [cos(args), sin(args)] -> [B, embed_dim]
sinEmbed := mlx.Concatenate([]*mlx.Array{mlx.Cos(args), mlx.Sin(args)}, 1)
// MLP: linear_1 -> silu -> linear_2
h := t.Linear1.Forward(sinEmbed)
h = mlx.SiLU(h)
return t.Linear2.Forward(h)
}
// TimeGuidanceEmbed wraps the timestep embedder
// Weight names: time_guidance_embed.timestep_embedder.*
type TimeGuidanceEmbed struct {
TimestepEmbedder *TimestepEmbedder `weight:"timestep_embedder"`
}
// Forward computes timestep embeddings
func (t *TimeGuidanceEmbed) Forward(timesteps *mlx.Array) *mlx.Array {
return t.TimestepEmbedder.Forward(timesteps)
}
// Modulation computes adaptive modulation parameters
// Weight names: double_stream_modulation_img.linear.weight, etc.
type Modulation struct {
Linear nn.LinearLayer `weight:"linear"`
}
// Forward computes modulation parameters
func (m *Modulation) Forward(temb *mlx.Array) *mlx.Array {
h := mlx.SiLU(temb)
return m.Linear.Forward(h)
}
// TransformerBlockAttn implements dual-stream attention
// Weight names: transformer_blocks.N.attn.*
type TransformerBlockAttn struct {
// Image stream (separate Q, K, V projections)
ToQ nn.LinearLayer `weight:"to_q"`
ToK nn.LinearLayer `weight:"to_k"`
ToV nn.LinearLayer `weight:"to_v"`
// Note: to_out has .0 suffix in weights, handled specially
ToOut0 nn.LinearLayer `weight:"to_out.0"`
// Text stream (add_ projections)
AddQProj nn.LinearLayer `weight:"add_q_proj"`
AddKProj nn.LinearLayer `weight:"add_k_proj"`
AddVProj nn.LinearLayer `weight:"add_v_proj"`
ToAddOut nn.LinearLayer `weight:"to_add_out"`
// QK norms for image stream
NormQ *mlx.Array `weight:"norm_q.weight"`
NormK *mlx.Array `weight:"norm_k.weight"`
// QK norms for text stream (added)
NormAddedQ *mlx.Array `weight:"norm_added_q.weight"`
NormAddedK *mlx.Array `weight:"norm_added_k.weight"`
}
// FeedForward implements SwiGLU MLP
// Weight names: transformer_blocks.N.ff.linear_in.weight, linear_out.weight
type FeedForward struct {
LinearIn nn.LinearLayer `weight:"linear_in"`
LinearOut nn.LinearLayer `weight:"linear_out"`
}
// Forward applies SwiGLU MLP
func (ff *FeedForward) Forward(x *mlx.Array) *mlx.Array {
// LinearIn outputs 2x hidden dim for SwiGLU
h := ff.LinearIn.Forward(x)
shape := h.Shape()
half := shape[len(shape)-1] / 2
// Split into gate and up
gate := mlx.Slice(h, []int32{0, 0, 0}, []int32{shape[0], shape[1], half})
up := mlx.Slice(h, []int32{0, 0, half}, []int32{shape[0], shape[1], shape[2]})
// SwiGLU: silu(gate) * up
h = mlx.Mul(mlx.SiLU(gate), up)
return ff.LinearOut.Forward(h)
}
// TransformerBlock implements a dual-stream transformer block
// Weight names: transformer_blocks.N.*
type TransformerBlock struct {
Attn *TransformerBlockAttn `weight:"attn"`
FF *FeedForward `weight:"ff"`
FFContext *FeedForward `weight:"ff_context"`
// Config (set after loading)
NHeads int32
HeadDim int32
Scale float32
}
// Forward applies the dual-stream block
// imgHidden: [B, imgLen, dim]
// txtHidden: [B, txtLen, dim]
// imgMod, txtMod: modulation params [B, 6*dim] each
// cos, sin: RoPE values
func (block *TransformerBlock) Forward(imgHidden, txtHidden *mlx.Array, imgMod, txtMod *mlx.Array, cos, sin *mlx.Array) (*mlx.Array, *mlx.Array) {
imgShape := imgHidden.Shape()
B := imgShape[0]
imgLen := imgShape[1]
dim := imgShape[2]
txtLen := txtHidden.Shape()[1]
// Parse modulation: 6 params each (shift1, scale1, gate1, shift2, scale2, gate2)
imgShift1, imgScale1, imgGate1 := parseModulation3(imgMod, dim, 0)
imgShift2, imgScale2, imgGate2 := parseModulation3(imgMod, dim, 3)
txtShift1, txtScale1, txtGate1 := parseModulation3(txtMod, dim, 0)
txtShift2, txtScale2, txtGate2 := parseModulation3(txtMod, dim, 3)
// === Attention branch ===
// Modulate inputs
imgNorm := modulateLayerNorm(imgHidden, imgShift1, imgScale1)
txtNorm := modulateLayerNorm(txtHidden, txtShift1, txtScale1)
// Compute Q, K, V for image stream (separate projections)
imgQ := block.Attn.ToQ.Forward(imgNorm)
imgK := block.Attn.ToK.Forward(imgNorm)
imgV := block.Attn.ToV.Forward(imgNorm)
// Compute Q, K, V for text stream (add_ projections)
txtQ := block.Attn.AddQProj.Forward(txtNorm)
txtK := block.Attn.AddKProj.Forward(txtNorm)
txtV := block.Attn.AddVProj.Forward(txtNorm)
// Reshape for attention: [B, L, dim] -> [B, L, nheads, headDim]
imgQ = mlx.Reshape(imgQ, B, imgLen, block.NHeads, block.HeadDim)
imgK = mlx.Reshape(imgK, B, imgLen, block.NHeads, block.HeadDim)
imgV = mlx.Reshape(imgV, B, imgLen, block.NHeads, block.HeadDim)
txtQ = mlx.Reshape(txtQ, B, txtLen, block.NHeads, block.HeadDim)
txtK = mlx.Reshape(txtK, B, txtLen, block.NHeads, block.HeadDim)
txtV = mlx.Reshape(txtV, B, txtLen, block.NHeads, block.HeadDim)
// Apply QK norm (RMSNorm with learned scale)
imgQ = applyQKNorm(imgQ, block.Attn.NormQ)
imgK = applyQKNorm(imgK, block.Attn.NormK)
txtQ = applyQKNorm(txtQ, block.Attn.NormAddedQ)
txtK = applyQKNorm(txtK, block.Attn.NormAddedK)
// Concatenate for joint attention: text first, then image
q := mlx.Concatenate([]*mlx.Array{txtQ, imgQ}, 1)
k := mlx.Concatenate([]*mlx.Array{txtK, imgK}, 1)
v := mlx.Concatenate([]*mlx.Array{txtV, imgV}, 1)
// Apply RoPE
q = ApplyRoPE4D(q, cos, sin)
k = ApplyRoPE4D(k, cos, sin)
// Transpose for SDPA: [B, nheads, L, headDim]
q = mlx.Transpose(q, 0, 2, 1, 3)
k = mlx.Transpose(k, 0, 2, 1, 3)
v = mlx.Transpose(v, 0, 2, 1, 3)
// Scaled dot-product attention
out := mlx.ScaledDotProductAttention(q, k, v, block.Scale, false)
// Transpose back: [B, L, nheads, headDim]
out = mlx.Transpose(out, 0, 2, 1, 3)
// Split back into txt and img
totalLen := txtLen + imgLen
txtOut := mlx.Slice(out, []int32{0, 0, 0, 0}, []int32{B, txtLen, block.NHeads, block.HeadDim})
imgOut := mlx.Slice(out, []int32{0, txtLen, 0, 0}, []int32{B, totalLen, block.NHeads, block.HeadDim})
// Reshape and project
txtOut = mlx.Reshape(txtOut, B, txtLen, dim)
imgOut = mlx.Reshape(imgOut, B, imgLen, dim)
txtOut = block.Attn.ToAddOut.Forward(txtOut)
imgOut = block.Attn.ToOut0.Forward(imgOut)
// Apply gates and residual
imgHidden = mlx.Add(imgHidden, mlx.Mul(imgGate1, imgOut))
txtHidden = mlx.Add(txtHidden, mlx.Mul(txtGate1, txtOut))
// === MLP branch ===
imgNorm = modulateLayerNorm(imgHidden, imgShift2, imgScale2)
txtNorm = modulateLayerNorm(txtHidden, txtShift2, txtScale2)
imgFFOut := block.FF.Forward(imgNorm)
txtFFOut := block.FFContext.Forward(txtNorm)
imgHidden = mlx.Add(imgHidden, mlx.Mul(imgGate2, imgFFOut))
txtHidden = mlx.Add(txtHidden, mlx.Mul(txtGate2, txtFFOut))
return imgHidden, txtHidden
}
// SingleTransformerBlockAttn implements attention for single-stream blocks
// Weight names: single_transformer_blocks.N.attn.*
type SingleTransformerBlockAttn struct {
ToQKVMlpProj nn.LinearLayer `weight:"to_qkv_mlp_proj"` // Fused QKV + MLP input
ToOut nn.LinearLayer `weight:"to_out"` // Fused attn_out + MLP out
NormQ *mlx.Array `weight:"norm_q.weight"`
NormK *mlx.Array `weight:"norm_k.weight"`
}
// SingleTransformerBlock implements a single-stream transformer block
// Weight names: single_transformer_blocks.N.*
type SingleTransformerBlock struct {
Attn *SingleTransformerBlockAttn `weight:"attn"`
// Config
NHeads int32
HeadDim int32
InnerDim int32
MLPHidDim int32
Scale float32
}
// Forward applies the single-stream block
// x: [B, L, dim] concatenated text+image
// mod: modulation [B, 3*dim]
func (block *SingleTransformerBlock) Forward(x *mlx.Array, mod *mlx.Array, cos, sin *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
dim := shape[2]
// Parse modulation: (shift, scale, gate)
shift, scale, gate := parseModulation3(mod, dim, 0)
// Modulate input
h := modulateLayerNorm(x, shift, scale)
// Fused projection: QKV + MLP gate/up
// linear1 outputs: [q, k, v, mlp_gate, mlp_up] = [dim, dim, dim, mlpHid, mlpHid]
qkvMlp := block.Attn.ToQKVMlpProj.Forward(h)
// Split: first 3*dim is QKV, rest is MLP
qkvDim := 3 * block.InnerDim
qkv := mlx.Slice(qkvMlp, []int32{0, 0, 0}, []int32{B, L, qkvDim})
mlpIn := mlx.Slice(qkvMlp, []int32{0, 0, qkvDim}, []int32{B, L, qkvMlp.Shape()[2]})
// Split QKV
q, k, v := splitQKV(qkv, B, L, block.InnerDim)
// Reshape for attention
q = mlx.Reshape(q, B, L, block.NHeads, block.HeadDim)
k = mlx.Reshape(k, B, L, block.NHeads, block.HeadDim)
v = mlx.Reshape(v, B, L, block.NHeads, block.HeadDim)
// QK norm
q = applyQKNorm(q, block.Attn.NormQ)
k = applyQKNorm(k, block.Attn.NormK)
// Apply RoPE
q = ApplyRoPE4D(q, cos, sin)
k = ApplyRoPE4D(k, cos, sin)
// Transpose for SDPA
q = mlx.Transpose(q, 0, 2, 1, 3)
k = mlx.Transpose(k, 0, 2, 1, 3)
v = mlx.Transpose(v, 0, 2, 1, 3)
// SDPA
attnOut := mlx.ScaledDotProductAttention(q, k, v, block.Scale, false)
// Transpose back and reshape
attnOut = mlx.Transpose(attnOut, 0, 2, 1, 3)
attnOut = mlx.Reshape(attnOut, B, L, block.InnerDim)
// MLP: SwiGLU
mlpShape := mlpIn.Shape()
half := mlpShape[2] / 2
mlpGate := mlx.Slice(mlpIn, []int32{0, 0, 0}, []int32{B, L, half})
mlpUp := mlx.Slice(mlpIn, []int32{0, 0, half}, []int32{B, L, mlpShape[2]})
mlpOut := mlx.Mul(mlx.SiLU(mlpGate), mlpUp)
// Concatenate attention and MLP for fused output
combined := mlx.Concatenate([]*mlx.Array{attnOut, mlpOut}, 2)
// Output projection
out := block.Attn.ToOut.Forward(combined)
// Apply gate and residual
return mlx.Add(x, mlx.Mul(gate, out))
}
// NormOut implements the output normalization with modulation
// Weight names: norm_out.linear.weight
type NormOut struct {
Linear nn.LinearLayer `weight:"linear"`
}
// Forward computes final modulated output
func (n *NormOut) Forward(x *mlx.Array, temb *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
dim := shape[2]
// Modulation: temb -> silu -> linear -> [shift, scale]
mod := mlx.SiLU(temb)
mod = n.Linear.Forward(mod)
// Split into scale and shift (diffusers order: scale first, shift second)
scale := mlx.Slice(mod, []int32{0, 0}, []int32{B, dim})
shift := mlx.Slice(mod, []int32{0, dim}, []int32{B, 2 * dim})
shift = mlx.ExpandDims(shift, 1)
scale = mlx.ExpandDims(scale, 1)
// Modulate with RMSNorm
return modulateLayerNorm(x, shift, scale)
}
// Flux2Transformer2DModel is the main Flux2 transformer
// Weight names at top level: time_guidance_embed.*, double_stream_modulation_*.*, etc.
type Flux2Transformer2DModel struct {
// Timestep embedding
TimeGuidanceEmbed *TimeGuidanceEmbed `weight:"time_guidance_embed"`
// Shared modulation
DoubleStreamModulationImg *Modulation `weight:"double_stream_modulation_img"`
DoubleStreamModulationTxt *Modulation `weight:"double_stream_modulation_txt"`
SingleStreamModulation *Modulation `weight:"single_stream_modulation"`
// Embedders
XEmbedder nn.LinearLayer `weight:"x_embedder"`
ContextEmbedder nn.LinearLayer `weight:"context_embedder"`
// Transformer blocks
TransformerBlocks []*TransformerBlock `weight:"transformer_blocks"`
SingleTransformerBlocks []*SingleTransformerBlock `weight:"single_transformer_blocks"`
// Output
NormOut *NormOut `weight:"norm_out"`
ProjOut nn.LinearLayer `weight:"proj_out"`
*TransformerConfig
}
// Load loads the Flux2 transformer from ollama blob storage.
func (m *Flux2Transformer2DModel) Load(modelManifest *manifest.ModelManifest) error {
fmt.Print(" Loading transformer... ")
// Load config from blob
var cfg TransformerConfig
if err := modelManifest.ReadConfigJSON("transformer/config.json", &cfg); err != nil {
return fmt.Errorf("config: %w", err)
}
m.TransformerConfig = &cfg
// Initialize slices
m.TransformerBlocks = make([]*TransformerBlock, cfg.NumLayers)
m.SingleTransformerBlocks = make([]*SingleTransformerBlock, cfg.NumSingleLayers)
// Initialize TimeGuidanceEmbed with embed dim
m.TimeGuidanceEmbed = &TimeGuidanceEmbed{
TimestepEmbedder: &TimestepEmbedder{EmbedDim: cfg.TimestepGuidanceChannels},
}
// Load weights from tensor blobs
weights, err := manifest.LoadWeightsFromManifest(modelManifest, "transformer")
if err != nil {
return fmt.Errorf("weights: %w", err)
}
if err := weights.Load(0); err != nil {
return fmt.Errorf("load weights: %w", err)
}
defer weights.ReleaseAll()
return m.loadWeights(weights)
}
// loadWeights loads weights from any WeightSource into the model
func (m *Flux2Transformer2DModel) loadWeights(weights safetensors.WeightSource) error {
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return fmt.Errorf("load module: %w", err)
}
m.initComputedFields()
fmt.Println("✓")
return nil
}
// initComputedFields initializes computed fields after loading weights
func (m *Flux2Transformer2DModel) initComputedFields() {
cfg := m.TransformerConfig
innerDim := cfg.InnerDim()
scale := float32(1.0 / math.Sqrt(float64(cfg.AttentionHeadDim)))
// Initialize transformer blocks
for _, block := range m.TransformerBlocks {
block.NHeads = cfg.NumAttentionHeads
block.HeadDim = cfg.AttentionHeadDim
block.Scale = scale
}
// Initialize single transformer blocks
for _, block := range m.SingleTransformerBlocks {
block.NHeads = cfg.NumAttentionHeads
block.HeadDim = cfg.AttentionHeadDim
block.InnerDim = innerDim
block.MLPHidDim = cfg.MLPHiddenDim()
block.Scale = scale
}
}
// Forward runs the Flux2 transformer
func (m *Flux2Transformer2DModel) Forward(patches, txtEmbeds *mlx.Array, timesteps *mlx.Array, rope *RoPECache) *mlx.Array {
patchShape := patches.Shape()
B := patchShape[0]
imgLen := patchShape[1]
txtLen := txtEmbeds.Shape()[1]
// Scale timestep to 0-1000 range (diffusers multiplies by 1000)
scaledTimesteps := mlx.MulScalar(timesteps, 1000.0)
// Compute timestep embedding
temb := m.TimeGuidanceEmbed.Forward(scaledTimesteps)
// Embed patches and text
imgHidden := m.XEmbedder.Forward(patches)
txtHidden := m.ContextEmbedder.Forward(txtEmbeds)
// Compute shared modulation
imgMod := m.DoubleStreamModulationImg.Forward(temb)
txtMod := m.DoubleStreamModulationTxt.Forward(temb)
singleMod := m.SingleStreamModulation.Forward(temb)
// Double (dual-stream) blocks
for _, block := range m.TransformerBlocks {
imgHidden, txtHidden = block.Forward(imgHidden, txtHidden, imgMod, txtMod, rope.Cos, rope.Sin)
}
// Concatenate for single-stream: text first, then image
hidden := mlx.Concatenate([]*mlx.Array{txtHidden, imgHidden}, 1)
// Single-stream blocks
for _, block := range m.SingleTransformerBlocks {
hidden = block.Forward(hidden, singleMod, rope.Cos, rope.Sin)
}
// Extract image portion
totalLen := txtLen + imgLen
imgOut := mlx.Slice(hidden, []int32{0, txtLen, 0}, []int32{B, totalLen, hidden.Shape()[2]})
// Final norm and projection
imgOut = m.NormOut.Forward(imgOut, temb)
return m.ProjOut.Forward(imgOut)
}
// Note: QK normalization uses mlx.RMSNorm (the fast version) directly
// See applyQKNorm function below
// compiledSwiGLU fuses: silu(gate) * up
// Called 30x per step (10 in dual-stream + 20 in single-stream blocks)
var compiledSwiGLU *mlx.CompiledFunc
func getCompiledSwiGLU() *mlx.CompiledFunc {
if compiledSwiGLU == nil {
compiledSwiGLU = mlx.CompileShapeless(func(inputs []*mlx.Array) []*mlx.Array {
gate, up := inputs[0], inputs[1]
return []*mlx.Array{mlx.Mul(mlx.SiLU(gate), up)}
}, true)
}
return compiledSwiGLU
}
// Helper functions
// parseModulation3 extracts 3 modulation params (shift, scale, gate) starting at offset
func parseModulation3(mod *mlx.Array, dim int32, offset int32) (*mlx.Array, *mlx.Array, *mlx.Array) {
B := mod.Shape()[0]
start := offset * dim
shift := mlx.Slice(mod, []int32{0, start}, []int32{B, start + dim})
scale := mlx.Slice(mod, []int32{0, start + dim}, []int32{B, start + 2*dim})
gate := mlx.Slice(mod, []int32{0, start + 2*dim}, []int32{B, start + 3*dim})
// Expand for broadcasting [B, dim] -> [B, 1, dim]
shift = mlx.ExpandDims(shift, 1)
scale = mlx.ExpandDims(scale, 1)
gate = mlx.ExpandDims(gate, 1)
return shift, scale, gate
}
// modulateLayerNorm applies LayerNorm then shift/scale modulation
// Diffusers uses LayerNorm(elementwise_affine=False) which centers the data
func modulateLayerNorm(x *mlx.Array, shift, scale *mlx.Array) *mlx.Array {
// Fast LayerNorm without learnable params
x = mlx.LayerNorm(x, 1e-6)
// Modulate: x * (1 + scale) + shift
x = mlx.Mul(x, mlx.AddScalar(scale, 1.0))
return mlx.Add(x, shift)
}
// splitQKV splits a fused QKV tensor into Q, K, V
func splitQKV(qkv *mlx.Array, B, L, dim int32) (*mlx.Array, *mlx.Array, *mlx.Array) {
q := mlx.Slice(qkv, []int32{0, 0, 0}, []int32{B, L, dim})
k := mlx.Slice(qkv, []int32{0, 0, dim}, []int32{B, L, 2 * dim})
v := mlx.Slice(qkv, []int32{0, 0, 2 * dim}, []int32{B, L, 3 * dim})
return q, k, v
}
// applyQKNorm applies RMSNorm with learned scale (no bias)
// Uses the optimized mlx_fast_rms_norm
func applyQKNorm(x *mlx.Array, scale *mlx.Array) *mlx.Array {
return mlx.RMSNorm(x, scale, 1e-6)
}

View File

@@ -0,0 +1,802 @@
package flux2
import (
"fmt"
"math"
"github.com/ollama/ollama/x/imagegen/manifest"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
"github.com/ollama/ollama/x/imagegen/vae"
)
// VAEConfig holds AutoencoderKLFlux2 configuration
type VAEConfig struct {
ActFn string `json:"act_fn"` // "silu"
BatchNormEps float32 `json:"batch_norm_eps"` // 0.0001
BatchNormMomentum float32 `json:"batch_norm_momentum"` // 0.1
BlockOutChannels []int32 `json:"block_out_channels"` // [128, 256, 512, 512]
ForceUpcast bool `json:"force_upcast"` // true
InChannels int32 `json:"in_channels"` // 3
LatentChannels int32 `json:"latent_channels"` // 32
LayersPerBlock int32 `json:"layers_per_block"` // 2
MidBlockAddAttn bool `json:"mid_block_add_attention"` // true
NormNumGroups int32 `json:"norm_num_groups"` // 32
OutChannels int32 `json:"out_channels"` // 3
PatchSize []int32 `json:"patch_size"` // [2, 2]
SampleSize int32 `json:"sample_size"` // 1024
UsePostQuantConv bool `json:"use_post_quant_conv"` // true
UseQuantConv bool `json:"use_quant_conv"` // true
}
// BatchNorm2D implements 2D batch normalization with running statistics
type BatchNorm2D struct {
RunningMean *mlx.Array // [C]
RunningVar *mlx.Array // [C]
Weight *mlx.Array // [C] gamma
Bias *mlx.Array // [C] beta
Eps float32
Momentum float32
}
// Forward applies batch normalization (inference mode - uses running stats)
// Input and output are in NHWC format [B, H, W, C]
func (bn *BatchNorm2D) Forward(x *mlx.Array) *mlx.Array {
shape := x.Shape()
C := shape[3]
// Reshape stats for broadcasting [1, 1, 1, C]
mean := mlx.Reshape(bn.RunningMean, 1, 1, 1, C)
variance := mlx.Reshape(bn.RunningVar, 1, 1, 1, C)
// Normalize: (x - mean) / sqrt(var + eps)
xNorm := mlx.Sub(x, mean)
xNorm = mlx.Div(xNorm, mlx.Sqrt(mlx.AddScalar(variance, bn.Eps)))
// Scale and shift (only if affine=True)
if bn.Weight != nil {
weight := mlx.Reshape(bn.Weight, 1, 1, 1, C)
xNorm = mlx.Mul(xNorm, weight)
}
if bn.Bias != nil {
bias := mlx.Reshape(bn.Bias, 1, 1, 1, C)
xNorm = mlx.Add(xNorm, bias)
}
return xNorm
}
// Denormalize inverts the batch normalization
// Used when decoding latents
func (bn *BatchNorm2D) Denormalize(x *mlx.Array) *mlx.Array {
shape := x.Shape()
C := shape[3]
// Reshape stats for broadcasting [1, 1, 1, C]
mean := mlx.Reshape(bn.RunningMean, 1, 1, 1, C)
variance := mlx.Reshape(bn.RunningVar, 1, 1, 1, C)
// Inverse: first undo affine, then undo normalization
// For affine=False: x_denorm = x * sqrt(var + eps) + mean
if bn.Bias != nil {
bias := mlx.Reshape(bn.Bias, 1, 1, 1, C)
x = mlx.Sub(x, bias)
}
if bn.Weight != nil {
weight := mlx.Reshape(bn.Weight, 1, 1, 1, C)
x = mlx.Div(x, weight)
}
x = mlx.Mul(x, mlx.Sqrt(mlx.AddScalar(variance, bn.Eps)))
x = mlx.Add(x, mean)
return x
}
// GroupNormLayer implements group normalization
// Reused from zimage package pattern
type GroupNormLayer struct {
Weight *mlx.Array `weight:"weight"`
Bias *mlx.Array `weight:"bias"`
NumGroups int32
Eps float32
}
// Forward applies group normalization
// Input and output are in NHWC format [B, H, W, C]
func (gn *GroupNormLayer) Forward(x *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
H := shape[1]
W := shape[2]
C := shape[3]
// Reshape to [B, H, W, groups, C/groups]
groupSize := C / gn.NumGroups
x = mlx.Reshape(x, B, H, W, gn.NumGroups, groupSize)
// Compute mean and variance per group
mean := mlx.Mean(x, 1, true)
mean = mlx.Mean(mean, 2, true)
mean = mlx.Mean(mean, 4, true)
xCentered := mlx.Sub(x, mean)
sq := mlx.Square(xCentered)
variance := mlx.Mean(sq, 1, true)
variance = mlx.Mean(variance, 2, true)
variance = mlx.Mean(variance, 4, true)
// Normalize
xNorm := mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, gn.Eps)))
// Reshape back to [B, H, W, C]
xNorm = mlx.Reshape(xNorm, B, H, W, C)
// Scale and shift
if gn.Weight != nil {
weight := mlx.Reshape(gn.Weight, 1, 1, 1, C)
xNorm = mlx.Mul(xNorm, weight)
}
if gn.Bias != nil {
bias := mlx.Reshape(gn.Bias, 1, 1, 1, C)
xNorm = mlx.Add(xNorm, bias)
}
return xNorm
}
// Conv2D represents a 2D convolution layer (reused pattern)
type Conv2D struct {
Weight *mlx.Array `weight:"weight"`
Bias *mlx.Array `weight:"bias,optional"`
Stride int32
Padding int32
}
// Transform implements safetensors.Transformer to transpose weights from PyTorch's OIHW to MLX's OHWI.
func (conv *Conv2D) Transform(field string, arr *mlx.Array) *mlx.Array {
if field == "Weight" {
return mlx.Transpose(arr, 0, 2, 3, 1)
}
return arr
}
// Forward applies convolution (NHWC format)
func (conv *Conv2D) Forward(x *mlx.Array) *mlx.Array {
out := mlx.Conv2d(x, conv.Weight, conv.Stride, conv.Padding)
if conv.Bias != nil {
bias := mlx.Reshape(conv.Bias, 1, 1, 1, conv.Bias.Dim(0))
out = mlx.Add(out, bias)
}
return out
}
// ResnetBlock2D implements a ResNet block for VAE
type ResnetBlock2D struct {
Norm1 *GroupNormLayer `weight:"norm1"`
Conv1 *Conv2D `weight:"conv1"`
Norm2 *GroupNormLayer `weight:"norm2"`
Conv2 *Conv2D `weight:"conv2"`
ConvShortcut *Conv2D `weight:"conv_shortcut,optional"`
}
// Forward applies the ResNet block
func (rb *ResnetBlock2D) Forward(x *mlx.Array) *mlx.Array {
h := rb.Norm1.Forward(x)
h = mlx.SiLU(h)
h = rb.Conv1.Forward(h)
h = rb.Norm2.Forward(h)
h = mlx.SiLU(h)
h = rb.Conv2.Forward(h)
if rb.ConvShortcut != nil {
x = rb.ConvShortcut.Forward(x)
}
return mlx.Add(h, x)
}
// VAEAttentionBlock implements self-attention for VAE
type VAEAttentionBlock struct {
GroupNorm *GroupNormLayer `weight:"group_norm"`
ToQ nn.LinearLayer `weight:"to_q"`
ToK nn.LinearLayer `weight:"to_k"`
ToV nn.LinearLayer `weight:"to_v"`
ToOut nn.LinearLayer `weight:"to_out.0"`
}
// Forward applies attention (NHWC format)
func (ab *VAEAttentionBlock) Forward(x *mlx.Array) *mlx.Array {
residual := x
shape := x.Shape()
B := shape[0]
H := shape[1]
W := shape[2]
C := shape[3]
h := ab.GroupNorm.Forward(x)
h = mlx.Reshape(h, B, H*W, C)
q := ab.ToQ.Forward(h)
k := ab.ToK.Forward(h)
v := ab.ToV.Forward(h)
q = mlx.ExpandDims(q, 1)
k = mlx.ExpandDims(k, 1)
v = mlx.ExpandDims(v, 1)
scale := float32(1.0 / math.Sqrt(float64(C)))
out := mlx.ScaledDotProductAttention(q, k, v, scale, false)
out = mlx.Squeeze(out, 1)
out = ab.ToOut.Forward(out)
out = mlx.Reshape(out, B, H, W, C)
out = mlx.Add(out, residual)
return out
}
// UpDecoderBlock2D implements an upsampling decoder block
type UpDecoderBlock2D struct {
ResnetBlocks []*ResnetBlock2D
Upsample *Conv2D
}
// Forward applies the up decoder block
func (ub *UpDecoderBlock2D) Forward(x *mlx.Array) *mlx.Array {
for _, resnet := range ub.ResnetBlocks {
x = resnet.Forward(x)
}
if ub.Upsample != nil {
x = upsample2x(x)
x = ub.Upsample.Forward(x)
}
return x
}
// upsample2x performs 2x nearest neighbor upsampling
func upsample2x(x *mlx.Array) *mlx.Array {
shape := x.Shape()
H := shape[1]
W := shape[2]
hIdx := mlx.ArangeInt(0, H, 1, mlx.DtypeInt32)
hIdx = mlx.Reshape(hIdx, H, 1)
hIdx = mlx.BroadcastTo(hIdx, []int32{H, 2})
hIdx = mlx.Reshape(hIdx, H*2)
wIdx := mlx.ArangeInt(0, W, 1, mlx.DtypeInt32)
wIdx = mlx.Reshape(wIdx, W, 1)
wIdx = mlx.BroadcastTo(wIdx, []int32{W, 2})
wIdx = mlx.Reshape(wIdx, W*2)
x = mlx.Take(x, hIdx, 1)
x = mlx.Take(x, wIdx, 2)
return x
}
// VAEMidBlock is the middle block with attention
type VAEMidBlock struct {
Resnet1 *ResnetBlock2D
Attention *VAEAttentionBlock
Resnet2 *ResnetBlock2D
}
// Forward applies the mid block
func (mb *VAEMidBlock) Forward(x *mlx.Array) *mlx.Array {
x = mb.Resnet1.Forward(x)
x = mb.Attention.Forward(x)
x = mb.Resnet2.Forward(x)
return x
}
// DefaultTilingConfig returns reasonable defaults for tiled decoding
// Matches diffusers: tile_latent_min_size=64, tile_overlap_factor=0.25
func DefaultTilingConfig() *vae.TilingConfig {
return vae.DefaultTilingConfig()
}
// AutoencoderKLFlux2 is the Flux2 VAE with BatchNorm
type AutoencoderKLFlux2 struct {
Config *VAEConfig
// Encoder components (for image editing)
EncoderConvIn *Conv2D
EncoderMid *VAEMidBlock
EncoderDown []*DownEncoderBlock2D
EncoderNormOut *GroupNormLayer
EncoderConvOut *Conv2D
// Decoder components
DecoderConvIn *Conv2D
DecoderMid *VAEMidBlock
DecoderUp []*UpDecoderBlock2D
DecoderNormOut *GroupNormLayer
DecoderConvOut *Conv2D
// Quant conv layers
QuantConv *Conv2D
PostQuantConv *Conv2D
// BatchNorm for latent normalization
LatentBN *BatchNorm2D
// Tiling configuration (nil = no tiling)
Tiling *vae.TilingConfig
}
// DownEncoderBlock2D implements a downsampling encoder block
type DownEncoderBlock2D struct {
ResnetBlocks []*ResnetBlock2D
Downsample *Conv2D
}
// Forward applies the down encoder block
func (db *DownEncoderBlock2D) Forward(x *mlx.Array) *mlx.Array {
for _, resnet := range db.ResnetBlocks {
x = resnet.Forward(x)
}
if db.Downsample != nil {
// Pad then conv with stride 2
x = mlx.Pad(x, []int32{0, 0, 0, 1, 0, 1, 0, 0})
x = db.Downsample.Forward(x)
}
return x
}
// Load loads the Flux2 VAE from ollama blob storage.
func (m *AutoencoderKLFlux2) Load(modelManifest *manifest.ModelManifest) error {
fmt.Print(" Loading VAE... ")
// Load config from blob
var cfg VAEConfig
if err := modelManifest.ReadConfigJSON("vae/config.json", &cfg); err != nil {
return fmt.Errorf("config: %w", err)
}
m.Config = &cfg
// Load weights from tensor blobs
weights, err := manifest.LoadWeightsFromManifest(modelManifest, "vae")
if err != nil {
return fmt.Errorf("weights: %w", err)
}
if err := weights.Load(0); err != nil {
return fmt.Errorf("load weights: %w", err)
}
defer weights.ReleaseAll()
return m.loadWeights(weights, &cfg)
}
// loadWeights loads VAE weights from any WeightSource
func (m *AutoencoderKLFlux2) loadWeights(weights safetensors.WeightSource, cfg *VAEConfig) error {
var err error
// Load encoder components (for image conditioning)
if err := m.loadEncoderWeights(weights, cfg); err != nil {
return fmt.Errorf("encoder: %w", err)
}
// Load decoder conv_in
m.DecoderConvIn = &Conv2D{Stride: 1, Padding: 1}
if err := safetensors.LoadModule(m.DecoderConvIn, weights, "decoder.conv_in"); err != nil {
return fmt.Errorf("decoder.conv_in: %w", err)
}
// Load mid block
m.DecoderMid, err = loadVAEMidBlock(weights, "decoder.mid_block", cfg.NormNumGroups)
if err != nil {
return fmt.Errorf("decoder.mid_block: %w", err)
}
// Load up blocks
numBlocks := len(cfg.BlockOutChannels)
m.DecoderUp = make([]*UpDecoderBlock2D, numBlocks)
for i := 0; i < numBlocks; i++ {
prefix := fmt.Sprintf("decoder.up_blocks.%d", i)
hasUpsample := i < numBlocks-1
m.DecoderUp[i], err = loadUpDecoderBlock2D(weights, prefix, cfg.LayersPerBlock+1, cfg.NormNumGroups, hasUpsample)
if err != nil {
return fmt.Errorf("%s: %w", prefix, err)
}
}
// Load decoder conv_norm_out and conv_out
m.DecoderNormOut = &GroupNormLayer{NumGroups: cfg.NormNumGroups, Eps: 1e-5}
if err := safetensors.LoadModule(m.DecoderNormOut, weights, "decoder.conv_norm_out"); err != nil {
return fmt.Errorf("decoder.conv_norm_out: %w", err)
}
m.DecoderConvOut = &Conv2D{Stride: 1, Padding: 1}
if err := safetensors.LoadModule(m.DecoderConvOut, weights, "decoder.conv_out"); err != nil {
return fmt.Errorf("decoder.conv_out: %w", err)
}
// Load post_quant_conv
if cfg.UsePostQuantConv {
m.PostQuantConv = &Conv2D{Stride: 1, Padding: 0}
if err := safetensors.LoadModule(m.PostQuantConv, weights, "post_quant_conv"); err != nil {
return fmt.Errorf("post_quant_conv: %w", err)
}
}
// Load latent BatchNorm (affine=False, so no weight/bias)
bnMean, err := weights.GetTensor("bn.running_mean")
if err != nil {
return fmt.Errorf("bn.running_mean: %w", err)
}
bnVar, err := weights.GetTensor("bn.running_var")
if err != nil {
return fmt.Errorf("bn.running_var: %w", err)
}
m.LatentBN = &BatchNorm2D{
RunningMean: bnMean,
RunningVar: bnVar,
Weight: nil, // affine=False
Bias: nil, // affine=False
Eps: cfg.BatchNormEps,
Momentum: cfg.BatchNormMomentum,
}
fmt.Println("✓")
return nil
}
// loadVAEMidBlock loads the mid block.
func loadVAEMidBlock(weights safetensors.WeightSource, prefix string, numGroups int32) (*VAEMidBlock, error) {
resnet1, err := loadResnetBlock2D(weights, prefix+".resnets.0", numGroups)
if err != nil {
return nil, err
}
attention, err := loadVAEAttentionBlock(weights, prefix+".attentions.0", numGroups)
if err != nil {
return nil, err
}
resnet2, err := loadResnetBlock2D(weights, prefix+".resnets.1", numGroups)
if err != nil {
return nil, err
}
return &VAEMidBlock{
Resnet1: resnet1,
Attention: attention,
Resnet2: resnet2,
}, nil
}
// loadResnetBlock2D loads a ResNet block.
func loadResnetBlock2D(weights safetensors.WeightSource, prefix string, numGroups int32) (*ResnetBlock2D, error) {
block := &ResnetBlock2D{
Norm1: &GroupNormLayer{NumGroups: numGroups, Eps: 1e-5},
Conv1: &Conv2D{Stride: 1, Padding: 1},
Norm2: &GroupNormLayer{NumGroups: numGroups, Eps: 1e-5},
Conv2: &Conv2D{Stride: 1, Padding: 1},
ConvShortcut: &Conv2D{Stride: 1, Padding: 0}, // Pre-allocate for optional loading
}
if err := safetensors.LoadModule(block, weights, prefix); err != nil {
return nil, err
}
// If ConvShortcut wasn't loaded (no weights found), nil it out
if block.ConvShortcut.Weight == nil {
block.ConvShortcut = nil
}
return block, nil
}
// loadVAEAttentionBlock loads an attention block using LoadModule.
func loadVAEAttentionBlock(weights safetensors.WeightSource, prefix string, numGroups int32) (*VAEAttentionBlock, error) {
ab := &VAEAttentionBlock{
GroupNorm: &GroupNormLayer{NumGroups: numGroups, Eps: 1e-5},
}
if err := safetensors.LoadModule(ab, weights, prefix); err != nil {
return nil, err
}
return ab, nil
}
// loadUpDecoderBlock2D loads an up decoder block.
func loadUpDecoderBlock2D(weights safetensors.WeightSource, prefix string, numLayers, numGroups int32, hasUpsample bool) (*UpDecoderBlock2D, error) {
resnets := make([]*ResnetBlock2D, numLayers)
for i := int32(0); i < numLayers; i++ {
resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i)
resnet, err := loadResnetBlock2D(weights, resPrefix, numGroups)
if err != nil {
return nil, err
}
resnets[i] = resnet
}
var upsample *Conv2D
if hasUpsample {
upsample = &Conv2D{Stride: 1, Padding: 1}
if err := safetensors.LoadModule(upsample, weights, prefix+".upsamplers.0.conv"); err != nil {
return nil, err
}
}
return &UpDecoderBlock2D{
ResnetBlocks: resnets,
Upsample: upsample,
}, nil
}
// Patchify converts latents [B, C, H, W] to patches [B, H*W/4, C*4] using 2x2 patches
// This is the inverse of the VAE's patchify for feeding to transformer
func (vae *AutoencoderKLFlux2) Patchify(latents *mlx.Array) *mlx.Array {
shape := latents.Shape()
B := shape[0]
C := shape[1]
H := shape[2]
W := shape[3]
patchH := vae.Config.PatchSize[0]
patchW := vae.Config.PatchSize[1]
pH := H / patchH
pW := W / patchW
// [B, C, H, W] -> [B, C, pH, patchH, pW, patchW]
x := mlx.Reshape(latents, B, C, pH, patchH, pW, patchW)
// [B, C, pH, patchH, pW, patchW] -> [B, pH, pW, C, patchH, patchW]
x = mlx.Transpose(x, 0, 2, 4, 1, 3, 5)
// [B, pH, pW, C, patchH, patchW] -> [B, pH*pW, C*patchH*patchW]
return mlx.Reshape(x, B, pH*pW, C*patchH*patchW)
}
// Unpatchify converts patches [B, L, C*4] back to [B, C, H, W]
func (vae *AutoencoderKLFlux2) Unpatchify(patches *mlx.Array, pH, pW, C int32) *mlx.Array {
shape := patches.Shape()
B := shape[0]
patchH := vae.Config.PatchSize[0]
patchW := vae.Config.PatchSize[1]
// [B, pH*pW, C*patchH*patchW] -> [B, pH, pW, C, patchH, patchW]
x := mlx.Reshape(patches, B, pH, pW, C, patchH, patchW)
// [B, pH, pW, C, patchH, patchW] -> [B, C, pH, patchH, pW, patchW]
x = mlx.Transpose(x, 0, 3, 1, 4, 2, 5)
// [B, C, pH, patchH, pW, patchW] -> [B, C, H, W]
H := pH * patchH
W := pW * patchW
return mlx.Reshape(x, B, C, H, W)
}
// denormalizePatchified applies inverse batch normalization to patchified latents.
// Input: [B, L, 128] where 128 = 32 latent channels * 4 (2x2 patch)
// Output: [B, L, 128] denormalized
func (vae *AutoencoderKLFlux2) denormalizePatchified(x *mlx.Array) *mlx.Array {
shape := x.Shape()
C := shape[2] // 128
// Reshape stats for broadcasting [1, 1, C]
mean := mlx.Reshape(vae.LatentBN.RunningMean, 1, 1, C)
variance := mlx.Reshape(vae.LatentBN.RunningVar, 1, 1, C)
// Inverse BN (affine=False): x_denorm = x * sqrt(var + eps) + mean
if vae.LatentBN.Bias != nil {
bias := mlx.Reshape(vae.LatentBN.Bias, 1, 1, C)
x = mlx.Sub(x, bias)
}
if vae.LatentBN.Weight != nil {
weight := mlx.Reshape(vae.LatentBN.Weight, 1, 1, C)
x = mlx.Div(x, weight)
}
x = mlx.Mul(x, mlx.Sqrt(mlx.AddScalar(variance, vae.LatentBN.Eps)))
x = mlx.Add(x, mean)
return x
}
// Decode decodes latent patches to images.
// If Tiling is set, uses tiled decoding to reduce memory for large images.
// latents: [B, L, C*4] patchified latents from transformer
// pH, pW: patch grid dimensions
// Returns: [B, 3, H, W] image tensor
func (v *AutoencoderKLFlux2) Decode(latents *mlx.Array, pH, pW int32) *mlx.Array {
// Denormalize patchified latents
z := v.denormalizePatchified(latents)
// Unpatchify: [B, L, C*4] -> [B, C, H, W]
z = v.Unpatchify(z, pH, pW, v.Config.LatentChannels)
// Convert NCHW -> NHWC for processing
z = mlx.Transpose(z, 0, 2, 3, 1)
// Use tiled decoding if enabled
if v.Tiling != nil {
mlx.Eval(z)
return vae.DecodeTiled(z, v.Tiling, v.decodeTile)
}
// Direct decode (no tiling)
h := v.decodeTile(z)
h = mlx.ClipScalar(h, 0.0, 1.0, true, true)
h = mlx.Transpose(h, 0, 3, 1, 2)
return h
}
// decodeTile decodes a single latent tile to pixels (internal helper)
// z: [B, H, W, C] latent tile in NHWC format
// Returns: [B, H*8, W*8, 3] pixel tile in NHWC format (before clipping)
func (vae *AutoencoderKLFlux2) decodeTile(z *mlx.Array) *mlx.Array {
// Post-quant conv
if vae.PostQuantConv != nil {
z = vae.PostQuantConv.Forward(z)
}
// Decoder
h := vae.DecoderConvIn.Forward(z)
h = vae.DecoderMid.Forward(h)
for _, upBlock := range vae.DecoderUp {
h = upBlock.Forward(h)
}
h = vae.DecoderNormOut.Forward(h)
h = mlx.SiLU(h)
h = vae.DecoderConvOut.Forward(h)
// VAE outputs [-1, 1], convert to [0, 1]
h = mlx.MulScalar(h, 0.5)
h = mlx.AddScalar(h, 0.5)
return h
}
// loadEncoderWeights loads the encoder components for image conditioning
func (m *AutoencoderKLFlux2) loadEncoderWeights(weights safetensors.WeightSource, cfg *VAEConfig) error {
var err error
// Load encoder conv_in
m.EncoderConvIn = &Conv2D{Stride: 1, Padding: 1}
if err := safetensors.LoadModule(m.EncoderConvIn, weights, "encoder.conv_in"); err != nil {
return fmt.Errorf("encoder.conv_in: %w", err)
}
// Load encoder down blocks
numBlocks := len(cfg.BlockOutChannels)
m.EncoderDown = make([]*DownEncoderBlock2D, numBlocks)
for i := 0; i < numBlocks; i++ {
prefix := fmt.Sprintf("encoder.down_blocks.%d", i)
hasDownsample := i < numBlocks-1
m.EncoderDown[i], err = loadDownEncoderBlock2D(weights, prefix, cfg.LayersPerBlock, cfg.NormNumGroups, hasDownsample)
if err != nil {
return fmt.Errorf("%s: %w", prefix, err)
}
}
// Load encoder mid block
m.EncoderMid, err = loadVAEMidBlock(weights, "encoder.mid_block", cfg.NormNumGroups)
if err != nil {
return fmt.Errorf("encoder.mid_block: %w", err)
}
// Load encoder conv_norm_out and conv_out
m.EncoderNormOut = &GroupNormLayer{NumGroups: cfg.NormNumGroups, Eps: 1e-5}
if err := safetensors.LoadModule(m.EncoderNormOut, weights, "encoder.conv_norm_out"); err != nil {
return fmt.Errorf("encoder.conv_norm_out: %w", err)
}
m.EncoderConvOut = &Conv2D{Stride: 1, Padding: 1}
if err := safetensors.LoadModule(m.EncoderConvOut, weights, "encoder.conv_out"); err != nil {
return fmt.Errorf("encoder.conv_out: %w", err)
}
// Load quant_conv (for encoding)
if cfg.UseQuantConv {
m.QuantConv = &Conv2D{Stride: 1, Padding: 0}
if err := safetensors.LoadModule(m.QuantConv, weights, "quant_conv"); err != nil {
return fmt.Errorf("quant_conv: %w", err)
}
}
return nil
}
// loadDownEncoderBlock2D loads a down encoder block.
func loadDownEncoderBlock2D(weights safetensors.WeightSource, prefix string, numLayers, numGroups int32, hasDownsample bool) (*DownEncoderBlock2D, error) {
resnets := make([]*ResnetBlock2D, numLayers)
for i := int32(0); i < numLayers; i++ {
resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i)
resnet, err := loadResnetBlock2D(weights, resPrefix, numGroups)
if err != nil {
return nil, err
}
resnets[i] = resnet
}
var downsample *Conv2D
if hasDownsample {
downsample = &Conv2D{Stride: 2, Padding: 0}
if err := safetensors.LoadModule(downsample, weights, prefix+".downsamplers.0.conv"); err != nil {
return nil, err
}
}
return &DownEncoderBlock2D{
ResnetBlocks: resnets,
Downsample: downsample,
}, nil
}
// EncodeImage encodes an image to normalized latents.
// image: [B, 3, H, W] image tensor in [-1, 1]
// Returns: [B, L, C*4] patchified normalized latents
func (vae *AutoencoderKLFlux2) EncodeImage(image *mlx.Array) *mlx.Array {
// Convert NCHW -> NHWC
x := mlx.Transpose(image, 0, 2, 3, 1)
// Encoder
h := vae.EncoderConvIn.Forward(x)
for _, downBlock := range vae.EncoderDown {
h = downBlock.Forward(h)
}
h = vae.EncoderMid.Forward(h)
h = vae.EncoderNormOut.Forward(h)
h = mlx.SiLU(h)
h = vae.EncoderConvOut.Forward(h)
// Quant conv outputs [B, H, W, 2*latent_channels] (mean + logvar)
if vae.QuantConv != nil {
h = vae.QuantConv.Forward(h)
}
// Take only the mean (first latent_channels) - deterministic encoding
// h is [B, H, W, 64] -> take first 32 channels for mean
shape := h.Shape()
latentChannels := vae.Config.LatentChannels // 32
h = mlx.Slice(h, []int32{0, 0, 0, 0}, []int32{shape[0], shape[1], shape[2], latentChannels})
// Convert NHWC -> NCHW for patchifying
h = mlx.Transpose(h, 0, 3, 1, 2)
// Patchify: [B, C, H, W] -> [B, L, C*4]
h = vae.Patchify(h)
// Apply BatchNorm on patchified latents [B, L, 128]
// The BatchNorm has 128 channels matching the patchified dimension
h = vae.normalizePatchified(h)
return h
}
// normalizePatchified applies batch normalization to patchified latents.
// Input: [B, L, 128] where 128 = 32 latent channels * 4 (2x2 patch)
// Output: [B, L, 128] normalized
func (vae *AutoencoderKLFlux2) normalizePatchified(x *mlx.Array) *mlx.Array {
shape := x.Shape()
C := shape[2] // 128
// Reshape stats for broadcasting [1, 1, C]
mean := mlx.Reshape(vae.LatentBN.RunningMean, 1, 1, C)
variance := mlx.Reshape(vae.LatentBN.RunningVar, 1, 1, C)
// Normalize: (x - mean) / sqrt(var + eps)
xNorm := mlx.Sub(x, mean)
xNorm = mlx.Div(xNorm, mlx.Sqrt(mlx.AddScalar(variance, vae.LatentBN.Eps)))
// Scale and shift (only if affine=True)
if vae.LatentBN.Weight != nil {
weight := mlx.Reshape(vae.LatentBN.Weight, 1, 1, C)
xNorm = mlx.Mul(xNorm, weight)
}
if vae.LatentBN.Bias != nil {
bias := mlx.Reshape(vae.LatentBN.Bias, 1, 1, C)
xNorm = mlx.Add(xNorm, bias)
}
return xNorm
}

View File

@@ -0,0 +1,388 @@
// Package qwen3 provides a shared Qwen3 text encoder used by multiple image generation models.
package qwen3
import (
"fmt"
"math"
"github.com/ollama/ollama/x/imagegen/manifest"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
"github.com/ollama/ollama/x/imagegen/tokenizer"
)
// Config holds Qwen3 text encoder configuration
type Config struct {
HiddenSize int32 `json:"hidden_size"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
IntermediateSize int32 `json:"intermediate_size"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumKeyValueHeads int32 `json:"num_key_value_heads"`
VocabSize int32 `json:"vocab_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
HeadDim int32 `json:"head_dim"`
}
// Attention implements Qwen3 attention with QK norms
type Attention struct {
QProj nn.LinearLayer `weight:"q_proj"`
KProj nn.LinearLayer `weight:"k_proj"`
VProj nn.LinearLayer `weight:"v_proj"`
OProj nn.LinearLayer `weight:"o_proj"`
QNorm *nn.RMSNorm `weight:"q_norm"`
KNorm *nn.RMSNorm `weight:"k_norm"`
// Computed fields
NHeads int32
NKVHeads int32
HeadDim int32
Scale float32
RopeTheta float32
}
// applyRoPEQwen3 applies the custom RoPE for Qwen3 text encoder
func applyRoPEQwen3(x *mlx.Array, seqLen int32, theta float32) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
H := shape[2]
D := shape[3]
half := D / 2
freqsArr := make([]float32, half)
logTheta := float32(math.Log(float64(theta)))
for i := int32(0); i < half; i++ {
freqsArr[i] = float32(math.Exp(float64(-logTheta * float32(i) / float32(half))))
}
freqs := mlx.NewArray(freqsArr, []int32{half})
posArr := make([]float32, seqLen)
for i := int32(0); i < seqLen; i++ {
posArr[i] = float32(i)
}
pos := mlx.NewArray(posArr, []int32{seqLen})
posExpanded := mlx.Reshape(pos, seqLen, 1)
freqsExpanded := mlx.Reshape(freqs, 1, half)
args := mlx.Mul(posExpanded, freqsExpanded)
cosVals := mlx.Cos(args)
sinVals := mlx.Sin(args)
cosVals = mlx.Reshape(cosVals, seqLen, 1, half)
sinVals = mlx.Reshape(sinVals, seqLen, 1, half)
x1 := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, L, H, half})
x2 := mlx.Slice(x, []int32{0, 0, 0, half}, []int32{B, L, H, D})
part1 := mlx.Sub(mlx.Mul(x1, cosVals), mlx.Mul(x2, sinVals))
part2 := mlx.Add(mlx.Mul(x1, sinVals), mlx.Mul(x2, cosVals))
return mlx.Concatenate([]*mlx.Array{part1, part2}, 3)
}
// Forward computes attention with causal masking and optional padding mask
func (attn *Attention) Forward(x *mlx.Array, mask *mlx.Array, maskMode string) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
q := attn.QProj.Forward(x)
k := attn.KProj.Forward(x)
v := attn.VProj.Forward(x)
q = mlx.Reshape(q, B, L, attn.NHeads, attn.HeadDim)
k = mlx.Reshape(k, B, L, attn.NKVHeads, attn.HeadDim)
v = mlx.Reshape(v, B, L, attn.NKVHeads, attn.HeadDim)
// QK norm uses 1e-6 hardcoded (Qwen3 specific)
q = attn.QNorm.Forward(q, 1e-6)
k = attn.KNorm.Forward(k, 1e-6)
q = applyRoPEQwen3(q, L, attn.RopeTheta)
k = applyRoPEQwen3(k, L, attn.RopeTheta)
q = mlx.Transpose(q, 0, 2, 1, 3)
k = mlx.Transpose(k, 0, 2, 1, 3)
v = mlx.Transpose(v, 0, 2, 1, 3)
if attn.NKVHeads < attn.NHeads {
repeats := attn.NHeads / attn.NKVHeads
k = repeatKV(k, repeats)
v = repeatKV(v, repeats)
}
out := mlx.ScaledDotProductAttentionWithSinks(q, k, v, attn.Scale, maskMode, mask, nil)
out = mlx.Transpose(out, 0, 2, 1, 3)
out = mlx.Reshape(out, B, L, attn.NHeads*attn.HeadDim)
out = attn.OProj.Forward(out)
return out
}
// repeatKV repeats key/value heads for GQA
func repeatKV(x *mlx.Array, repeats int32) *mlx.Array {
if repeats == 1 {
return x
}
shape := x.Shape()
x = mlx.ExpandDims(x, 2)
x = mlx.Tile(x, []int32{1, 1, repeats, 1, 1})
return mlx.Reshape(x, shape[0], shape[1]*repeats, shape[2], shape[3])
}
// MLP implements Qwen3 SwiGLU MLP
type MLP struct {
GateProj nn.LinearLayer `weight:"gate_proj"`
UpProj nn.LinearLayer `weight:"up_proj"`
DownProj nn.LinearLayer `weight:"down_proj"`
}
// Forward applies the MLP
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
gate := m.GateProj.Forward(x)
gate = mlx.SiLU(gate)
up := m.UpProj.Forward(x)
h := mlx.Mul(gate, up)
return m.DownProj.Forward(h)
}
// Block represents a single Qwen3 transformer block
type Block struct {
Attention *Attention `weight:"self_attn"`
MLP *MLP `weight:"mlp"`
InputLayerNorm *nn.RMSNorm `weight:"input_layernorm"`
PostAttnLayerNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
}
// Forward applies the Qwen3 block
func (qb *Block) Forward(x *mlx.Array, eps float32, mask *mlx.Array, maskMode string) *mlx.Array {
h := qb.InputLayerNorm.Forward(x, eps)
attnOut := qb.Attention.Forward(h, mask, maskMode)
x = mlx.Add(x, attnOut)
h = qb.PostAttnLayerNorm.Forward(x, eps)
mlpOut := qb.MLP.Forward(h)
x = mlx.Add(x, mlpOut)
return x
}
// TextEncoder is the full Qwen3 encoder
type TextEncoder struct {
EmbedTokens *nn.Embedding `weight:"model.embed_tokens"`
Layers []*Block `weight:"model.layers"`
FinalNorm *nn.RMSNorm `weight:"model.norm"`
*Config
}
// Load loads the Qwen3 text encoder from ollama blob storage.
func (m *TextEncoder) Load(modelManifest *manifest.ModelManifest, configPath string) error {
fmt.Print(" Loading text encoder... ")
// Load config from blob
var cfg Config
if err := modelManifest.ReadConfigJSON(configPath, &cfg); err != nil {
return fmt.Errorf("config: %w", err)
}
m.Config = &cfg
m.Layers = make([]*Block, cfg.NumHiddenLayers)
// Load weights from tensor blobs
weights, err := manifest.LoadWeightsFromManifest(modelManifest, "text_encoder")
if err != nil {
return fmt.Errorf("weights: %w", err)
}
if err := weights.Load(0); err != nil {
return fmt.Errorf("load weights: %w", err)
}
defer weights.ReleaseAll()
return m.loadWeights(weights)
}
// loadWeights loads weights from any WeightSource into the model
func (m *TextEncoder) loadWeights(weights safetensors.WeightSource) error {
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return fmt.Errorf("load module: %w", err)
}
m.initComputedFields()
fmt.Println("✓")
return nil
}
// initComputedFields initializes computed fields after loading weights
func (m *TextEncoder) initComputedFields() {
cfg := m.Config
m.FinalNorm.Eps = cfg.RMSNormEps
for _, block := range m.Layers {
// Attention
block.Attention.NHeads = cfg.NumAttentionHeads
block.Attention.NKVHeads = cfg.NumKeyValueHeads
block.Attention.HeadDim = cfg.HeadDim
block.Attention.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
block.Attention.RopeTheta = cfg.RopeTheta
block.Attention.QNorm.Eps = cfg.RMSNormEps
block.Attention.KNorm.Eps = cfg.RMSNormEps
// Block norms
block.InputLayerNorm.Eps = cfg.RMSNormEps
block.PostAttnLayerNorm.Eps = cfg.RMSNormEps
}
}
// Forward encodes text tokens with provided attention mask (LxL) and mask mode.
func (te *TextEncoder) Forward(tokens *mlx.Array, attnMask *mlx.Array, maskMode string) *mlx.Array {
h := te.EmbedTokens.Forward(tokens)
eps := te.RMSNormEps
for _, layer := range te.Layers {
h = layer.Forward(h, eps, attnMask, maskMode)
}
// Apply final RMS norm
h = te.FinalNorm.Forward(h, eps)
return h
}
// ForwardWithLayerOutputs encodes text tokens and returns hidden states from specified layers.
// This is used by Flux2 which needs embeddings from specific intermediate layers.
func (te *TextEncoder) ForwardWithLayerOutputs(tokens *mlx.Array, layerIndices []int, attnMask *mlx.Array, maskMode string) []*mlx.Array {
h := te.EmbedTokens.Forward(tokens)
eps := te.RMSNormEps
outputs := make([]*mlx.Array, len(layerIndices))
layerSet := make(map[int]int)
for i, idx := range layerIndices {
layerSet[idx] = i
}
for i, layer := range te.Layers {
h = layer.Forward(h, eps, attnMask, maskMode)
if outIdx, ok := layerSet[i]; ok {
outputs[outIdx] = h
}
}
return outputs
}
// ApplyChatTemplate wraps prompt in Qwen3 chat format.
// If think is true, adds the <think></think> block after the assistant tag
// (matches tokenizer.apply_chat_template with enable_thinking=False in Python).
func ApplyChatTemplate(prompt string, think bool) string {
base := "<|im_start|>user\n" + prompt + "<|im_end|>\n<|im_start|>assistant\n"
if think {
return base + "<think>\n\n</think>\n\n"
}
return base
}
// EncodePrompt encodes a text prompt using the tokenizer and encoder.
// If think is true, includes the <think></think> block in the chat template.
func (te *TextEncoder) EncodePrompt(tok *tokenizer.Tokenizer, prompt string, maxLen int, think bool) (*mlx.Array, *mlx.Array) {
formattedPrompt := ApplyChatTemplate(prompt, think)
tokens := tok.Encode(formattedPrompt, false)
if len(tokens) > maxLen {
tokens = tokens[:maxLen]
}
maskData := make([]float32, maxLen)
for i := 0; i < len(tokens); i++ {
maskData[i] = 1.0
}
// Get PAD token (different from EOS for Qwen3)
padToken := tok.PAD()
if padToken < 0 {
padToken = tok.EOS() // fallback
}
paddedTokens := make([]int32, maxLen)
copy(paddedTokens, tokens)
for i := len(tokens); i < maxLen; i++ {
paddedTokens[i] = padToken
}
tokensArr := mlx.NewArrayInt32(paddedTokens, []int32{1, int32(maxLen)})
maskArr := mlx.NewArray(maskData, []int32{1, int32(maxLen)})
// Build combined causal + PAD mask [L, L]
// mask[i,j] = 0 if (j <= i AND valid[j]) else -inf
L := int32(maxLen)
validLen := int32(len(tokens))
combinedMaskData := make([]float32, L*L)
negInf := float32(-1e9)
for i := int32(0); i < L; i++ {
for j := int32(0); j < L; j++ {
idx := i*L + j
if j <= i && j < validLen {
combinedMaskData[idx] = 0
} else {
combinedMaskData[idx] = negInf
}
}
}
maskMat := mlx.NewArray(combinedMaskData, []int32{L, L})
embeddings := te.Forward(tokensArr, maskMat, "")
return embeddings, maskArr
}
// EncodePromptWithLayers encodes a text prompt and returns embeddings from specified layers.
// Used by Flux2 which concatenates embeddings from multiple intermediate layers.
// If think is true, includes the <think></think> block in the chat template.
// Returns embeddings and padded sequence length.
func (te *TextEncoder) EncodePromptWithLayers(tok *tokenizer.Tokenizer, prompt string, maxLen int, layerIndices []int, think bool) (*mlx.Array, int32) {
formattedPrompt := ApplyChatTemplate(prompt, think)
tokens := tok.Encode(formattedPrompt, false)
if len(tokens) > maxLen {
tokens = tokens[:maxLen]
}
// Pad to maxLen
padToken := tok.PAD()
if padToken < 0 {
padToken = tok.EOS() // fallback
}
padded := make([]int32, maxLen)
copy(padded, tokens)
for i := len(tokens); i < maxLen; i++ {
padded[i] = padToken
}
tokensArr := mlx.NewArrayInt32(padded, []int32{1, int32(maxLen)})
// Build combined causal + PAD mask [L, L]
// mask[i,j] = 0 if (j <= i AND valid[j]) else -inf
// This combines causal masking with PAD token masking
L := int32(maxLen)
validLen := int32(len(tokens))
maskData := make([]float32, L*L)
negInf := float32(-1e9)
for i := int32(0); i < L; i++ {
for j := int32(0); j < L; j++ {
idx := i*L + j
if j <= i && j < validLen {
maskData[idx] = 0 // allowed: causal OK and not PAD
} else {
maskData[idx] = negInf // blocked: future or PAD
}
}
}
maskMat := mlx.NewArray(maskData, []int32{L, L})
layerOutputs := te.ForwardWithLayerOutputs(tokensArr, layerIndices, maskMat, "")
// Concatenate layer outputs along the hidden dimension
// Each output is [B, L, hidden_dim], result is [B, L, num_layers * hidden_dim]
embeddings := mlx.Concatenate(layerOutputs, 2)
// Return embeddings and padded length
return embeddings, int32(maxLen)
}

View File

@@ -0,0 +1,141 @@
package zimage
import (
"math"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// FlowMatchSchedulerConfig holds scheduler configuration
type FlowMatchSchedulerConfig struct {
NumTrainTimesteps int32 `json:"num_train_timesteps"` // 1000
Shift float32 `json:"shift"` // 3.0
UseDynamicShifting bool `json:"use_dynamic_shifting"` // false
}
// DefaultFlowMatchSchedulerConfig returns default config
func DefaultFlowMatchSchedulerConfig() *FlowMatchSchedulerConfig {
return &FlowMatchSchedulerConfig{
NumTrainTimesteps: 1000,
Shift: 3.0,
UseDynamicShifting: true, // Z-Image-Turbo uses dynamic shifting
}
}
// FlowMatchEulerScheduler implements the Flow Match Euler discrete scheduler
// This is used in Z-Image-Turbo for fast sampling
type FlowMatchEulerScheduler struct {
Config *FlowMatchSchedulerConfig
Timesteps []float32 // Discretized timesteps
Sigmas []float32 // Noise levels at each timestep
NumSteps int // Number of inference steps
}
// NewFlowMatchEulerScheduler creates a new scheduler
func NewFlowMatchEulerScheduler(cfg *FlowMatchSchedulerConfig) *FlowMatchEulerScheduler {
return &FlowMatchEulerScheduler{
Config: cfg,
}
}
// SetTimesteps sets up the scheduler for the given number of inference steps
func (s *FlowMatchEulerScheduler) SetTimesteps(numSteps int) {
s.SetTimestepsWithMu(numSteps, 0)
}
// SetTimestepsWithMu sets up the scheduler with dynamic mu shift
func (s *FlowMatchEulerScheduler) SetTimestepsWithMu(numSteps int, mu float32) {
s.NumSteps = numSteps
// Create evenly spaced timesteps from 1.0 to 0.0 (flow matching goes t=1 to t=0)
// Match Python: np.linspace(1.0, 0.0, num_inference_steps + 1)
s.Timesteps = make([]float32, numSteps+1)
s.Sigmas = make([]float32, numSteps+1)
for i := 0; i <= numSteps; i++ {
t := 1.0 - float32(i)/float32(numSteps)
// Apply time shift if using dynamic shifting
if s.Config.UseDynamicShifting && mu != 0 {
t = s.timeShift(mu, t)
}
s.Timesteps[i] = t
s.Sigmas[i] = t
}
}
// timeShift applies the dynamic time shift (match Python)
func (s *FlowMatchEulerScheduler) timeShift(mu float32, t float32) float32 {
if t <= 0 {
return 0
}
// exp(mu) / (exp(mu) + (1/t - 1))
expMu := float32(math.Exp(float64(mu)))
return expMu / (expMu + (1.0/t - 1.0))
}
// Step performs one denoising step
// modelOutput: predicted velocity/noise from the model
// timestepIdx: current timestep index
// sample: current noisy sample
// Returns: denoised sample for next step
func (s *FlowMatchEulerScheduler) Step(modelOutput, sample *mlx.Array, timestepIdx int) *mlx.Array {
// Get current and next sigma
sigma := s.Sigmas[timestepIdx]
sigmaNext := s.Sigmas[timestepIdx+1]
// Euler step: x_{t-dt} = x_t + (sigma_next - sigma) * v_t
// where v_t is the velocity predicted by the model
dt := sigmaNext - sigma // This is negative (going from noise to clean)
// x_next = x + dt * velocity
scaledOutput := mlx.MulScalar(modelOutput, dt)
return mlx.Add(sample, scaledOutput)
}
// ScaleSample scales the sample for model input (identity for flow matching)
func (s *FlowMatchEulerScheduler) ScaleSample(sample *mlx.Array, timestepIdx int) *mlx.Array {
// Flow matching doesn't need scaling
return sample
}
// GetTimestep returns the timestep value at the given index
func (s *FlowMatchEulerScheduler) GetTimestep(idx int) float32 {
if idx < len(s.Timesteps) {
return s.Timesteps[idx]
}
return 0.0
}
// GetTimesteps returns all timesteps (implements Scheduler interface)
func (s *FlowMatchEulerScheduler) GetTimesteps() []float32 {
return s.Timesteps
}
// AddNoise adds noise to clean samples for a given timestep
// Used for img2img or inpainting
func (s *FlowMatchEulerScheduler) AddNoise(cleanSample, noise *mlx.Array, timestepIdx int) *mlx.Array {
// In flow matching: x_t = (1-t) * x_0 + t * noise
t := s.Timesteps[timestepIdx]
oneMinusT := 1.0 - t
scaledClean := mlx.MulScalar(cleanSample, oneMinusT)
scaledNoise := mlx.MulScalar(noise, t)
return mlx.Add(scaledClean, scaledNoise)
}
// InitNoise creates initial noise for sampling (BFloat16 for GPU efficiency)
func (s *FlowMatchEulerScheduler) InitNoise(shape []int32, seed int64) *mlx.Array {
return mlx.RandomNormalWithDtype(shape, uint64(seed), mlx.DtypeBFloat16)
}
// GetLatentShape returns the latent shape for a given image size
func GetLatentShape(batchSize, height, width, latentChannels int32, patchSize int32) []int32 {
// Latent is 8x smaller than image (VAE downscale)
latentH := height / 8
latentW := width / 8
return []int32{batchSize, latentChannels, latentH, latentW}
}

View File

@@ -0,0 +1,17 @@
package zimage
import (
"github.com/ollama/ollama/x/imagegen/models/qwen3"
)
// Re-export types from shared qwen3 package for backwards compatibility
type (
Qwen3Config = qwen3.Config
Qwen3Attention = qwen3.Attention
Qwen3MLP = qwen3.MLP
Qwen3Block = qwen3.Block
Qwen3TextEncoder = qwen3.TextEncoder
)
// ApplyChatTemplate wraps prompt in Qwen3 chat format
var ApplyChatTemplate = qwen3.ApplyChatTemplate

View File

@@ -0,0 +1,759 @@
// Package zimage implements the Z-Image diffusion transformer model.
package zimage
import (
"fmt"
"math"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/manifest"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// TransformerConfig holds Z-Image transformer configuration
type TransformerConfig struct {
Dim int32 `json:"dim"`
NHeads int32 `json:"n_heads"`
NKVHeads int32 `json:"n_kv_heads"`
NLayers int32 `json:"n_layers"`
NRefinerLayers int32 `json:"n_refiner_layers"`
InChannels int32 `json:"in_channels"`
PatchSize int32 `json:"-"` // Computed from AllPatchSize
CapFeatDim int32 `json:"cap_feat_dim"`
NormEps float32 `json:"norm_eps"`
RopeTheta float32 `json:"rope_theta"`
TScale float32 `json:"t_scale"`
QKNorm bool `json:"qk_norm"`
AxesDims []int32 `json:"axes_dims"`
AxesLens []int32 `json:"axes_lens"`
AllPatchSize []int32 `json:"all_patch_size"` // JSON array, PatchSize = first element
}
// TimestepEmbedder creates sinusoidal timestep embeddings
// Output dimension is 256 (fixed), used for AdaLN modulation
type TimestepEmbedder struct {
Linear1 nn.LinearLayer `weight:"mlp.0"`
Linear2 nn.LinearLayer `weight:"mlp.2"`
FreqEmbedSize int32 // 256 (computed)
}
// Forward computes timestep embeddings -> [B, 256]
func (te *TimestepEmbedder) Forward(t *mlx.Array) *mlx.Array {
// t: [B] timesteps
// Create sinusoidal embedding
half := te.FreqEmbedSize / 2
// freqs = exp(-log(10000) * arange(half) / half)
freqs := make([]float32, half)
for i := int32(0); i < half; i++ {
freqs[i] = float32(math.Exp(-math.Log(10000.0) * float64(i) / float64(half)))
}
freqsArr := mlx.NewArray(freqs, []int32{1, half})
// t[:, None] * freqs[None, :] -> [B, half]
tExpanded := mlx.ExpandDims(t, 1) // [B, 1]
args := mlx.Mul(tExpanded, freqsArr)
// embedding = [cos(args), sin(args)] -> [B, 256]
cosArgs := mlx.Cos(args)
sinArgs := mlx.Sin(args)
embedding := mlx.Concatenate([]*mlx.Array{cosArgs, sinArgs}, 1)
// MLP: linear1 -> silu -> linear2
h := te.Linear1.Forward(embedding)
h = mlx.SiLU(h)
h = te.Linear2.Forward(h)
return h
}
// XEmbedder embeds image patches to model dimension
type XEmbedder struct {
Linear nn.LinearLayer `weight:"2-1"`
}
// Forward embeds patchified image latents
func (xe *XEmbedder) Forward(x *mlx.Array) *mlx.Array {
// x: [B, L, in_channels * 4] -> [B, L, dim]
return xe.Linear.Forward(x)
}
// CapEmbedder projects caption features to model dimension
type CapEmbedder struct {
Norm *nn.RMSNorm `weight:"0"`
Linear nn.LinearLayer `weight:"1"`
PadToken *mlx.Array // loaded separately at root level
}
// Forward projects caption embeddings: [B, L, cap_feat_dim] -> [B, L, dim]
func (ce *CapEmbedder) Forward(capFeats *mlx.Array) *mlx.Array {
// RMSNorm on last axis (uses 1e-6)
h := ce.Norm.Forward(capFeats, 1e-6)
// Linear projection
return ce.Linear.Forward(h)
}
// FeedForward implements SwiGLU FFN
type FeedForward struct {
W1 nn.LinearLayer `weight:"w1"` // gate projection
W2 nn.LinearLayer `weight:"w2"` // down projection
W3 nn.LinearLayer `weight:"w3"` // up projection
OutDim int32 // computed from W2
}
// Forward applies SwiGLU: silu(W1(x)) * W3(x), then W2
func (ff *FeedForward) Forward(x *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
D := shape[2]
// Reshape for matmul
x = mlx.Reshape(x, B*L, D)
gate := ff.W1.Forward(x)
gate = mlx.SiLU(gate)
up := ff.W3.Forward(x)
h := mlx.Mul(gate, up)
out := ff.W2.Forward(h)
return mlx.Reshape(out, B, L, ff.OutDim)
}
// Attention implements multi-head attention with QK norm
type Attention struct {
ToQ nn.LinearLayer `weight:"to_q"`
ToK nn.LinearLayer `weight:"to_k"`
ToV nn.LinearLayer `weight:"to_v"`
ToOut nn.LinearLayer `weight:"to_out.0"`
NormQ *mlx.Array `weight:"norm_q.weight"` // [head_dim] for per-head RMSNorm
NormK *mlx.Array `weight:"norm_k.weight"`
// Fused QKV (computed at init time for efficiency, not loaded from weights)
ToQKV nn.LinearLayer `weight:"-"` // Fused Q+K+V projection (created by FuseQKV)
Fused bool `weight:"-"` // Whether to use fused QKV path
// Computed fields (not loaded from weights)
NHeads int32 `weight:"-"`
HeadDim int32 `weight:"-"`
Dim int32 `weight:"-"`
Scale float32 `weight:"-"`
}
// FuseQKV creates a fused QKV projection by concatenating weights.
// This reduces 3 matmuls to 1 for a ~5-10% speedup.
// Note: Fusion is skipped for quantized weights as it would require complex
// dequant-concat-requant operations. The FP8 memory bandwidth savings outweigh
// the ~5% fusion benefit.
func (attn *Attention) FuseQKV() {
if attn.ToQ == nil || attn.ToK == nil || attn.ToV == nil {
return
}
// Skip fusion for quantized weights - type assert to check
toQ, qOk := attn.ToQ.(*nn.Linear)
toK, kOk := attn.ToK.(*nn.Linear)
toV, vOk := attn.ToV.(*nn.Linear)
if !qOk || !kOk || !vOk {
// One or more are QuantizedLinear, skip fusion
return
}
if toQ.Weight == nil || toK.Weight == nil || toV.Weight == nil {
return
}
// Concatenate weights: [dim, dim] x 3 -> [3*dim, dim]
// Weight shapes: ToQ.Weight [out_dim, in_dim], etc.
qWeight := toQ.Weight
kWeight := toK.Weight
vWeight := toV.Weight
// Concatenate along output dimension (axis 0)
fusedWeight := mlx.Concatenate([]*mlx.Array{qWeight, kWeight, vWeight}, 0)
// Evaluate fused weight to ensure it's materialized
mlx.Eval(fusedWeight)
// Create fused linear layer
fusedLinear := &nn.Linear{Weight: fusedWeight}
// Handle bias if present
if toQ.Bias != nil && toK.Bias != nil && toV.Bias != nil {
fusedBias := mlx.Concatenate([]*mlx.Array{toQ.Bias, toK.Bias, toV.Bias}, 0)
mlx.Eval(fusedBias)
fusedLinear.Bias = fusedBias
}
attn.ToQKV = fusedLinear
attn.Fused = true
}
// Forward computes attention
func (attn *Attention) Forward(x *mlx.Array, cos, sin *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
D := shape[2]
xFlat := mlx.Reshape(x, B*L, D)
var q, k, v *mlx.Array
if attn.Fused && attn.ToQKV != nil {
// Fused QKV path: single matmul then split
qkv := attn.ToQKV.Forward(xFlat) // [B*L, 3*dim]
// Split into Q, K, V along last dimension
// Each has shape [B*L, dim]
q = mlx.Slice(qkv, []int32{0, 0}, []int32{B * L, attn.Dim})
k = mlx.Slice(qkv, []int32{0, attn.Dim}, []int32{B * L, 2 * attn.Dim})
v = mlx.Slice(qkv, []int32{0, 2 * attn.Dim}, []int32{B * L, 3 * attn.Dim})
} else {
// Separate Q, K, V projections
q = attn.ToQ.Forward(xFlat)
k = attn.ToK.Forward(xFlat)
v = attn.ToV.Forward(xFlat)
}
// Reshape to [B, L, nheads, head_dim]
q = mlx.Reshape(q, B, L, attn.NHeads, attn.HeadDim)
k = mlx.Reshape(k, B, L, attn.NHeads, attn.HeadDim)
v = mlx.Reshape(v, B, L, attn.NHeads, attn.HeadDim)
// QK norm
q = mlx.RMSNorm(q, attn.NormQ, 1e-5)
k = mlx.RMSNorm(k, attn.NormK, 1e-5)
// Apply RoPE if provided
if cos != nil && sin != nil {
q = applyRoPE3D(q, cos, sin)
k = applyRoPE3D(k, cos, sin)
}
// Transpose to [B, nheads, L, head_dim]
q = mlx.Transpose(q, 0, 2, 1, 3)
k = mlx.Transpose(k, 0, 2, 1, 3)
v = mlx.Transpose(v, 0, 2, 1, 3)
// SDPA
out := mlx.ScaledDotProductAttention(q, k, v, attn.Scale, false)
// Transpose back and reshape
out = mlx.Transpose(out, 0, 2, 1, 3)
out = mlx.Reshape(out, B*L, attn.Dim)
out = attn.ToOut.Forward(out)
return mlx.Reshape(out, B, L, attn.Dim)
}
// applyRoPE3D applies 3-axis rotary position embeddings
// x: [B, L, nheads, head_dim]
// cos, sin: [B, L, 1, head_dim/2]
func applyRoPE3D(x *mlx.Array, cos, sin *mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
nheads := shape[2]
headDim := shape[3]
half := headDim / 2
// Create even/odd index arrays
evenIdx := make([]int32, half)
oddIdx := make([]int32, half)
for i := int32(0); i < half; i++ {
evenIdx[i] = i * 2
oddIdx[i] = i*2 + 1
}
evenIndices := mlx.NewArrayInt32(evenIdx, []int32{half})
oddIndices := mlx.NewArrayInt32(oddIdx, []int32{half})
// Extract x1 (even indices) and x2 (odd indices) along last axis
x1 := mlx.Take(x, evenIndices, 3) // [B, L, nheads, half]
x2 := mlx.Take(x, oddIndices, 3) // [B, L, nheads, half]
// Apply rotation: [x1*cos - x2*sin, x1*sin + x2*cos]
r1 := mlx.Sub(mlx.Mul(x1, cos), mlx.Mul(x2, sin))
r2 := mlx.Add(mlx.Mul(x1, sin), mlx.Mul(x2, cos))
// Stack and reshape to interleave: [r1_0, r2_0, r1_1, r2_1, ...]
r1 = mlx.ExpandDims(r1, 4) // [B, L, nheads, half, 1]
r2 = mlx.ExpandDims(r2, 4) // [B, L, nheads, half, 1]
stacked := mlx.Concatenate([]*mlx.Array{r1, r2}, 4) // [B, L, nheads, half, 2]
return mlx.Reshape(stacked, B, L, nheads, headDim)
}
// TransformerBlock is a single transformer block with optional AdaLN modulation
type TransformerBlock struct {
Attention *Attention `weight:"attention"`
FeedForward *FeedForward `weight:"feed_forward"`
AttentionNorm1 *nn.RMSNorm `weight:"attention_norm1"`
AttentionNorm2 *nn.RMSNorm `weight:"attention_norm2"`
FFNNorm1 *nn.RMSNorm `weight:"ffn_norm1"`
FFNNorm2 *nn.RMSNorm `weight:"ffn_norm2"`
AdaLN nn.LinearLayer `weight:"adaLN_modulation.0,optional"` // only if modulation
// Computed fields
HasModulation bool
Dim int32
}
// Forward applies the transformer block
func (tb *TransformerBlock) Forward(x *mlx.Array, adaln *mlx.Array, cos, sin *mlx.Array, eps float32) *mlx.Array {
if tb.AdaLN != nil && adaln != nil {
// Compute modulation: [B, 256] -> [B, 4*dim]
chunks := tb.AdaLN.Forward(adaln)
// Split into 4 parts: scale_msa, gate_msa, scale_mlp, gate_mlp
chunkShape := chunks.Shape()
chunkDim := chunkShape[1] / 4
scaleMSA := mlx.Slice(chunks, []int32{0, 0}, []int32{chunkShape[0], chunkDim})
gateMSA := mlx.Slice(chunks, []int32{0, chunkDim}, []int32{chunkShape[0], chunkDim * 2})
scaleMLP := mlx.Slice(chunks, []int32{0, chunkDim * 2}, []int32{chunkShape[0], chunkDim * 3})
gateMLP := mlx.Slice(chunks, []int32{0, chunkDim * 3}, []int32{chunkShape[0], chunkDim * 4})
// Expand for broadcasting: [B, 1, dim]
scaleMSA = mlx.ExpandDims(scaleMSA, 1)
gateMSA = mlx.ExpandDims(gateMSA, 1)
scaleMLP = mlx.ExpandDims(scaleMLP, 1)
gateMLP = mlx.ExpandDims(gateMLP, 1)
// Attention with modulation
normX := tb.AttentionNorm1.Forward(x, eps)
normX = mlx.Mul(normX, mlx.AddScalar(scaleMSA, 1.0))
attnOut := tb.Attention.Forward(normX, cos, sin)
attnOut = tb.AttentionNorm2.Forward(attnOut, eps)
x = mlx.Add(x, mlx.Mul(mlx.Tanh(gateMSA), attnOut))
// FFN with modulation
normFFN := tb.FFNNorm1.Forward(x, eps)
normFFN = mlx.Mul(normFFN, mlx.AddScalar(scaleMLP, 1.0))
ffnOut := tb.FeedForward.Forward(normFFN)
ffnOut = tb.FFNNorm2.Forward(ffnOut, eps)
x = mlx.Add(x, mlx.Mul(mlx.Tanh(gateMLP), ffnOut))
} else {
// No modulation (context refiner)
attnOut := tb.Attention.Forward(tb.AttentionNorm1.Forward(x, eps), cos, sin)
x = mlx.Add(x, tb.AttentionNorm2.Forward(attnOut, eps))
ffnOut := tb.FeedForward.Forward(tb.FFNNorm1.Forward(x, eps))
x = mlx.Add(x, tb.FFNNorm2.Forward(ffnOut, eps))
}
return x
}
// FinalLayer outputs the denoised patches
type FinalLayer struct {
AdaLN nn.LinearLayer `weight:"adaLN_modulation.1"` // [256] -> [dim]
Output nn.LinearLayer `weight:"linear"` // [dim] -> [out_channels]
OutDim int32 // computed from Output
}
// Forward computes final output
func (fl *FinalLayer) Forward(x *mlx.Array, c *mlx.Array) *mlx.Array {
// c: [B, 256] -> scale: [B, dim]
scale := mlx.SiLU(c)
scale = fl.AdaLN.Forward(scale)
scale = mlx.ExpandDims(scale, 1) // [B, 1, dim]
// LayerNorm (affine=False) then scale
x = layerNormNoAffine(x, 1e-6)
x = mlx.Mul(x, mlx.AddScalar(scale, 1.0))
// Output projection
shape := x.Shape()
B := shape[0]
L := shape[1]
D := shape[2]
x = mlx.Reshape(x, B*L, D)
x = fl.Output.Forward(x)
return mlx.Reshape(x, B, L, fl.OutDim)
}
// layerNormNoAffine applies layer norm without learnable parameters
func layerNormNoAffine(x *mlx.Array, eps float32) *mlx.Array {
ndim := x.Ndim()
lastAxis := ndim - 1
mean := mlx.Mean(x, lastAxis, true)
xCentered := mlx.Sub(x, mean)
variance := mlx.Mean(mlx.Square(xCentered), lastAxis, true)
return mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, eps)))
}
// Transformer is the full Z-Image DiT model
type Transformer struct {
TEmbed *TimestepEmbedder `weight:"t_embedder"`
XEmbed *XEmbedder `weight:"all_x_embedder"`
CapEmbed *CapEmbedder `weight:"cap_embedder"`
NoiseRefiners []*TransformerBlock `weight:"noise_refiner"`
ContextRefiners []*TransformerBlock `weight:"context_refiner"`
Layers []*TransformerBlock `weight:"layers"`
FinalLayer *FinalLayer `weight:"all_final_layer.2-1"`
XPadToken *mlx.Array `weight:"x_pad_token"`
CapPadToken *mlx.Array `weight:"cap_pad_token"`
*TransformerConfig
}
// Load loads the Z-Image transformer from ollama blob storage.
func (m *Transformer) Load(modelManifest *manifest.ModelManifest) error {
fmt.Print(" Loading transformer... ")
// Load config from blob
var cfg TransformerConfig
if err := modelManifest.ReadConfigJSON("transformer/config.json", &cfg); err != nil {
return fmt.Errorf("config: %w", err)
}
if len(cfg.AllPatchSize) > 0 {
cfg.PatchSize = cfg.AllPatchSize[0]
}
m.TransformerConfig = &cfg
m.NoiseRefiners = make([]*TransformerBlock, cfg.NRefinerLayers)
m.ContextRefiners = make([]*TransformerBlock, cfg.NRefinerLayers)
m.Layers = make([]*TransformerBlock, cfg.NLayers)
weights, err := manifest.LoadWeightsFromManifest(modelManifest, "transformer")
if err != nil {
return fmt.Errorf("weights: %w", err)
}
if err := weights.Load(0); err != nil {
return fmt.Errorf("load weights: %w", err)
}
defer weights.ReleaseAll()
return m.loadWeights(weights)
}
// loadWeights loads weights from any WeightSource into the model
func (m *Transformer) loadWeights(weights safetensors.WeightSource) error {
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return fmt.Errorf("load module: %w", err)
}
m.initComputedFields()
fmt.Println("✓")
return nil
}
// initComputedFields initializes computed fields after loading weights
func (m *Transformer) initComputedFields() {
cfg := m.TransformerConfig
m.TEmbed.FreqEmbedSize = 256
m.FinalLayer.OutDim = m.FinalLayer.Output.OutputDim()
m.CapEmbed.Norm.Eps = 1e-6
for _, block := range m.NoiseRefiners {
initTransformerBlock(block, cfg)
}
for _, block := range m.ContextRefiners {
initTransformerBlock(block, cfg)
}
for _, block := range m.Layers {
initTransformerBlock(block, cfg)
}
}
// FuseAllQKV fuses QKV projections in all attention layers for efficiency.
// This reduces 3 matmuls to 1 per attention layer, providing ~5-10% speedup.
func (m *Transformer) FuseAllQKV() {
for _, block := range m.NoiseRefiners {
block.Attention.FuseQKV()
}
for _, block := range m.ContextRefiners {
block.Attention.FuseQKV()
}
for _, block := range m.Layers {
block.Attention.FuseQKV()
}
}
// initTransformerBlock sets computed fields on a transformer block
func initTransformerBlock(block *TransformerBlock, cfg *TransformerConfig) {
block.Dim = cfg.Dim
block.HasModulation = block.AdaLN != nil
// Init attention computed fields
attn := block.Attention
attn.NHeads = cfg.NHeads
attn.HeadDim = cfg.Dim / cfg.NHeads
attn.Dim = cfg.Dim
attn.Scale = float32(1.0 / math.Sqrt(float64(attn.HeadDim)))
// Init feedforward OutDim
block.FeedForward.OutDim = block.FeedForward.W2.OutputDim()
// Set eps on all RMSNorm layers
block.AttentionNorm1.Eps = cfg.NormEps
block.AttentionNorm2.Eps = cfg.NormEps
block.FFNNorm1.Eps = cfg.NormEps
block.FFNNorm2.Eps = cfg.NormEps
}
// RoPECache holds precomputed RoPE values
type RoPECache struct {
ImgCos *mlx.Array
ImgSin *mlx.Array
CapCos *mlx.Array
CapSin *mlx.Array
UnifiedCos *mlx.Array
UnifiedSin *mlx.Array
ImgLen int32
CapLen int32
GridH int32 // Image token grid height
GridW int32 // Image token grid width
}
// PrepareRoPECache precomputes RoPE values for the given image and caption lengths.
// hTok and wTok are the number of tokens in each dimension (latentH/patchSize, latentW/patchSize).
func (m *Transformer) PrepareRoPECache(hTok, wTok, capLen int32) *RoPECache {
imgLen := hTok * wTok
// Image positions: grid over (1, H, W) starting at (capLen+1, 0, 0)
imgPos := createCoordinateGrid(1, hTok, wTok, capLen+1, 0, 0)
imgPos = mlx.ToBFloat16(imgPos)
// Caption positions: grid over (capLen, 1, 1) starting at (1, 0, 0)
capPos := createCoordinateGrid(capLen, 1, 1, 1, 0, 0)
capPos = mlx.ToBFloat16(capPos)
// Compute RoPE from UNIFIED positions
unifiedPos := mlx.Concatenate([]*mlx.Array{imgPos, capPos}, 1)
unifiedCos, unifiedSin := prepareRoPE3D(unifiedPos, m.TransformerConfig.AxesDims)
// Slice RoPE for image and caption parts
imgCos := mlx.Slice(unifiedCos, []int32{0, 0, 0, 0}, []int32{1, imgLen, 1, 64})
imgSin := mlx.Slice(unifiedSin, []int32{0, 0, 0, 0}, []int32{1, imgLen, 1, 64})
capCos := mlx.Slice(unifiedCos, []int32{0, imgLen, 0, 0}, []int32{1, imgLen + capLen, 1, 64})
capSin := mlx.Slice(unifiedSin, []int32{0, imgLen, 0, 0}, []int32{1, imgLen + capLen, 1, 64})
return &RoPECache{
ImgCos: imgCos,
ImgSin: imgSin,
CapCos: capCos,
CapSin: capSin,
UnifiedCos: unifiedCos,
UnifiedSin: unifiedSin,
ImgLen: imgLen,
CapLen: capLen,
GridH: hTok,
GridW: wTok,
}
}
// Forward runs the Z-Image transformer with precomputed RoPE
func (m *Transformer) Forward(x *mlx.Array, t *mlx.Array, capFeats *mlx.Array, rope *RoPECache) *mlx.Array {
imgLen := rope.ImgLen
// Timestep embedding -> [B, 256]
temb := m.TEmbed.Forward(mlx.MulScalar(t, m.TransformerConfig.TScale))
// Embed image patches -> [B, L_img, dim]
x = m.XEmbed.Forward(x)
// Embed caption features -> [B, L_cap, dim]
capEmb := m.CapEmbed.Forward(capFeats)
eps := m.NormEps
// Noise refiner: refine image patches with modulation
for _, refiner := range m.NoiseRefiners {
x = refiner.Forward(x, temb, rope.ImgCos, rope.ImgSin, eps)
}
// Context refiner: refine caption (no modulation)
for _, refiner := range m.ContextRefiners {
capEmb = refiner.Forward(capEmb, nil, rope.CapCos, rope.CapSin, eps)
}
// Concatenate image and caption for joint attention
unified := mlx.Concatenate([]*mlx.Array{x, capEmb}, 1)
// Main transformer layers use full unified RoPE
for _, layer := range m.Layers {
unified = layer.Forward(unified, temb, rope.UnifiedCos, rope.UnifiedSin, eps)
}
// Extract image tokens only
unifiedShape := unified.Shape()
B := unifiedShape[0]
imgOut := mlx.Slice(unified, []int32{0, 0, 0}, []int32{B, imgLen, unifiedShape[2]})
// Final layer
return m.FinalLayer.Forward(imgOut, temb)
}
// ForwardWithCache runs the transformer with layer caching for faster inference.
// On refresh steps (step % cacheInterval == 0), all layers are computed and cached.
// On other steps, shallow layers (0 to cacheLayers-1) reuse cached outputs.
func (m *Transformer) ForwardWithCache(
x *mlx.Array,
t *mlx.Array,
capFeats *mlx.Array,
rope *RoPECache,
stepCache *cache.StepCache,
step int,
cacheInterval int,
) *mlx.Array {
imgLen := rope.ImgLen
cacheLayers := stepCache.NumLayers()
eps := m.NormEps
// Timestep embedding -> [B, 256]
temb := m.TEmbed.Forward(mlx.MulScalar(t, m.TransformerConfig.TScale))
// Embed image patches -> [B, L_img, dim]
x = m.XEmbed.Forward(x)
// Context refiners: compute once on step 0, reuse forever
// (caption embedding doesn't depend on timestep or latents)
var capEmb *mlx.Array
if stepCache.GetConstant() != nil {
capEmb = stepCache.GetConstant()
} else {
capEmb = m.CapEmbed.Forward(capFeats)
for _, refiner := range m.ContextRefiners {
capEmb = refiner.Forward(capEmb, nil, rope.CapCos, rope.CapSin, eps)
}
stepCache.SetConstant(capEmb)
}
// Noise refiners: always compute (depend on x which changes each step)
for _, refiner := range m.NoiseRefiners {
x = refiner.Forward(x, temb, rope.ImgCos, rope.ImgSin, eps)
}
// Concatenate image and caption for joint attention
unified := mlx.Concatenate([]*mlx.Array{x, capEmb}, 1)
// Determine if this is a cache refresh step
refreshCache := stepCache.ShouldRefresh(step, cacheInterval)
// Main transformer layers with caching
for i, layer := range m.Layers {
if i < cacheLayers && !refreshCache && stepCache.Get(i) != nil {
// Use cached output for shallow layers
unified = stepCache.Get(i)
} else {
// Compute layer
unified = layer.Forward(unified, temb, rope.UnifiedCos, rope.UnifiedSin, eps)
// Cache shallow layer outputs on refresh steps
if i < cacheLayers && refreshCache {
stepCache.Set(i, unified)
}
}
}
// Extract image tokens only
unifiedShape := unified.Shape()
B := unifiedShape[0]
imgOut := mlx.Slice(unified, []int32{0, 0, 0}, []int32{B, imgLen, unifiedShape[2]})
// Final layer
return m.FinalLayer.Forward(imgOut, temb)
}
// createCoordinateGrid creates 3D position grid [1, d0*d1*d2, 3]
func createCoordinateGrid(d0, d1, d2, s0, s1, s2 int32) *mlx.Array {
// Create meshgrid and stack
total := d0 * d1 * d2
coords := make([]float32, total*3)
idx := 0
for i := int32(0); i < d0; i++ {
for j := int32(0); j < d1; j++ {
for k := int32(0); k < d2; k++ {
coords[idx*3+0] = float32(s0 + i)
coords[idx*3+1] = float32(s1 + j)
coords[idx*3+2] = float32(s2 + k)
idx++
}
}
}
return mlx.NewArray(coords, []int32{1, total, 3})
}
// prepareRoPE3D computes cos/sin for 3-axis RoPE
// positions: [B, L, 3] with (h, w, t) coordinates
// axesDims: [32, 48, 48] - dimensions for each axis
// Returns: cos, sin each [B, L, 1, head_dim/2]
func prepareRoPE3D(positions *mlx.Array, axesDims []int32) (*mlx.Array, *mlx.Array) {
// Compute frequencies for each axis
// dims = [32, 48, 48], so halves = [16, 24, 24]
ropeTheta := float32(256.0)
freqs := make([]*mlx.Array, 3)
for axis := 0; axis < 3; axis++ {
half := axesDims[axis] / 2
f := make([]float32, half)
for i := int32(0); i < half; i++ {
f[i] = float32(math.Exp(-math.Log(float64(ropeTheta)) * float64(i) / float64(half)))
}
freqs[axis] = mlx.NewArray(f, []int32{1, 1, 1, half})
}
// Extract position coordinates
shape := positions.Shape()
B := shape[0]
L := shape[1]
// positions[:, :, 0] -> h positions
posH := mlx.Slice(positions, []int32{0, 0, 0}, []int32{B, L, 1})
posW := mlx.Slice(positions, []int32{0, 0, 1}, []int32{B, L, 2})
posT := mlx.Slice(positions, []int32{0, 0, 2}, []int32{B, L, 3})
// Compute args: pos * freqs for each axis
posH = mlx.ExpandDims(posH, 3) // [B, L, 1, 1]
posW = mlx.ExpandDims(posW, 3)
posT = mlx.ExpandDims(posT, 3)
argsH := mlx.Mul(posH, freqs[0]) // [B, L, 1, 16]
argsW := mlx.Mul(posW, freqs[1]) // [B, L, 1, 24]
argsT := mlx.Mul(posT, freqs[2]) // [B, L, 1, 24]
// Concatenate: [B, L, 1, 16+24+24=64]
args := mlx.Concatenate([]*mlx.Array{argsH, argsW, argsT}, 3)
// Compute cos and sin
return mlx.Cos(args), mlx.Sin(args)
}
// PatchifyLatents converts latents [B, C, H, W] to patches [B, L, C*patch^2]
// Matches Python: x.reshape(C, 1, 1, H_tok, 2, W_tok, 2).transpose(1,2,3,5,4,6,0).reshape(1,-1,C*4)
func PatchifyLatents(latents *mlx.Array, patchSize int32) *mlx.Array {
shape := latents.Shape()
C := shape[1]
H := shape[2]
W := shape[3]
pH := H / patchSize // H_tok
pW := W / patchSize // W_tok
// Match Python exactly: reshape treating B=1 as part of contiguous data
// [1, C, H, W] -> [C, 1, 1, pH, 2, pW, 2]
x := mlx.Reshape(latents, C, 1, 1, pH, patchSize, pW, patchSize)
// Python: transpose(1, 2, 3, 5, 4, 6, 0)
// [C, 1, 1, pH, 2, pW, 2] -> [1, 1, pH, pW, 2, 2, C]
x = mlx.Transpose(x, 1, 2, 3, 5, 4, 6, 0)
// [1, 1, pH, pW, 2, 2, C] -> [1, pH*pW, C*4]
return mlx.Reshape(x, 1, pH*pW, C*patchSize*patchSize)
}
// UnpatchifyLatents converts patches [B, L, C*patch^2] back to [B, C, H, W]
// Matches Python: out.reshape(1,1,H_tok,W_tok,2,2,C).transpose(6,0,1,2,4,3,5).reshape(1,C,H,W)
func UnpatchifyLatents(patches *mlx.Array, patchSize, H, W, C int32) *mlx.Array {
pH := H / patchSize
pW := W / patchSize
// [1, L, C*4] -> [1, 1, pH, pW, 2, 2, C]
x := mlx.Reshape(patches, 1, 1, pH, pW, patchSize, patchSize, C)
// Python: transpose(6, 0, 1, 2, 4, 3, 5)
// [1, 1, pH, pW, 2, 2, C] -> [C, 1, 1, pH, 2, pW, 2]
x = mlx.Transpose(x, 6, 0, 1, 2, 4, 3, 5)
// [C, 1, 1, pH, 2, pW, 2] -> [1, C, H, W]
return mlx.Reshape(x, 1, C, H, W)
}

View File

@@ -0,0 +1,820 @@
package zimage
import (
"fmt"
"math"
"github.com/ollama/ollama/x/imagegen/manifest"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/safetensors"
"github.com/ollama/ollama/x/imagegen/vae"
)
// VAEConfig holds VAE decoder configuration
type VAEConfig struct {
InChannels int32 `json:"in_channels"`
OutChannels int32 `json:"out_channels"`
LatentChannels int32 `json:"latent_channels"`
BlockOutChannels []int32 `json:"block_out_channels"`
LayersPerBlock int32 `json:"layers_per_block"`
NormNumGroups int32 `json:"norm_num_groups"`
ScalingFactor float32 `json:"scaling_factor"`
ShiftFactor float32 `json:"shift_factor"`
}
// GroupNormLayer implements group normalization
type GroupNormLayer struct {
Weight *mlx.Array
Bias *mlx.Array
NumGroups int32
Eps float32
}
// NewGroupNorm creates a group norm layer
func NewGroupNorm(weight, bias *mlx.Array, numGroups int32) *GroupNormLayer {
return &GroupNormLayer{
Weight: weight,
Bias: bias,
NumGroups: numGroups,
Eps: 1e-5,
}
}
// Forward applies group normalization
// Input and output are in NHWC format [B, H, W, C]
func (gn *GroupNormLayer) Forward(x *mlx.Array) *mlx.Array {
// x: [B, H, W, C] (NHWC format)
shape := x.Shape()
B := shape[0]
H := shape[1]
W := shape[2]
C := shape[3]
// For large spatial sizes, use tiled computation to avoid CUDA grid limits
// CUDA grid.y max is 65535, so H*W/16 must be <= 65535, meaning H*W <= ~1M
// To be safe, tile when H*W > 512*512 = 262144
if H*W > 512*512 {
return gn.forwardTiled(x, B, H, W, C)
}
return gn.forwardSmall(x, B, H, W, C)
}
// forwardSmall is the standard GroupNorm for tensors that fit within CUDA grid limits
func (gn *GroupNormLayer) forwardSmall(x *mlx.Array, B, H, W, C int32) *mlx.Array {
// Reshape to [B, H, W, groups, C/groups]
groupSize := C / gn.NumGroups
x = mlx.Reshape(x, B, H, W, gn.NumGroups, groupSize)
// Compute mean and variance per group (over H, W, and C/groups dimensions)
mean := mlx.Mean(x, 1, true)
mean = mlx.Mean(mean, 2, true)
mean = mlx.Mean(mean, 4, true)
xCentered := mlx.Sub(x, mean)
// Variance over same axes
sq := mlx.Square(xCentered)
variance := mlx.Mean(sq, 1, true)
variance = mlx.Mean(variance, 2, true)
variance = mlx.Mean(variance, 4, true)
// Normalize
xNorm := mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, gn.Eps)))
// Reshape back to [B, H, W, C]
xNorm = mlx.Reshape(xNorm, B, H, W, C)
// Scale and shift (weight and bias are [C])
if gn.Weight != nil {
weight := mlx.Reshape(gn.Weight, 1, 1, 1, C)
xNorm = mlx.Mul(xNorm, weight)
}
if gn.Bias != nil {
bias := mlx.Reshape(gn.Bias, 1, 1, 1, C)
xNorm = mlx.Add(xNorm, bias)
}
return xNorm
}
// forwardTiled handles large tensors by processing in H-tiles to avoid CUDA grid limits
func (gn *GroupNormLayer) forwardTiled(x *mlx.Array, B, H, W, C int32) *mlx.Array {
groupSize := C / gn.NumGroups
// Keep the input - we need it for slicing tiles later
// Track if we were the ones who kept it, so we can restore state after
wasKept := x.Kept()
mlx.Keep(x)
// Compute per-group mean and variance using flattened spatial dimensions
// Build the entire compute graph first, then eval once
// Reshape to [B, H*W, groups, groupSize]
xFlat := mlx.Reshape(x, B, H*W, gn.NumGroups, groupSize)
// Mean over spatial (axis 1) and groupSize (axis 3) dimensions
// Result shape: [B, 1, groups, 1]
mean1 := mlx.Mean(xFlat, 1, true)
mean := mlx.Mean(mean1, 3, true)
// Variance using E[X^2] - E[X]^2
xSq := mlx.Square(xFlat)
meanSq1 := mlx.Mean(xSq, 1, true)
meanSq := mlx.Mean(meanSq1, 3, true)
meanSquared := mlx.Square(mean)
variance := mlx.Sub(meanSq, meanSquared)
// invStd = 1/sqrt(var + eps)
varPlusEps := mlx.AddScalar(variance, gn.Eps)
stdDev := mlx.Sqrt(varPlusEps)
one := mlx.Full(1.0, 1)
invStd := mlx.Div(one, stdDev)
// Eval mean and invStd together - these are what we need for the tile loop
mlx.Keep(mean, invStd)
mlx.Eval(mean, invStd)
// Tile along H dimension
tileH := int32(512 * 512 / W)
if tileH < 1 {
tileH = 1
}
if tileH > H {
tileH = H
}
// Prepare weight and bias reshaped for 4D broadcast [1, 1, groups, groupSize]
var weightGN, biasGN *mlx.Array
if gn.Weight != nil {
weightGN = mlx.Reshape(gn.Weight, 1, 1, gn.NumGroups, groupSize)
mlx.Keep(weightGN)
mlx.Eval(weightGN)
}
if gn.Bias != nil {
biasGN = mlx.Reshape(gn.Bias, 1, 1, gn.NumGroups, groupSize)
mlx.Keep(biasGN)
mlx.Eval(biasGN)
}
var tiles []*mlx.Array
for hStart := int32(0); hStart < H; hStart += tileH {
hEnd := hStart + tileH
if hEnd > H {
hEnd = H
}
tileHeight := hEnd - hStart
spatialSize := tileHeight * W
// Build the compute graph for this tile (no intermediate Evals)
// Extract tile and flatten spatial dims: [B, tileH*W, groups, groupSize]
tile := mlx.Slice(x, []int32{0, hStart, 0, 0}, []int32{B, hEnd, W, C})
tileFlat := mlx.Reshape(tile, B, spatialSize, gn.NumGroups, groupSize)
// Normalize: (x - mean) * invStd
tileCentered := mlx.Sub(tileFlat, mean)
tileNorm := mlx.Mul(tileCentered, invStd)
// Apply scale and shift in 4D space
if weightGN != nil {
tileNorm = mlx.Mul(tileNorm, weightGN)
}
if biasGN != nil {
tileNorm = mlx.Add(tileNorm, biasGN)
}
// Reshape back to [B, tileH, W, C]
tileOut := mlx.Reshape(tileNorm, B, tileHeight, W, C)
// Now eval and keep this tile
mlx.Keep(tileOut)
mlx.Eval(tileOut)
tiles = append(tiles, tileOut)
}
// Concatenate tiles along H axis
var result *mlx.Array
if len(tiles) == 1 {
result = tiles[0]
} else {
result = mlx.Concatenate(tiles, 1)
mlx.Eval(result)
// Free the individual tiles now that they're concatenated
for _, t := range tiles {
t.Free()
}
}
// Clean up kept arrays
// Restore x's kept state - only free if we were the ones who kept it
if !wasKept {
x.Free()
}
mean.Free()
invStd.Free()
if weightGN != nil {
weightGN.Free()
}
if biasGN != nil {
biasGN.Free()
}
return result
}
// Conv2D represents a 2D convolution layer
// Works natively in NHWC format (MLX's native format)
type Conv2D struct {
Weight *mlx.Array // [out_channels, kH, kW, in_channels] (OHWI for MLX)
Bias *mlx.Array // [out_channels]
Stride int32
Padding int32
}
// NewConv2D creates a Conv2D layer
// weight comes in as [out_channels, in_channels, kH, kW] (OIHW from PyTorch)
// we transpose to [out_channels, kH, kW, in_channels] (OHWI for MLX)
func NewConv2D(weight, bias *mlx.Array, stride, padding int32) *Conv2D {
// Transpose weight from OIHW to OHWI
// [O, I, H, W] -> [O, H, W, I]
weightOHWI := mlx.Transpose(weight, 0, 2, 3, 1)
return &Conv2D{
Weight: weightOHWI,
Bias: bias,
Stride: stride,
Padding: padding,
}
}
// Forward applies convolution
// Input and output are in NHWC format [N, H, W, C]
func (conv *Conv2D) Forward(x *mlx.Array) *mlx.Array {
// Conv in NHWC format (MLX native)
out := mlx.Conv2d(x, conv.Weight, conv.Stride, conv.Padding)
if conv.Bias != nil {
// Bias is [C], reshape to [1, 1, 1, C] for NHWC broadcast
bias := mlx.Reshape(conv.Bias, 1, 1, 1, conv.Bias.Dim(0))
out = mlx.Add(out, bias)
}
return out
}
// ResnetBlock2D implements a ResNet block for VAE
type ResnetBlock2D struct {
Norm1 *GroupNormLayer
Conv1 *Conv2D
Norm2 *GroupNormLayer
Conv2 *Conv2D
ConvShortcut *Conv2D // nil if in_channels == out_channels
}
// NewResnetBlock2D creates a ResNet block
func NewResnetBlock2D(weights safetensors.WeightSource, prefix string, numGroups int32) (*ResnetBlock2D, error) {
norm1Weight, err := weights.GetTensor(prefix + ".norm1.weight")
if err != nil {
return nil, err
}
norm1Bias, err := weights.GetTensor(prefix + ".norm1.bias")
if err != nil {
return nil, err
}
conv1Weight, err := weights.GetTensor(prefix + ".conv1.weight")
if err != nil {
return nil, err
}
conv1Bias, err := weights.GetTensor(prefix + ".conv1.bias")
if err != nil {
return nil, err
}
norm2Weight, err := weights.GetTensor(prefix + ".norm2.weight")
if err != nil {
return nil, err
}
norm2Bias, err := weights.GetTensor(prefix + ".norm2.bias")
if err != nil {
return nil, err
}
conv2Weight, err := weights.GetTensor(prefix + ".conv2.weight")
if err != nil {
return nil, err
}
conv2Bias, err := weights.GetTensor(prefix + ".conv2.bias")
if err != nil {
return nil, err
}
block := &ResnetBlock2D{
Norm1: NewGroupNorm(norm1Weight, norm1Bias, numGroups),
Conv1: NewConv2D(conv1Weight, conv1Bias, 1, 1),
Norm2: NewGroupNorm(norm2Weight, norm2Bias, numGroups),
Conv2: NewConv2D(conv2Weight, conv2Bias, 1, 1),
}
if weights.HasTensor(prefix + ".conv_shortcut.weight") {
shortcutWeight, err := weights.GetTensor(prefix + ".conv_shortcut.weight")
if err != nil {
return nil, err
}
shortcutBias, err := weights.GetTensor(prefix + ".conv_shortcut.bias")
if err != nil {
return nil, err
}
block.ConvShortcut = NewConv2D(shortcutWeight, shortcutBias, 1, 0)
}
return block, nil
}
// Forward applies the ResNet block with staged evaluation
func (rb *ResnetBlock2D) Forward(x *mlx.Array) *mlx.Array {
var h *mlx.Array
// Stage 1: norm1
{
h = rb.Norm1.Forward(x)
mlx.Eval(h)
}
// Stage 2: silu + conv1
{
prev := h
h = mlx.SiLU(h)
h = rb.Conv1.Forward(h)
prev.Free()
mlx.Eval(h)
}
// Stage 3: norm2
{
prev := h
h = rb.Norm2.Forward(h)
prev.Free()
mlx.Eval(h)
}
// Stage 4: silu + conv2
{
prev := h
h = mlx.SiLU(h)
h = rb.Conv2.Forward(h)
prev.Free()
mlx.Eval(h)
}
// Residual connection
{
prev := h
if rb.ConvShortcut != nil {
shortcut := rb.ConvShortcut.Forward(x)
h = mlx.Add(h, shortcut)
} else {
h = mlx.Add(h, x)
}
prev.Free()
mlx.Eval(h)
}
return h
}
// VAEAttentionBlock implements self-attention for VAE
type VAEAttentionBlock struct {
GroupNorm *GroupNormLayer
ToQWeight *mlx.Array
ToQBias *mlx.Array
ToKWeight *mlx.Array
ToKBias *mlx.Array
ToVWeight *mlx.Array
ToVBias *mlx.Array
ToOutWeight *mlx.Array
ToOutBias *mlx.Array
NumHeads int32
}
// NewVAEAttentionBlock creates an attention block
func NewVAEAttentionBlock(weights safetensors.WeightSource, prefix string, numGroups int32) (*VAEAttentionBlock, error) {
normWeight, err := weights.GetTensor(prefix + ".group_norm.weight")
if err != nil {
return nil, err
}
normBias, err := weights.GetTensor(prefix + ".group_norm.bias")
if err != nil {
return nil, err
}
toQWeight, err := weights.GetTensor(prefix + ".to_q.weight")
if err != nil {
return nil, err
}
toQBias, err := weights.GetTensor(prefix + ".to_q.bias")
if err != nil {
return nil, err
}
toKWeight, err := weights.GetTensor(prefix + ".to_k.weight")
if err != nil {
return nil, err
}
toKBias, err := weights.GetTensor(prefix + ".to_k.bias")
if err != nil {
return nil, err
}
toVWeight, err := weights.GetTensor(prefix + ".to_v.weight")
if err != nil {
return nil, err
}
toVBias, err := weights.GetTensor(prefix + ".to_v.bias")
if err != nil {
return nil, err
}
toOutWeight, err := weights.GetTensor(prefix + ".to_out.0.weight")
if err != nil {
return nil, err
}
toOutBias, err := weights.GetTensor(prefix + ".to_out.0.bias")
if err != nil {
return nil, err
}
return &VAEAttentionBlock{
GroupNorm: NewGroupNorm(normWeight, normBias, numGroups),
ToQWeight: mlx.Transpose(toQWeight, 1, 0),
ToQBias: toQBias,
ToKWeight: mlx.Transpose(toKWeight, 1, 0),
ToKBias: toKBias,
ToVWeight: mlx.Transpose(toVWeight, 1, 0),
ToVBias: toVBias,
ToOutWeight: mlx.Transpose(toOutWeight, 1, 0),
ToOutBias: toOutBias,
NumHeads: 1,
}, nil
}
// Forward applies attention with staged evaluation
// Input and output are in NHWC format [B, H, W, C]
func (ab *VAEAttentionBlock) Forward(x *mlx.Array) *mlx.Array {
residual := x
shape := x.Shape()
B := shape[0]
H := shape[1]
W := shape[2]
C := shape[3]
var h *mlx.Array
// Stage 1: GroupNorm + reshape to [B, H*W, C]
{
h = ab.GroupNorm.Forward(x)
h = mlx.Reshape(h, B, H*W, C)
mlx.Eval(h)
}
var out *mlx.Array
// Stage 2: Q, K, V projections + attention
{
q := mlx.Linear(h, ab.ToQWeight)
q = mlx.Add(q, ab.ToQBias)
k := mlx.Linear(h, ab.ToKWeight)
k = mlx.Add(k, ab.ToKBias)
v := mlx.Linear(h, ab.ToVWeight)
v = mlx.Add(v, ab.ToVBias)
h.Free()
q = mlx.ExpandDims(q, 1)
k = mlx.ExpandDims(k, 1)
v = mlx.ExpandDims(v, 1)
scale := float32(1.0 / math.Sqrt(float64(C)))
out = mlx.ScaledDotProductAttention(q, k, v, scale, false)
out = mlx.Squeeze(out, 1)
mlx.Eval(out)
}
// Stage 3: Output projection + reshape + residual
{
prev := out
out = mlx.Linear(out, ab.ToOutWeight)
out = mlx.Add(out, ab.ToOutBias)
out = mlx.Reshape(out, B, H, W, C)
out = mlx.Add(out, residual)
prev.Free()
mlx.Eval(out)
}
return out
}
// UpDecoderBlock2D implements an upsampling decoder block
type UpDecoderBlock2D struct {
ResnetBlocks []*ResnetBlock2D
Upsample *Conv2D
}
// NewUpDecoderBlock2D creates an up decoder block
func NewUpDecoderBlock2D(weights safetensors.WeightSource, prefix string, numLayers, numGroups int32, hasUpsample bool) (*UpDecoderBlock2D, error) {
resnets := make([]*ResnetBlock2D, numLayers)
for i := int32(0); i < numLayers; i++ {
resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i)
resnet, err := NewResnetBlock2D(weights, resPrefix, numGroups)
if err != nil {
return nil, err
}
resnets[i] = resnet
}
var upsample *Conv2D
if hasUpsample {
upWeight, err := weights.GetTensor(prefix + ".upsamplers.0.conv.weight")
if err != nil {
return nil, err
}
upBias, err := weights.GetTensor(prefix + ".upsamplers.0.conv.bias")
if err != nil {
return nil, err
}
upsample = NewConv2D(upWeight, upBias, 1, 1)
}
return &UpDecoderBlock2D{
ResnetBlocks: resnets,
Upsample: upsample,
}, nil
}
// Forward applies the up decoder block with staged evaluation to reduce peak memory
func (ub *UpDecoderBlock2D) Forward(x *mlx.Array) *mlx.Array {
for _, resnet := range ub.ResnetBlocks {
prev := x
x = resnet.Forward(x) // ResNet handles its own pools
prev.Free()
}
if ub.Upsample != nil {
// Stage 1: Upsample2x (nearest neighbor)
{
prev := x
x = Upsample2x(x)
prev.Free()
mlx.Eval(x)
}
// Stage 2: Upsample conv
{
prev := x
x = ub.Upsample.Forward(x)
prev.Free()
mlx.Eval(x)
}
}
return x
}
// VAEMidBlock is the middle block with attention
type VAEMidBlock struct {
Resnet1 *ResnetBlock2D
Attention *VAEAttentionBlock
Resnet2 *ResnetBlock2D
}
// NewVAEMidBlock creates the mid block
func NewVAEMidBlock(weights safetensors.WeightSource, prefix string, numGroups int32) (*VAEMidBlock, error) {
resnet1, err := NewResnetBlock2D(weights, prefix+".resnets.0", numGroups)
if err != nil {
return nil, err
}
attention, err := NewVAEAttentionBlock(weights, prefix+".attentions.0", numGroups)
if err != nil {
return nil, err
}
resnet2, err := NewResnetBlock2D(weights, prefix+".resnets.1", numGroups)
if err != nil {
return nil, err
}
return &VAEMidBlock{
Resnet1: resnet1,
Attention: attention,
Resnet2: resnet2,
}, nil
}
// Forward applies the mid block with staged evaluation
func (mb *VAEMidBlock) Forward(x *mlx.Array) *mlx.Array {
prev := x
x = mb.Resnet1.Forward(x) // ResNet handles its own pools
prev.Free()
// Attention handles its own pools
prev = x
x = mb.Attention.Forward(x)
prev.Free()
prev = x
x = mb.Resnet2.Forward(x) // ResNet handles its own pools
prev.Free()
return x
}
// VAEDecoder is the full VAE decoder
type VAEDecoder struct {
Config *VAEConfig
ConvIn *Conv2D
MidBlock *VAEMidBlock
UpBlocks []*UpDecoderBlock2D
ConvNormOut *GroupNormLayer
ConvOut *Conv2D
// Tiling configuration (nil = no tiling)
Tiling *vae.TilingConfig
}
// Load loads the VAE decoder from ollama blob storage.
func (m *VAEDecoder) Load(modelManifest *manifest.ModelManifest) error {
// Load config from blob
var cfg VAEConfig
if err := modelManifest.ReadConfigJSON("vae/config.json", &cfg); err != nil {
return fmt.Errorf("config: %w", err)
}
m.Config = &cfg
// Load weights from tensor blobs
weights, err := manifest.LoadWeightsFromManifest(modelManifest, "vae")
if err != nil {
return fmt.Errorf("weights: %w", err)
}
if err := weights.Load(0); err != nil {
return fmt.Errorf("load weights: %w", err)
}
defer weights.ReleaseAll()
return m.loadWeights(weights, &cfg)
}
// loadWeights loads VAE weights from any WeightSource
func (m *VAEDecoder) loadWeights(weights safetensors.WeightSource, cfg *VAEConfig) error {
var err error
// Load conv_in
fmt.Print(" Loading conv_in... ")
convInWeight, err := weights.GetTensor("decoder.conv_in.weight")
if err != nil {
return err
}
convInBias, err := weights.GetTensor("decoder.conv_in.bias")
if err != nil {
return err
}
m.ConvIn = NewConv2D(convInWeight, convInBias, 1, 1)
fmt.Println("✓")
// Load mid block
fmt.Print(" Loading mid block... ")
m.MidBlock, err = NewVAEMidBlock(weights, "decoder.mid_block", cfg.NormNumGroups)
if err != nil {
return err
}
fmt.Println("✓")
// Load up blocks
fmt.Print(" Loading up blocks... ")
numBlocks := len(cfg.BlockOutChannels)
m.UpBlocks = make([]*UpDecoderBlock2D, numBlocks)
for i := 0; i < numBlocks; i++ {
prefix := fmt.Sprintf("decoder.up_blocks.%d", i)
hasUpsample := i < numBlocks-1
m.UpBlocks[i], err = NewUpDecoderBlock2D(weights, prefix, cfg.LayersPerBlock+1, cfg.NormNumGroups, hasUpsample)
if err != nil {
return err
}
}
fmt.Printf("✓ [%d blocks]\n", numBlocks)
// Load conv_norm_out
fmt.Print(" Loading conv_norm_out... ")
normWeight, err := weights.GetTensor("decoder.conv_norm_out.weight")
if err != nil {
return err
}
normBias, err := weights.GetTensor("decoder.conv_norm_out.bias")
if err != nil {
return err
}
m.ConvNormOut = NewGroupNorm(normWeight, normBias, cfg.NormNumGroups)
fmt.Println("✓")
// Load conv_out
fmt.Print(" Loading conv_out... ")
convOutWeight, err := weights.GetTensor("decoder.conv_out.weight")
if err != nil {
return err
}
convOutBias, err := weights.GetTensor("decoder.conv_out.bias")
if err != nil {
return err
}
m.ConvOut = NewConv2D(convOutWeight, convOutBias, 1, 1)
fmt.Println("✓")
return nil
}
// Decode decodes latents to images.
// Input latents are in NCHW format, output is in NCHW format.
// If Tiling is set, uses tiled decoding to reduce memory for large images.
func (v *VAEDecoder) Decode(latents *mlx.Array) *mlx.Array {
// Scale latents
z := mlx.DivScalar(latents, v.Config.ScalingFactor)
z = mlx.AddScalar(z, v.Config.ShiftFactor)
// Convert NCHW -> NHWC for internal processing
z = mlx.Transpose(z, 0, 2, 3, 1)
// Use tiled decoding if enabled
if v.Tiling != nil {
mlx.Eval(z)
return vae.DecodeTiled(z, v.Tiling, v.decodeTile)
}
// Direct decode
h := v.decodeTile(z)
h = mlx.ClipScalar(h, 0.0, 1.0, true, true)
// Convert NHWC -> NCHW for output
h = mlx.Transpose(h, 0, 3, 1, 2)
mlx.Eval(h)
return h
}
// decodeTile decodes a single latent tile to pixels.
// Input: [B, H, W, C] latent tile in NHWC format (already scaled)
// Output: [B, H*8, W*8, 3] pixel tile in NHWC format
func (v *VAEDecoder) decodeTile(z *mlx.Array) *mlx.Array {
h := v.ConvIn.Forward(z)
mlx.Eval(h)
prev := h
h = v.MidBlock.Forward(h)
prev.Free()
for _, upBlock := range v.UpBlocks {
prev = h
h = upBlock.Forward(h)
prev.Free()
}
prev = h
h = v.ConvNormOut.Forward(h)
mlx.Eval(h) // Eval after GroupNorm to avoid grid dimension issues
prev.Free()
prev = h
h = mlx.SiLU(h)
h = v.ConvOut.Forward(h)
mlx.Eval(h)
prev.Free()
// VAE outputs [-1, 1], convert to [0, 1]
h = mlx.MulScalar(h, 0.5)
h = mlx.AddScalar(h, 0.5)
return h
}
// Upsample2x performs 2x nearest neighbor upsampling using Take.
// Input and output are in NHWC format: [B, H, W, C] -> [B, H*2, W*2, C]
// Uses Take with repeated indices to produce contiguous output.
func Upsample2x(x *mlx.Array) *mlx.Array {
shape := x.Shape()
H := shape[1]
W := shape[2]
// Create indices [0, 0, 1, 1, 2, 2, ...] for nearest neighbor
// For H dimension
hIdx := mlx.ArangeInt(0, H, 1, mlx.DtypeInt32)
hIdx = mlx.Reshape(hIdx, H, 1)
hIdx = mlx.BroadcastTo(hIdx, []int32{H, 2})
hIdx = mlx.Reshape(hIdx, H*2)
// For W dimension
wIdx := mlx.ArangeInt(0, W, 1, mlx.DtypeInt32)
wIdx = mlx.Reshape(wIdx, W, 1)
wIdx = mlx.BroadcastTo(wIdx, []int32{W, 2})
wIdx = mlx.Reshape(wIdx, W*2)
// Take along H axis (axis 1 in NHWC)
x = mlx.Take(x, hIdx, 1)
// Take along W axis (axis 2 in NHWC)
x = mlx.Take(x, wIdx, 2)
return x
}

View File

@@ -0,0 +1,488 @@
// Package zimage implements the Z-Image diffusion transformer model.
package zimage
import (
"context"
"fmt"
"time"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/manifest"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/tokenizer"
"github.com/ollama/ollama/x/imagegen/vae"
)
// GenerateConfig holds all options for image generation.
type GenerateConfig struct {
Prompt string
NegativePrompt string // Empty = no CFG
CFGScale float32 // Only used if NegativePrompt is set (default: 4.0)
Width int32 // Image width (default: 1024)
Height int32 // Image height (default: 1024)
Steps int // Denoising steps (default: 9 for turbo)
Seed int64 // Random seed
Progress func(step, totalSteps int) // Optional progress callback
CapturePath string // GPU capture path (debug)
// TeaCache options (timestep embedding aware caching)
TeaCache bool // TeaCache is always enabled for faster inference
TeaCacheThreshold float32 // Threshold for cache reuse (default: 0.1, lower = more aggressive)
// Fused QKV (fuse Q/K/V projections into single matmul)
FusedQKV bool // Enable fused QKV projection (default: false)
}
// Model represents a Z-Image diffusion model.
type Model struct {
ModelName string
Tokenizer *tokenizer.Tokenizer
TextEncoder *Qwen3TextEncoder
Transformer *Transformer
VAEDecoder *VAEDecoder
qkvFused bool // Track if QKV has been fused (do only once)
}
// Load loads the Z-Image model from ollama blob storage.
func (m *Model) Load(modelName string) error {
fmt.Printf("Loading Z-Image model from manifest: %s...\n", modelName)
start := time.Now()
if mlx.GPUIsAvailable() {
mlx.SetDefaultDeviceGPU()
mlx.EnableCompile()
}
m.ModelName = modelName
// Load manifest
manifest, err := manifest.LoadManifest(modelName)
if err != nil {
return fmt.Errorf("load manifest: %w", err)
}
// Load tokenizer from manifest with config
fmt.Print(" Loading tokenizer... ")
tokData, err := manifest.ReadConfig("tokenizer/tokenizer.json")
if err != nil {
return fmt.Errorf("tokenizer: %w", err)
}
// Try to read tokenizer config files from manifest
tokConfig := &tokenizer.TokenizerConfig{}
if data, err := manifest.ReadConfig("tokenizer/tokenizer_config.json"); err == nil {
tokConfig.TokenizerConfigJSON = data
}
if data, err := manifest.ReadConfig("tokenizer/generation_config.json"); err == nil {
tokConfig.GenerationConfigJSON = data
}
if data, err := manifest.ReadConfig("tokenizer/special_tokens_map.json"); err == nil {
tokConfig.SpecialTokensMapJSON = data
}
tok, err := tokenizer.LoadFromBytesWithConfig(tokData, tokConfig)
if err != nil {
return fmt.Errorf("tokenizer: %w", err)
}
m.Tokenizer = tok
fmt.Println("✓")
// Load text encoder
m.TextEncoder = &Qwen3TextEncoder{}
if err := m.TextEncoder.Load(manifest, "text_encoder/config.json"); err != nil {
return fmt.Errorf("text encoder: %w", err)
}
mlx.Eval(mlx.Collect(m.TextEncoder)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
// Load transformer
m.Transformer = &Transformer{}
if err := m.Transformer.Load(manifest); err != nil {
return fmt.Errorf("transformer: %w", err)
}
mlx.Eval(mlx.Collect(m.Transformer)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
// Load VAE decoder
m.VAEDecoder = &VAEDecoder{}
if err := m.VAEDecoder.Load(manifest); err != nil {
return fmt.Errorf("VAE decoder: %w", err)
}
mlx.Eval(mlx.Collect(m.VAEDecoder)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
mem := mlx.MetalGetActiveMemory()
fmt.Printf(" Loaded in %.2fs (%.1f GB VRAM)\n", time.Since(start).Seconds(), float64(mem)/(1024*1024*1024))
return nil
}
// Generate creates an image from a prompt.
func (m *Model) Generate(prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
return m.GenerateFromConfig(context.Background(), &GenerateConfig{
Prompt: prompt,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
})
}
// GenerateWithProgress creates an image with progress callback.
func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress func(step, totalSteps int)) (*mlx.Array, error) {
return m.GenerateFromConfig(context.Background(), &GenerateConfig{
Prompt: prompt,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
Progress: progress,
})
}
// GenerateWithCFG creates an image with classifier-free guidance.
func (m *Model) GenerateWithCFG(prompt, negativePrompt string, width, height int32, steps int, seed int64, cfgScale float32, progress func(step, totalSteps int)) (*mlx.Array, error) {
return m.GenerateFromConfig(context.Background(), &GenerateConfig{
Prompt: prompt,
NegativePrompt: negativePrompt,
CFGScale: cfgScale,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
Progress: progress,
})
}
// GenerateFromConfig generates an image using the unified config struct.
func (m *Model) GenerateFromConfig(ctx context.Context, cfg *GenerateConfig) (*mlx.Array, error) {
start := time.Now()
result, err := m.generate(ctx, cfg)
if err != nil {
return nil, err
}
if cfg.NegativePrompt != "" {
fmt.Printf("Generated with CFG (scale=%.1f) in %.2fs (%d steps)\n", cfg.CFGScale, time.Since(start).Seconds(), cfg.Steps)
} else {
fmt.Printf("Generated in %.2fs (%d steps)\n", time.Since(start).Seconds(), cfg.Steps)
}
return result, nil
}
// GenerateImage implements runner.ImageModel interface.
func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64, progress func(step, total int)) (*mlx.Array, error) {
return m.GenerateFromConfig(ctx, &GenerateConfig{
Prompt: prompt,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
Progress: progress,
})
}
// generate is the internal denoising pipeline.
func (m *Model) generate(ctx context.Context, cfg *GenerateConfig) (*mlx.Array, error) {
// Apply defaults
if cfg.Width <= 0 {
cfg.Width = 1024
}
if cfg.Height <= 0 {
cfg.Height = 1024
}
if cfg.Steps <= 0 {
cfg.Steps = 9 // Z-Image turbo default
}
if cfg.CFGScale <= 0 {
cfg.CFGScale = 4.0
}
// TeaCache enabled by default
cfg.TeaCache = true
if cfg.TeaCacheThreshold <= 0 {
cfg.TeaCacheThreshold = 0.15
}
// Enable fused QKV if requested (only fuse once)
if cfg.FusedQKV && !m.qkvFused {
m.Transformer.FuseAllQKV()
m.qkvFused = true
fmt.Println(" Fused QKV enabled")
}
useCFG := cfg.NegativePrompt != ""
tcfg := m.Transformer.TransformerConfig
latentH := cfg.Height / 8
latentW := cfg.Width / 8
hTok := latentH / tcfg.PatchSize
wTok := latentW / tcfg.PatchSize
// Text encoding with padding to multiple of 32
var posEmb, negEmb *mlx.Array
{
posEmb, _ = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.Prompt, 512, false)
if useCFG {
negEmb, _ = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.NegativePrompt, 512, false)
}
// Pad both to same length (multiple of 32)
maxLen := posEmb.Shape()[1]
if useCFG && negEmb.Shape()[1] > maxLen {
maxLen = negEmb.Shape()[1]
}
if pad := (32 - (maxLen % 32)) % 32; pad > 0 {
maxLen += pad
}
posEmb = padToLength(posEmb, maxLen)
if useCFG {
negEmb = padToLength(negEmb, maxLen)
mlx.Keep(posEmb, negEmb)
mlx.Eval(posEmb, negEmb)
} else {
mlx.Keep(posEmb)
mlx.Eval(posEmb)
}
}
// Scheduler
scheduler := NewFlowMatchEulerScheduler(DefaultFlowMatchSchedulerConfig())
scheduler.SetTimestepsWithMu(cfg.Steps, CalculateShift(hTok*wTok))
// Init latents [B, C, H, W]
var latents *mlx.Array
{
latents = scheduler.InitNoise([]int32{1, tcfg.InChannels, latentH, latentW}, cfg.Seed)
mlx.Eval(latents)
}
// RoPE cache
var ropeCache *RoPECache
{
ropeCache = m.Transformer.PrepareRoPECache(hTok, wTok, posEmb.Shape()[1])
mlx.Keep(ropeCache.ImgCos, ropeCache.ImgSin, ropeCache.CapCos, ropeCache.CapSin,
ropeCache.UnifiedCos, ropeCache.UnifiedSin)
mlx.Eval(ropeCache.UnifiedCos)
}
// Pre-compute batched embeddings for CFG (outside the loop for efficiency)
var batchedEmb *mlx.Array
if useCFG {
// Concatenate embeddings once: [1, L, D] + [1, L, D] -> [2, L, D]
batchedEmb = mlx.Concatenate([]*mlx.Array{posEmb, negEmb}, 0)
mlx.Keep(batchedEmb)
mlx.Eval(batchedEmb)
}
// TeaCache for timestep-aware caching
// For CFG mode, we cache pos/neg separately, skip early steps, and always compute CFG fresh
var teaCache *cache.TeaCache
if cfg.TeaCache {
skipEarly := 0
if useCFG {
skipEarly = 3 // Skip first 3 steps for CFG to preserve structure
}
teaCache = cache.NewTeaCache(&cache.TeaCacheConfig{
Threshold: cfg.TeaCacheThreshold,
RescaleFactor: 1.0,
SkipEarlySteps: skipEarly,
})
if useCFG {
fmt.Printf(" TeaCache enabled (CFG mode): threshold=%.2f, skip first %d steps\n", cfg.TeaCacheThreshold, skipEarly)
} else {
fmt.Printf(" TeaCache enabled: threshold=%.2f\n", cfg.TeaCacheThreshold)
}
}
// cleanup frees all kept arrays when we need to abort early
cleanup := func() {
posEmb.Free()
if negEmb != nil {
negEmb.Free()
}
ropeCache.ImgCos.Free()
ropeCache.ImgSin.Free()
ropeCache.CapCos.Free()
ropeCache.CapSin.Free()
ropeCache.UnifiedCos.Free()
ropeCache.UnifiedSin.Free()
if batchedEmb != nil {
batchedEmb.Free()
}
if teaCache != nil {
teaCache.Free()
}
latents.Free()
}
// Denoising loop
if cfg.Progress != nil {
cfg.Progress(0, cfg.Steps) // Start at 0%
}
for i := 0; i < cfg.Steps; i++ {
// Check for cancellation
if ctx != nil {
select {
case <-ctx.Done():
cleanup()
return nil, ctx.Err()
default:
}
}
stepStart := time.Now()
// GPU capture on step 2 if requested
if cfg.CapturePath != "" && i == 1 {
mlx.MetalStartCapture(cfg.CapturePath)
}
tCurr := scheduler.Timesteps[i]
var noisePred *mlx.Array
// TeaCache: check if we should compute or reuse cached output
shouldCompute := teaCache == nil || teaCache.ShouldCompute(i, tCurr)
if shouldCompute {
timestep := mlx.ToBFloat16(mlx.NewArray([]float32{1.0 - tCurr}, []int32{1}))
patches := PatchifyLatents(latents, tcfg.PatchSize)
var output *mlx.Array
if useCFG {
// CFG Batching: single forward pass with batch=2
// Tile patches: [1, L, D] -> [2, L, D]
batchedPatches := mlx.Tile(patches, []int32{2, 1, 1})
// Tile timestep: [1] -> [2]
batchedTimestep := mlx.Tile(timestep, []int32{2})
// Single batched forward pass (RoPE broadcasts from [1,L,H,D] to [2,L,H,D])
batchedOutput := m.Transformer.Forward(batchedPatches, batchedTimestep, batchedEmb, ropeCache)
// Split output: [2, L, D] -> pos [1, L, D], neg [1, L, D]
outputShape := batchedOutput.Shape()
L := outputShape[1]
D := outputShape[2]
posOutput := mlx.Slice(batchedOutput, []int32{0, 0, 0}, []int32{1, L, D})
negOutput := mlx.Slice(batchedOutput, []int32{1, 0, 0}, []int32{2, L, D})
// Convert to noise predictions (unpatchify and negate)
posPred := UnpatchifyLatents(posOutput, tcfg.PatchSize, latentH, latentW, tcfg.InChannels)
posPred = mlx.Neg(posPred)
negPred := UnpatchifyLatents(negOutput, tcfg.PatchSize, latentH, latentW, tcfg.InChannels)
negPred = mlx.Neg(negPred)
// Cache pos/neg separately for TeaCache
if teaCache != nil {
teaCache.UpdateCFGCache(posPred, negPred, tCurr)
mlx.Keep(teaCache.Arrays()...)
}
// Apply CFG: noisePred = neg + scale * (pos - neg)
diff := mlx.Sub(posPred, negPred)
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
noisePred = mlx.Add(negPred, scaledDiff)
} else {
// Non-CFG forward pass
output = m.Transformer.Forward(patches, timestep, posEmb, ropeCache)
noisePred = UnpatchifyLatents(output, tcfg.PatchSize, latentH, latentW, tcfg.InChannels)
noisePred = mlx.Neg(noisePred)
// Update TeaCache
if teaCache != nil {
teaCache.UpdateCache(noisePred, tCurr)
mlx.Keep(teaCache.Arrays()...)
}
}
} else if useCFG && teaCache != nil && teaCache.HasCFGCache() {
// CFG mode: get cached pos/neg and compute CFG fresh
posPred, negPred := teaCache.GetCFGCached()
diff := mlx.Sub(posPred, negPred)
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
noisePred = mlx.Add(negPred, scaledDiff)
fmt.Printf(" [TeaCache: reusing cached pos/neg outputs]\n")
} else {
// Non-CFG mode: reuse cached noise prediction
noisePred = teaCache.GetCached()
fmt.Printf(" [TeaCache: reusing cached output]\n")
}
oldLatents := latents
latents = scheduler.Step(noisePred, latents, i)
mlx.Eval(latents)
oldLatents.Free()
if cfg.CapturePath != "" && i == 1 {
mlx.MetalStopCapture()
}
activeMem := float64(mlx.MetalGetActiveMemory()) / (1024 * 1024 * 1024)
peakMem := float64(mlx.MetalGetPeakMemory()) / (1024 * 1024 * 1024)
fmt.Printf(" Step %d/%d: t=%.4f (%.2fs) [%.1f GB active, %.1f GB peak]\n",
i+1, cfg.Steps, tCurr, time.Since(stepStart).Seconds(), activeMem, peakMem)
if cfg.Progress != nil {
cfg.Progress(i+1, cfg.Steps) // Report completed step
}
}
// Free denoising temporaries before VAE decode
posEmb.Free()
if negEmb != nil {
negEmb.Free()
}
ropeCache.ImgCos.Free()
ropeCache.ImgSin.Free()
ropeCache.CapCos.Free()
ropeCache.CapSin.Free()
ropeCache.UnifiedCos.Free()
ropeCache.UnifiedSin.Free()
if batchedEmb != nil {
batchedEmb.Free()
}
if teaCache != nil {
hits, misses := teaCache.Stats()
fmt.Printf(" TeaCache stats: %d hits, %d misses (%.1f%% cache rate)\n",
hits, misses, float64(hits)/float64(hits+misses)*100)
teaCache.Free()
}
// VAE decode - enable tiling for larger images to reduce memory
// VAE attention is O(n²) on latent pixels, tiling helps significantly
if latentH > 64 || latentW > 64 {
m.VAEDecoder.Tiling = vae.DefaultTilingConfig()
}
decoded := m.VAEDecoder.Decode(latents)
latents.Free()
return decoded, nil
}
// padToLength pads a sequence tensor to the target length by repeating the last token.
func padToLength(x *mlx.Array, targetLen int32) *mlx.Array {
shape := x.Shape()
currentLen := shape[1]
if currentLen >= targetLen {
return x
}
padLen := targetLen - currentLen
lastToken := mlx.Slice(x, []int32{0, currentLen - 1, 0}, []int32{shape[0], currentLen, shape[2]})
padding := mlx.Tile(lastToken, []int32{1, padLen, 1})
return mlx.Concatenate([]*mlx.Array{x, padding}, 1)
}
// CalculateShift computes the mu shift value for dynamic scheduling
func CalculateShift(imgSeqLen int32) float32 {
baseSeqLen := float32(256)
maxSeqLen := float32(4096)
baseShift := float32(0.5)
maxShift := float32(1.15)
m := (maxShift - baseShift) / (maxSeqLen - baseSeqLen)
b := baseShift - m*baseSeqLen
return float32(imgSeqLen)*m + b
}