ollama source for Momentry Core verification

This commit is contained in:
Accusys
2026-05-22 17:19:10 +08:00
commit 0b31ff9135
2020 changed files with 1413145 additions and 0 deletions

439
x/models/dflash/dflash.go Normal file
View File

@@ -0,0 +1,439 @@
// Package dflash implements DFlash block-diffusion draft models for MLX.
package dflash
import (
"encoding/json"
"fmt"
"math"
"sort"
"strings"
"github.com/ollama/ollama/x/mlxrunner/batch"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
"github.com/ollama/ollama/x/mlxrunner/model/base"
"github.com/ollama/ollama/x/models/nn"
)
func init() {
base.RegisterDraft("DFlashDraftModel", newModel)
base.RegisterDraft("dflash", newModel)
}
var _ base.DFlashDraftModel = (*Model)(nil)
type Config struct {
HiddenSize int32 `json:"hidden_size"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumKeyValueHeads int32 `json:"num_key_value_heads"`
HeadDim int32 `json:"head_dim"`
IntermediateSize int32 `json:"intermediate_size"`
VocabSize int32 `json:"vocab_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
RopeScaling *nn.RopeParameters `json:"rope_scaling"`
RopeParameters *nn.RopeParameters `json:"rope_parameters"`
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
BlockSizeValue int32 `json:"block_size"`
NumTargetLayers int32 `json:"num_target_layers"`
LayerTypes []string `json:"layer_types"`
SlidingWindow int32 `json:"sliding_window"`
FinalLogitSoftcapping *float32 `json:"final_logit_softcapping"`
DFlash struct {
TargetLayerIDs []int `json:"target_layer_ids"`
MaskTokenID int32 `json:"mask_token_id"`
} `json:"dflash_config"`
QuantGroupSize int `json:"-"`
QuantBits int `json:"-"`
QuantMode string `json:"-"`
TensorQuant map[string]*model.TensorQuantInfo `json:"-"`
Scale float32 `json:"-"`
RopeFreqs *mlx.Array `json:"-"`
RopeScale float32 `json:"-"`
}
type Model struct {
FC nn.LinearLayer
HiddenNorm *nn.RMSNorm
Layers []*Layer
Norm *nn.RMSNorm
target base.Model
targetEmbeddings base.MTPEmbeddingModel
tensorPrefix string
*Config
}
type Layer struct {
Attention *Attention
MLP *MLP
InputNorm *nn.RMSNorm
PostAttentionNorm *nn.RMSNorm
}
type Attention struct {
QProj nn.LinearLayer
KProj nn.LinearLayer
VProj nn.LinearLayer
OProj nn.LinearLayer
QNorm *nn.RMSNorm
KNorm *nn.RMSNorm
Sliding bool
}
type MLP struct {
GateProj nn.LinearLayer
UpProj nn.LinearLayer
DownProj nn.LinearLayer
}
func parseConfig(data []byte) (Config, error) {
var cfg Config
if err := json.Unmarshal(data, &cfg); err != nil {
return Config{}, fmt.Errorf("parse dflash config: %w", err)
}
if cfg.HiddenSize <= 0 {
return Config{}, fmt.Errorf("invalid hidden_size: %d", cfg.HiddenSize)
}
if cfg.NumHiddenLayers <= 0 {
return Config{}, fmt.Errorf("invalid num_hidden_layers: %d", cfg.NumHiddenLayers)
}
if cfg.NumAttentionHeads <= 0 {
return Config{}, fmt.Errorf("invalid num_attention_heads: %d", cfg.NumAttentionHeads)
}
if cfg.NumKeyValueHeads <= 0 {
cfg.NumKeyValueHeads = cfg.NumAttentionHeads
}
if cfg.HeadDim <= 0 {
if cfg.HiddenSize%cfg.NumAttentionHeads != 0 {
return Config{}, fmt.Errorf("hidden_size (%d) must be divisible by num_attention_heads (%d)", cfg.HiddenSize, cfg.NumAttentionHeads)
}
cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
}
if cfg.RMSNormEps == 0 {
cfg.RMSNormEps = 1e-6
}
if cfg.RopeTheta == 0 {
ropeParams := cfg.RopeParameters
if ropeParams == nil {
ropeParams = cfg.RopeScaling
}
if ropeParams != nil && ropeParams.RopeTheta > 0 {
cfg.RopeTheta = ropeParams.RopeTheta
}
}
if cfg.RopeTheta == 0 {
cfg.RopeTheta = 1000000
}
if cfg.BlockSizeValue <= 0 {
return Config{}, fmt.Errorf("invalid block_size: %d", cfg.BlockSizeValue)
}
if len(cfg.DFlash.TargetLayerIDs) == 0 {
return Config{}, fmt.Errorf("dflash_config.target_layer_ids is required")
}
if !sort.IntsAreSorted(cfg.DFlash.TargetLayerIDs) {
return Config{}, fmt.Errorf("dflash_config.target_layer_ids must be sorted")
}
if len(cfg.LayerTypes) == 0 {
cfg.LayerTypes = make([]string, cfg.NumHiddenLayers)
for i := range cfg.LayerTypes {
cfg.LayerTypes[i] = "full_attention"
}
}
if len(cfg.LayerTypes) != int(cfg.NumHiddenLayers) {
return Config{}, fmt.Errorf("layer_types length %d does not match num_hidden_layers %d", len(cfg.LayerTypes), cfg.NumHiddenLayers)
}
for i, typ := range cfg.LayerTypes {
switch strings.ToLower(typ) {
case "full_attention":
case "sliding_attention":
if cfg.SlidingWindow <= 0 {
return Config{}, fmt.Errorf("layer %d uses sliding_attention but sliding_window is not set", i)
}
default:
return Config{}, fmt.Errorf("unsupported layer type %q", typ)
}
}
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
cfg.RopeScale = 1
ropeParams := cfg.RopeParameters
if ropeParams == nil {
ropeParams = cfg.RopeScaling
}
if ropeParams != nil && strings.EqualFold(ropeParams.TypeName(), "yarn") {
cfg.RopeFreqs, cfg.RopeScale = nn.BuildYarnRopeFreqs(int(cfg.HeadDim), cfg.RopeTheta, ropeParams)
}
return cfg, nil
}
func newModel(root *model.Root, target base.Model) (base.DraftModel, error) {
if root == nil || root.Draft == nil {
return nil, fmt.Errorf("draft metadata missing")
}
configPath := root.Draft.Config
if configPath == "" {
configPath = "draft/config.json"
}
configData, err := root.Manifest.ReadConfig(configPath)
if err != nil {
return nil, fmt.Errorf("load dflash config: %w", err)
}
cfg, err := parseConfig(configData)
if err != nil {
return nil, err
}
if target.NumLayers() < int(cfg.NumTargetLayers) {
return nil, fmt.Errorf("dflash target expects %d layers, target has %d", cfg.NumTargetLayers, target.NumLayers())
}
for _, layerID := range cfg.DFlash.TargetLayerIDs {
if layerID < 0 || layerID >= target.NumLayers() {
return nil, fmt.Errorf("dflash target layer id %d out of range for %d-layer target", layerID, target.NumLayers())
}
}
targetEmbeddings, ok := target.(base.MTPEmbeddingModel)
if !ok {
return nil, fmt.Errorf("dflash draft requires target token embeddings, got %T", target)
}
if qt := root.QuantType(); qt != "" {
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams(qt)
if gs := root.GroupSize(); gs > 0 {
cfg.QuantGroupSize = gs
}
} else {
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams("")
}
cfg.TensorQuant = root.AllTensorQuant()
prefix := root.Draft.TensorPrefix
if prefix == "" {
prefix = "draft."
}
m := &Model{
Config: &cfg,
Layers: make([]*Layer, cfg.NumHiddenLayers),
target: target,
targetEmbeddings: targetEmbeddings,
tensorPrefix: prefix,
}
for i := range m.Layers {
m.Layers[i] = &Layer{Attention: &Attention{}, MLP: &MLP{}}
}
return m, nil
}
func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
prefix := m.tensorPrefix
linears := model.NewLinearFactory(tensors, m.QuantGroupSize, m.QuantBits, m.QuantMode, m.TensorQuant)
m.FC = linears.Make(prefix + "fc")
if m.FC == nil {
return fmt.Errorf("missing dflash fc weight")
}
if w := tensors[prefix+"hidden_norm.weight"]; w != nil {
m.HiddenNorm = nn.NewRMSNorm(w, m.RMSNormEps)
}
if w := tensors[prefix+"norm.weight"]; w != nil {
m.Norm = nn.NewRMSNorm(w, m.RMSNormEps)
}
if m.HiddenNorm == nil || m.Norm == nil {
return fmt.Errorf("missing dflash norm weights")
}
for i := range m.NumHiddenLayers {
layerPrefix := fmt.Sprintf("%slayers.%d", prefix, i)
layer := &Layer{
Attention: &Attention{Sliding: strings.ToLower(m.LayerTypes[i]) == "sliding_attention"},
MLP: &MLP{
GateProj: linears.Make(layerPrefix + ".mlp.gate_proj"),
UpProj: linears.Make(layerPrefix + ".mlp.up_proj"),
DownProj: linears.Make(layerPrefix + ".mlp.down_proj"),
},
}
if w := tensors[layerPrefix+".input_layernorm.weight"]; w != nil {
layer.InputNorm = nn.NewRMSNorm(w, m.RMSNormEps)
}
if w := tensors[layerPrefix+".post_attention_layernorm.weight"]; w != nil {
layer.PostAttentionNorm = nn.NewRMSNorm(w, m.RMSNormEps)
}
layer.Attention.QProj = linears.Make(layerPrefix + ".self_attn.q_proj")
layer.Attention.KProj = linears.Make(layerPrefix + ".self_attn.k_proj")
layer.Attention.VProj = linears.Make(layerPrefix + ".self_attn.v_proj")
layer.Attention.OProj = linears.Make(layerPrefix + ".self_attn.o_proj")
if w := tensors[layerPrefix+".self_attn.q_norm.weight"]; w != nil {
layer.Attention.QNorm = nn.NewRMSNorm(w, m.RMSNormEps)
}
if w := tensors[layerPrefix+".self_attn.k_norm.weight"]; w != nil {
layer.Attention.KNorm = nn.NewRMSNorm(w, m.RMSNormEps)
}
if layer.InputNorm == nil || layer.PostAttentionNorm == nil {
return fmt.Errorf("dflash layer %d: missing layer norms", i)
}
if layer.Attention.QProj == nil || layer.Attention.KProj == nil || layer.Attention.VProj == nil || layer.Attention.OProj == nil {
return fmt.Errorf("dflash layer %d: missing attention projections", i)
}
if layer.Attention.QNorm == nil || layer.Attention.KNorm == nil {
return fmt.Errorf("dflash layer %d: missing attention q/k norms", i)
}
if layer.MLP.GateProj == nil || layer.MLP.UpProj == nil || layer.MLP.DownProj == nil {
return fmt.Errorf("dflash layer %d: missing mlp projections", i)
}
m.Layers[i] = layer
}
return nil
}
func (m *Model) TargetLayerIDs() []int {
return append([]int(nil), m.DFlash.TargetLayerIDs...)
}
func (m *Model) BlockSize() int {
return int(m.BlockSizeValue)
}
func (m *Model) MaskTokenID() int32 {
return m.DFlash.MaskTokenID
}
func (m *Model) NewCaches() []cache.Cache {
caches := make([]cache.Cache, len(m.Layers))
for i, typ := range m.LayerTypes {
if strings.ToLower(typ) == "sliding_attention" {
// RotatingKVCache.View returns maxSize-1 tokens so assistant
// paths can append the current query. DFlash uses that same
// view for target context, so allocate one extra slot to expose
// the draft model's sliding_window-1 context tokens.
caches[i] = cache.NewRotatingKVCache(int(m.SlidingWindow))
} else {
caches[i] = cache.NewKVCache()
}
}
return caches
}
func (m *Model) AppendContext(targetHidden *mlx.Array, caches []cache.Cache) {
if targetHidden == nil || targetHidden.Dim(1) == 0 {
return
}
hCtx := m.HiddenNorm.Forward(m.FC.Forward(targetHidden), m.RMSNormEps)
offset := int32(0)
if len(caches) > 0 && caches[0] != nil {
offset = int32(caches[0].Offset())
}
b := &batch.Batch{
InputIDs: mlx.Zeros(mlx.DTypeInt32, targetHidden.Dim(0), targetHidden.Dim(1)),
SeqOffsets: []int32{offset},
SeqQueryLens: []int32{int32(targetHidden.Dim(1))},
}
positions := mlx.FromValues(b.SeqOffsets, len(b.SeqOffsets))
for i, layer := range m.Layers {
if i >= len(caches) || caches[i] == nil {
continue
}
layer.Attention.AppendContext(hCtx, b, positions, caches[i], m.Config)
}
}
func (m *Model) Draft(inputIDs *mlx.Array, caches []cache.Cache) *mlx.Array {
dims := inputIDs.Dims()
B, L := int32(dims[0]), int32(dims[1])
offset := int32(0)
if len(caches) > 0 && caches[0] != nil {
offset = int32(caches[0].Offset())
}
b := &batch.Batch{
InputIDs: inputIDs,
SeqOffsets: []int32{offset},
SeqQueryLens: []int32{L},
}
positions := mlx.FromValues(b.SeqOffsets, len(b.SeqOffsets))
h := m.targetEmbeddings.TokenEmbeddings(inputIDs)
for i, layer := range m.Layers {
var c cache.Cache
if i < len(caches) {
c = caches[i]
}
h = layer.Forward(h, b, c, positions, B, L, m.Config)
}
logits := m.target.Unembed(m.Norm.Forward(h, m.RMSNormEps))
if m.FinalLogitSoftcapping != nil {
cap := mlx.FromValue(*m.FinalLogitSoftcapping).AsType(logits.DType())
logits = mlx.LogitSoftcap(logits, cap)
}
return logits
}
func (l *Layer) Forward(x *mlx.Array, b *batch.Batch, c cache.Cache, positions *mlx.Array, B, L int32, cfg *Config) *mlx.Array {
h := mlx.Add(x, l.Attention.Forward(l.InputNorm.Forward(x, cfg.RMSNormEps), b, c, positions, B, L, cfg))
return mlx.Add(h, l.MLP.Forward(l.PostAttentionNorm.Forward(h, cfg.RMSNormEps)))
}
func (a *Attention) AppendContext(xCtx *mlx.Array, b *batch.Batch, positions *mlx.Array, c cache.Cache, cfg *Config) {
B, L := int32(xCtx.Dim(0)), int32(xCtx.Dim(1))
k := a.KProj.Forward(xCtx)
v := a.VProj.Forward(xCtx)
k = mlx.Reshape(k, B, L, cfg.NumKeyValueHeads, cfg.HeadDim)
v = mlx.Reshape(v, B, L, cfg.NumKeyValueHeads, cfg.HeadDim)
k = a.KNorm.Forward(k, cfg.RMSNormEps)
k = mlx.Transpose(k, 0, 2, 1, 3)
v = mlx.Transpose(v, 0, 2, 1, 3)
k = nn.ScaleRotaryPart(mlx.RoPEWithFreqs(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, positions, cfg.RopeFreqs), int(cfg.HeadDim), cfg.RopeScale)
c.(cache.Attention).Update(b, k, v)
}
func (a *Attention) Forward(x *mlx.Array, b *batch.Batch, c cache.Cache, positions *mlx.Array, B, L int32, cfg *Config) *mlx.Array {
q := a.QProj.Forward(x)
propK := a.KProj.Forward(x)
propV := a.VProj.Forward(x)
q = mlx.Reshape(q, B, L, cfg.NumAttentionHeads, cfg.HeadDim)
propK = mlx.Reshape(propK, B, L, cfg.NumKeyValueHeads, cfg.HeadDim)
propV = mlx.Reshape(propV, B, L, cfg.NumKeyValueHeads, cfg.HeadDim)
q = a.QNorm.Forward(q, cfg.RMSNormEps)
propK = a.KNorm.Forward(propK, cfg.RMSNormEps)
q = mlx.Transpose(q, 0, 2, 1, 3)
propK = mlx.Transpose(propK, 0, 2, 1, 3)
propV = mlx.Transpose(propV, 0, 2, 1, 3)
q = nn.ScaleRotaryPart(mlx.RoPEWithFreqs(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, positions, cfg.RopeFreqs), int(cfg.HeadDim), cfg.RopeScale)
propK = nn.ScaleRotaryPart(mlx.RoPEWithFreqs(propK, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, positions, cfg.RopeFreqs), int(cfg.HeadDim), cfg.RopeScale)
k, v := propK, propV
if viewer, ok := c.(cache.Viewer); ok {
if history := viewer.View(b); history != nil {
k = history.K().Concatenate(2, propK)
v = history.V().Concatenate(2, propV)
}
}
mask := nn.AttentionMask{}
if a.Sliding {
mask = nn.CausalMask()
if int(cfg.SlidingWindow) > 0 && k.Dim(2) > int(cfg.SlidingWindow) {
mask = mask.Intersect(nn.SlidingWindowMask(b, k.Dim(2), int(cfg.SlidingWindow), q.DType()))
}
}
out := nn.ScaledDotProductAttention(b, q, cfg.Scale, nn.WithKV(k, v, []int32{int32(k.Dim(2))}), nn.WithMask(mask))
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
return a.OProj.Forward(out)
}
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
return m.DownProj.Forward(mlx.SwiGLU(m.GateProj.Forward(x), m.UpProj.Forward(x)))
}

View File

@@ -0,0 +1,52 @@
package dflash
import (
"math"
"testing"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
func TestParseConfigYarnRopeScaling(t *testing.T) {
if err := mlx.CheckInit(); err != nil {
t.Skipf("MLX not available: %v", err)
}
data := []byte(`{
"hidden_size": 2048,
"num_hidden_layers": 8,
"num_attention_heads": 32,
"num_key_value_heads": 4,
"head_dim": 128,
"intermediate_size": 6144,
"vocab_size": 248320,
"rms_norm_eps": 0.000001,
"rope_theta": 10000000,
"rope_scaling": {
"beta_fast": 32.0,
"beta_slow": 1.0,
"factor": 64.0,
"original_max_position_embeddings": 4096,
"rope_type": "yarn"
},
"block_size": 16,
"num_target_layers": 40,
"layer_types": ["full_attention", "full_attention", "full_attention", "full_attention", "full_attention", "full_attention", "full_attention", "full_attention"],
"dflash_config": {
"mask_token_id": 248070,
"target_layer_ids": [1, 10, 19, 28, 37]
}
}`)
cfg, err := parseConfig(data)
if err != nil {
t.Fatalf("parseConfig failed: %v", err)
}
if cfg.RopeFreqs == nil {
t.Fatalf("RopeFreqs is nil")
}
wantScale := float32(0.1*math.Log(64.0) + 1.0)
if math.Abs(float64(cfg.RopeScale-wantScale)) > 1e-6 {
t.Fatalf("RopeScale = %v, want %v", cfg.RopeScale, wantScale)
}
}

517
x/models/gemma3/gemma3.go Normal file
View File

@@ -0,0 +1,517 @@
// Package gemma3 provides the Gemma 3 text model implementation for MLX.
package gemma3
import (
"encoding/json"
"fmt"
"math"
"github.com/ollama/ollama/x/mlxrunner/batch"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
"github.com/ollama/ollama/x/mlxrunner/model/base"
"github.com/ollama/ollama/x/models/nn"
"github.com/ollama/ollama/x/tokenizer"
)
func init() {
base.Register("Gemma3ForCausalLM", newModel)
base.Register("Gemma3ForConditionalGeneration", newModel)
}
// TextConfig holds configuration for the Gemma 3 text model.
type TextConfig struct {
HiddenSize int32 `json:"hidden_size"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
IntermediateSize int32 `json:"intermediate_size"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumKeyValueHeads int32 `json:"num_key_value_heads"`
HeadDim int32 `json:"head_dim"`
VocabSize int32 `json:"vocab_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
RopeLocalBaseFreq float32 `json:"rope_local_base_freq"`
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
SlidingWindow int32 `json:"sliding_window"`
SlidingWindowPattern int32 `json:"sliding_window_pattern"`
LayerTypes []string `json:"layer_types"`
TieWordEmbeddings bool `json:"tie_word_embeddings"`
// Quantization parameters (set during load based on model quantization).
QuantGroupSize int `json:"-"`
QuantBits int `json:"-"`
QuantMode string `json:"-"`
TensorQuant map[string]*model.TensorQuantInfo `json:"-"`
// Computed fields.
Scale float32 `json:"-"`
}
// Attention implements Gemma 3 attention with Q/K normalization.
type Attention struct {
QProj nn.LinearLayer
KProj nn.LinearLayer
VProj nn.LinearLayer
OProj nn.LinearLayer
QNorm *nn.RMSNorm
KNorm *nn.RMSNorm
// Precomputed (1 + weight) for Gemma-style RMSNorm.
QNormScaled *mlx.Array
KNormScaled *mlx.Array
}
// MLP is the feed-forward network with GELU activation.
type MLP struct {
GateProj nn.LinearLayer
UpProj nn.LinearLayer
DownProj nn.LinearLayer
}
// DecoderLayer is a single transformer block.
type DecoderLayer struct {
InputNorm *nn.RMSNorm
Attention *Attention
PostAttnNorm *nn.RMSNorm
PreFFNorm *nn.RMSNorm
MLP *MLP
PostFFNorm *nn.RMSNorm
// Precomputed (1 + weight) for Gemma-style RMSNorm.
InputNormScaled *mlx.Array
PostAttnNormScaled *mlx.Array
PreFFNormScaled *mlx.Array
PostFFNormScaled *mlx.Array
// Layer metadata.
IsSliding bool
LayerIdx int32
}
// Model is the Gemma 3 text-only model.
type Model struct {
EmbedTokens nn.EmbeddingLayer
Layers []*DecoderLayer
Norm *nn.RMSNorm
LMHead nn.LinearLayer
// Precomputed (1 + weight) for Gemma-style RMSNorm.
NormScaled *mlx.Array
tok *tokenizer.Tokenizer
*TextConfig
weightPrefix string
}
func defaultHeads(numLayers int32) (numHeads, numKVHeads int32) {
switch numLayers {
case 34:
return 8, 4
case 48:
return 16, 8
case 62:
return 32, 16
default:
return 8, 4
}
}
func parseTextConfig(configData []byte) (TextConfig, bool, error) {
var cfg TextConfig
if err := json.Unmarshal(configData, &cfg); err != nil {
return TextConfig{}, false, fmt.Errorf("parse config: %w", err)
}
var wrapped struct {
TextConfig *TextConfig `json:"text_config"`
}
if err := json.Unmarshal(configData, &wrapped); err != nil {
return TextConfig{}, false, fmt.Errorf("parse nested text config: %w", err)
}
fromConditional := wrapped.TextConfig != nil
if fromConditional {
cfg = *wrapped.TextConfig
if cfg.HeadDim == 0 {
cfg.HeadDim = 256
}
if cfg.NumAttentionHeads == 0 {
cfg.NumAttentionHeads, cfg.NumKeyValueHeads = defaultHeads(cfg.NumHiddenLayers)
}
if cfg.NumKeyValueHeads == 0 {
_, cfg.NumKeyValueHeads = defaultHeads(cfg.NumHiddenLayers)
}
if cfg.VocabSize == 0 {
cfg.VocabSize = 262208
}
if cfg.SlidingWindowPattern == 0 && len(cfg.LayerTypes) == 0 {
cfg.SlidingWindowPattern = 6
}
if cfg.MaxPositionEmbeddings == 0 {
cfg.MaxPositionEmbeddings = 131072
}
}
if cfg.HeadDim == 0 {
cfg.HeadDim = 256
}
if cfg.NumAttentionHeads == 0 {
cfg.NumAttentionHeads, cfg.NumKeyValueHeads = defaultHeads(cfg.NumHiddenLayers)
}
if cfg.NumKeyValueHeads == 0 {
cfg.NumKeyValueHeads = max(1, cfg.NumAttentionHeads/2)
}
if cfg.RopeTheta == 0 {
cfg.RopeTheta = 1000000
}
if cfg.RopeLocalBaseFreq == 0 {
cfg.RopeLocalBaseFreq = 10000
}
if cfg.RMSNormEps == 0 {
cfg.RMSNormEps = 1e-6
}
if cfg.VocabSize == 0 {
cfg.VocabSize = 262208
}
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
return cfg, fromConditional, nil
}
func resolveWeightPrefix(tensors map[string]*mlx.Array) string {
for _, prefix := range []string{"", "language_model."} {
if tensors[prefix+"model.embed_tokens.weight"] != nil {
return prefix
}
}
return ""
}
func isLayerSliding(layerIdx int32, cfg *TextConfig) bool {
if len(cfg.LayerTypes) > 0 && int(layerIdx) < len(cfg.LayerTypes) {
return cfg.LayerTypes[layerIdx] == "sliding_attention"
}
if cfg.SlidingWindowPattern <= 0 {
return false
}
return (layerIdx+1)%cfg.SlidingWindowPattern != 0
}
func precomputeGemmaScaledWeights(m *Model) {
if m.Norm != nil {
m.NormScaled = mlx.AddScalar(m.Norm.Weight, 1.0)
}
var scaled []*mlx.Array
if m.NormScaled != nil {
scaled = append(scaled, m.NormScaled)
}
for _, layer := range m.Layers {
if layer == nil || layer.Attention == nil {
continue
}
if layer.InputNorm != nil {
layer.InputNormScaled = mlx.AddScalar(layer.InputNorm.Weight, 1.0)
scaled = append(scaled, layer.InputNormScaled)
}
if layer.PostAttnNorm != nil {
layer.PostAttnNormScaled = mlx.AddScalar(layer.PostAttnNorm.Weight, 1.0)
scaled = append(scaled, layer.PostAttnNormScaled)
}
if layer.PreFFNorm != nil {
layer.PreFFNormScaled = mlx.AddScalar(layer.PreFFNorm.Weight, 1.0)
scaled = append(scaled, layer.PreFFNormScaled)
}
if layer.PostFFNorm != nil {
layer.PostFFNormScaled = mlx.AddScalar(layer.PostFFNorm.Weight, 1.0)
scaled = append(scaled, layer.PostFFNormScaled)
}
if layer.Attention.QNorm != nil {
layer.Attention.QNormScaled = mlx.AddScalar(layer.Attention.QNorm.Weight, 1.0)
scaled = append(scaled, layer.Attention.QNormScaled)
}
if layer.Attention.KNorm != nil {
layer.Attention.KNormScaled = mlx.AddScalar(layer.Attention.KNorm.Weight, 1.0)
scaled = append(scaled, layer.Attention.KNormScaled)
}
}
if len(scaled) > 0 {
mlx.Eval(scaled...)
}
}
func newModel(root *model.Root) (base.Model, error) {
configData, err := root.Manifest.ReadConfig("config.json")
if err != nil {
return nil, fmt.Errorf("load config: %w", err)
}
cfg, _, err := parseTextConfig(configData)
if err != nil {
return nil, err
}
if qt := root.QuantType(); qt != "" {
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams(qt)
if gs := root.GroupSize(); gs > 0 {
cfg.QuantGroupSize = gs
}
} else {
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams("")
}
cfg.TensorQuant = root.AllTensorQuant()
tokData, err := root.Manifest.ReadConfig("tokenizer.json")
if err != nil {
return nil, fmt.Errorf("load tokenizer config: %w", err)
}
tokConfig := &tokenizer.TokenizerConfig{ConfigJSON: configData}
if genConfigData, err := root.Manifest.ReadConfig("generation_config.json"); err == nil {
tokConfig.GenerationConfigJSON = genConfigData
}
if tokConfigData, err := root.Manifest.ReadConfig("tokenizer_config.json"); err == nil {
tokConfig.TokenizerConfigJSON = tokConfigData
}
tok, err := tokenizer.LoadFromBytesWithConfig(tokData, tokConfig)
if err != nil {
return nil, fmt.Errorf("parse tokenizer: %w", err)
}
m := &Model{
Layers: make([]*DecoderLayer, cfg.NumHiddenLayers),
TextConfig: &cfg,
tok: tok,
}
for i := range m.Layers {
m.Layers[i] = &DecoderLayer{
LayerIdx: int32(i),
IsSliding: isLayerSliding(int32(i), m.TextConfig),
}
}
return m, nil
}
// LoadWeights receives all tensors loaded from the manifest and assigns them
// to model fields.
func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
m.weightPrefix = resolveWeightPrefix(tensors)
prefix := m.weightPrefix
linears := model.NewLinearFactory(tensors, m.QuantGroupSize, m.QuantBits, m.QuantMode, m.TensorQuant)
embedTokens := model.MakeEmbeddingLayer(tensors, prefix+"model.embed_tokens", m.QuantGroupSize, m.QuantBits, m.QuantMode, m.TensorQuant)
if embedTokens == nil {
return fmt.Errorf("missing embedding weight: %smodel.embed_tokens.weight", prefix)
}
m.EmbedTokens = embedTokens
normWeight := tensors[prefix+"model.norm.weight"]
if normWeight == nil {
return fmt.Errorf("missing final norm weight: %smodel.norm.weight", prefix)
}
m.Norm = nn.NewRMSNorm(normWeight, m.RMSNormEps)
if lmHead := linears.Make(prefix + "lm_head"); lmHead != nil {
m.LMHead = lmHead
} else if lmHead := linears.Make("lm_head"); lmHead != nil {
m.LMHead = lmHead
} else {
// Gemma usually ties output projection to embeddings.
m.LMHead = m.EmbedTokens.AsLinear()
}
for i := int32(0); i < m.NumHiddenLayers; i++ {
layerPrefix := fmt.Sprintf("%smodel.layers.%d", prefix, i)
layer := &DecoderLayer{
LayerIdx: i,
IsSliding: isLayerSliding(i, m.TextConfig),
Attention: &Attention{},
MLP: &MLP{},
}
if w := tensors[layerPrefix+".input_layernorm.weight"]; w != nil {
layer.InputNorm = nn.NewRMSNorm(w, m.RMSNormEps)
}
if w := tensors[layerPrefix+".post_attention_layernorm.weight"]; w != nil {
layer.PostAttnNorm = nn.NewRMSNorm(w, m.RMSNormEps)
}
if w := tensors[layerPrefix+".pre_feedforward_layernorm.weight"]; w != nil {
layer.PreFFNorm = nn.NewRMSNorm(w, m.RMSNormEps)
}
if w := tensors[layerPrefix+".post_feedforward_layernorm.weight"]; w != nil {
layer.PostFFNorm = nn.NewRMSNorm(w, m.RMSNormEps)
}
layer.Attention.QProj = linears.Make(layerPrefix + ".self_attn.q_proj")
layer.Attention.KProj = linears.Make(layerPrefix + ".self_attn.k_proj")
layer.Attention.VProj = linears.Make(layerPrefix + ".self_attn.v_proj")
layer.Attention.OProj = linears.Make(layerPrefix + ".self_attn.o_proj")
if w := tensors[layerPrefix+".self_attn.q_norm.weight"]; w != nil {
layer.Attention.QNorm = nn.NewRMSNorm(w, m.RMSNormEps)
}
if w := tensors[layerPrefix+".self_attn.k_norm.weight"]; w != nil {
layer.Attention.KNorm = nn.NewRMSNorm(w, m.RMSNormEps)
}
layer.MLP.GateProj = linears.Make(layerPrefix + ".mlp.gate_proj")
layer.MLP.UpProj = linears.Make(layerPrefix + ".mlp.up_proj")
layer.MLP.DownProj = linears.Make(layerPrefix + ".mlp.down_proj")
if layer.InputNorm == nil {
return fmt.Errorf("layer %d: missing input_layernorm", i)
}
if layer.PostAttnNorm == nil {
return fmt.Errorf("layer %d: missing post_attention_layernorm", i)
}
if layer.PreFFNorm == nil {
return fmt.Errorf("layer %d: missing pre_feedforward_layernorm", i)
}
if layer.PostFFNorm == nil {
return fmt.Errorf("layer %d: missing post_feedforward_layernorm", i)
}
if layer.Attention.QProj == nil || layer.Attention.KProj == nil || layer.Attention.VProj == nil || layer.Attention.OProj == nil {
return fmt.Errorf("layer %d: missing attention projections", i)
}
if layer.Attention.QNorm == nil || layer.Attention.KNorm == nil {
return fmt.Errorf("layer %d: missing attention q/k norms", i)
}
if layer.MLP.GateProj == nil || layer.MLP.UpProj == nil || layer.MLP.DownProj == nil {
return fmt.Errorf("layer %d: missing mlp projections", i)
}
m.Layers[i] = layer
}
precomputeGemmaScaledWeights(m)
if m.NormScaled == nil {
return fmt.Errorf("missing precomputed final norm weight")
}
return nil
}
func (m *Model) Forward(b *batch.Batch, caches []cache.Cache) *mlx.Array {
dims := b.InputIDs.Dims()
B, L := int32(dims[0]), int32(dims[1])
positions := mlx.FromValues(b.SeqOffsets, len(b.SeqOffsets))
h := m.EmbedTokens.Forward(b.InputIDs)
h = mlx.MulScalar(h, float32(math.Sqrt(float64(m.HiddenSize))))
for i, layer := range m.Layers {
var c cache.Cache
if caches != nil && i < len(caches) {
c = caches[i]
}
h = layer.Forward(h, b, c, positions, B, L, m.TextConfig)
}
return mlx.RMSNormFn(h, m.NormScaled, m.RMSNormEps)
}
func (m *Model) Unembed(x *mlx.Array) *mlx.Array {
return m.LMHead.Forward(x)
}
func (m *Model) NumLayers() int {
return len(m.Layers)
}
func (m *Model) MaxContextLength() int {
return int(m.MaxPositionEmbeddings)
}
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
return m.tok
}
// NewCaches creates cache objects for all layers.
func (m *Model) NewCaches() []cache.Cache {
caches := make([]cache.Cache, len(m.Layers))
for i, layer := range m.Layers {
if m.SlidingWindow > 0 && layer.IsSliding {
caches[i] = cache.NewRotatingKVCache(int(m.SlidingWindow))
} else {
caches[i] = cache.NewKVCache()
}
}
return caches
}
// FormatPrompt applies the Gemma 3 chat template.
func (m *Model) FormatPrompt(prompt string) string {
return fmt.Sprintf("<start_of_turn>user\n%s<end_of_turn>\n<start_of_turn>model\n", prompt)
}
func (l *DecoderLayer) Forward(x *mlx.Array, b *batch.Batch, c cache.Cache, positions *mlx.Array, B, L int32, cfg *TextConfig) *mlx.Array {
normed := mlx.RMSNormFn(x, l.InputNormScaled, cfg.RMSNormEps)
attnOut := l.Attention.Forward(normed, b, c, positions, B, L, l.IsSliding, cfg)
attnOut = mlx.RMSNormFn(attnOut, l.PostAttnNormScaled, cfg.RMSNormEps)
h := mlx.Add(x, attnOut)
normed = mlx.RMSNormFn(h, l.PreFFNormScaled, cfg.RMSNormEps)
mlpOut := l.MLP.Forward(normed)
mlpOut = mlx.RMSNormFn(mlpOut, l.PostFFNormScaled, cfg.RMSNormEps)
return mlx.Add(h, mlpOut)
}
func (a *Attention) Forward(x *mlx.Array, b *batch.Batch, c cache.Cache, positions *mlx.Array, B, L int32, isSliding bool, cfg *TextConfig) *mlx.Array {
q := a.QProj.Forward(x)
k := a.KProj.Forward(x)
v := a.VProj.Forward(x)
q = mlx.Reshape(q, B, L, cfg.NumAttentionHeads, cfg.HeadDim)
q = mlx.Transpose(q, 0, 2, 1, 3)
k = mlx.Reshape(k, B, L, cfg.NumKeyValueHeads, cfg.HeadDim)
k = mlx.Transpose(k, 0, 2, 1, 3)
v = mlx.Reshape(v, B, L, cfg.NumKeyValueHeads, cfg.HeadDim)
v = mlx.Transpose(v, 0, 2, 1, 3)
q = mlx.RMSNormFn(q, a.QNormScaled, cfg.RMSNormEps)
k = mlx.RMSNormFn(k, a.KNormScaled, cfg.RMSNormEps)
ropeTheta := cfg.RopeTheta
if isSliding {
ropeTheta = cfg.RopeLocalBaseFreq
}
q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, ropeTheta, 1.0, positions)
k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, ropeTheta, 1.0, positions)
// MLX SDPA supports grouped-query attention directly (Q heads can be a
// multiple of K/V heads), so avoid materializing repeated K/V tensors.
var kv nn.SDPAOption
if c != nil {
history := c.(cache.Attention).Update(b, k, v)
kv = nn.WithKVHistory(history)
} else {
kv = nn.WithKV(k, v, b.SeqQueryLens)
}
out := nn.ScaledDotProductAttention(b, q, cfg.Scale, kv, nn.WithMask(nn.CausalMask()))
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
return a.OProj.Forward(out)
}
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
gate := mlx.GELUApprox(m.GateProj.Forward(x))
up := m.UpProj.Forward(x)
return m.DownProj.Forward(mlx.Mul(gate, up))
}

View File

@@ -0,0 +1,390 @@
package gemma4
import (
"encoding/json"
"fmt"
"github.com/ollama/ollama/x/mlxrunner/batch"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
"github.com/ollama/ollama/x/mlxrunner/model/base"
"github.com/ollama/ollama/x/models/nn"
)
var (
_ base.DraftModel = (*AssistantModel)(nil)
_ base.MTPDraftModel = (*AssistantModel)(nil)
)
type AssistantConfig struct {
TextConfig TextConfig `json:"text_config"`
BackboneHiddenSize int32 `json:"backbone_hidden_size"`
UseOrderedEmbeddings bool `json:"use_ordered_embeddings"`
NumCentroids int32 `json:"num_centroids"`
CentroidIntermediateTopK int32 `json:"centroid_intermediate_top_k"`
}
type AssistantModel struct {
PreProjection nn.LinearLayer
PostProjection nn.LinearLayer
EmbedTokens nn.EmbeddingLayer
LMHead nn.LinearLayer
Centroids nn.LinearLayer
TokenOrdering *mlx.Array
Layers []*AssistantLayer
Norm *nn.RMSNorm
NormScaled *mlx.Array
*AssistantConfig
tensorPrefix string
QuantGroupSize int
QuantBits int
QuantMode string
TensorQuant map[string]*model.TensorQuantInfo
}
type AssistantLayer struct {
InputNorm *nn.RMSNorm
PostAttnNorm *nn.RMSNorm
PreFFNorm *nn.RMSNorm
PostFFNorm *nn.RMSNorm
InputNormScaled *mlx.Array
PostAttnNormScaled *mlx.Array
PreFFNormScaled *mlx.Array
PostFFNormScaled *mlx.Array
Attention *AssistantAttention
MLP *MLP
LayerScalar *mlx.Array
IsSliding bool
}
type AssistantAttention struct {
QProj nn.LinearLayer
OProj nn.LinearLayer
QNorm *nn.RMSNorm
QNormScaled *mlx.Array
}
func parseAssistantConfig(configData []byte) (AssistantConfig, error) {
var raw struct {
TextConfig json.RawMessage `json:"text_config"`
BackboneHiddenSize int32 `json:"backbone_hidden_size"`
UseOrderedEmbeddings bool `json:"use_ordered_embeddings"`
NumCentroids int32 `json:"num_centroids"`
CentroidIntermediateTopK int32 `json:"centroid_intermediate_top_k"`
}
if err := json.Unmarshal(configData, &raw); err != nil {
return AssistantConfig{}, fmt.Errorf("parse assistant config: %w", err)
}
if len(raw.TextConfig) == 0 {
return AssistantConfig{}, fmt.Errorf("assistant config missing text_config")
}
text, err := parseTextConfig(raw.TextConfig)
if err != nil {
return AssistantConfig{}, err
}
if raw.NumCentroids == 0 {
raw.NumCentroids = 2048
}
if raw.CentroidIntermediateTopK == 0 {
raw.CentroidIntermediateTopK = 32
}
return AssistantConfig{
TextConfig: text,
BackboneHiddenSize: raw.BackboneHiddenSize,
UseOrderedEmbeddings: raw.UseOrderedEmbeddings,
NumCentroids: raw.NumCentroids,
CentroidIntermediateTopK: raw.CentroidIntermediateTopK,
}, nil
}
func newAssistantModel(root *model.Root, target base.Model) (base.DraftModel, error) {
if root == nil || root.Draft == nil {
return nil, fmt.Errorf("draft metadata missing")
}
configPath := root.Draft.Config
if configPath == "" {
configPath = "draft/config.json"
}
configData, err := root.Manifest.ReadConfig(configPath)
if err != nil {
return nil, fmt.Errorf("load draft config: %w", err)
}
cfg, err := parseAssistantConfig(configData)
if err != nil {
return nil, err
}
targetGemma, ok := target.(*Model)
if !ok {
return nil, fmt.Errorf("gemma4 assistant requires gemma4 target, got %T", target)
}
if cfg.BackboneHiddenSize != 0 && cfg.BackboneHiddenSize != targetGemma.HiddenSize {
return nil, fmt.Errorf("assistant backbone hidden size %d does not match target hidden size %d", cfg.BackboneHiddenSize, targetGemma.HiddenSize)
}
if cfg.TextConfig.VocabSize != targetGemma.VocabSize {
return nil, fmt.Errorf("assistant vocab size %d does not match target vocab size %d", cfg.TextConfig.VocabSize, targetGemma.VocabSize)
}
tensorPrefix := root.Draft.TensorPrefix
if tensorPrefix == "" {
tensorPrefix = "draft."
}
m := &AssistantModel{
AssistantConfig: &cfg,
tensorPrefix: tensorPrefix,
Layers: make([]*AssistantLayer, cfg.TextConfig.NumHiddenLayers),
TensorQuant: root.AllTensorQuant(),
}
if qt := root.QuantType(); qt != "" {
m.QuantGroupSize, m.QuantBits, m.QuantMode = model.QuantizationParams(qt)
if gs := root.GroupSize(); gs > 0 {
m.QuantGroupSize = gs
}
}
for i := range m.Layers {
m.Layers[i] = &AssistantLayer{
IsSliding: isLayerSliding(int32(i), &m.TextConfig),
Attention: &AssistantAttention{},
MLP: &MLP{},
}
}
return m, nil
}
func (m *AssistantModel) LoadWeights(tensors map[string]*mlx.Array) error {
prefix := m.tensorPrefix
linears := model.NewLinearFactory(tensors, m.QuantGroupSize, m.QuantBits, m.QuantMode, m.TensorQuant)
m.PreProjection = linears.Make(prefix + "pre_projection")
m.PostProjection = linears.Make(prefix + "post_projection")
if m.PreProjection == nil || m.PostProjection == nil {
return fmt.Errorf("missing assistant projection weights")
}
m.EmbedTokens = model.MakeEmbeddingLayer(tensors, prefix+"model.embed_tokens", m.QuantGroupSize, m.QuantBits, m.QuantMode, m.TensorQuant)
if m.EmbedTokens == nil {
return fmt.Errorf("missing assistant embedding weight")
}
m.LMHead = m.EmbedTokens.AsLinear()
if m.UseOrderedEmbeddings {
m.Centroids = linears.Make(prefix + "masked_embedding.centroids")
m.TokenOrdering = tensors[prefix+"masked_embedding.token_ordering"]
if m.Centroids == nil || m.TokenOrdering == nil {
return fmt.Errorf("missing ordered embedding tensors: %smasked_embedding.centroids.weight and %smasked_embedding.token_ordering", prefix, prefix)
}
m.TokenOrdering = m.TokenOrdering.AsType(mlx.DTypeInt32)
}
normWeight := tensors[prefix+"model.norm.weight"]
if normWeight == nil {
return fmt.Errorf("missing assistant final norm")
}
m.Norm = nn.NewRMSNorm(normWeight, m.TextConfig.RMSNormEps)
for i := range m.TextConfig.NumHiddenLayers {
layerPrefix := fmt.Sprintf("%smodel.layers.%d", prefix, i)
layer := &AssistantLayer{
IsSliding: isLayerSliding(i, &m.TextConfig),
Attention: &AssistantAttention{
QProj: linears.Make(layerPrefix + ".self_attn.q_proj"),
OProj: linears.Make(layerPrefix + ".self_attn.o_proj"),
},
MLP: &MLP{
GateProj: linears.Make(layerPrefix + ".mlp.gate_proj"),
UpProj: linears.Make(layerPrefix + ".mlp.up_proj"),
DownProj: linears.Make(layerPrefix + ".mlp.down_proj"),
},
LayerScalar: tensors[layerPrefix+".layer_scalar"],
}
if w := tensors[layerPrefix+".input_layernorm.weight"]; w != nil {
layer.InputNorm = nn.NewRMSNorm(w, m.TextConfig.RMSNormEps)
}
if w := tensors[layerPrefix+".post_attention_layernorm.weight"]; w != nil {
layer.PostAttnNorm = nn.NewRMSNorm(w, m.TextConfig.RMSNormEps)
}
if w := tensors[layerPrefix+".pre_feedforward_layernorm.weight"]; w != nil {
layer.PreFFNorm = nn.NewRMSNorm(w, m.TextConfig.RMSNormEps)
}
if w := tensors[layerPrefix+".post_feedforward_layernorm.weight"]; w != nil {
layer.PostFFNorm = nn.NewRMSNorm(w, m.TextConfig.RMSNormEps)
}
if w := tensors[layerPrefix+".self_attn.q_norm.weight"]; w != nil {
layer.Attention.QNorm = nn.NewRMSNorm(w, m.TextConfig.RMSNormEps)
}
if layer.InputNorm == nil || layer.PostAttnNorm == nil || layer.PreFFNorm == nil || layer.PostFFNorm == nil {
return fmt.Errorf("assistant layer %d: missing norm weights", i)
}
if layer.Attention.QProj == nil || layer.Attention.OProj == nil || layer.Attention.QNorm == nil {
return fmt.Errorf("assistant layer %d: missing attention weights", i)
}
if layer.MLP.GateProj == nil || layer.MLP.UpProj == nil || layer.MLP.DownProj == nil {
return fmt.Errorf("assistant layer %d: missing mlp weights", i)
}
m.Layers[i] = layer
}
m.precomputeScaledWeights()
return nil
}
func (m *AssistantModel) precomputeScaledWeights() {
if m.Norm != nil {
m.NormScaled = m.Norm.Weight
}
for _, layer := range m.Layers {
if layer.InputNorm != nil {
layer.InputNormScaled = layer.InputNorm.Weight
}
if layer.PostAttnNorm != nil {
layer.PostAttnNormScaled = layer.PostAttnNorm.Weight
}
if layer.PreFFNorm != nil {
layer.PreFFNormScaled = layer.PreFFNorm.Weight
}
if layer.PostFFNorm != nil {
layer.PostFFNormScaled = layer.PostFFNorm.Weight
}
if layer.Attention != nil && layer.Attention.QNorm != nil {
layer.Attention.QNormScaled = layer.Attention.QNorm.Weight
}
}
}
func (m *AssistantModel) Draft(inputsEmbeds *mlx.Array, position int32, caches []cache.Cache) (logits, hidden *mlx.Array) {
dims := inputsEmbeds.Dims()
B, L := int32(dims[0]), int32(dims[1])
b := &batch.Batch{
InputIDs: mlx.Zeros(mlx.DTypeInt32, int(B), int(L)),
SeqOffsets: []int32{position},
SeqQueryLens: []int32{L},
}
sliding, full := m.sharedHistories(b, caches)
h := m.PreProjection.Forward(inputsEmbeds)
positions := mlx.FromValues([]int32{position}, 1)
for _, layer := range m.Layers {
h = layer.Forward(h, b, positions, B, L, &m.TextConfig, sliding, full)
}
hidden = mlx.RMSNormFn(h, m.NormScaled, m.TextConfig.RMSNormEps)
projected := m.PostProjection.Forward(hidden)
return m.unembed(hidden), projected
}
func (m *AssistantModel) sharedHistories(b *batch.Batch, caches []cache.Cache) (sliding, full *nn.KVHistory) {
if len(caches) < 2 {
return nil, nil
}
if v, ok := caches[len(caches)-2].(cache.Viewer); ok {
sliding = v.View(b)
}
if v, ok := caches[len(caches)-1].(cache.Viewer); ok {
full = v.View(b)
}
return sliding, full
}
func (m *AssistantModel) unembed(hidden *mlx.Array) *mlx.Array {
if m.UseOrderedEmbeddings {
return m.applyCentroidMasking(hidden)
}
return m.LMHead.Forward(hidden)
}
func (m *AssistantModel) applyCentroidMasking(hidden *mlx.Array) *mlx.Array {
B, L := hidden.Dim(0), hidden.Dim(1)
vocab := int(m.TextConfig.VocabSize)
numCentroids := int(m.NumCentroids)
vocabPerCentroid := vocab / numCentroids
topK := int(m.CentroidIntermediateTopK)
centroidLogits := m.Centroids.Forward(hidden)
topKIndices := centroidLogits.Negative().ArgpartitionAxis(topK-1, -1).Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, topK))
ordering := m.TokenOrdering.Reshape(numCentroids, vocabPerCentroid)
selectedCanonical := ordering.TakeAxis(topKIndices, 0)
selectedFlat := selectedCanonical.Reshape(B * L * topK * vocabPerCentroid)
embeddings := m.EmbedTokens.Forward(selectedFlat)
embeddings = embeddings.Reshape(B, L, topK*vocabPerCentroid, int(m.TextConfig.HiddenSize))
selectedLogits := hidden.ExpandDims(2).Matmul(embeddings.Transpose(0, 1, 3, 2)).Squeeze(2)
out := mlx.Zeros(selectedLogits.DType(), B, L, vocab)
out = mlx.AddScalar(out, -1.0e30)
return out.PutAlongAxis(selectedCanonical.Reshape(B, L, topK*vocabPerCentroid), selectedLogits, -1)
}
func (l *AssistantLayer) Forward(x *mlx.Array, b *batch.Batch, positions *mlx.Array, B, L int32, cfg *TextConfig, sliding, full *nn.KVHistory) *mlx.Array {
normed := mlx.RMSNormFn(x, l.InputNormScaled, cfg.RMSNormEps)
attnOut := l.Attention.Forward(normed, b, positions, B, L, l.IsSliding, cfg, sliding, full)
attnOut = mlx.RMSNormFn(attnOut, l.PostAttnNormScaled, cfg.RMSNormEps)
h := mlx.Add(x, attnOut)
normed = mlx.RMSNormFn(h, l.PreFFNormScaled, cfg.RMSNormEps)
mlpOut := l.MLP.Forward(normed)
mlpOut = mlx.RMSNormFn(mlpOut, l.PostFFNormScaled, cfg.RMSNormEps)
h = mlx.Add(h, mlpOut)
if l.LayerScalar != nil {
h = mlx.Mul(h, l.LayerScalar)
}
return h
}
func (a *AssistantAttention) Forward(x *mlx.Array, b *batch.Batch, positions *mlx.Array, B, L int32, isSliding bool, cfg *TextConfig, sliding, full *nn.KVHistory) *mlx.Array {
headDim := cfg.HeadDim
scale := cfg.SlidingScale
ropeDims := cfg.SlidingRopeDims
ropeBase := cfg.SlidingRopeBase
history := sliding
if !isSliding {
headDim = cfg.GlobalHeadDim
scale = cfg.FullScale
ropeDims = cfg.FullRopeDims
ropeBase = cfg.FullRopeBase
history = full
}
if history == nil {
panic("gemma4 assistant missing shared target KV history")
}
q := a.QProj.Forward(x)
q = mlx.Reshape(q, B, L, cfg.NumAttentionHeads, headDim)
q = mlx.Transpose(q, 0, 2, 1, 3)
q = mlx.RMSNormFn(q, a.QNormScaled, cfg.RMSNormEps)
var ropeFreqs *mlx.Array
if !isSliding {
ropeFreqs = cfg.FullRopeFreqs
}
q = mlx.RoPEWithFreqs(q, ropeDims, false, ropeBase, 1.0, positions, ropeFreqs)
mask := nn.CausalMask()
if isSliding && cfg.SlidingWindow > 0 {
mask = mask.Intersect(nn.SlidingWindowMask(b, history.K().Dim(2), int(cfg.SlidingWindow), q.DType()))
}
out := nn.ScaledDotProductAttention(b, q, scale, nn.WithKVHistory(history), nn.WithMask(mask))
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*headDim)
if !mlx.MetalIsAvailable() {
out = mlx.Contiguous(out, false)
}
return a.OProj.Forward(out)
}

1471
x/models/gemma4/gemma4.go Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,230 @@
package gemma4
import (
"testing"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
// onesLike creates a tensor of the given shape filled with a small constant.
func onesLike(shape ...int) *mlx.Array {
return mlx.AddScalar(mlx.Zeros(mlx.DTypeBFloat16, shape...), 0.01)
}
func TestMoEForward(t *testing.T) {
skipIfNoMLX(t)
// Small config matching 26b architecture pattern.
cfg := &TextConfig{
HiddenSize: 16, // tiny for testing
NumAttentionHeads: 2,
NumKeyValueHeads: 1,
NumGlobalKeyValueHeads: 1,
HeadDim: 8,
GlobalHeadDim: 8,
NumExperts: 4,
TopKExperts: 2,
ExpertIntermediateSize: 8,
EnableMoeBlock: true,
AttentionKEqV: false,
RMSNormEps: 1e-6,
SlidingScale: 1.0,
FullScale: 1.0,
}
B, L := int32(1), int32(3)
x := onesLike(int(B), int(L), int(cfg.HiddenSize))
// Test Router.Forward.
router := &Router{
Proj: linearFromWeight(onesLike(int(cfg.NumExperts), int(cfg.HiddenSize))),
Scale: onesLike(int(cfg.HiddenSize)),
}
t.Run("Router", func(t *testing.T) {
scores, inds := router.Forward(x, cfg)
mlx.Eval(scores, inds)
sDims := scores.Dims()
iDims := inds.Dims()
t.Logf("scores shape: %v, inds shape: %v", sDims, iDims)
if len(sDims) != 2 || sDims[0] != int(B*L) || sDims[1] != int(cfg.TopKExperts) {
t.Errorf("scores shape = %v, want [%d, %d]", sDims, B*L, cfg.TopKExperts)
}
if len(iDims) != 2 || iDims[0] != int(B*L) || iDims[1] != int(cfg.TopKExperts) {
t.Errorf("inds shape = %v, want [%d, %d]", iDims, B*L, cfg.TopKExperts)
}
})
// Test MoEBlock.Forward.
moe := &MoEBlock{
GateWeight: onesLike(int(cfg.NumExperts), int(cfg.HiddenSize), int(cfg.ExpertIntermediateSize)),
UpWeight: onesLike(int(cfg.NumExperts), int(cfg.HiddenSize), int(cfg.ExpertIntermediateSize)),
DownWeight: onesLike(int(cfg.NumExperts), int(cfg.ExpertIntermediateSize), int(cfg.HiddenSize)),
PerExpertScale: onesLike(int(cfg.NumExperts)),
}
t.Run("MoEBlock", func(t *testing.T) {
scores, inds := router.Forward(x, cfg)
mlx.Eval(scores, inds)
out := moe.Forward(x, scores, inds, cfg)
mlx.Eval(out)
outDims := out.Dims()
t.Logf("MoE output shape: %v", outDims)
if len(outDims) != 3 || outDims[0] != int(B) || outDims[1] != int(L) || outDims[2] != int(cfg.HiddenSize) {
t.Errorf("output shape = %v, want [%d, %d, %d]", outDims, B, L, cfg.HiddenSize)
}
})
// Test with larger batch to exercise the sorted GatherMM path (B*L >= 64).
t.Run("MoEBlock_sorted", func(t *testing.T) {
bigB, bigL := int32(1), int32(128)
bigX := onesLike(int(bigB), int(bigL), int(cfg.HiddenSize))
scores, inds := router.Forward(bigX, cfg)
mlx.Eval(scores, inds)
out := moe.Forward(bigX, scores, inds, cfg)
mlx.Eval(out)
outDims := out.Dims()
t.Logf("MoE sorted output shape: %v", outDims)
if len(outDims) != 3 || outDims[0] != int(bigB) || outDims[1] != int(bigL) || outDims[2] != int(cfg.HiddenSize) {
t.Errorf("output shape = %v, want [%d, %d, %d]", outDims, bigB, bigL, cfg.HiddenSize)
}
})
}
// TestRouterForwardMatchesLegacy verifies the optimized Router.Forward —
// which takes the top-k of the raw logits and softmaxes only the selected
// values — produces the same indices and (within tolerance) the same
// normalized scores as the legacy path that softmaxes over every expert
// first, gathers the top-k probabilities, then renormalizes.
func TestRouterForwardMatchesLegacy(t *testing.T) {
skipIfNoMLX(t)
cfg := &TextConfig{
HiddenSize: 8,
NumExperts: 4,
TopKExperts: 2,
RMSNormEps: 1e-6,
RouterScale: 0.5,
}
// Distinct per-expert weight rows so top-k has a well-defined ordering
// (tied scores would let argpartition pick either tied expert and make
// the index comparison below flaky).
projWeight := mlx.FromValues([]float32{
0.10, 0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, // expert 0
0.30, 0.29, 0.28, 0.27, 0.26, 0.25, 0.24, 0.23, // expert 1
-0.05, -0.06, -0.07, -0.08, -0.09, -0.10, -0.11, -0.12, // expert 2
0.50, 0.48, 0.46, 0.44, 0.42, 0.40, 0.38, 0.36, // expert 3
}, int(cfg.NumExperts), int(cfg.HiddenSize))
scale := mlx.FromValues([]float32{
1.0, 0.9, 1.1, 1.0, 1.2, 0.8, 1.0, 1.05,
}, int(cfg.HiddenSize))
r := &Router{
Proj: linearFromWeight(projWeight),
Scale: scale,
}
// Varied x so different positions potentially hit different top-k.
x := mlx.FromValues([]float32{
0.2, -0.1, 0.3, 0.0, 0.4, -0.2, 0.1, 0.05,
-0.3, 0.2, -0.1, 0.4, -0.05, 0.3, 0.0, 0.2,
0.5, 0.4, -0.2, 0.1, -0.3, 0.0, 0.3, -0.1,
}, 1, 3, int(cfg.HiddenSize))
gotScores, gotInds := r.Forward(x, cfg)
wantScores, wantInds := legacyRouterForward(r, x, cfg)
gotInds = gotInds.AsType(mlx.DTypeInt32)
wantInds = wantInds.AsType(mlx.DTypeInt32)
mlx.Eval(gotScores, gotInds, wantScores, wantInds)
if got, want := gotInds.Ints(), wantInds.Ints(); !intSlicesEqual(got, want) {
t.Fatalf("indices mismatch:\n got %v\n want %v", got, want)
}
if got, want := gotScores.Floats(), wantScores.Floats(); !floatSlicesClose(got, want, 1e-5) {
t.Fatalf("scores mismatch:\n got %v\n want %v", got, want)
}
}
// legacyRouterForward implements the pre-optimization router: full softmax
// over every expert, gather the top-k probabilities, then renormalize them
// to sum to 1. Algebraically identical to the fused form in Router.Forward.
func legacyRouterForward(r *Router, x *mlx.Array, cfg *TextConfig) (*mlx.Array, *mlx.Array) {
dims := x.Dims()
BL := int32(dims[0]) * int32(dims[1])
xFlat := mlx.Reshape(x, BL, cfg.HiddenSize)
normed := mlx.RMSNormFn(xFlat, nil, cfg.RMSNormEps)
normed = mlx.MulScalar(normed, cfg.RouterScale)
normed = mlx.Mul(normed, r.Scale)
expertScores := r.Proj.Forward(normed)
probs := mlx.SoftmaxAxis(expertScores, -1, true)
neg := mlx.Neg(expertScores)
inds := mlx.Argpartition(neg, int(cfg.TopKExperts)-1, -1)
inds = mlx.SliceStartStop(inds,
[]int32{0, 0},
[]int32{BL, cfg.TopKExperts},
)
scores := mlx.TakeAlongAxis(probs, inds, -1)
sumScores := mlx.Sum(scores, -1, true)
scores = mlx.Div(scores, sumScores)
return scores, inds
}
func intSlicesEqual(a, b []int) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}
func floatSlicesClose(a, b []float32, tol float32) bool {
if len(a) != len(b) {
return false
}
for i := range a {
d := a[i] - b[i]
if d < 0 {
d = -d
}
if d > tol {
return false
}
}
return true
}
// linearFromWeight creates a simple nn.LinearLayer from a weight tensor (no bias).
func linearFromWeight(w *mlx.Array) *simpleLinear {
return &simpleLinear{weight: w}
}
type simpleLinear struct {
weight *mlx.Array
}
func (l *simpleLinear) Forward(x *mlx.Array) *mlx.Array {
return x.Matmul(mlx.Transpose(l.weight, 1, 0))
}
func (l *simpleLinear) OutputDim() int32 {
return int32(l.weight.Dims()[0])
}

View File

@@ -0,0 +1,600 @@
package gemma4
import (
"testing"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
func TestParseTextConfigE2B(t *testing.T) {
skipIfNoMLX(t)
data := []byte(`{
"architectures": ["Gemma4ForConditionalGeneration"],
"text_config": {
"hidden_size": 1536,
"num_hidden_layers": 35,
"intermediate_size": 6144,
"num_attention_heads": 8,
"num_key_value_heads": 1,
"head_dim": 256,
"global_head_dim": 512,
"vocab_size": 262144,
"rms_norm_eps": 1e-6,
"max_position_embeddings": 131072,
"sliding_window": 512,
"sliding_window_pattern": 5,
"final_logit_softcapping": 30.0,
"use_double_wide_mlp": true,
"num_kv_shared_layers": 20,
"hidden_size_per_layer_input": 256,
"vocab_size_per_layer_input": 262144,
"attention_k_eq_v": false,
"tie_word_embeddings": true,
"layer_types": [
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention"
],
"rope_parameters": {
"full_attention": {
"partial_rotary_factor": 0.25,
"rope_theta": 1000000.0,
"rope_type": "proportional"
},
"sliding_attention": {
"rope_theta": 10000.0,
"rope_type": "default"
}
}
}
}`)
cfg, err := parseTextConfig(data)
if err != nil {
t.Fatalf("parseTextConfig failed: %v", err)
}
// Basic fields.
if cfg.HiddenSize != 1536 {
t.Errorf("HiddenSize = %d, want 1536", cfg.HiddenSize)
}
if cfg.NumHiddenLayers != 35 {
t.Errorf("NumHiddenLayers = %d, want 35", cfg.NumHiddenLayers)
}
if cfg.GlobalHeadDim != 512 {
t.Errorf("GlobalHeadDim = %d, want 512", cfg.GlobalHeadDim)
}
if cfg.FinalLogitSoftcapping != 30.0 {
t.Errorf("FinalLogitSoftcapping = %f, want 30.0", cfg.FinalLogitSoftcapping)
}
if cfg.NumKVSharedLayers != 20 {
t.Errorf("NumKVSharedLayers = %d, want 20", cfg.NumKVSharedLayers)
}
if cfg.HiddenSizePerLayer != 256 {
t.Errorf("HiddenSizePerLayer = %d, want 256", cfg.HiddenSizePerLayer)
}
// RoPE settings.
if cfg.SlidingRopeDims != 256 {
t.Errorf("SlidingRopeDims = %d, want 256", cfg.SlidingRopeDims)
}
if cfg.FullRopeDims != 512 {
t.Errorf("FullRopeDims = %d, want 512 (GlobalHeadDim, partial rotation handled via custom freqs)", cfg.FullRopeDims)
}
if cfg.SlidingRopeBase != 10000 {
t.Errorf("SlidingRopeBase = %f, want 10000", cfg.SlidingRopeBase)
}
if cfg.FullRopeBase != 1000000 {
t.Errorf("FullRopeBase = %f, want 1000000", cfg.FullRopeBase)
}
// Attention scale.
if cfg.SlidingScale == 0 || cfg.FullScale == 0 {
t.Error("attention scales should be non-zero")
}
// KV sharing map.
// First shared layer is 35 - 20 = 15.
if donor, ok := cfg.KVShareMap[15]; !ok || donor != 13 {
t.Errorf("KVShareMap[15] = %d, ok=%v; want 13, true", donor, ok)
}
if donor, ok := cfg.KVShareMap[19]; !ok || donor != 14 {
t.Errorf("KVShareMap[19] = %d, ok=%v; want 14, true (full attn donor)", donor, ok)
}
if donor, ok := cfg.KVShareMap[34]; !ok || donor != 14 {
t.Errorf("KVShareMap[34] = %d, ok=%v; want 14, true (full attn donor)", donor, ok)
}
// Layer 14 should not be shared.
if _, ok := cfg.KVShareMap[14]; ok {
t.Error("layer 14 should not be in KVShareMap (non-shared)")
}
// Donors.
if !cfg.KVDonors[13] {
t.Error("layer 13 should be a KV donor")
}
if !cfg.KVDonors[14] {
t.Error("layer 14 should be a KV donor")
}
}
func TestParseTextConfig26B(t *testing.T) {
skipIfNoMLX(t)
data := []byte(`{
"architectures": ["Gemma4ForConditionalGeneration"],
"text_config": {
"hidden_size": 2816,
"num_hidden_layers": 30,
"intermediate_size": 2112,
"num_attention_heads": 16,
"num_key_value_heads": 8,
"num_global_key_value_heads": 2,
"head_dim": 256,
"global_head_dim": 512,
"vocab_size": 262144,
"rms_norm_eps": 1e-6,
"max_position_embeddings": 131072,
"sliding_window": 1024,
"final_logit_softcapping": 30.0,
"use_double_wide_mlp": false,
"num_kv_shared_layers": 0,
"hidden_size_per_layer_input": null,
"attention_k_eq_v": true,
"enable_moe_block": true,
"num_experts": 128,
"top_k_experts": 8,
"moe_intermediate_size": 704,
"tie_word_embeddings": true,
"layer_types": [
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention"
],
"rope_parameters": {
"full_attention": {
"partial_rotary_factor": 0.25,
"rope_theta": 1000000.0,
"rope_type": "proportional"
},
"sliding_attention": {
"rope_theta": 10000.0,
"rope_type": "default"
}
}
}
}`)
cfg, err := parseTextConfig(data)
if err != nil {
t.Fatalf("parseTextConfig failed: %v", err)
}
if cfg.HiddenSize != 2816 {
t.Errorf("HiddenSize = %d, want 2816", cfg.HiddenSize)
}
if !cfg.AttentionKEqV {
t.Error("AttentionKEqV should be true")
}
if cfg.NumGlobalKeyValueHeads != 2 {
t.Errorf("NumGlobalKeyValueHeads = %d, want 2", cfg.NumGlobalKeyValueHeads)
}
if !cfg.EnableMoeBlock {
t.Error("EnableMoeBlock should be true")
}
if cfg.NumExperts != 128 {
t.Errorf("NumExperts = %d, want 128", cfg.NumExperts)
}
if cfg.TopKExperts != 8 {
t.Errorf("TopKExperts = %d, want 8", cfg.TopKExperts)
}
if cfg.ExpertIntermediateSize != 704 {
t.Errorf("ExpertIntermediateSize = %d, want 704", cfg.ExpertIntermediateSize)
}
if cfg.HiddenSizePerLayer != 0 {
t.Errorf("HiddenSizePerLayer = %d, want 0 (no PLE)", cfg.HiddenSizePerLayer)
}
}
func TestParseTextConfig31B(t *testing.T) {
skipIfNoMLX(t)
data := []byte(`{
"architectures": ["Gemma4ForConditionalGeneration"],
"text_config": {
"hidden_size": 5376,
"num_hidden_layers": 60,
"intermediate_size": 21504,
"num_attention_heads": 32,
"num_key_value_heads": 16,
"num_global_key_value_heads": 4,
"head_dim": 256,
"global_head_dim": 512,
"vocab_size": 262144,
"rms_norm_eps": 1e-6,
"max_position_embeddings": 131072,
"sliding_window": 1024,
"final_logit_softcapping": 30.0,
"use_double_wide_mlp": false,
"num_kv_shared_layers": 0,
"hidden_size_per_layer_input": null,
"attention_k_eq_v": true,
"tie_word_embeddings": true,
"layer_types": [
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention"
],
"rope_parameters": {
"full_attention": {
"partial_rotary_factor": 0.25,
"rope_theta": 1000000.0,
"rope_type": "proportional"
},
"sliding_attention": {
"rope_theta": 10000.0,
"rope_type": "default"
}
}
}
}`)
cfg, err := parseTextConfig(data)
if err != nil {
t.Fatalf("parseTextConfig failed: %v", err)
}
if cfg.HiddenSize != 5376 {
t.Errorf("HiddenSize = %d, want 5376", cfg.HiddenSize)
}
if cfg.NumHiddenLayers != 60 {
t.Errorf("NumHiddenLayers = %d, want 60", cfg.NumHiddenLayers)
}
if !cfg.AttentionKEqV {
t.Error("AttentionKEqV should be true")
}
if cfg.NumGlobalKeyValueHeads != 4 {
t.Errorf("NumGlobalKeyValueHeads = %d, want 4", cfg.NumGlobalKeyValueHeads)
}
if cfg.NumKeyValueHeads != 16 {
t.Errorf("NumKeyValueHeads = %d, want 16", cfg.NumKeyValueHeads)
}
if cfg.NumKVSharedLayers != 0 {
t.Errorf("NumKVSharedLayers = %d, want 0", cfg.NumKVSharedLayers)
}
if cfg.HiddenSizePerLayer != 0 {
t.Errorf("HiddenSizePerLayer = %d, want 0 (no PLE)", cfg.HiddenSizePerLayer)
}
if cfg.SlidingWindow != 1024 {
t.Errorf("SlidingWindow = %d, want 1024", cfg.SlidingWindow)
}
// KV sharing should be empty (no shared layers).
if len(cfg.KVShareMap) != 0 {
t.Errorf("KVShareMap should be empty, got %d entries", len(cfg.KVShareMap))
}
// Layer types: pattern is 5 sliding + 1 full, repeating 10 times.
if !isLayerSliding(0, &cfg) {
t.Error("layer 0 should be sliding")
}
if isLayerSliding(5, &cfg) {
t.Error("layer 5 should be full attention")
}
if !isLayerSliding(6, &cfg) {
t.Error("layer 6 should be sliding")
}
if isLayerSliding(59, &cfg) {
t.Error("layer 59 should be full attention")
}
}
func TestParseTextConfigE4B(t *testing.T) {
skipIfNoMLX(t)
data := []byte(`{
"architectures": ["Gemma4ForConditionalGeneration"],
"text_config": {
"hidden_size": 2560,
"num_hidden_layers": 42,
"intermediate_size": 10240,
"num_attention_heads": 8,
"num_key_value_heads": 2,
"head_dim": 256,
"global_head_dim": 512,
"vocab_size": 262144,
"rms_norm_eps": 1e-6,
"max_position_embeddings": 131072,
"sliding_window": 512,
"final_logit_softcapping": 30.0,
"use_double_wide_mlp": false,
"num_kv_shared_layers": 18,
"hidden_size_per_layer_input": 256,
"vocab_size_per_layer_input": 262144,
"attention_k_eq_v": false,
"enable_moe_block": false,
"tie_word_embeddings": true,
"layer_types": [
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention",
"sliding_attention","sliding_attention","sliding_attention","sliding_attention","sliding_attention","full_attention"
],
"rope_parameters": {
"full_attention": {
"partial_rotary_factor": 0.25,
"rope_theta": 1000000.0,
"rope_type": "proportional"
},
"sliding_attention": {
"rope_theta": 10000.0,
"rope_type": "default"
}
}
}
}`)
cfg, err := parseTextConfig(data)
if err != nil {
t.Fatalf("parseTextConfig failed: %v", err)
}
if cfg.HiddenSize != 2560 {
t.Errorf("HiddenSize = %d, want 2560", cfg.HiddenSize)
}
if cfg.NumHiddenLayers != 42 {
t.Errorf("NumHiddenLayers = %d, want 42", cfg.NumHiddenLayers)
}
if cfg.IntermediateSize != 10240 {
t.Errorf("IntermediateSize = %d, want 10240", cfg.IntermediateSize)
}
if cfg.NumKeyValueHeads != 2 {
t.Errorf("NumKeyValueHeads = %d, want 2", cfg.NumKeyValueHeads)
}
if cfg.UseDoubleWideMLP {
t.Error("UseDoubleWideMLP should be false")
}
if cfg.NumKVSharedLayers != 18 {
t.Errorf("NumKVSharedLayers = %d, want 18", cfg.NumKVSharedLayers)
}
if cfg.HiddenSizePerLayer != 256 {
t.Errorf("HiddenSizePerLayer = %d, want 256 (has PLE)", cfg.HiddenSizePerLayer)
}
if cfg.AttentionKEqV {
t.Error("AttentionKEqV should be false")
}
if cfg.EnableMoeBlock {
t.Error("EnableMoeBlock should be false")
}
if cfg.SlidingWindow != 512 {
t.Errorf("SlidingWindow = %d, want 512", cfg.SlidingWindow)
}
// Layer types: pattern is 5 sliding + 1 full, repeating 7 times = 42 layers.
if !isLayerSliding(0, &cfg) {
t.Error("layer 0 should be sliding")
}
if isLayerSliding(5, &cfg) {
t.Error("layer 5 should be full attention")
}
if !isLayerSliding(6, &cfg) {
t.Error("layer 6 should be sliding")
}
if isLayerSliding(41, &cfg) {
t.Error("layer 41 should be full attention")
}
// KV sharing: first shared = 42 - 18 = 24.
// Layer 24 is sliding, its donor should be the last non-shared sliding layer.
// Non-shared layers: 0-23. Last sliding in 0-23 is layer 22 (23=full).
if donor, ok := cfg.KVShareMap[24]; !ok {
t.Error("layer 24 should be in KVShareMap")
} else {
t.Logf("layer 24 donor = %d", donor)
}
// Layer 29 is full_attention (5th full), donor should be the last non-shared full layer.
// Non-shared full layers: 5, 11, 17, 23.
if donor, ok := cfg.KVShareMap[29]; !ok || donor != 23 {
t.Errorf("KVShareMap[29] = %d, ok=%v; want 23, true (full attn donor)", donor, ok)
}
// Layer 23 should NOT be shared (it's the last non-shared layer).
if _, ok := cfg.KVShareMap[23]; ok {
t.Error("layer 23 should not be in KVShareMap (non-shared)")
}
}
func TestLayerTypeDetection(t *testing.T) {
cfg := &TextConfig{
LayerTypes: []string{
"sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", "full_attention",
},
}
if !isLayerSliding(0, cfg) {
t.Error("layer 0 should be sliding")
}
if !isLayerSliding(3, cfg) {
t.Error("layer 3 should be sliding")
}
if isLayerSliding(4, cfg) {
t.Error("layer 4 should be full attention")
}
}
func TestMTPDraftDefaults(t *testing.T) {
tests := []struct {
name string
cfg *TextConfig
wantInitial int
wantMax int
}{
{
name: "nil config",
wantInitial: 4,
wantMax: 16,
},
{
name: "31b bf16",
cfg: &TextConfig{HiddenSize: 5376, NumHiddenLayers: 60},
wantInitial: 14,
wantMax: 16,
},
{
name: "31b quantized",
cfg: &TextConfig{HiddenSize: 5376, NumHiddenLayers: 60, QuantBits: 4},
wantInitial: 14,
wantMax: 16,
},
{
name: "26b-a4b moe",
cfg: &TextConfig{HiddenSize: 2816, NumHiddenLayers: 30, EnableMoeBlock: true},
wantInitial: 8,
wantMax: 16,
},
{
name: "generic default",
cfg: &TextConfig{HiddenSize: 2560, NumHiddenLayers: 42, HiddenSizePerLayer: 256},
wantInitial: 4,
wantMax: 16,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := (&Model{TextConfig: tt.cfg}).MTPDraftDefaults(false)
if got.InitialDraftTokens != tt.wantInitial || got.MaxDraftTokens != tt.wantMax || !got.Enabled {
t.Fatalf("MTPDraftDefaults() = %+v, want initial=%d max=%d enabled=true", got, tt.wantInitial, tt.wantMax)
}
})
}
}
func TestNewCachesOmitsSharedKVLayers(t *testing.T) {
m := &Model{
Layers: []*DecoderLayer{
{IsSliding: true, KVShareDonor: -1},
{IsSliding: false, KVShareDonor: -1},
{IsSliding: true, KVShareDonor: 0},
{IsSliding: false, KVShareDonor: 1},
},
TextConfig: &TextConfig{SlidingWindow: 512},
}
caches := m.NewCaches()
if got, want := len(caches), 2; got != want {
t.Fatalf("len(NewCaches()) = %d, want %d", got, want)
}
}
func TestNewCachesIncludesAllNonSharedLayers(t *testing.T) {
m := &Model{
Layers: []*DecoderLayer{
{IsSliding: true, KVShareDonor: -1},
{IsSliding: false, KVShareDonor: -1},
{IsSliding: true, KVShareDonor: -1},
},
TextConfig: &TextConfig{SlidingWindow: 512},
}
caches := m.NewCaches()
if got, want := len(caches), len(m.Layers); got != want {
t.Fatalf("len(NewCaches()) = %d, want %d", got, want)
}
}
func TestNewCachesAssistantSharedHistoryOrdering(t *testing.T) {
cases := []struct {
name string
totalLayers int
slidingBeforeFull int
cacheLayers int
}{
{name: "31B", totalLayers: 60, slidingBeforeFull: 5, cacheLayers: 60},
{name: "26B-A4B", totalLayers: 30, slidingBeforeFull: 5, cacheLayers: 30},
{name: "E4B", totalLayers: 42, slidingBeforeFull: 5, cacheLayers: 24},
{name: "E2B", totalLayers: 35, slidingBeforeFull: 4, cacheLayers: 15},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
groupSize := tc.slidingBeforeFull + 1
layers := make([]*DecoderLayer, tc.totalLayers)
for i := range layers {
donor := int32(-1)
if i >= tc.cacheLayers {
donor = 0
}
layers[i] = &DecoderLayer{
IsSliding: i%groupSize < tc.slidingBeforeFull,
KVShareDonor: donor,
}
}
m := &Model{
Layers: layers,
TextConfig: &TextConfig{SlidingWindow: 512},
}
caches := m.NewCaches()
if got := len(caches); got != tc.cacheLayers {
t.Fatalf("len(NewCaches()) = %d, want %d", got, tc.cacheLayers)
}
gotSliding := len(caches) - 2
gotFull := len(caches) - 1
if !m.Layers[gotSliding].IsSliding {
t.Fatalf("cache %d should be sliding attention", gotSliding)
}
if m.Layers[gotFull].IsSliding {
t.Fatalf("cache %d should be full attention", gotFull)
}
})
}
}
func TestResolveWeightPrefix(t *testing.T) {
if err := mlx.CheckInit(); err != nil {
t.Skipf("MLX not available: %v", err)
}
tests := []struct {
name string
key string
wantPfx string
}{
{"bare", "embed_tokens.weight", ""},
{"language_model", "model.language_model.embed_tokens.weight", "model.language_model."},
{"with_model", "model.embed_tokens.weight", "model."},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
dummy := mlx.FromValue(float32(1.0))
mlx.Eval(dummy)
tensors := map[string]*mlx.Array{tt.key: dummy}
got := resolveWeightPrefix(tensors)
if got != tt.wantPfx {
t.Errorf("resolveWeightPrefix(%q) = %q, want %q", tt.key, got, tt.wantPfx)
}
})
}
}
func skipIfNoMLX(t *testing.T) {
t.Helper()
if err := mlx.CheckInit(); err != nil {
t.Skipf("MLX not available: %v", err)
}
}

View File

@@ -0,0 +1,770 @@
// Package glm4_moe_lite provides the GLM4-MoE-Lite implementation for MLX.
// This model uses Multi-head Latent Attention (MLA) and Mixture of Experts (MoE).
package glm4_moe_lite
import (
"encoding/json"
"fmt"
"math"
"github.com/ollama/ollama/x/mlxrunner/batch"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
"github.com/ollama/ollama/x/mlxrunner/model/base"
"github.com/ollama/ollama/x/models/nn"
"github.com/ollama/ollama/x/tokenizer"
)
func init() {
base.Register("Glm4MoeLiteForCausalLM", newModel)
base.Register("GLM4MoeLite", newModel)
}
// RopeScaling holds RoPE scaling configuration
type RopeScaling struct {
Factor float32 `json:"factor"`
MscaleAllDim float32 `json:"mscale_all_dim"`
}
// Config holds GLM4-MoE-Lite model configuration
type Config struct {
HiddenSize int32 `json:"hidden_size"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
IntermediateSize int32 `json:"intermediate_size"`
MoEIntermediateSize int32 `json:"moe_intermediate_size"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumKeyValueHeads int32 `json:"num_key_value_heads"`
VocabSize int32 `json:"vocab_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
AttentionBias bool `json:"attention_bias"`
// MLA (Multi-head Latent Attention) parameters
QLoraRank int32 `json:"q_lora_rank"`
KVLoraRank int32 `json:"kv_lora_rank"`
QKRopeHeadDim int32 `json:"qk_rope_head_dim"`
QKNopeHeadDim int32 `json:"qk_nope_head_dim"`
VHeadDim int32 `json:"v_head_dim"`
// MoE parameters
NRoutedExperts int32 `json:"n_routed_experts"`
NSharedExperts int32 `json:"n_shared_experts"`
NumExpertsPerTok int32 `json:"num_experts_per_tok"`
RoutedScalingFactor float32 `json:"routed_scaling_factor"`
NormTopKProb bool `json:"norm_topk_prob"`
FirstKDenseReplace int32 `json:"first_k_dense_replace"`
NGroup int32 `json:"n_group"`
TopKGroup int32 `json:"topk_group"`
// RoPE scaling
RopeScaling *RopeScaling `json:"rope_scaling"`
// Quantization parameters (set during load based on model quantization)
QuantGroupSize int `json:"-"` // Group size for quantization (default 64)
QuantBits int `json:"-"` // Bits per weight (4 or 8)
QuantMode string `json:"-"` // Quantization mode ("affine", etc.)
TensorQuant map[string]*model.TensorQuantInfo `json:"-"`
// Computed fields
QHeadDim int32 `json:"-"` // qk_nope_head_dim + qk_rope_head_dim
Scale float32 `json:"-"` // 1/sqrt(QHeadDim) with mscale adjustment
}
// MLAAttention implements Multi-head Latent Attention with absorption.
type MLAAttention struct {
QAProj nn.LinearLayer
QALayerNorm *nn.RMSNorm
QBProj nn.LinearLayer
KVAProjWithMQA nn.LinearLayer
KVALayerNorm *nn.RMSNorm
EmbedQ *nn.MultiLinear
UnembedOut *nn.MultiLinear
OProj nn.LinearLayer
}
// Forward computes absorbed MLA attention output.
func (a *MLAAttention) Forward(x *mlx.Array, b *batch.Batch, c cache.Cache, positions *mlx.Array, B, L int32, cfg *Config) *mlx.Array {
q := a.QAProj.Forward(x)
q = a.QALayerNorm.Forward(q, cfg.RMSNormEps)
q = a.QBProj.Forward(q)
q = mlx.Reshape(q, B, L, cfg.NumAttentionHeads, cfg.QHeadDim)
q = mlx.Transpose(q, 0, 2, 1, 3)
qNope := mlx.SliceStartStop(q, []int32{0, 0, 0, 0}, []int32{B, cfg.NumAttentionHeads, L, cfg.QKNopeHeadDim})
qPE := mlx.SliceStartStop(q, []int32{0, 0, 0, cfg.QKNopeHeadDim}, []int32{B, cfg.NumAttentionHeads, L, cfg.QHeadDim})
compressedKV := a.KVAProjWithMQA.Forward(x)
kvCompressed := mlx.SliceStartStop(compressedKV, []int32{0, 0, 0}, []int32{B, L, cfg.KVLoraRank})
kPE := mlx.SliceStartStop(compressedKV, []int32{0, 0, cfg.KVLoraRank}, []int32{B, L, cfg.KVLoraRank + cfg.QKRopeHeadDim})
kPE = mlx.Reshape(kPE, B, L, 1, cfg.QKRopeHeadDim)
kPE = mlx.Transpose(kPE, 0, 2, 1, 3)
kvLatent := a.KVALayerNorm.Forward(kvCompressed, cfg.RMSNormEps)
kvLatent = mlx.ExpandDims(kvLatent, 1)
qPE = mlx.RoPEWithBase(qPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, positions)
kPE = mlx.RoPEWithBase(kPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, positions)
qLatent := a.EmbedQ.Forward(qNope)
keys := mlx.Concatenate([]*mlx.Array{kvLatent, kPE}, 3)
// MLA compresses K and V into a single tensor: the cache stores
// [kvLatent, kPE] concatenated along the last dim as its keys,
// and V is the kvLatent prefix (first KVLoraRank positions) of
// that same tensor. WithMLAHistory handles the slice on our
// behalf so the model never touches the history's K/V.
var kv nn.SDPAOption
if c != nil {
placeholderValues := mlx.ZerosF32([]int32{B, 1, L, 0})
history := c.(cache.Attention).Update(b, keys, placeholderValues)
kv = nn.WithMLAHistory(history, int(cfg.KVLoraRank))
} else {
values := mlx.SliceStartStop(keys, []int32{0, 0, 0, 0}, []int32{B, 1, L, cfg.KVLoraRank})
kv = nn.WithKV(keys, values, b.SeqQueryLens)
}
queries := mlx.Concatenate([]*mlx.Array{qLatent, qPE}, 3)
out := nn.ScaledDotProductAttention(b, queries, cfg.Scale, kv, nn.WithMask(nn.CausalMask()))
out = a.UnembedOut.Forward(out)
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.VHeadDim)
return a.OProj.Forward(out)
}
// DenseMLP implements the standard SwiGLU MLP for dense layers
type DenseMLP struct {
GateProj nn.LinearLayer
UpProj nn.LinearLayer
DownProj nn.LinearLayer
}
// Forward applies the SwiGLU MLP
func (m *DenseMLP) Forward(x *mlx.Array) *mlx.Array {
return m.DownProj.Forward(mlx.SwiGLU(m.GateProj.Forward(x), m.UpProj.Forward(x)))
}
// MoEGate implements the expert gating mechanism
type MoEGate struct {
Gate nn.LinearLayer
EScoreCorrectionBias *mlx.Array
}
// Forward computes expert selection indices and scores
func (g *MoEGate) Forward(x *mlx.Array, cfg *Config) (*mlx.Array, *mlx.Array) {
gates := g.Gate.Forward(x)
var origScores, negScores *mlx.Array
if g.EScoreCorrectionBias != nil {
origScores, negScores = mlx.SigmoidRouter(gates, g.EScoreCorrectionBias)
} else {
origScores = mlx.Sigmoid(gates)
negScores = mlx.Neg(origScores)
}
topK := cfg.NumExpertsPerTok
inds := mlx.Argpartition(negScores, int(topK)-1, -1)
dims := inds.Dims()
inds = mlx.SliceStartStop(inds, []int32{0, 0, 0}, []int32{int32(dims[0]), int32(dims[1]), topK})
scores := mlx.TakeAlongAxis(origScores, inds, -1)
if topK > 1 && cfg.NormTopKProb {
sumScores := mlx.Sum(scores, -1, true)
scores = mlx.Div(scores, sumScores)
}
scores = mlx.MulScalar(scores, cfg.RoutedScalingFactor)
return inds, scores
}
// SwitchMLP implements the MoE expert computation using stacked weights
type SwitchMLP struct {
GateWeight *mlx.Array
UpWeight *mlx.Array
DownWeight *mlx.Array
GateWeightQ, GateScales, GateBiases *mlx.Array
UpWeightQ, UpScales, UpBiases *mlx.Array
DownWeightQ, DownScales, DownBiases *mlx.Array
GateBits int
UpBits int
DownBits int
GateGroupSize int
UpGroupSize int
DownGroupSize int
UseQuantized bool
}
// Forward applies the switched expert MLP
func (s *SwitchMLP) Forward(x *mlx.Array, indices *mlx.Array, cfg *Config) *mlx.Array {
dims := x.Dims()
B, L := int32(dims[0]), int32(dims[1])
topK := cfg.NumExpertsPerTok
xExpanded := mlx.ExpandDims(mlx.ExpandDims(x, -2), -2)
xFlat := mlx.Reshape(xExpanded, B*L, 1, 1, cfg.HiddenSize)
idxFlat := mlx.Reshape(indices, B*L, topK)
doSort := B*L >= 64
var invOrder *mlx.Array
n := B * L * topK
if doSort {
idxAll := mlx.Flatten(idxFlat)
order := mlx.Argsort(idxAll, 0)
invOrder = mlx.Argsort(order, 0)
xFlat = mlx.ExpandDims(mlx.Take(mlx.Squeeze(xFlat, 1), mlx.FloorDivideScalar(order, topK), 0), 1)
idxFlat = mlx.Reshape(mlx.Take(idxAll, order, 0), n, 1)
}
var gate, up, hidden, down *mlx.Array
if s.UseQuantized {
gate = mlx.GatherQMM(xFlat, s.GateWeightQ, s.GateScales, s.GateBiases,
nil, idxFlat, true, s.GateGroupSize, s.GateBits, cfg.QuantMode, doSort)
up = mlx.GatherQMM(xFlat, s.UpWeightQ, s.UpScales, s.UpBiases,
nil, idxFlat, true, s.UpGroupSize, s.UpBits, cfg.QuantMode, doSort)
hidden = mlx.SwiGLU(gate, up)
down = mlx.GatherQMM(hidden, s.DownWeightQ, s.DownScales, s.DownBiases,
nil, idxFlat, true, s.DownGroupSize, s.DownBits, cfg.QuantMode, doSort)
} else {
gate = mlx.GatherMM(xFlat, mlx.Transpose(s.GateWeight, 0, 2, 1), nil, idxFlat, doSort)
up = mlx.GatherMM(xFlat, mlx.Transpose(s.UpWeight, 0, 2, 1), nil, idxFlat, doSort)
hidden = mlx.SwiGLU(gate, up)
down = mlx.GatherMM(hidden, mlx.Transpose(s.DownWeight, 0, 2, 1), nil, idxFlat, doSort)
}
if doSort {
down = mlx.Reshape(mlx.Take(mlx.Squeeze(mlx.Squeeze(down, 2), 1), invOrder, 0), B*L, topK, cfg.HiddenSize)
} else {
down = mlx.Squeeze(down, 2)
}
return mlx.Reshape(down, B, L, topK, cfg.HiddenSize)
}
// SharedExperts implements the shared expert MLP
type SharedExperts struct {
GateProj nn.LinearLayer
UpProj nn.LinearLayer
DownProj nn.LinearLayer
}
// Forward applies the shared expert MLP
func (s *SharedExperts) Forward(x *mlx.Array) *mlx.Array {
return s.DownProj.Forward(mlx.SwiGLU(s.GateProj.Forward(x), s.UpProj.Forward(x)))
}
// MoE implements the full Mixture of Experts layer
type MoE struct {
Gate *MoEGate
SwitchMLP *SwitchMLP
SharedExperts *SharedExperts
}
// Forward applies the MoE layer
func (m *MoE) Forward(x *mlx.Array, cfg *Config) *mlx.Array {
dims := x.Dims()
B, L := int32(dims[0]), int32(dims[1])
inds, scores := m.Gate.Forward(x, cfg)
expertOut := m.SwitchMLP.Forward(x, inds, cfg)
scoresExpanded := mlx.ExpandDims(scores, -1)
y := mlx.Sum(mlx.Mul(expertOut, scoresExpanded), 2, false)
if m.SharedExperts != nil {
y = mlx.Add(y, m.SharedExperts.Forward(x))
}
return mlx.Reshape(y, B, L, cfg.HiddenSize)
}
// DenseBlock represents a dense transformer block (for first_k_dense_replace layers)
type DenseBlock struct {
Attention *MLAAttention
MLP *DenseMLP
InputLayerNorm *nn.RMSNorm
PostAttentionLayerNorm *nn.RMSNorm
}
// Forward applies the dense block
func (blk *DenseBlock) Forward(x *mlx.Array, b *batch.Batch, c cache.Cache, positions *mlx.Array, B, L int32, cfg *Config) *mlx.Array {
r := blk.Attention.Forward(blk.InputLayerNorm.Forward(x, cfg.RMSNormEps), b, c, positions, B, L, cfg)
h := mlx.Add(x, r)
r = blk.MLP.Forward(blk.PostAttentionLayerNorm.Forward(h, cfg.RMSNormEps))
return mlx.Add(h, r)
}
// MoEBlock represents a MoE transformer block
type MoEBlock struct {
Attention *MLAAttention
MoE *MoE
InputLayerNorm *nn.RMSNorm
PostAttentionLayerNorm *nn.RMSNorm
}
// Forward applies the MoE block
func (blk *MoEBlock) Forward(x *mlx.Array, b *batch.Batch, c cache.Cache, positions *mlx.Array, B, L int32, cfg *Config) *mlx.Array {
r := blk.Attention.Forward(blk.InputLayerNorm.Forward(x, cfg.RMSNormEps), b, c, positions, B, L, cfg)
h := mlx.Add(x, r)
r = blk.MoE.Forward(blk.PostAttentionLayerNorm.Forward(h, cfg.RMSNormEps), cfg)
return mlx.Add(h, r)
}
// Block interface for both dense and MoE blocks
type Block interface {
Forward(x *mlx.Array, b *batch.Batch, c cache.Cache, positions *mlx.Array, B, L int32, cfg *Config) *mlx.Array
}
// Model represents the complete GLM4-MoE-Lite model
type Model struct {
EmbedTokens nn.EmbeddingLayer
Layers []Block
Norm *nn.RMSNorm
LMHead nn.LinearLayer
tok *tokenizer.Tokenizer
*Config
}
// computeScale computes the attention scale.
func computeScale(cfg *Config) float32 {
keyLength := cfg.QKNopeHeadDim + cfg.QKRopeHeadDim
scale := float32(1.0 / math.Sqrt(float64(keyLength)))
if cfg.RopeScaling != nil && cfg.RopeScaling.MscaleAllDim > 0 && cfg.RopeScaling.Factor > 1 {
s := 0.1*cfg.RopeScaling.MscaleAllDim*float32(math.Log(float64(cfg.RopeScaling.Factor))) + 1.0
scale *= s * s
}
return scale
}
// supportsGatherQMM returns true if the quantization mode has GatherQMM kernel support.
func supportsGatherQMM(mode string, bits int) bool {
return mode == "affine" && (bits == 4 || bits == 8)
}
// ExpertWeight holds a single expert's weight with optional quantization components.
type ExpertWeight struct {
Weight *mlx.Array
Scales *mlx.Array
Biases *mlx.Array
Bits int
GroupSize int
}
// loadExpertWeight loads an expert weight from the tensor map.
func loadExpertWeight(tensors map[string]*mlx.Array, path string, useQuantized bool, cfg *Config) *ExpertWeight {
w := tensors[path+".weight"]
if w == nil {
return nil
}
scales := tensors[path+".weight_scale"]
if scales != nil {
qbiases := tensors[path+".weight_qbias"]
groupSize, bits, mode := model.ResolveLinearQuantParams(
cfg.QuantGroupSize,
cfg.QuantBits,
cfg.QuantMode,
cfg.TensorQuant,
path+".weight",
w,
scales,
)
if useQuantized && supportsGatherQMM(mode, bits) {
return &ExpertWeight{Weight: w, Scales: scales, Biases: qbiases, Bits: bits, GroupSize: groupSize}
}
return &ExpertWeight{Weight: mlx.Dequantize(w, scales, qbiases, groupSize, bits, mode)}
}
return &ExpertWeight{Weight: w}
}
// StackedExpertWeights holds stacked weights for all experts.
type StackedExpertWeights struct {
Weight *mlx.Array
Scales *mlx.Array
Biases *mlx.Array
Bits int
GroupSize int
}
// collectAndStackExpertWeights loads and stacks expert weights for one projection type.
func collectAndStackExpertWeights(
tensors map[string]*mlx.Array,
prefix string,
projName string,
numExperts int32,
useQuantized bool,
cfg *Config,
) *StackedExpertWeights {
var w, s, b []*mlx.Array
var bits, groupSize int
for e := int32(0); e < numExperts; e++ {
path := fmt.Sprintf("%s.mlp.experts.%d.%s", prefix, e, projName)
ew := loadExpertWeight(tensors, path, useQuantized, cfg)
if ew == nil {
continue
}
w = append(w, ew.Weight)
if ew.Scales != nil {
s = append(s, ew.Scales)
}
if ew.Biases != nil {
b = append(b, ew.Biases)
}
if e == 0 {
bits = ew.Bits
groupSize = ew.GroupSize
}
}
result := &StackedExpertWeights{Bits: bits, GroupSize: groupSize}
if len(w) > 0 {
result.Weight = mlx.Stack(w, 0)
if len(s) > 0 {
result.Scales = mlx.Stack(s, 0)
}
if len(b) > 0 {
result.Biases = mlx.Stack(b, 0)
}
}
return result
}
// sanitizeExpertWeights stacks individual expert weights into tensors.
func sanitizeExpertWeights(tensors map[string]*mlx.Array, prefix string, numExperts int32, useQuantized bool, cfg *Config) (gate, up, down *StackedExpertWeights) {
gate = collectAndStackExpertWeights(tensors, prefix, "gate_proj", numExperts, useQuantized, cfg)
up = collectAndStackExpertWeights(tensors, prefix, "up_proj", numExperts, useQuantized, cfg)
down = collectAndStackExpertWeights(tensors, prefix, "down_proj", numExperts, useQuantized, cfg)
return gate, up, down
}
// sanitizeMLAWeights transforms kv_b_proj weights into absorbed MLA format.
func sanitizeMLAWeights(tensors map[string]*mlx.Array, prefix string, cfg *Config) (*mlx.Array, *mlx.Array) {
path := prefix + ".self_attn.kv_b_proj"
w := tensors[path+".weight"]
if w == nil {
return nil, nil
}
// Check if quantized and dequantize
if scales := tensors[path+".weight_scale"]; scales != nil {
qbiases := tensors[path+".weight_qbias"]
groupSize, bits, mode := model.ResolveLinearQuantParams(
cfg.QuantGroupSize,
cfg.QuantBits,
cfg.QuantMode,
cfg.TensorQuant,
path+".weight",
w,
scales,
)
w = mlx.Dequantize(w, scales, qbiases, groupSize, bits, mode)
}
headDim := cfg.QKNopeHeadDim + cfg.VHeadDim
w = mlx.Reshape(w, cfg.NumAttentionHeads, headDim, cfg.KVLoraRank)
wk := mlx.SliceStartStop(w, []int32{0, 0, 0}, []int32{cfg.NumAttentionHeads, cfg.QKNopeHeadDim, cfg.KVLoraRank})
wv := mlx.SliceStartStop(w, []int32{0, cfg.QKNopeHeadDim, 0}, []int32{cfg.NumAttentionHeads, headDim, cfg.KVLoraRank})
embedQ := mlx.Transpose(wk, 0, 2, 1)
unembedOut := wv
return embedQ, unembedOut
}
// newModel creates a new GLM4-MoE-Lite model from a Root (config + tokenizer,
// no weights loaded yet). Called by the registry via base.New().
func newModel(root *model.Root) (base.Model, error) {
configData, err := root.Manifest.ReadConfig("config.json")
if err != nil {
return nil, fmt.Errorf("load config: %w", err)
}
var cfg Config
if err := json.Unmarshal(configData, &cfg); err != nil {
return nil, fmt.Errorf("parse config: %w", err)
}
cfg.QHeadDim = cfg.QKNopeHeadDim + cfg.QKRopeHeadDim
cfg.Scale = computeScale(&cfg)
// Set up quantization parameters from pre-scanned metadata
if qt := root.QuantType(); qt != "" {
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams(qt)
if gs := root.GroupSize(); gs > 0 {
cfg.QuantGroupSize = gs
}
} else {
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams("")
}
cfg.TensorQuant = root.AllTensorQuant()
// Load tokenizer
tokData, err := root.Manifest.ReadConfig("tokenizer.json")
if err != nil {
return nil, fmt.Errorf("load tokenizer config: %w", err)
}
tokConfig := &tokenizer.TokenizerConfig{
ConfigJSON: configData,
}
if genConfigData, err := root.Manifest.ReadConfig("generation_config.json"); err == nil {
tokConfig.GenerationConfigJSON = genConfigData
}
if tokConfigData, err := root.Manifest.ReadConfig("tokenizer_config.json"); err == nil {
tokConfig.TokenizerConfigJSON = tokConfigData
}
tok, err := tokenizer.LoadFromBytesWithConfig(tokData, tokConfig)
if err != nil {
return nil, fmt.Errorf("parse tokenizer: %w", err)
}
m := &Model{
Layers: make([]Block, cfg.NumHiddenLayers),
Config: &cfg,
tok: tok,
}
return m, nil
}
// LoadWeights receives all tensors loaded from the manifest and assigns them
// to model fields. Handles MLA absorption, expert stacking, and quantized
// layer creation.
func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
cfg := m.Config
linears := model.NewLinearFactory(tensors, cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode, cfg.TensorQuant)
useQuantized := supportsGatherQMM(cfg.QuantMode, cfg.QuantBits)
if !useQuantized && cfg.TensorQuant != nil {
for _, tq := range cfg.TensorQuant {
if tq == nil {
continue
}
_, bits, mode := model.QuantizationParams(tq.QuantType)
if supportsGatherQMM(mode, bits) {
useQuantized = true
break
}
}
}
// Load embedding
m.EmbedTokens = model.MakeEmbeddingLayer(tensors, "model.embed_tokens", cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode, cfg.TensorQuant)
// Load final norm
if w := tensors["model.norm.weight"]; w != nil {
m.Norm = nn.NewRMSNorm(w, cfg.RMSNormEps)
}
// Load LM head
m.LMHead = linears.Make("lm_head")
// Load layers
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
prefix := fmt.Sprintf("model.layers.%d", i)
// Load attention (same for both block types)
attn := &MLAAttention{}
attn.QAProj = linears.Make(prefix + ".self_attn.q_a_proj")
if w := tensors[prefix+".self_attn.q_a_layernorm.weight"]; w != nil {
attn.QALayerNorm = nn.NewRMSNorm(w, cfg.RMSNormEps)
}
attn.QBProj = linears.Make(prefix + ".self_attn.q_b_proj")
attn.KVAProjWithMQA = linears.Make(prefix + ".self_attn.kv_a_proj_with_mqa")
if w := tensors[prefix+".self_attn.kv_a_layernorm.weight"]; w != nil {
attn.KVALayerNorm = nn.NewRMSNorm(w, cfg.RMSNormEps)
}
attn.OProj = linears.Make(prefix + ".self_attn.o_proj")
// Sanitize MLA weights for absorbed attention
embedQ, unembedOut := sanitizeMLAWeights(tensors, prefix, cfg)
attn.EmbedQ = nn.NewMultiLinear(embedQ)
attn.UnembedOut = nn.NewMultiLinear(unembedOut)
inputLN := tensors[prefix+".input_layernorm.weight"]
postAttnLN := tensors[prefix+".post_attention_layernorm.weight"]
if i < cfg.FirstKDenseReplace {
// Dense block
block := &DenseBlock{Attention: attn}
if inputLN != nil {
block.InputLayerNorm = nn.NewRMSNorm(inputLN, cfg.RMSNormEps)
}
if postAttnLN != nil {
block.PostAttentionLayerNorm = nn.NewRMSNorm(postAttnLN, cfg.RMSNormEps)
}
block.MLP = &DenseMLP{
GateProj: linears.Make(prefix + ".mlp.gate_proj"),
UpProj: linears.Make(prefix + ".mlp.up_proj"),
DownProj: linears.Make(prefix + ".mlp.down_proj"),
}
m.Layers[i] = block
} else {
// MoE block
block := &MoEBlock{Attention: attn}
if inputLN != nil {
block.InputLayerNorm = nn.NewRMSNorm(inputLN, cfg.RMSNormEps)
}
if postAttnLN != nil {
block.PostAttentionLayerNorm = nn.NewRMSNorm(postAttnLN, cfg.RMSNormEps)
}
// Stack expert weights
gate, up, down := sanitizeExpertWeights(tensors, prefix, cfg.NRoutedExperts, useQuantized, cfg)
switchMLP := &SwitchMLP{UseQuantized: useQuantized}
if useQuantized {
switchMLP.GateWeightQ = gate.Weight
switchMLP.GateScales = gate.Scales
switchMLP.GateBiases = gate.Biases
switchMLP.GateBits = gate.Bits
switchMLP.GateGroupSize = gate.GroupSize
switchMLP.UpWeightQ = up.Weight
switchMLP.UpScales = up.Scales
switchMLP.UpBiases = up.Biases
switchMLP.UpBits = up.Bits
switchMLP.UpGroupSize = up.GroupSize
switchMLP.DownWeightQ = down.Weight
switchMLP.DownScales = down.Scales
switchMLP.DownBiases = down.Biases
switchMLP.DownBits = down.Bits
switchMLP.DownGroupSize = down.GroupSize
} else {
switchMLP.GateWeight = gate.Weight
switchMLP.UpWeight = up.Weight
switchMLP.DownWeight = down.Weight
}
moeGate := &MoEGate{}
moeGate.Gate = linears.Make(prefix + ".mlp.gate")
if bias := tensors[prefix+".mlp.gate.e_score_correction_bias"]; bias != nil {
moeGate.EScoreCorrectionBias = bias
}
block.MoE = &MoE{
Gate: moeGate,
SwitchMLP: switchMLP,
}
// Load shared experts if present
if cfg.NSharedExperts > 0 {
block.MoE.SharedExperts = &SharedExperts{
GateProj: linears.Make(prefix + ".mlp.shared_experts.gate_proj"),
UpProj: linears.Make(prefix + ".mlp.shared_experts.up_proj"),
DownProj: linears.Make(prefix + ".mlp.shared_experts.down_proj"),
}
}
m.Layers[i] = block
}
}
return nil
}
// Forward computes the forward pass of the model
func (m *Model) Forward(b *batch.Batch, caches []cache.Cache) *mlx.Array {
dims := b.InputIDs.Dims()
B, L := int32(dims[0]), int32(dims[1])
positions := mlx.FromValues(b.SeqOffsets, len(b.SeqOffsets))
h := m.EmbedTokens.Forward(b.InputIDs)
for i, layer := range m.Layers {
var c cache.Cache
if caches != nil {
c = caches[i]
}
h = layer.Forward(h, b, c, positions, B, L, m.Config)
}
h = m.Norm.Forward(h, m.RMSNormEps)
return h
}
// Unembed applies the LM head to get logits.
func (m *Model) Unembed(x *mlx.Array) *mlx.Array {
return m.LMHead.Forward(x)
}
// NumLayers returns the number of transformer layers
func (m *Model) NumLayers() int { return len(m.Layers) }
// MaxContextLength returns the maximum context length
func (m *Model) MaxContextLength() int { return int(m.MaxPositionEmbeddings) }
// VocabSize returns the vocabulary size
func (m *Model) VocabSize() int32 { return m.Config.VocabSize }
// Tokenizer returns the model's tokenizer
func (m *Model) Tokenizer() *tokenizer.Tokenizer { return m.tok }
// NewCache creates a new KV cache for the model
func (m *Model) NewCache(maxSeqLen int32) []cache.Cache {
caches := make([]cache.Cache, len(m.Layers))
for i := range caches {
caches[i] = cache.NewKVCache()
}
return caches
}
// FormatPrompt applies the GLM-4 chat template with thinking enabled by default.
func (m *Model) FormatPrompt(prompt string) string {
return "[gMASK]<sop><|user|>" + prompt + "<|assistant|><think>"
}
// FormatPromptWithThinking applies the GLM-4 chat template with explicit thinking control.
func (m *Model) FormatPromptWithThinking(prompt string, think bool) string {
if think {
return "[gMASK]<sop><|user|>" + prompt + "<|assistant|><think>"
}
return "[gMASK]<sop><|user|>" + prompt + "<|assistant|></think>"
}
// NewRenderer returns a new Renderer for formatting multi-turn conversations.
func (m *Model) NewRenderer() *Renderer {
return &Renderer{}
}
// NewParser returns a new Parser for extracting thinking and tool calls from output.
func (m *Model) NewParser() *Parser {
return &Parser{}
}

View File

@@ -0,0 +1,520 @@
package glm4_moe_lite
import (
"context"
"encoding/json"
"encoding/xml"
"fmt"
"log/slog"
"strings"
"unicode"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/logutil"
)
type parserState int
const (
parserState_LookingForThinkingOpen parserState = iota
parserState_ThinkingStartedEatingWhitespace
parserState_CollectingThinking
parserState_ThinkingDoneEatingWhitespace
parserState_CollectingContent
parserState_ToolStartedEatingWhitespace
parserState_CollectingToolContent
)
const (
thinkingOpenTag = "<think>"
thinkingCloseTag = "</think>"
toolOpenTag = "<tool_call>"
toolCloseTag = "</tool_call>"
)
// Parser parses GLM4-MoE-Lite model output to extract thinking and tool calls.
// GLM-4's prompt ends with <think> when thinking is enabled, so the parser
// must start in CollectingThinking state (the model outputs thinking content directly).
type Parser struct {
state parserState
buffer strings.Builder
tools []api.Tool
}
// HasToolSupport returns true as GLM4 supports tool calling.
func (p *Parser) HasToolSupport() bool {
return true
}
// HasThinkingSupport returns true as GLM4 supports thinking mode.
func (p *Parser) HasThinkingSupport() bool {
return true
}
// Init initializes the parser with tools and thinking configuration.
func (p *Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
p.tools = tools
// When thinking is enabled (nil or true), the prompt ends with <think>,
// so model output starts directly with thinking content (no opening tag).
if thinkValue == nil || thinkValue.Bool() {
p.state = parserState_CollectingThinking
}
return tools
}
type parserEvent interface {
isParserEvent()
}
type eventContent struct {
content string
}
func (eventContent) isParserEvent() {}
type eventRawToolCall struct {
raw string
}
func (eventRawToolCall) isParserEvent() {}
type eventThinkingContent struct {
content string
}
func (eventThinkingContent) isParserEvent() {}
// Add processes new output text and returns parsed content, thinking, and tool calls.
func (p *Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
p.buffer.WriteString(s)
events := p.parseEvents()
var toolCalls []api.ToolCall
var contentSb strings.Builder
var thinkingSb strings.Builder
for _, event := range events {
switch event := event.(type) {
case eventRawToolCall:
toolCall, err := parseToolCall(event, p.tools)
if err != nil {
slog.Warn("glm-4 tool call parsing failed", "error", err)
return "", "", nil, err
}
toolCalls = append(toolCalls, toolCall)
case eventThinkingContent:
thinkingSb.WriteString(event.content)
case eventContent:
contentSb.WriteString(event.content)
}
}
return contentSb.String(), thinkingSb.String(), toolCalls, nil
}
func (p *Parser) parseEvents() []parserEvent {
var all []parserEvent
keepLooping := true
for keepLooping {
var events []parserEvent
events, keepLooping = p.eat()
if len(events) > 0 {
all = append(all, events...)
}
}
if len(all) > 0 {
slog.Log(context.TODO(), logutil.LevelTrace, "glm-4 events parsed", "events", all, "state", p.state, "buffer", p.buffer.String())
}
return all
}
// eatLeadingWhitespaceAndTransitionTo consumes leading whitespace from the buffer
// and transitions to the next state. Returns (nil, false) if only whitespace remains
// in the buffer (needs more input), or (nil, true) if we successfully transitioned.
func (p *Parser) eatLeadingWhitespaceAndTransitionTo(nextState parserState) ([]parserEvent, bool) {
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
p.buffer.Reset()
if trimmed == "" {
return nil, false // Still only whitespace, keep waiting for more input
}
p.state = nextState
p.buffer.WriteString(trimmed)
return nil, true // Successfully transitioned
}
// splitAtTag splits the buffer at the given tag, returns the content before (trimmed of trailing whitespace),
// the content after (optionally trimmed of leading whitespace), and updates the buffer
func (p *Parser) splitAtTag(tag string, trimAfter bool) (string, string) {
split := strings.SplitN(p.buffer.String(), tag, 2)
before := split[0]
before = strings.TrimRightFunc(before, unicode.IsSpace)
after := split[1]
if trimAfter {
after = strings.TrimLeftFunc(after, unicode.IsSpace)
}
p.buffer.Reset()
p.buffer.WriteString(after)
return before, after
}
func (p *Parser) eat() ([]parserEvent, bool) {
var events []parserEvent
switch p.state {
case parserState_LookingForThinkingOpen:
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
if strings.HasPrefix(trimmed, thinkingOpenTag) {
// Found <think> opening tag
after := strings.TrimPrefix(trimmed, thinkingOpenTag)
after = strings.TrimLeftFunc(after, unicode.IsSpace)
p.buffer.Reset()
p.buffer.WriteString(after)
if after == "" {
p.state = parserState_ThinkingStartedEatingWhitespace
} else {
p.state = parserState_CollectingThinking
}
return events, true
} else if strings.HasPrefix(thinkingOpenTag, trimmed) {
// Partial opening tag seen, keep accumulating
return events, false
} else if trimmed == "" {
// Only whitespace, keep accumulating
return events, false
} else {
// No thinking tag found, skip to content collection
p.state = parserState_CollectingContent
// Don't trim - we want to keep the original content
return events, true
}
case parserState_ThinkingStartedEatingWhitespace:
return p.eatLeadingWhitespaceAndTransitionTo(parserState_CollectingThinking)
case parserState_CollectingThinking:
acc := p.buffer.String()
if strings.Contains(acc, thinkingCloseTag) {
thinking, remaining := p.splitAtTag(thinkingCloseTag, true)
if len(thinking) > 0 {
events = append(events, eventThinkingContent{content: thinking})
}
if remaining == "" {
p.state = parserState_ThinkingDoneEatingWhitespace
} else {
p.state = parserState_CollectingContent
}
return events, true
} else if overlapLen := overlap(acc, thinkingCloseTag); overlapLen > 0 {
// Partial closing tag - withhold it along with any trailing whitespace before it
beforePartialTag := acc[:len(acc)-overlapLen]
trailingWsLen := trailingWhitespaceLen(beforePartialTag)
ambiguousStart := len(beforePartialTag) - trailingWsLen
unambiguous := acc[:ambiguousStart]
ambiguous := acc[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, eventThinkingContent{content: unambiguous})
}
return events, false
} else {
// Pure thinking content - withhold trailing whitespace (might precede closing tag)
whitespaceLen := trailingWhitespaceLen(acc)
ambiguousStart := len(acc) - whitespaceLen
unambiguous := acc[:ambiguousStart]
ambiguous := acc[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, eventThinkingContent{content: unambiguous})
}
return events, false
}
case parserState_ThinkingDoneEatingWhitespace:
return p.eatLeadingWhitespaceAndTransitionTo(parserState_CollectingContent)
case parserState_CollectingContent:
if strings.Contains(p.buffer.String(), toolOpenTag) {
before, after := p.splitAtTag(toolOpenTag, true)
if len(before) > 0 {
events = append(events, eventContent{content: before})
}
if after == "" {
p.state = parserState_ToolStartedEatingWhitespace
} else {
p.state = parserState_CollectingToolContent
}
return events, true
} else if overlapLen := overlap(p.buffer.String(), toolOpenTag); overlapLen > 0 {
beforePartialTag := p.buffer.String()[:len(p.buffer.String())-overlapLen]
trailingWsLen := trailingWhitespaceLen(beforePartialTag)
ambiguousStart := len(beforePartialTag) - trailingWsLen
unambiguous := p.buffer.String()[:ambiguousStart]
ambiguous := p.buffer.String()[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, eventContent{content: unambiguous})
}
return events, false
} else {
whitespaceLen := trailingWhitespaceLen(p.buffer.String())
ambiguousStart := len(p.buffer.String()) - whitespaceLen
unambiguous := p.buffer.String()[:ambiguousStart]
ambiguous := p.buffer.String()[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, eventContent{content: unambiguous})
}
return events, false
}
case parserState_ToolStartedEatingWhitespace:
return p.eatLeadingWhitespaceAndTransitionTo(parserState_CollectingToolContent)
case parserState_CollectingToolContent:
acc := p.buffer.String()
if strings.Contains(acc, toolCloseTag) {
toolContent, _ := p.splitAtTag(toolCloseTag, true)
if len(toolContent) == 0 {
slog.Warn("glm4 tool call closing tag found but no content before it")
}
events = append(events, eventRawToolCall{raw: toolContent})
p.state = parserState_CollectingContent
return events, true
} else {
// Keep accumulating - tool calls are not streamed
// We just wait for the closing tag
return events, false
}
default:
panic("unreachable")
}
}
// overlap returns the length of the overlap between the end of s and the start of tag.
func overlap(s, tag string) int {
for i := 1; i <= len(tag) && i <= len(s); i++ {
if strings.HasSuffix(s, tag[:i]) {
return i
}
}
return 0
}
// trailingWhitespaceLen returns the length of trailing whitespace in s.
func trailingWhitespaceLen(s string) int {
trimmed := strings.TrimRightFunc(s, unicode.IsSpace)
return len(s) - len(trimmed)
}
// ToolCallXML represents the structure of a GLM-4 tool call for XML parsing
type ToolCallXML struct {
XMLName xml.Name `xml:"tool_call"`
Content string `xml:",chardata"` // Function name (text nodes between tags)
Keys []string `xml:"arg_key"` // All arg_key elements in document order
Values []string `xml:"arg_value"` // All arg_value elements in document order
}
// escapeContent escapes XML entities in text content while preserving arg_key/arg_value tags
func escapeContent(s string) string {
var result strings.Builder
inTag := false
for i := range len(s) {
ch := s[i]
if ch == '<' {
// Check if this is a known tag
if strings.HasPrefix(s[i:], "<arg_key>") ||
strings.HasPrefix(s[i:], "</arg_key>") ||
strings.HasPrefix(s[i:], "<arg_value>") ||
strings.HasPrefix(s[i:], "</arg_value>") {
inTag = true
}
}
if inTag {
result.WriteByte(ch)
if ch == '>' {
inTag = false
}
} else {
// Escape special characters in text content
switch ch {
case '&':
result.WriteString("&amp;")
case '<':
result.WriteString("&lt;")
case '>':
result.WriteString("&gt;")
default:
result.WriteByte(ch)
}
}
}
return result.String()
}
// repairUnclosedArgValues inserts missing </arg_value> closing tags.
// GLM models sometimes omit the closing tag, producing XML like:
//
// <arg_value>value</tool_call>
//
// instead of:
//
// <arg_value>value</arg_value></tool_call>
func repairUnclosedArgValues(s string) string {
var result strings.Builder
for {
openIdx := strings.Index(s, "<arg_value>")
if openIdx == -1 {
result.WriteString(s)
break
}
afterOpen := openIdx + len("<arg_value>")
closeIdx := strings.Index(s[afterOpen:], "</arg_value>")
nextKeyIdx := strings.Index(s[afterOpen:], "<arg_key>")
if closeIdx != -1 && (nextKeyIdx == -1 || closeIdx < nextKeyIdx) {
end := afterOpen + closeIdx + len("</arg_value>")
result.WriteString(s[:end])
s = s[end:]
continue
}
if nextKeyIdx != -1 {
insertAt := afterOpen + nextKeyIdx
result.WriteString(s[:insertAt])
result.WriteString("</arg_value>")
s = s[insertAt:]
} else {
result.WriteString(s)
result.WriteString("</arg_value>")
break
}
}
return result.String()
}
func parseToolCall(raw eventRawToolCall, tools []api.Tool) (api.ToolCall, error) {
// Escape any unescaped entities in text content
escaped := escapeContent(raw.raw)
// Wrap the content in a root element to make it valid XML
xmlString := "<tool_call>" + escaped + "</tool_call>"
// Parse XML into struct, retrying once with repaired XML if it fails
var parsed ToolCallXML
if err := xml.Unmarshal([]byte(xmlString), &parsed); err != nil {
parsed = ToolCallXML{}
repaired := "<tool_call>" + repairUnclosedArgValues(escaped) + "</tool_call>"
if err2 := xml.Unmarshal([]byte(repaired), &parsed); err2 != nil {
return api.ToolCall{}, fmt.Errorf("failed to parse XML: %w", err)
}
}
// Extract and trim function name
functionName := strings.TrimSpace(parsed.Content)
if functionName == "" {
return api.ToolCall{}, fmt.Errorf("empty function name")
}
// Verify keys and values are paired correctly
if len(parsed.Keys) != len(parsed.Values) {
return api.ToolCall{}, fmt.Errorf("mismatched arg_key and arg_value counts: %d keys, %d values", len(parsed.Keys), len(parsed.Values))
}
// Find the matching tool to get parameter types
var matchedTool *api.Tool
for i := range tools {
if tools[i].Function.Name == functionName {
matchedTool = &tools[i]
break
}
}
// Build arguments map by pairing keys and values
toolCall := api.ToolCall{
Function: api.ToolCallFunction{
Name: functionName,
Arguments: api.NewToolCallFunctionArguments(),
},
}
for i := range parsed.Keys {
key := strings.TrimSpace(parsed.Keys[i])
value := parsed.Values[i] // Don't trim here - parseValue handles it
// Look up parameter type
var paramType api.PropertyType
if matchedTool != nil && matchedTool.Function.Parameters.Properties != nil {
if prop, ok := matchedTool.Function.Parameters.Properties.Get(key); ok {
// Handle anyOf by collecting all types from the union
if len(prop.AnyOf) > 0 {
for _, anyOfProp := range prop.AnyOf {
paramType = append(paramType, anyOfProp.Type...)
}
} else {
paramType = prop.Type
}
}
}
// Parse value with type coercion
toolCall.Function.Arguments.Set(key, parseValue(value, paramType))
}
return toolCall, nil
}
// parseValue parses a string value and coerces it to the appropriate type based on paramType.
func parseValue(value string, paramType api.PropertyType) any {
value = strings.TrimSpace(value)
// If no type specified, return as string
if len(paramType) == 0 {
return value
}
// Try to parse based on specified types
for _, t := range paramType {
switch t {
case "boolean":
if value == "true" {
return true
}
if value == "false" {
return false
}
case "integer":
var i int64
if _, err := fmt.Sscanf(value, "%d", &i); err == nil {
return i
}
case "number":
var f float64
if _, err := fmt.Sscanf(value, "%f", &f); err == nil {
return f
}
case "array", "object":
// Try to parse as JSON
var result any
if err := json.Unmarshal([]byte(value), &result); err == nil {
return result
}
}
}
// Default to string
return value
}

View File

@@ -0,0 +1,190 @@
package glm4_moe_lite
import (
"testing"
"github.com/ollama/ollama/api"
)
func TestParserThinking(t *testing.T) {
tests := []struct {
name string
input string
thinkEnabled bool
wantContent string
wantThinking string
wantToolCalls int
}{
{
name: "thinking enabled - simple thinking then content",
input: "Let me think about this...</think>Here is my answer.",
thinkEnabled: true,
wantThinking: "Let me think about this...",
wantContent: "Here is my answer.",
},
{
name: "thinking enabled - only thinking",
input: "I need to consider multiple factors...",
thinkEnabled: true,
wantThinking: "I need to consider multiple factors...",
wantContent: "",
},
{
name: "thinking disabled - direct content",
input: "Here is my direct answer.",
thinkEnabled: false,
wantThinking: "",
wantContent: "Here is my direct answer.",
},
{
name: "thinking with tool call",
input: "Let me search for that...</think>I'll use a tool.<tool_call>search<arg_key>query</arg_key><arg_value>test</arg_value></tool_call>",
thinkEnabled: true,
wantThinking: "Let me search for that...",
wantContent: "I'll use a tool.",
wantToolCalls: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := &Parser{}
var thinkValue *api.ThinkValue
if tt.thinkEnabled {
thinkValue = &api.ThinkValue{Value: true}
} else {
thinkValue = &api.ThinkValue{Value: false}
}
// Define tools for tool call tests
props := api.NewToolPropertiesMap()
props.Set("query", api.ToolProperty{Type: api.PropertyType{"string"}})
tools := []api.Tool{
{
Function: api.ToolFunction{
Name: "search",
Parameters: api.ToolFunctionParameters{
Properties: props,
},
},
},
}
p.Init(tools, nil, thinkValue)
content, thinking, calls, err := p.Add(tt.input, true)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if thinking != tt.wantThinking {
t.Errorf("thinking = %q, want %q", thinking, tt.wantThinking)
}
if content != tt.wantContent {
t.Errorf("content = %q, want %q", content, tt.wantContent)
}
if len(calls) != tt.wantToolCalls {
t.Errorf("len(calls) = %d, want %d", len(calls), tt.wantToolCalls)
}
})
}
}
func TestParserToolCall(t *testing.T) {
p := &Parser{}
props := api.NewToolPropertiesMap()
props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}})
props.Set("unit", api.ToolProperty{Type: api.PropertyType{"string"}})
tools := []api.Tool{
{
Function: api.ToolFunction{
Name: "get_weather",
Parameters: api.ToolFunctionParameters{
Properties: props,
},
},
},
}
// Initialize with thinking disabled
tv := &api.ThinkValue{Value: false}
p.Init(tools, nil, tv)
input := "<tool_call>get_weather<arg_key>location</arg_key><arg_value>San Francisco</arg_value><arg_key>unit</arg_key><arg_value>celsius</arg_value></tool_call>"
_, _, calls, err := p.Add(input, true)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(calls) != 1 {
t.Fatalf("expected 1 tool call, got %d", len(calls))
}
call := calls[0]
if call.Function.Name != "get_weather" {
t.Errorf("function name = %q, want %q", call.Function.Name, "get_weather")
}
location, ok := call.Function.Arguments.Get("location")
if !ok || location != "San Francisco" {
t.Errorf("location = %v, want %q", location, "San Francisco")
}
unit, ok := call.Function.Arguments.Get("unit")
if !ok || unit != "celsius" {
t.Errorf("unit = %v, want %q", unit, "celsius")
}
}
func TestOverlap(t *testing.T) {
tests := []struct {
s string
tag string
want int
}{
{"hello<", "</think>", 1},
{"hello</", "</think>", 2},
{"hello</t", "</think>", 3},
{"hello</th", "</think>", 4},
{"hello</thi", "</think>", 5},
{"hello</thin", "</think>", 6},
{"hello</think", "</think>", 7},
{"hello</think>", "</think>", 8}, // Complete tag at end returns full length
{"hello", "</think>", 0},
{"", "</think>", 0},
}
for _, tt := range tests {
t.Run(tt.s+"_"+tt.tag, func(t *testing.T) {
got := overlap(tt.s, tt.tag)
if got != tt.want {
t.Errorf("overlap(%q, %q) = %d, want %d", tt.s, tt.tag, got, tt.want)
}
})
}
}
func TestTrailingWhitespaceLen(t *testing.T) {
tests := []struct {
s string
want int
}{
{"hello ", 3},
{"hello\n\t ", 3},
{"hello", 0},
{"", 0},
{" ", 3},
}
for _, tt := range tests {
t.Run(tt.s, func(t *testing.T) {
got := trailingWhitespaceLen(tt.s)
if got != tt.want {
t.Errorf("trailingWhitespaceLen(%q) = %d, want %d", tt.s, got, tt.want)
}
})
}
}

View File

@@ -0,0 +1,173 @@
package glm4_moe_lite
import (
"encoding/json"
"fmt"
"strings"
"github.com/ollama/ollama/api"
)
// Renderer renders messages for GLM4-MoE-Lite models.
//
// GLM-4 Thinking Modes (ref: https://docs.z.ai/guides/capabilities/thinking-mode):
//
// 1. INTERLEAVED THINKING
// The model thinks between tool calls and after receiving tool results.
// This enables complex step-by-step reasoning: interpreting each tool output
// before deciding what to do next. Thinking blocks are preserved and returned
// with tool results to maintain reasoning continuity.
//
// 2. PRESERVED THINKING
// The model retains reasoning content from previous assistant turns in context.
// This preserves reasoning continuity across multi-turn conversations. The
// upstream API has a "clear_thinking" parameter to control this:
// - clear_thinking=true: clears reasoning from previous turns (outputs </think>)
// - clear_thinking=false: preserves <think>...</think> blocks from previous turns
//
// 3. TURN-LEVEL THINKING
// Controls whether the model should reason on each turn. The upstream API
// uses "enable_thinking" parameter:
// - enable_thinking=true: outputs <think> to start reasoning
// - enable_thinking=false: outputs </think> to skip reasoning
//
// OLLAMA DEFAULTS:
// - Thinking is ENABLED by default (thinkValue=nil or true outputs <think>)
// - Thinking is PRESERVED by default (reasoning content from previous turns is always
// included in <think>...</think> blocks, equivalent to clear_thinking=false)
// - Users can disable thinking per-turn via thinkValue=false
type Renderer struct{}
// Render renders messages into the GLM4 chat format.
func (r *Renderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
var sb strings.Builder
sb.WriteString("[gMASK]<sop>")
if len(tools) > 0 {
sb.WriteString("<|system|>\n")
sb.WriteString("# Tools\n\n")
sb.WriteString("You may call one or more functions to assist with the user query.\n\n")
sb.WriteString("You are provided with function signatures within <tools></tools> XML tags:\n")
sb.WriteString("<tools>\n")
for _, tool := range tools {
d, _ := json.Marshal(tool)
sb.WriteString(formatToolJSON(d))
sb.WriteString("\n")
}
sb.WriteString("</tools>\n\n")
sb.WriteString("For each function call, output the function name and arguments within the following XML format:\n")
sb.WriteString("<tool_call>{function-name}<arg_key>{arg-key-1}</arg_key><arg_value>{arg-value-1}</arg_value><arg_key>{arg-key-2}</arg_key><arg_value>{arg-value-2}</arg_value>...</tool_call>")
}
think := true
if thinkValue != nil && !thinkValue.Bool() {
think = false
}
for i, message := range messages {
switch message.Role {
case "user":
sb.WriteString("<|user|>")
sb.WriteString(message.Content)
case "assistant":
sb.WriteString("<|assistant|>")
if message.Thinking != "" {
sb.WriteString("<think>" + message.Thinking + "</think>")
} else {
sb.WriteString("</think>")
}
if message.Content != "" {
sb.WriteString(message.Content)
}
if len(message.ToolCalls) > 0 {
for _, toolCall := range message.ToolCalls {
sb.WriteString("<tool_call>" + toolCall.Function.Name)
sb.WriteString(renderToolArguments(toolCall.Function.Arguments))
sb.WriteString("</tool_call>")
}
}
case "tool":
if i == 0 || messages[i-1].Role != "tool" {
sb.WriteString("<|observation|>")
}
sb.WriteString("<tool_response>")
sb.WriteString(message.Content)
sb.WriteString("</tool_response>")
case "system":
sb.WriteString("<|system|>")
sb.WriteString(message.Content)
}
}
sb.WriteString("<|assistant|>")
if think {
sb.WriteString("<think>")
} else {
sb.WriteString("</think>")
}
return sb.String(), nil
}
// renderToolArguments converts tool call arguments to GLM4 XML format.
func renderToolArguments(args api.ToolCallFunctionArguments) string {
var sb strings.Builder
for key, value := range args.All() {
sb.WriteString("<arg_key>" + key + "</arg_key>")
var valueStr string
if str, ok := value.(string); ok {
valueStr = str
} else {
jsonBytes, err := json.Marshal(value)
if err != nil {
valueStr = fmt.Sprintf("%v", value)
} else {
valueStr = string(jsonBytes)
}
}
sb.WriteString("<arg_value>" + valueStr + "</arg_value>")
}
return sb.String()
}
// formatToolJSON formats JSON for GLM4 tool definitions by adding spaces after : and ,
func formatToolJSON(raw []byte) string {
var sb strings.Builder
sb.Grow(len(raw) + len(raw)/10)
inString := false
escaped := false
for i := range raw {
ch := raw[i]
sb.WriteByte(ch)
if inString {
if escaped {
escaped = false
continue
}
if ch == '\\' {
escaped = true
continue
}
if ch == '"' {
inString = false
}
continue
}
if ch == '"' {
inString = true
continue
}
if ch == ':' || ch == ',' {
sb.WriteByte(' ')
}
}
return sb.String()
}

View File

@@ -0,0 +1,203 @@
package glm4_moe_lite
import (
"strings"
"testing"
"github.com/ollama/ollama/api"
)
func TestRendererSimple(t *testing.T) {
r := &Renderer{}
messages := []api.Message{
{Role: "user", Content: "Hello"},
}
// Thinking enabled (default)
result, err := r.Render(messages, nil, nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
expected := "[gMASK]<sop><|user|>Hello<|assistant|><think>"
if result != expected {
t.Errorf("result = %q, want %q", result, expected)
}
}
func TestRendererThinkingDisabled(t *testing.T) {
r := &Renderer{}
messages := []api.Message{
{Role: "user", Content: "Hello"},
}
tv := &api.ThinkValue{Value: false}
result, err := r.Render(messages, nil, tv)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
expected := "[gMASK]<sop><|user|>Hello<|assistant|></think>"
if result != expected {
t.Errorf("result = %q, want %q", result, expected)
}
}
func TestRendererMultiTurn(t *testing.T) {
r := &Renderer{}
messages := []api.Message{
{Role: "user", Content: "What is 2+2?"},
{Role: "assistant", Content: "4", Thinking: "Let me calculate: 2+2=4"},
{Role: "user", Content: "And 3+3?"},
}
result, err := r.Render(messages, nil, nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Check key parts
if !strings.Contains(result, "[gMASK]<sop>") {
t.Error("missing [gMASK]<sop> prefix")
}
if !strings.Contains(result, "<|user|>What is 2+2?") {
t.Error("missing first user message")
}
if !strings.Contains(result, "<|assistant|><think>Let me calculate: 2+2=4</think>4") {
t.Error("missing assistant message with thinking")
}
if !strings.Contains(result, "<|user|>And 3+3?") {
t.Error("missing second user message")
}
if !strings.HasSuffix(result, "<|assistant|><think>") {
t.Errorf("should end with <|assistant|><think>, got suffix: %q", result[len(result)-30:])
}
}
func TestRendererWithSystem(t *testing.T) {
r := &Renderer{}
messages := []api.Message{
{Role: "system", Content: "You are a helpful assistant."},
{Role: "user", Content: "Hello"},
}
result, err := r.Render(messages, nil, nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !strings.Contains(result, "<|system|>You are a helpful assistant.") {
t.Error("missing system message")
}
}
func TestRendererWithTools(t *testing.T) {
r := &Renderer{}
messages := []api.Message{
{Role: "user", Content: "What's the weather?"},
}
props := api.NewToolPropertiesMap()
props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}, Description: "The city"})
tools := []api.Tool{
{
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get the weather for a location",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: props,
Required: []string{"location"},
},
},
},
}
result, err := r.Render(messages, tools, nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Check for tool system prompt
if !strings.Contains(result, "<|system|>") {
t.Error("missing system tag for tools")
}
if !strings.Contains(result, "# Tools") {
t.Error("missing tools header")
}
if !strings.Contains(result, "<tools>") {
t.Error("missing tools tag")
}
if !strings.Contains(result, "get_weather") {
t.Error("missing tool name")
}
if !strings.Contains(result, "</tools>") {
t.Error("missing closing tools tag")
}
}
func TestRendererWithToolCalls(t *testing.T) {
r := &Renderer{}
args := api.NewToolCallFunctionArguments()
args.Set("location", "San Francisco")
messages := []api.Message{
{Role: "user", Content: "What's the weather in SF?"},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: args,
},
},
},
},
{Role: "tool", Content: "Sunny, 72F"},
}
result, err := r.Render(messages, nil, nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !strings.Contains(result, "<tool_call>get_weather") {
t.Error("missing tool call")
}
if !strings.Contains(result, "<arg_key>location</arg_key>") {
t.Error("missing arg_key")
}
if !strings.Contains(result, "<arg_value>San Francisco</arg_value>") {
t.Error("missing arg_value")
}
if !strings.Contains(result, "</tool_call>") {
t.Error("missing tool call closing tag")
}
if !strings.Contains(result, "<|observation|>") {
t.Error("missing observation tag")
}
if !strings.Contains(result, "<tool_response>Sunny, 72F</tool_response>") {
t.Error("missing tool response")
}
}
func TestFormatToolJSON(t *testing.T) {
input := []byte(`{"name":"test","value":123}`)
result := formatToolJSON(input)
// Should add spaces after : and ,
if !strings.Contains(result, ": ") {
t.Error("should add space after colon")
}
if !strings.Contains(result, ", ") {
t.Error("should add space after comma")
}
}

1098
x/models/laguna/laguna.go Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,509 @@
package laguna
import (
"math"
"testing"
"github.com/ollama/ollama/x/mlxrunner/batch"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/models/nn"
)
func TestParseConfigLagunaXS(t *testing.T) {
skipIfNoMLX(t)
cfg, err := parseConfig([]byte(`{
"model_type": "laguna",
"hidden_size": 2048,
"intermediate_size": 8192,
"moe_intermediate_size": 512,
"shared_expert_intermediate_size": 512,
"num_hidden_layers": 4,
"num_attention_heads": 48,
"num_attention_heads_per_layer": [48, 64, 64, 64],
"num_key_value_heads": 8,
"head_dim": 128,
"vocab_size": 100352,
"max_position_embeddings": 131072,
"layer_types": ["full_attention", "sliding_attention", "sliding_attention", "sliding_attention"],
"sliding_window": 512,
"mlp_only_layers": [0],
"decoder_sparse_step": 1,
"num_experts": 256,
"num_experts_per_tok": 8,
"norm_topk_prob": true,
"moe_routed_scaling_factor": 2.5,
"gating": "per-head",
"rms_norm_eps": 1e-6,
"partial_rotary_factor": 0.5,
"rope_parameters": {
"rope_theta": 500000,
"rope_type": "yarn",
"factor": 32,
"original_max_position_embeddings": 4096,
"beta_fast": 64,
"beta_slow": 1,
"attention_factor": 1
},
"swa_rope_parameters": {
"partial_rotary_factor": 1.0,
"rope_theta": 10000,
"rope_type": "linear"
}
}`))
if err != nil {
t.Fatal(err)
}
if cfg.FullRopeDim != 64 {
t.Fatalf("FullRopeDim = %d, want 64", cfg.FullRopeDim)
}
if cfg.FullRopeBase != 500000 {
t.Fatalf("FullRopeBase = %v, want 500000", cfg.FullRopeBase)
}
if cfg.FullRopeScale != 1 {
t.Fatalf("FullRopeScale = %v, want explicit YaRN attention_factor", cfg.FullRopeScale)
}
if cfg.FullRopeFreqs == nil {
t.Fatal("FullRopeFreqs should be precomputed for YaRN")
}
if cfg.SlidingRopeDim != 128 {
t.Fatalf("SlidingRopeDim = %d, want 128", cfg.SlidingRopeDim)
}
if cfg.SlidingRopeBase != 10000 {
t.Fatalf("SlidingRopeBase = %v, want 10000", cfg.SlidingRopeBase)
}
if !layerIsSliding(&cfg, 1) {
t.Fatal("layer 1 should use sliding attention")
}
if layerUsesMoE(&cfg, 0) {
t.Fatal("layer 0 should be dense due to mlp_only_layers")
}
if !layerUsesMoE(&cfg, 1) {
t.Fatal("layer 1 should use MoE")
}
if got := numHeadsForLayer(&cfg, 1); got != 64 {
t.Fatalf("numHeadsForLayer(1) = %d, want 64", got)
}
}
func TestParseConfigLagunaFP8RopeScaling(t *testing.T) {
skipIfNoMLX(t)
cfg, err := parseConfig([]byte(`{
"hidden_size": 2048,
"intermediate_size": 8192,
"num_hidden_layers": 1,
"num_attention_heads": 48,
"num_key_value_heads": 8,
"head_dim": 128,
"vocab_size": 100352,
"max_position_embeddings": 131072,
"rope_theta": 500000,
"partial_rotary_factor": 0.5,
"rope_scaling": {
"rope_type": "yarn",
"factor": 32
}
}`))
if err != nil {
t.Fatal(err)
}
if cfg.FullRopeBase != 500000 {
t.Fatalf("FullRopeBase = %v, want 500000", cfg.FullRopeBase)
}
if cfg.FullRopeDim != 64 {
t.Fatalf("FullRopeDim = %d, want 64", cfg.FullRopeDim)
}
}
func TestParseConfigLagunaGASchema(t *testing.T) {
skipIfNoMLX(t)
cfg, err := parseConfig([]byte(`{
"model_type": "laguna",
"hidden_size": 2048,
"intermediate_size": 8192,
"moe_intermediate_size": 512,
"shared_expert_intermediate_size": 512,
"num_hidden_layers": 4,
"num_attention_heads": 48,
"num_attention_heads_per_layer": [48, 64, 64, 64],
"num_key_value_heads": 8,
"head_dim": 128,
"vocab_size": 100352,
"max_position_embeddings": 131072,
"layer_types": ["full_attention", "sliding_attention", "sliding_attention", "sliding_attention"],
"sliding_window": 512,
"mlp_layer_types": ["dense", "sparse", "sparse", "sparse"],
"num_experts": 256,
"num_experts_per_tok": 8,
"moe_routed_scaling_factor": 2.5,
"gating": true,
"rms_norm_eps": 1e-6,
"partial_rotary_factor": 0.5,
"rope_parameters": {
"full_attention": {
"rope_theta": 500000,
"rope_type": "yarn",
"factor": 32,
"original_max_position_embeddings": 4096,
"beta_fast": 64,
"beta_slow": 1,
"attention_factor": 1,
"partial_rotary_factor": 0.5
},
"sliding_attention": {
"rope_theta": 10000,
"rope_type": "default",
"partial_rotary_factor": 1.0
}
}
}`))
if err != nil {
t.Fatal(err)
}
if cfg.Gating != "per-head" {
t.Fatalf("Gating = %q, want per-head", cfg.Gating)
}
if !cfg.NormTopKProb {
t.Fatal("NormTopKProb should default true")
}
if cfg.FullRopeBase != 500000 {
t.Fatalf("FullRopeBase = %v, want 500000", cfg.FullRopeBase)
}
if cfg.SlidingRopeBase != 10000 {
t.Fatalf("SlidingRopeBase = %v, want 10000", cfg.SlidingRopeBase)
}
if cfg.FullRopeDim != 64 {
t.Fatalf("FullRopeDim = %d, want 64", cfg.FullRopeDim)
}
if cfg.SlidingRopeDim != 128 {
t.Fatalf("SlidingRopeDim = %d, want 128", cfg.SlidingRopeDim)
}
if layerUsesMoE(&cfg, 0) {
t.Fatal("layer 0 should be dense due to mlp_layer_types")
}
if !layerUsesMoE(&cfg, 1) {
t.Fatal("layer 1 should use MoE")
}
}
func TestTinyLagunaLoadAndForward(t *testing.T) {
skipIfNoMLX(t)
cfg, err := parseConfig([]byte(`{
"model_type": "laguna",
"hidden_size": 8,
"intermediate_size": 12,
"moe_intermediate_size": 4,
"shared_expert_intermediate_size": 4,
"num_hidden_layers": 2,
"num_attention_heads": 2,
"num_attention_heads_per_layer": [2, 2],
"num_key_value_heads": 1,
"head_dim": 4,
"vocab_size": 16,
"max_position_embeddings": 64,
"layer_types": ["full_attention", "sliding_attention"],
"sliding_window": 2,
"mlp_only_layers": [0],
"decoder_sparse_step": 1,
"num_experts": 2,
"num_experts_per_tok": 1,
"norm_topk_prob": false,
"moe_routed_scaling_factor": 2.5,
"gating": "per-head",
"rms_norm_eps": 1e-5,
"partial_rotary_factor": 0.5,
"rope_parameters": {
"rope_theta": 10000,
"rope_type": "yarn",
"factor": 2,
"original_max_position_embeddings": 16,
"beta_fast": 32,
"beta_slow": 1
},
"swa_rope_parameters": {
"partial_rotary_factor": 1.0,
"rope_theta": 10000,
"rope_type": "linear"
}
}`))
if err != nil {
t.Fatal(err)
}
m := &Model{
Config: &cfg,
Layers: []*Layer{
{LayerIdx: 0, IsSliding: false},
{LayerIdx: 1, IsSliding: true},
},
}
tensors := tinyLagunaTensors()
if err := m.LoadWeights(tensors); err != nil {
t.Fatalf("LoadWeights failed: %v", err)
}
tokens := mlx.FromValues([]int32{1, 2, 3}, 1, 3)
caches := m.NewCaches()
defer func() {
for _, c := range caches {
if c != nil {
c.Free()
}
}
}()
hidden := m.Forward(&batch.Batch{
InputIDs: tokens,
SeqOffsets: []int32{0},
SeqQueryLens: []int32{int32(tokens.Dim(1))},
}, caches)
mlx.Eval(hidden)
if got := hidden.Dims(); len(got) != 3 || got[0] != 1 || got[1] != 3 || got[2] != 8 {
t.Fatalf("hidden shape = %v, want [1 3 8]", got)
}
logits := m.Unembed(hidden)
mlx.Eval(logits)
if got := logits.Dims(); len(got) != 3 || got[0] != 1 || got[1] != 3 || got[2] != 16 {
t.Fatalf("logits shape = %v, want [1 3 16]", got)
}
for i, v := range logits.Floats() {
if math.IsNaN(float64(v)) || math.IsInf(float64(v), 0) {
t.Fatalf("logits[%d] is not finite: %v", i, v)
}
}
}
func TestTinyLagunaLoadWeightsFusesDenseGateUp(t *testing.T) {
skipIfNoMLX(t)
cfg, err := parseConfig([]byte(`{
"model_type": "laguna",
"hidden_size": 8,
"intermediate_size": 12,
"moe_intermediate_size": 4,
"shared_expert_intermediate_size": 4,
"num_hidden_layers": 2,
"num_attention_heads": 2,
"num_attention_heads_per_layer": [2, 2],
"num_key_value_heads": 1,
"head_dim": 4,
"vocab_size": 16,
"max_position_embeddings": 64,
"layer_types": ["full_attention", "sliding_attention"],
"sliding_window": 2,
"mlp_only_layers": [0],
"decoder_sparse_step": 1,
"num_experts": 2,
"num_experts_per_tok": 1,
"norm_topk_prob": false,
"moe_routed_scaling_factor": 2.5,
"gating": "per-head",
"rms_norm_eps": 1e-5
}`))
if err != nil {
t.Fatal(err)
}
m := &Model{
Config: &cfg,
Layers: []*Layer{
{LayerIdx: 0, IsSliding: false},
{LayerIdx: 1, IsSliding: true},
},
}
if err := m.LoadWeights(tinyLagunaTensors()); err != nil {
t.Fatalf("LoadWeights failed: %v", err)
}
moe, ok := m.Layers[1].MLP.(*SparseMoE)
if !ok {
t.Fatalf("layer 1 MLP type = %T, want *SparseMoE", m.Layers[1].MLP)
}
if !moe.SwitchMLP.UseFusedGateUp {
t.Fatal("expected dense SwitchMLP to fuse gate/up expert weights")
}
if moe.SwitchMLP.GateUpWeight == nil {
t.Fatal("expected fused GateUpWeight to be populated")
}
if got, want := moe.SwitchMLP.GateUpWeight.Dims(), []int{2, 8, 8}; len(got) != len(want) || got[0] != want[0] || got[1] != want[1] || got[2] != want[2] {
t.Fatalf("GateUpWeight dims = %v, want %v", got, want)
}
}
func TestSparseMoERouteBiasAffectsSelectionNotRoutingWeights(t *testing.T) {
skipIfNoMLX(t)
cfg := &Config{
HiddenSize: 1,
NumExperts: 2,
NumExpertsPerTok: 1,
NormTopKProb: false,
}
moe := &SparseMoE{
Gate: nn.NewLinear(mlx.FromValues([]float32{-4, -3}, 2, 1).AsType(mlx.DTypeBFloat16), nil),
EScoreCorrectionBias: mlx.FromValues([]float32{0.5, 0}, 2),
}
xFlat := mlx.FromValues([]float32{1}, 1, int(cfg.HiddenSize)).AsType(mlx.DTypeBFloat16)
scores, inds := moe.route(xFlat, cfg)
scores = scores.AsType(mlx.DTypeFloat32)
inds = inds.AsType(mlx.DTypeInt32)
mlx.Eval(scores, inds)
gates := moe.Gate.Forward(xFlat).AsType(mlx.DTypeFloat32)
probs := mlx.Sigmoid(gates)
mlx.Eval(probs)
probVals := probs.Floats()
if probVals[0] >= probVals[1] {
t.Fatalf("expected unbiased sigmoid scores to prefer expert 1, got %v", probVals)
}
if probVals[0]+0.5 <= probVals[1] {
t.Fatalf("expected bias to flip selection to expert 0, got probs=%v", probVals)
}
if got := inds.Ints(); len(got) != 1 || got[0] != 0 {
t.Fatalf("selected experts = %v, want [0]", got)
}
if got := scores.Floats(); len(got) != 1 || math.Abs(float64(got[0]-probVals[0])) > 1e-6 {
t.Fatalf("routing weights = %v, want [%v] using unbiased sigmoid scores", got, probVals[0])
}
}
func TestSwitchMLPFusedGateUpMatchesSeparate(t *testing.T) {
skipIfNoMLX(t)
cfg := &Config{HiddenSize: 4, NumExpertsPerTok: 2}
B, L := int32(2), int32(3)
xVals := make([]float32, int(B*L*cfg.HiddenSize))
for i := range xVals {
xVals[i] = float32((i%17)-8) * 0.01
}
x := mlx.FromValues(xVals, int(B), int(L), int(cfg.HiddenSize)).AsType(mlx.DTypeBFloat16)
indicesVals := make([]int32, B*L*cfg.NumExpertsPerTok)
for i := 0; i < len(indicesVals); i += int(cfg.NumExpertsPerTok) {
indicesVals[i] = int32((i / int(cfg.NumExpertsPerTok)) % 2)
indicesVals[i+1] = int32(((i / int(cfg.NumExpertsPerTok)) + 1) % 2)
}
indices := mlx.FromValues(indicesVals, int(B*L), int(cfg.NumExpertsPerTok))
separate := &SwitchMLP{
GateWeight: makePatternExpertWeight(2, 4, 3, 0.011),
UpWeight: makePatternExpertWeight(2, 4, 3, 0.017),
DownWeight: makePatternExpertWeight(2, 3, 4, 0.013),
}
fused := &SwitchMLP{
GateUpWeight: fuseExpertStacks(separate.GateWeight, separate.UpWeight, 2),
DownWeight: separate.DownWeight,
UseFusedGateUp: true,
}
gotSeparate := separate.Forward(x, indices, cfg)
gotFused := fused.Forward(x, indices, cfg)
mlx.Eval(gotSeparate, gotFused)
gotFusedF32 := gotFused.AsType(mlx.DTypeFloat32)
gotSeparateF32 := gotSeparate.AsType(mlx.DTypeFloat32)
mlx.Eval(gotFusedF32, gotSeparateF32)
assertFloatSlicesClose(t, gotFusedF32.Floats(), gotSeparateF32.Floats(), 1e-5)
}
func TestCombinedTensorGlobalScaleIgnoresInputGlobalScale(t *testing.T) {
skipIfNoMLX(t)
tensors := map[string]*mlx.Array{
"proj.weight.global_scale": mlx.FromValues([]float32{0.25}, 1),
"proj.weight.input_global_scale": mlx.FromValues([]float32{8}, 1),
}
got, _ := combinedTensorGlobalScale(tensors, "proj.weight")
if got == nil {
t.Fatal("combinedTensorGlobalScale returned nil")
}
mlx.Eval(got)
vals := got.Floats()
if len(vals) != 1 || vals[0] != 0.25 {
t.Fatalf("combinedTensorGlobalScale = %v, want [0.25]", vals)
}
}
func tinyLagunaTensors() map[string]*mlx.Array {
tensors := map[string]*mlx.Array{
"model.embed_tokens.weight": weights(16, 8),
"model.norm.weight": ones(8),
"lm_head.weight": weights(16, 8),
}
for layer := range 2 {
prefix := "model.layers." + string(rune('0'+layer))
tensors[prefix+".input_layernorm.weight"] = ones(8)
tensors[prefix+".post_attention_layernorm.weight"] = ones(8)
tensors[prefix+".self_attn.q_proj.weight"] = weights(8, 8)
tensors[prefix+".self_attn.k_proj.weight"] = weights(4, 8)
tensors[prefix+".self_attn.v_proj.weight"] = weights(4, 8)
tensors[prefix+".self_attn.o_proj.weight"] = weights(8, 8)
tensors[prefix+".self_attn.g_proj.weight"] = weights(2, 8)
tensors[prefix+".self_attn.q_norm.weight"] = ones(4)
tensors[prefix+".self_attn.k_norm.weight"] = ones(4)
}
tensors["model.layers.0.mlp.gate_proj.weight"] = weights(12, 8)
tensors["model.layers.0.mlp.up_proj.weight"] = weights(12, 8)
tensors["model.layers.0.mlp.down_proj.weight"] = weights(8, 12)
tensors["model.layers.1.mlp.gate.weight"] = weights(2, 8)
tensors["model.layers.1.mlp.experts.e_score_correction_bias"] = mlx.FromValues([]float32{0.1, -0.1}, 2)
for expert := range 2 {
prefix := "model.layers.1.mlp.experts." + string(rune('0'+expert))
tensors[prefix+".gate_proj.weight"] = weights(4, 8)
tensors[prefix+".up_proj.weight"] = weights(4, 8)
tensors[prefix+".down_proj.weight"] = weights(8, 4)
}
tensors["model.layers.1.mlp.shared_expert.gate_proj.weight"] = weights(4, 8)
tensors["model.layers.1.mlp.shared_expert.up_proj.weight"] = weights(4, 8)
tensors["model.layers.1.mlp.shared_expert.down_proj.weight"] = weights(8, 4)
return tensors
}
func makeExpertWeight(vals []float32, dims ...int) *mlx.Array {
return mlx.FromValues(vals, dims...).AsType(mlx.DTypeBFloat16)
}
func makePatternExpertWeight(numExperts, rows, cols int, scale float32) *mlx.Array {
vals := make([]float32, numExperts*rows*cols)
for i := range vals {
vals[i] = float32((i%23)-11) * scale
}
return makeExpertWeight(vals, numExperts, rows, cols)
}
func assertFloatSlicesClose(t *testing.T, got, want []float32, tol float64) {
t.Helper()
if len(got) != len(want) {
t.Fatalf("length mismatch: got %d want %d", len(got), len(want))
}
for i := range got {
if math.Abs(float64(got[i]-want[i])) > tol {
t.Fatalf("value[%d] = %v, want %v (tol=%g)", i, got[i], want[i], tol)
}
}
}
func weights(rows, cols int) *mlx.Array {
vals := make([]float32, rows*cols)
for i := range vals {
vals[i] = float32((i%7)-3) * 0.01
}
return mlx.FromValues(vals, rows, cols)
}
func ones(n int) *mlx.Array {
vals := make([]float32, n)
for i := range vals {
vals[i] = 1
}
return mlx.FromValues(vals, n)
}
func skipIfNoMLX(t *testing.T) {
t.Helper()
if err := mlx.CheckInit(); err != nil {
t.Skipf("MLX not available: %v", err)
}
}

319
x/models/llama/llama.go Normal file
View File

@@ -0,0 +1,319 @@
// Package llama provides a Llama-style decoder-only transformer for MLX.
package llama
import (
"encoding/json"
"fmt"
"math"
"github.com/ollama/ollama/x/mlxrunner/batch"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
"github.com/ollama/ollama/x/mlxrunner/model/base"
"github.com/ollama/ollama/x/models/nn"
"github.com/ollama/ollama/x/tokenizer"
)
func init() {
base.Register("LlamaForCausalLM", newModel)
}
// Config holds Llama model configuration.
type Config struct {
HiddenSize int32 `json:"hidden_size"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
IntermediateSize int32 `json:"intermediate_size"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumKeyValueHeads int32 `json:"num_key_value_heads"`
VocabSize int32 `json:"vocab_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
TieWordEmbeddings bool `json:"tie_word_embeddings"`
// Quantization parameters (set during load based on model quantization).
QuantGroupSize int `json:"-"`
QuantBits int `json:"-"`
QuantMode string `json:"-"`
TensorQuant map[string]*model.TensorQuantInfo `json:"-"`
// Computed fields.
HeadDim int32 `json:"-"`
Scale float32 `json:"-"`
}
// Model is a Llama text model.
type Model struct {
EmbedTokens nn.EmbeddingLayer
Layers []*Layer
Norm *nn.RMSNorm
LMHead nn.LinearLayer
tok *tokenizer.Tokenizer
*Config
weightPrefix string
}
type Layer struct {
Attention *Attention
MLP *MLP
AttentionNorm *nn.RMSNorm
MLPNorm *nn.RMSNorm
}
type Attention struct {
QProj nn.LinearLayer
KProj nn.LinearLayer
VProj nn.LinearLayer
OProj nn.LinearLayer
}
type MLP struct {
GateProj nn.LinearLayer
UpProj nn.LinearLayer
DownProj nn.LinearLayer
}
func resolveWeightPrefix(tensors map[string]*mlx.Array) string {
for _, prefix := range []string{"", "language_model."} {
if tensors[prefix+"model.embed_tokens.weight"] != nil {
return prefix
}
}
return ""
}
func newModel(root *model.Root) (base.Model, error) {
configData, err := root.Manifest.ReadConfig("config.json")
if err != nil {
return nil, fmt.Errorf("load config: %w", err)
}
var cfg Config
if err := json.Unmarshal(configData, &cfg); err != nil {
return nil, fmt.Errorf("parse config: %w", err)
}
if cfg.HiddenSize <= 0 {
return nil, fmt.Errorf("invalid hidden_size: %d", cfg.HiddenSize)
}
if cfg.NumAttentionHeads <= 0 {
return nil, fmt.Errorf("invalid num_attention_heads: %d", cfg.NumAttentionHeads)
}
if cfg.NumKeyValueHeads <= 0 {
cfg.NumKeyValueHeads = cfg.NumAttentionHeads
}
if cfg.HiddenSize%cfg.NumAttentionHeads != 0 {
return nil, fmt.Errorf("hidden_size (%d) must be divisible by num_attention_heads (%d)", cfg.HiddenSize, cfg.NumAttentionHeads)
}
if cfg.HeadDim == 0 {
cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
}
if cfg.HeadDim <= 0 {
return nil, fmt.Errorf("invalid head_dim: %d", cfg.HeadDim)
}
if cfg.NumAttentionHeads%cfg.NumKeyValueHeads != 0 {
return nil, fmt.Errorf("num_attention_heads (%d) must be divisible by num_key_value_heads (%d)", cfg.NumAttentionHeads, cfg.NumKeyValueHeads)
}
if cfg.RopeTheta == 0 {
cfg.RopeTheta = 10000
}
if cfg.RMSNormEps == 0 {
cfg.RMSNormEps = 1e-5
}
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
if qt := root.QuantType(); qt != "" {
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams(qt)
if gs := root.GroupSize(); gs > 0 {
cfg.QuantGroupSize = gs
}
} else {
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams("")
}
cfg.TensorQuant = root.AllTensorQuant()
tokData, err := root.Manifest.ReadConfig("tokenizer.json")
if err != nil {
return nil, fmt.Errorf("load tokenizer config: %w", err)
}
tokConfig := &tokenizer.TokenizerConfig{
ConfigJSON: configData,
}
if genConfigData, err := root.Manifest.ReadConfig("generation_config.json"); err == nil {
tokConfig.GenerationConfigJSON = genConfigData
}
if tokConfigData, err := root.Manifest.ReadConfig("tokenizer_config.json"); err == nil {
tokConfig.TokenizerConfigJSON = tokConfigData
}
tok, err := tokenizer.LoadFromBytesWithConfig(tokData, tokConfig)
if err != nil {
return nil, fmt.Errorf("parse tokenizer: %w", err)
}
m := &Model{
Layers: make([]*Layer, cfg.NumHiddenLayers),
Config: &cfg,
tok: tok,
}
return m, nil
}
// LoadWeights receives all tensors loaded from the manifest and assigns them
// to model fields.
func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
m.weightPrefix = resolveWeightPrefix(tensors)
prefix := m.weightPrefix
linears := model.NewLinearFactory(tensors, m.QuantGroupSize, m.QuantBits, m.QuantMode, m.TensorQuant)
embedTokens := model.MakeEmbeddingLayer(tensors, prefix+"model.embed_tokens", m.QuantGroupSize, m.QuantBits, m.QuantMode, m.TensorQuant)
if embedTokens == nil {
return fmt.Errorf("missing embedding weight: %smodel.embed_tokens.weight", prefix)
}
m.EmbedTokens = embedTokens
normWeight := tensors[prefix+"model.norm.weight"]
if normWeight == nil {
return fmt.Errorf("missing final norm weight: %smodel.norm.weight", prefix)
}
m.Norm = nn.NewRMSNorm(normWeight, m.RMSNormEps)
if m.TieWordEmbeddings {
m.LMHead = m.EmbedTokens.AsLinear()
} else if lmHead := linears.Make(prefix + "lm_head"); lmHead != nil {
m.LMHead = lmHead
} else if lmHead := linears.Make("lm_head"); lmHead != nil {
m.LMHead = lmHead
} else {
// Fallback used by many Llama checkpoints where output is tied.
m.LMHead = m.EmbedTokens.AsLinear()
}
for i := int32(0); i < m.NumHiddenLayers; i++ {
layerPrefix := fmt.Sprintf("%smodel.layers.%d", prefix, i)
layer := &Layer{
Attention: &Attention{},
MLP: &MLP{},
}
if w := tensors[layerPrefix+".input_layernorm.weight"]; w != nil {
layer.AttentionNorm = nn.NewRMSNorm(w, m.RMSNormEps)
}
if w := tensors[layerPrefix+".post_attention_layernorm.weight"]; w != nil {
layer.MLPNorm = nn.NewRMSNorm(w, m.RMSNormEps)
}
layer.Attention.QProj = linears.Make(layerPrefix + ".self_attn.q_proj")
layer.Attention.KProj = linears.Make(layerPrefix + ".self_attn.k_proj")
layer.Attention.VProj = linears.Make(layerPrefix + ".self_attn.v_proj")
layer.Attention.OProj = linears.Make(layerPrefix + ".self_attn.o_proj")
layer.MLP.GateProj = linears.Make(layerPrefix + ".mlp.gate_proj")
layer.MLP.UpProj = linears.Make(layerPrefix + ".mlp.up_proj")
layer.MLP.DownProj = linears.Make(layerPrefix + ".mlp.down_proj")
if layer.AttentionNorm == nil {
return fmt.Errorf("layer %d: missing input_layernorm", i)
}
if layer.MLPNorm == nil {
return fmt.Errorf("layer %d: missing post_attention_layernorm", i)
}
if layer.Attention.QProj == nil || layer.Attention.KProj == nil || layer.Attention.VProj == nil || layer.Attention.OProj == nil {
return fmt.Errorf("layer %d: missing attention projections", i)
}
if layer.MLP.GateProj == nil || layer.MLP.UpProj == nil || layer.MLP.DownProj == nil {
return fmt.Errorf("layer %d: missing mlp projections", i)
}
m.Layers[i] = layer
}
return nil
}
func (m *Model) Forward(b *batch.Batch, caches []cache.Cache) *mlx.Array {
dims := b.InputIDs.Dims()
B, L := int32(dims[0]), int32(dims[1])
positions := mlx.FromValues(b.SeqOffsets, len(b.SeqOffsets))
h := m.EmbedTokens.Forward(b.InputIDs)
for i, layer := range m.Layers {
var c cache.Cache
if caches != nil && i < len(caches) {
c = caches[i]
}
h = layer.Forward(h, b, c, positions, B, L, m.Config)
}
return m.Norm.Forward(h, m.RMSNormEps)
}
func (m *Model) Unembed(x *mlx.Array) *mlx.Array {
return m.LMHead.Forward(x)
}
func (m *Model) NumLayers() int {
return len(m.Layers)
}
func (m *Model) MaxContextLength() int {
return int(m.MaxPositionEmbeddings)
}
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
return m.tok
}
func (m *Model) NewCaches() []cache.Cache {
caches := make([]cache.Cache, len(m.Layers))
for i := range caches {
caches[i] = cache.NewKVCache()
}
return caches
}
func (l *Layer) Forward(x *mlx.Array, b *batch.Batch, c cache.Cache, positions *mlx.Array, B, L int32, cfg *Config) *mlx.Array {
h := mlx.Add(x, l.Attention.Forward(l.AttentionNorm.Forward(x, cfg.RMSNormEps), b, c, positions, B, L, cfg))
return mlx.Add(h, l.MLP.Forward(l.MLPNorm.Forward(h, cfg.RMSNormEps)))
}
func (a *Attention) Forward(x *mlx.Array, b *batch.Batch, c cache.Cache, positions *mlx.Array, B, L int32, cfg *Config) *mlx.Array {
q := a.QProj.Forward(x)
k := a.KProj.Forward(x)
v := a.VProj.Forward(x)
q = mlx.Reshape(q, B, L, cfg.NumAttentionHeads, cfg.HeadDim)
q = mlx.Transpose(q, 0, 2, 1, 3)
k = mlx.Reshape(k, B, L, cfg.NumKeyValueHeads, cfg.HeadDim)
k = mlx.Transpose(k, 0, 2, 1, 3)
v = mlx.Reshape(v, B, L, cfg.NumKeyValueHeads, cfg.HeadDim)
v = mlx.Transpose(v, 0, 2, 1, 3)
q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, positions)
k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, positions)
// MLX SDPA supports grouped-query attention directly (Q heads can be a
// multiple of K/V heads), so avoid materializing repeated K/V tensors.
var kv nn.SDPAOption
if c != nil {
history := c.(cache.Attention).Update(b, k, v)
kv = nn.WithKVHistory(history)
} else {
kv = nn.WithKV(k, v, b.SeqQueryLens)
}
out := nn.ScaledDotProductAttention(b, q, cfg.Scale, kv, nn.WithMask(nn.CausalMask()))
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
return a.OProj.Forward(out)
}
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
return m.DownProj.Forward(mlx.SwiGLU(m.GateProj.Forward(x), m.UpProj.Forward(x)))
}

266
x/models/nn/nn.go Normal file
View File

@@ -0,0 +1,266 @@
package nn
import "github.com/ollama/ollama/x/mlxrunner/mlx"
// Layer is the interface for neural network layers with a Forward method.
type Layer interface {
Forward(x *mlx.Array) *mlx.Array
}
// LinearLayer is an interface for linear layers (both regular and quantized).
type LinearLayer interface {
Forward(x *mlx.Array) *mlx.Array
OutputDim() int32
}
// EmbeddingLayer is an interface for embedding layers that can also expose a
// tied-output projection when the model reuses embedding weights as the LM head.
type EmbeddingLayer interface {
Forward(indices *mlx.Array) *mlx.Array
AsLinear() LinearLayer
}
// Conv1d applies 1D convolution over NLC input.
type Conv1d struct {
Weight *mlx.Array
Bias *mlx.Array
Stride int32
Padding int32
Dilation int32
Groups int32
}
func NewConv1d(weight, bias *mlx.Array, stride, padding, dilation, groups int32) *Conv1d {
if stride <= 0 {
stride = 1
}
if dilation <= 0 {
dilation = 1
}
if groups <= 0 {
groups = 1
}
return &Conv1d{
Weight: weight,
Bias: bias,
Stride: stride,
Padding: padding,
Dilation: dilation,
Groups: groups,
}
}
func (c *Conv1d) Forward(x *mlx.Array) *mlx.Array {
return mlx.Conv1d(x, c.Weight, c.Bias, c.Stride, c.Padding, c.Dilation, c.Groups)
}
// Linear applies an affine transformation: y = x @ W.T + b
type Linear struct {
Weight *mlx.Array
Bias *mlx.Array
}
func NewLinear(weight *mlx.Array, bias *mlx.Array) *Linear {
return &Linear{Weight: weight, Bias: bias}
}
func (l *Linear) Forward(x *mlx.Array) *mlx.Array {
w := l.Weight.Transpose(1, 0)
if l.Bias != nil && l.Bias.Valid() {
return l.Bias.Addmm(x, w, 1.0, 1.0)
}
return x.Matmul(w)
}
func (l *Linear) OutputDim() int32 {
return int32(l.Weight.Dim(0))
}
// QuantizedLinear applies an affine transformation using quantized weights.
type QuantizedLinear struct {
Weight *mlx.Array // Quantized weight data
Scales *mlx.Array // Scale factors for dequantization
QBiases *mlx.Array // Quantization biases (nil for nvfp4)
Bias *mlx.Array // Layer bias [output_dims] or nil
GlobalScale *mlx.Array // Per-tensor global scale for double-scale nvfp4 (nil for standard)
GroupSize int
Bits int
Mode string
}
func NewQuantizedLinear(weight *mlx.Array, bias *mlx.Array, groupSize, bits int, mode string) *QuantizedLinear {
qw, scales, qbiases := mlx.Quantize(weight, groupSize, bits, mode)
if qbiases != nil {
mlx.Eval(qw, scales, qbiases)
} else {
mlx.Eval(qw, scales)
}
return &QuantizedLinear{
Weight: qw,
Scales: scales,
QBiases: qbiases,
Bias: bias,
GroupSize: groupSize,
Bits: bits,
Mode: mode,
}
}
func (ql *QuantizedLinear) Forward(x *mlx.Array) *mlx.Array {
var out *mlx.Array
if ql.GlobalScale != nil {
// Double-scale nvfp4 (e.g., NVIDIA ModelOpt): standard quantized_matmul
// followed by global_scale multiply. The global_scale is a per-tensor
// F32 scalar (weight_scale_2 in NVIDIA's format).
// TODO: switch to a fused double-scale matmul once MLX has kernel
// coverage for this path.
out = mlx.QuantizedMatmul(x, ql.Weight, ql.Scales, ql.QBiases, true, ql.GroupSize, ql.Bits, ql.Mode)
out = mlx.Mul(out, ql.GlobalScale)
} else {
out = mlx.QuantizedMatmul(x, ql.Weight, ql.Scales, ql.QBiases, true, ql.GroupSize, ql.Bits, ql.Mode)
}
if ql.Bias != nil && ql.Bias.Valid() {
out = out.Add(ql.Bias)
}
return out
}
func (ql *QuantizedLinear) OutputDim() int32 {
return int32(ql.Weight.Dim(0))
}
// RMSNorm represents an RMS normalization layer.
type RMSNorm struct {
Weight *mlx.Array
Eps float32
}
func NewRMSNorm(weight *mlx.Array, eps float32) *RMSNorm {
return &RMSNorm{Weight: weight, Eps: eps}
}
func (rn *RMSNorm) Forward(x *mlx.Array, eps float32) *mlx.Array {
if eps == 0 {
eps = rn.Eps
}
return mlx.RMSNormFn(x, rn.Weight, eps)
}
// Embedding represents an embedding layer.
type Embedding struct {
Weight *mlx.Array
}
func NewEmbedding(weight *mlx.Array) *Embedding {
return &Embedding{Weight: weight}
}
func (e *Embedding) Forward(indices *mlx.Array) *mlx.Array {
return e.Weight.TakeAxis(indices, 0)
}
func (e *Embedding) AsLinear() LinearLayer {
return NewLinear(e.Weight, nil)
}
// QuantizedEmbedding performs row-wise embedding lookup from affine/nvfp4/etc.
// packed weights and dequantizes only the selected rows.
type QuantizedEmbedding struct {
Weight *mlx.Array
Scales *mlx.Array
QBiases *mlx.Array
GroupSize int
Bits int
Mode string
}
func NewQuantizedEmbedding(weight, scales, qbiases *mlx.Array, groupSize, bits int, mode string) *QuantizedEmbedding {
return &QuantizedEmbedding{
Weight: weight,
Scales: scales,
QBiases: qbiases,
GroupSize: groupSize,
Bits: bits,
Mode: mode,
}
}
func (qe *QuantizedEmbedding) Forward(indices *mlx.Array) *mlx.Array {
weight := qe.Weight.TakeAxis(indices, 0)
scales := qe.Scales.TakeAxis(indices, 0)
var qbiases *mlx.Array
if qe.QBiases != nil && qe.QBiases.Valid() {
qbiases = qe.QBiases.TakeAxis(indices, 0)
}
return mlx.Dequantize(weight, scales, qbiases, qe.GroupSize, qe.Bits, qe.Mode)
}
func (qe *QuantizedEmbedding) AsLinear() LinearLayer {
return &QuantizedLinear{
Weight: qe.Weight,
Scales: qe.Scales,
QBiases: qe.QBiases,
GroupSize: qe.GroupSize,
Bits: qe.Bits,
Mode: qe.Mode,
}
}
// LayerNorm represents a standard layer normalization layer (with bias).
type LayerNorm struct {
Weight *mlx.Array
Bias *mlx.Array
Eps float32
}
func (ln *LayerNorm) Forward(x *mlx.Array) *mlx.Array {
eps := ln.Eps
if eps == 0 {
eps = 1e-5
}
return mlx.LayerNormFn(x, ln.Weight, ln.Bias, eps)
}
// MultiLinearLayer is an interface for per-head linear layers.
type MultiLinearLayer interface {
Forward(x *mlx.Array) *mlx.Array
}
// MultiLinear performs per-head linear projections.
// Weight shape: [num_heads, output_dims, input_dims]
type MultiLinear struct {
Weight *mlx.Array
}
func NewMultiLinear(weight *mlx.Array) *MultiLinear {
return &MultiLinear{Weight: weight}
}
func (ml *MultiLinear) Forward(x *mlx.Array) *mlx.Array {
wT := ml.Weight.Transpose(0, 2, 1)
return x.Matmul(wT)
}
// ApplyCausalMask applies causal (lower triangular) mask to attention scores.
func ApplyCausalMask(scores *mlx.Array) *mlx.Array {
shape := scores.Dims()
seqLen := int32(shape[2])
mask := mlx.Tri(seqLen, seqLen, 0)
negInf := mlx.NewScalarArray(float32(-1e9))
mask = mask.ExpandDims(0).ExpandDims(0)
return mlx.Where(mask, scores, negInf)
}
// ApplyCausalMaskWithOffset applies causal mask for cached attention.
func ApplyCausalMaskWithOffset(scores *mlx.Array, offset int32) *mlx.Array {
if offset == 0 {
return ApplyCausalMask(scores)
}
shape := scores.Dims()
queryLen := int32(shape[2])
keyLen := int32(shape[3])
mask := mlx.Tri(queryLen, keyLen, int(offset))
negInf := mlx.NewScalarArray(float32(-1e9))
mask = mask.ExpandDims(0).ExpandDims(0)
return mlx.Where(mask, scores, negInf)
}

187
x/models/nn/nn_test.go Normal file
View File

@@ -0,0 +1,187 @@
package nn
import (
"math"
"testing"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
func skipIfNoMLX(t *testing.T) {
t.Helper()
if err := mlx.CheckInit(); err != nil {
t.Skipf("MLX not available: %v", err)
}
}
func approxEqual(a, b, tol float32) bool {
return float32(math.Abs(float64(a-b))) < tol
}
// TestLayerNormNoBias verifies LayerNorm without bias against manual computation.
func TestLayerNormNoBias(t *testing.T) {
skipIfNoMLX(t)
// Input: [1, 4] — single row, 4 features
x := mlx.FromValues([]float32{1, 2, 3, 4}, 1, 4)
weight := mlx.FromValues([]float32{1, 1, 1, 1}, 4)
mlx.Eval(x, weight)
ln := &LayerNorm{Weight: weight, Eps: 1e-5}
out := ln.Forward(x)
mlx.Eval(out)
data := out.Floats()
if len(data) != 4 {
t.Fatalf("expected 4 values, got %d", len(data))
}
// Manual LayerNorm: mean=2.5, var=1.25, std=sqrt(1.25+1e-5)
// normalized = (x - mean) / std
mean := float32(2.5)
variance := float32(1.25)
std := float32(math.Sqrt(float64(variance + 1e-5)))
for i, v := range []float32{1, 2, 3, 4} {
expected := (v - mean) / std
if !approxEqual(data[i], expected, 1e-4) {
t.Errorf("index %d: expected %.6f, got %.6f", i, expected, data[i])
}
}
}
// TestLayerNormWithBias verifies LayerNorm with weight and bias.
func TestLayerNormWithBias(t *testing.T) {
skipIfNoMLX(t)
x := mlx.FromValues([]float32{1, 2, 3, 4}, 1, 4)
weight := mlx.FromValues([]float32{2, 2, 2, 2}, 4)
bias := mlx.FromValues([]float32{10, 20, 30, 40}, 4)
mlx.Eval(x, weight, bias)
ln := &LayerNorm{Weight: weight, Bias: bias, Eps: 1e-5}
out := ln.Forward(x)
mlx.Eval(out)
data := out.Floats()
if len(data) != 4 {
t.Fatalf("expected 4 values, got %d", len(data))
}
mean := float32(2.5)
variance := float32(1.25)
std := float32(math.Sqrt(float64(variance + 1e-5)))
biases := []float32{10, 20, 30, 40}
for i, v := range []float32{1, 2, 3, 4} {
expected := ((v-mean)/std)*2 + biases[i]
if !approxEqual(data[i], expected, 1e-4) {
t.Errorf("index %d: expected %.6f, got %.6f", i, expected, data[i])
}
}
}
// TestLayerNormBatched verifies LayerNorm normalizes each row independently.
func TestLayerNormBatched(t *testing.T) {
skipIfNoMLX(t)
// Input: [2, 3] — two rows
x := mlx.FromValues([]float32{
1, 2, 3,
10, 20, 30,
}, 2, 3)
weight := mlx.FromValues([]float32{1, 1, 1}, 3)
mlx.Eval(x, weight)
ln := &LayerNorm{Weight: weight, Eps: 1e-5}
out := ln.Forward(x)
mlx.Eval(out)
data := out.Floats()
if len(data) != 6 {
t.Fatalf("expected 6 values, got %d", len(data))
}
// Each row should be independently normalized.
// Row 0: [1,2,3] mean=2, var=2/3
// Row 1: [10,20,30] mean=20, var=200/3
// After normalization both rows should have the same pattern
// since [10,20,30] = 10*[1,2,3], the normalized values are identical.
for i := range 3 {
if !approxEqual(data[i], data[i+3], 1e-4) {
t.Errorf("row 0 elem %d (%.6f) != row 1 elem %d (%.6f); expected identical normalized values",
i, data[i], i, data[i+3])
}
}
// Verify the normalized values sum to ~0 (mean-centered)
sum := data[0] + data[1] + data[2]
if !approxEqual(sum, 0, 1e-4) {
t.Errorf("normalized row sum should be ~0, got %.6f", sum)
}
}
// TestLayerNormDefaultEps verifies the default epsilon of 1e-5 is used when Eps is 0.
func TestLayerNormDefaultEps(t *testing.T) {
skipIfNoMLX(t)
x := mlx.FromValues([]float32{1, 2, 3, 4}, 1, 4)
weight := mlx.FromValues([]float32{1, 1, 1, 1}, 4)
mlx.Eval(x, weight)
// Eps=0 should use default 1e-5
ln0 := &LayerNorm{Weight: weight, Eps: 0}
out0 := ln0.Forward(x)
mlx.Eval(out0)
lnExplicit := &LayerNorm{Weight: weight, Eps: 1e-5}
outExplicit := lnExplicit.Forward(x)
mlx.Eval(outExplicit)
d0 := out0.Floats()
dE := outExplicit.Floats()
for i := range d0 {
if !approxEqual(d0[i], dE[i], 1e-6) {
t.Errorf("index %d: Eps=0 gave %.6f, Eps=1e-5 gave %.6f", i, d0[i], dE[i])
}
}
}
func TestQuantizedLinearMXFP4MatchesDequantizedWeight(t *testing.T) {
skipIfNoMLX(t)
weightVals := make([]float32, 3*32)
for i := range weightVals {
weightVals[i] = float32((i%11)-5) / 7
}
inputVals := make([]float32, 2*32)
for i := range inputVals {
inputVals[i] = float32((i%7)-3) / 5
}
weight := mlx.FromValues(weightVals, 3, 32).AsType(mlx.DTypeBFloat16)
input := mlx.FromValues(inputVals, 2, 32).AsType(mlx.DTypeBFloat16)
mlx.Eval(weight, input)
ql := NewQuantizedLinear(weight, nil, 32, 4, "mxfp4")
if ql.QBiases != nil {
t.Fatalf("mxfp4 qbiases = %v, want nil", ql.QBiases)
}
dequantizedWeight := mlx.Dequantize(ql.Weight, ql.Scales, ql.QBiases, 32, 4, "mxfp4")
mlx.Eval(dequantizedWeight)
qOut := ql.Forward(input).AsType(mlx.DTypeFloat32)
dOut := NewLinear(dequantizedWeight, nil).Forward(input).AsType(mlx.DTypeFloat32)
mlx.Eval(qOut, dOut)
got := qOut.Floats()
want := dOut.Floats()
if len(got) != len(want) {
t.Fatalf("output length = %d, want %d", len(got), len(want))
}
for i := range got {
if !approxEqual(got[i], want[i], 1e-3) {
t.Fatalf("output[%d] = %.6f, want %.6f", i, got[i], want[i])
}
}
}

259
x/models/nn/recurrent.go Normal file
View File

@@ -0,0 +1,259 @@
package nn
import (
"github.com/ollama/ollama/x/mlxrunner/batch"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
// RecurrentOption configures a call to CausalConv1D or GatedDelta.
type RecurrentOption func(*recurrentConfig)
// recurrentConfig is the resolved set of inputs supplied via
// RecurrentOption. Exactly one of history or (convState/deltaState)
// must be supplied per call.
type recurrentConfig struct {
history *RecurrentHistory
convState *mlx.Array
deltaState *mlx.Array
}
// WithRecurrentHistory supplies a cache's per-layer view of conv and
// delta state. The cache hides any storage layout (per-row, paged,
// gather/scatter) behind the history.
func WithRecurrentHistory(h *RecurrentHistory) RecurrentOption {
return func(c *recurrentConfig) { c.history = h }
}
// WithRecurrentState supplies explicit conv and delta state tensors
// for the no-cache path. Each wrapper consumes one of the two — pass
// nil for the unused slot when calling only one wrapper.
func WithRecurrentState(convState, deltaState *mlx.Array) RecurrentOption {
return func(c *recurrentConfig) {
c.convState = convState
c.deltaState = deltaState
}
}
// resolve applies opts and panics if WithRecurrentHistory and
// WithRecurrentState were combined or neither was supplied.
func resolveRecurrentConfig(opts []RecurrentOption) recurrentConfig {
var cfg recurrentConfig
for _, opt := range opts {
opt(&cfg)
}
haveHistory := cfg.history != nil
haveState := cfg.convState != nil || cfg.deltaState != nil
if haveHistory && haveState {
panic("WithRecurrentHistory and WithRecurrentState are mutually exclusive")
}
if !haveHistory && !haveState {
panic("no recurrent state supplied (use WithRecurrentHistory or WithRecurrentState)")
}
return cfg
}
// CausalConv1D runs a depthwise causal 1D convolution with recurrent
// state management. Prepends the prior conv state along axis 1, runs
// the conv, and returns (output, nextConv). nextConv is the trailing
// convTail positions of the concat — write it back to the cache via
// Put alongside the scan's new delta state.
//
// Conv selection: when conv is non-nil (a full nn.Conv1d layer), it
// runs through conv.Forward. Otherwise weight is treated as the bare
// depthwise kernel [C, K] and the fallback manual implementation runs.
// Exactly one of conv or weight should be non-nil.
//
// Shapes: input [B, L, D]; prior state [B, convTail, D]; output
// [B, L, D] (the causal conv strips the prepended state).
//
// Prior state comes from exactly one of WithRecurrentHistory (cache
// path) or WithRecurrentState (no-cache path).
func CausalConv1D(b *batch.Batch, input *mlx.Array, conv *Conv1d, weight *mlx.Array, convTail int, opts ...RecurrentOption) (out, nextConv *mlx.Array) {
cfg := resolveRecurrentConfig(opts)
var prior *mlx.Array
if cfg.history != nil {
prior = cfg.history.ConvState()
} else {
prior = cfg.convState
}
mask := paddingMask(b, int32(input.Dim(1)))
if mask != nil {
zero := mlx.FromValue(float32(0)).AsType(input.DType())
input = mlx.Where(mlx.ExpandDims(mask, 2), input, zero)
}
concat := mlx.Concatenate([]*mlx.Array{prior, input}, 1)
if conv != nil {
out = conv.Forward(concat)
} else {
out = depthwiseCausalConv1d(concat, weight, int32(input.Dim(1)))
}
B := int32(concat.Dim(0))
total := int32(concat.Dim(1))
D := int32(concat.Dim(2))
// Gather the tail from each of the non-padded sequence ends
if mask != nil && convTail > 0 {
offsets := make([]int32, int(B)*convTail)
for i := range int(B) {
end := b.SeqQueryLens[i]
for k := range convTail {
offsets[i*convTail+k] = end + int32(k)
}
}
positions := mlx.NewArrayInt32(offsets, []int32{B, int32(convTail), 1})
nextConv = mlx.TakeAlongAxis(concat, positions, 1)
} else {
nextConv = mlx.SliceStartStop(concat,
[]int32{0, total - int32(convTail), 0},
[]int32{B, total, D})
}
return out, nextConv
}
// depthwiseCausalConv1d implements a depthwise 1D causal convolution
// manually as a sum of kernel-offset multiplies. x has shape
// [B, inLen, C], weight has shape [C, K]; output has shape [B, outLen, C]
// where outLen = inLen - K + 1 (the caller passes outLen to avoid the
// subtraction). Used as the fallback path in CausalConv1D when no
// full Conv1d layer is configured.
func depthwiseCausalConv1d(x, w *mlx.Array, outLen int32) *mlx.Array {
if x == nil || w == nil {
return nil
}
if w.NumDims() != 2 {
return nil
}
B := int32(x.Dim(0))
C := int32(w.Dim(0))
K := int32(w.Dim(1))
var out *mlx.Array
for i := range K {
seg := mlx.SliceStartStop(x, []int32{0, i, 0}, []int32{B, i + outLen, C})
wi := mlx.SliceStartStop(w, []int32{0, i}, []int32{C, i + 1})
wi = mlx.Reshape(wi, 1, 1, C)
term := mlx.Mul(seg, wi)
if out == nil {
out = term
} else {
out = mlx.Add(out, term)
}
}
return out
}
// GatedDelta wraps mlx.FastGatedDelta with recurrent state management.
// Reads prior delta state from the supplied option and returns
// (output, newDelta). Write newDelta back via the cache's Put
// alongside the conv wrapper's nextConv.
//
// Shape conventions:
//
// q: [B, L, numKeyHeads, headKDim]
// k: [B, L, numKeyHeads, headKDim]
// v: [B, L, numValueHeads, headVDim]
// state: [B, numValueHeads, headVDim, headKDim]
//
// Prior state comes from exactly one of WithRecurrentHistory (cache
// path) or WithRecurrentState (no-cache path).
func GatedDelta(b *batch.Batch, q, k, v, gDecay, beta *mlx.Array, opts ...RecurrentOption) (out, newDelta *mlx.Array) {
cfg := resolveRecurrentConfig(opts)
var state *mlx.Array
if cfg.history != nil {
state = cfg.history.DeltaState()
} else {
state = cfg.deltaState
}
return mlx.FastGatedDelta(q, k, v, gDecay, beta, state, paddingMask(b, int32(q.Dim(1))))
}
// RecurrentHistory is an opaque per-forward view a recurrent cache
// hands to the SSM kernel wrappers — prior conv and delta state
// tensors. Models do not construct this directly; pass it through
// via WithRecurrentHistory, or use WithRecurrentState on the no-cache
// path.
//
// Opaque structure to model code; accessors ConvState/DeltaState
// provide the escape hatch for custom SSM paths.
type RecurrentHistory struct {
convState, deltaState *mlx.Array
}
// NewRecurrentHistory constructs a RecurrentHistory. Intended for
// cache implementations across packages; model code uses
// WithRecurrentHistory / WithRecurrentState instead.
func NewRecurrentHistory(convState, deltaState *mlx.Array) *RecurrentHistory {
return &RecurrentHistory{convState: convState, deltaState: deltaState}
}
// ConvState returns the current convolution state tensor.
//
// Last-resort escape hatch for custom SSM paths — may force a slow
// materialization to canonical form depending on the cache's
// internal storage. Prefer CausalConv1D via WithRecurrentHistory.
func (h *RecurrentHistory) ConvState() *mlx.Array { return h.convState }
// DeltaState returns the current delta state tensor.
//
// Last-resort escape hatch for custom SSM paths — may force a slow
// materialization to canonical form depending on the cache's
// internal storage. Prefer GatedDelta via WithRecurrentHistory.
func (h *RecurrentHistory) DeltaState() *mlx.Array { return h.deltaState }
type paddingMaskInputs struct {
batch *batch.Batch
L int32
}
func (in paddingMaskInputs) build() *mlx.Array {
B := len(in.batch.SeqQueryLens)
needed := false
for i := range B {
if in.batch.SeqQueryLens[i] < in.L {
needed = true
break
}
}
if !needed {
return nil
}
L := int(in.L)
vals := make([]bool, B*L)
for i := range B {
n := int(in.batch.SeqQueryLens[i])
base := i * L
for j := range n {
vals[base+j] = true
}
}
return mlx.FromValues(vals, B, L)
}
// paddingMask derives a [B, L] bool mask from b.SeqQueryLens for
// right-padded inputs (real tokens at [0, len_i), padding at
// [len_i, L)). Returns nil when b has no rows or every row is full —
// the no-padding fast path that costs nothing extra.
func paddingMask(b *batch.Batch, L int32) *mlx.Array {
inputs := paddingMaskInputs{batch: b, L: L}
if cached, ok := b.Memo.Get(inputs); ok {
return cached.(*mlx.Array)
}
mask := inputs.build()
b.Memo.Put(inputs, mask)
return mask
}

View File

@@ -0,0 +1,340 @@
package nn
import (
"math"
"testing"
"github.com/ollama/ollama/x/mlxrunner/batch"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
func ones(dtype mlx.DType, shape ...int) *mlx.Array {
return mlx.AddScalar(mlx.Zeros(dtype, shape...), 1)
}
// fromValues builds a tensor with sequentially-numbered float32
// values so element-by-element parity actually exercises the kernel.
func fromValues(seed float32, shape ...int) *mlx.Array {
n := 1
for _, d := range shape {
n *= d
}
vals := make([]float32, n)
for i := range vals {
vals[i] = seed + 0.1*float32(i)
}
return mlx.FromValues(vals, shape...)
}
// depthwiseCausalRef is a Go-side reference for the depthwise causal
// 1D conv fallback. concat is [B, total, C], weight is [C, K], output
// is [B, total-K+1, C]. Used to anchor the wrapper's parity tests.
func depthwiseCausalRef(concat, weight *mlx.Array) []float32 {
mlx.Eval(concat, weight)
cVals := concat.Floats()
wVals := weight.Floats()
B := concat.Dim(0)
total := concat.Dim(1)
C := concat.Dim(2)
K := weight.Dim(1)
outLen := total - K + 1
out := make([]float32, B*outLen*C)
for bi := range B {
for q := range outLen {
for c := range C {
var sum float32
for k := range K {
x := cVals[bi*total*C+(q+k)*C+c]
w := wVals[c*K+k]
sum += x * w
}
out[bi*outLen*C+q*C+c] = sum
}
}
}
return out
}
// TestCausalConv1DParity drives the wrapper with non-trivial prior,
// input, and weight values, then compares against a direct depthwise-
// causal-conv reference.
func TestCausalConv1DParity(t *testing.T) {
skipIfNoMLX(t)
B, L, D, convTail := 1, 4, 3, 2
K := convTail + 1
input := fromValues(0.5, B, L, D)
prior := fromValues(-0.3, B, convTail, D)
weight := fromValues(0.2, D, K)
out, nextConv := CausalConv1D(&batch.Batch{}, input, nil, weight, convTail, WithRecurrentState(prior, nil))
mlx.Eval(out, nextConv)
concat := mlx.Concatenate([]*mlx.Array{prior, input}, 1)
want := depthwiseCausalRef(concat, weight)
got := out.Floats()
if len(got) != len(want) {
t.Fatalf("out len = %d, want %d", len(got), len(want))
}
for i := range want {
if math.Abs(float64(got[i]-want[i])) > 1e-5 {
t.Fatalf("out[%d]: got %v, want %v", i, got[i], want[i])
}
}
// nextConv (no padding) is the trailing convTail rows of concat.
mlx.Eval(concat)
cVals := concat.Floats()
total := concat.Dim(1)
wantTail := make([]float32, B*convTail*D)
for bi := range B {
for k := range convTail {
for d := range D {
wantTail[bi*convTail*D+k*D+d] = cVals[bi*total*D+(total-convTail+k)*D+d]
}
}
}
tail := nextConv.Floats()
if len(tail) != len(wantTail) {
t.Fatalf("nextConv len = %d, want %d", len(tail), len(wantTail))
}
for i := range wantTail {
if tail[i] != wantTail[i] {
t.Fatalf("nextConv[%d]: got %v, want %v", i, tail[i], wantTail[i])
}
}
}
// TestCausalConv1DPaddedRowParity drives a B=2 batch with one short
// row (qLen<L). For the short row, (a) `out` positions [0..qLen)
// must equal a B=1 reference at length qLen, (b) `nextConv` for the
// short row must be the row's last convTail real positions (not the
// padded tail), (c) the full row must be unaffected.
func TestCausalConv1DPaddedRowParity(t *testing.T) {
skipIfNoMLX(t)
L, D, convTail := 4, 3, 2
qLenShort := 2
K := convTail + 1
weight := fromValues(0.2, D, K)
priorFull := fromValues(0.5, 2, convTail, D)
priorShort := mlx.SliceStartStop(priorFull,
[]int32{1, 0, 0},
[]int32{2, int32(convTail), int32(D)})
// Pad row 1 with arbitrary values past qLenShort — the wrapper
// must zero them before convolving. Distinct values let us catch
// any leak.
inputFull := fromValues(1.0, 1, L, D)
inputShortReal := mlx.FromValues([]float32{
2.0, 2.1, 2.2,
2.3, 2.4, 2.5,
}, 1, qLenShort, D)
inputShortPad := mlx.FromValues([]float32{
99, 99, 99,
99, 99, 99,
}, 1, L-qLenShort, D)
inputShortFull := mlx.Concatenate([]*mlx.Array{inputShortReal, inputShortPad}, 1)
input := mlx.Concatenate([]*mlx.Array{inputFull, inputShortFull}, 0)
b := &batch.Batch{
InputIDs: mlx.Zeros(mlx.DTypeInt32, 2, L),
SeqOffsets: []int32{0, 0},
SeqQueryLens: []int32{int32(L), int32(qLenShort)},
}
out, nextConv := CausalConv1D(b, input, nil, weight, convTail, WithRecurrentState(priorFull, nil))
mlx.Eval(out, nextConv)
// Reference for row 0: B=1 unpadded length-L call.
refOut0, refNextConv0 := CausalConv1D(&batch.Batch{},
inputFull, nil, weight, convTail,
WithRecurrentState(mlx.SliceStartStop(priorFull,
[]int32{0, 0, 0},
[]int32{1, int32(convTail), int32(D)}), nil))
// Reference for row 1: B=1 unpadded length-qLenShort call.
refOut1, refNextConv1 := CausalConv1D(&batch.Batch{},
inputShortReal, nil, weight, convTail,
WithRecurrentState(priorShort, nil))
mlx.Eval(refOut0, refNextConv0, refOut1, refNextConv1)
gotOut := out.Floats()
wantOut0 := refOut0.Floats()
wantOut1 := refOut1.Floats()
for q := range L {
for d := range D {
i := q*D + d
if gotOut[i] != wantOut0[i] {
t.Fatalf("row 0 out[q=%d,d=%d]: got %v, want %v", q, d, gotOut[i], wantOut0[i])
}
}
}
for q := range qLenShort {
for d := range D {
gotI := L*D + q*D + d
refI := q*D + d
if math.Abs(float64(gotOut[gotI]-wantOut1[refI])) > 1e-5 {
t.Fatalf("row 1 real out[q=%d,d=%d]: got %v, want %v", q, d, gotOut[gotI], wantOut1[refI])
}
}
}
// nextConv: row 0 unaffected, row 1 must be the row's real tail
// (positions [qLenShort - convTail, qLenShort) of the per-row
// concat, i.e. the last two real input rows in this setup).
gotTail := nextConv.Floats()
wantTail0 := refNextConv0.Floats()
wantTail1 := refNextConv1.Floats()
for k := range convTail {
for d := range D {
i := k*D + d
if gotTail[i] != wantTail0[i] {
t.Fatalf("row 0 nextConv[k=%d,d=%d]: got %v, want %v", k, d, gotTail[i], wantTail0[i])
}
}
}
for k := range convTail {
for d := range D {
gotI := convTail*D + k*D + d
refI := k*D + d
if gotTail[gotI] != wantTail1[refI] {
t.Fatalf("row 1 nextConv[k=%d,d=%d]: got %v, want %v (must come from real positions, not the padded tail)",
k, d, gotTail[gotI], wantTail1[refI])
}
}
}
}
func TestGatedDeltaZeroFallback(t *testing.T) {
skipIfNoMLX(t)
B, L, nK, nV, dK, dV := 1, 2, 1, 1, 4, 4
q := ones(mlx.DTypeFloat32, B, L, nK, dK)
k := ones(mlx.DTypeFloat32, B, L, nK, dK)
v := ones(mlx.DTypeFloat32, B, L, nV, dV)
gDecay := ones(mlx.DTypeFloat32, B, L, nV)
beta := ones(mlx.DTypeFloat32, B, L, nV)
zero := mlx.Zeros(mlx.DTypeFloat32, B, nV, dV, dK)
outA, stateA := GatedDelta(&batch.Batch{}, q, k, v, gDecay, beta, WithRecurrentState(nil, zero))
outB, stateB := mlx.FastGatedDelta(q, k, v, gDecay, beta, zero, nil)
mlx.Eval(outA, stateA, outB, stateB)
gotOut, wantOut := outA.Floats(), outB.Floats()
for i := range wantOut {
if gotOut[i] != wantOut[i] {
t.Fatalf("output[%d]: wrapper=%v direct=%v", i, gotOut[i], wantOut[i])
}
}
gotState, wantState := stateA.Floats(), stateB.Floats()
for i := range wantState {
if gotState[i] != wantState[i] {
t.Fatalf("state[%d]: wrapper=%v direct=%v", i, gotState[i], wantState[i])
}
}
}
func TestGatedDeltaUsesPriorState(t *testing.T) {
skipIfNoMLX(t)
B, L, nK, nV, dK, dV := 1, 2, 1, 1, 4, 4
q := ones(mlx.DTypeFloat32, B, L, nK, dK)
k := ones(mlx.DTypeFloat32, B, L, nK, dK)
v := ones(mlx.DTypeFloat32, B, L, nV, dV)
gDecay := ones(mlx.DTypeFloat32, B, L, nV)
beta := ones(mlx.DTypeFloat32, B, L, nV)
priorState := mlx.MulScalar(ones(mlx.DTypeFloat32, B, nV, dV, dK), 3)
outA, _ := GatedDelta(&batch.Batch{}, q, k, v, gDecay, beta, WithRecurrentState(nil, priorState))
outB, _ := mlx.FastGatedDelta(q, k, v, gDecay, beta, priorState, nil)
mlx.Eval(outA, outB)
gotOut, wantOut := outA.Floats(), outB.Floats()
for i := range wantOut {
if gotOut[i] != wantOut[i] {
t.Fatalf("output[%d]: wrapper=%v direct=%v", i, gotOut[i], wantOut[i])
}
}
}
// TestGatedDeltaPaddedRowParity drives a B=2 batch where row 1 is
// short (qLen < L). The wrapper must substitute neutral values
// (q=k=v=beta=0, g=1) at row 1's padded positions so the recurrence
// is a no-op there — and row 1's final state must equal the state
// after its last real token. Pinned via parity against a B=1 length-
// qLen call on the same row.
func TestGatedDeltaPaddedRowParity(t *testing.T) {
skipIfNoMLX(t)
L, nK, nV, dK, dV := 4, 1, 1, 4, 4
qLenShort := 2
makeRows := func(seedA, seedB float32, shape ...int) *mlx.Array {
// Build a rank-(len(shape)+1) tensor with B=2 rows from two
// distinct seeds so the rows are not accidentally identical.
n := 1
for _, d := range shape {
n *= d
}
vals := make([]float32, 2*n)
for i := range n {
vals[i] = seedA + 0.1*float32(i)
}
for i := range n {
vals[n+i] = seedB + 0.1*float32(i)
}
full := append([]int{2}, shape...)
return mlx.FromValues(vals, full...)
}
q := makeRows(0.5, -0.5, L, nK, dK)
k := makeRows(0.7, -0.7, L, nK, dK)
v := makeRows(0.3, -0.3, L, nV, dV)
gDecay := makeRows(0.1, -0.1, L, nV)
beta := makeRows(0.4, -0.4, L, nV)
priorState := makeRows(0.2, -0.2, nV, dV, dK)
b := &batch.Batch{
InputIDs: mlx.Zeros(mlx.DTypeInt32, 2, L),
SeqOffsets: []int32{0, 0},
SeqQueryLens: []int32{int32(L), int32(qLenShort)},
}
_, state := GatedDelta(b, q, k, v, gDecay, beta, WithRecurrentState(nil, priorState))
mlx.Eval(state)
// Reference for row 1: B=1 length-qLenShort call against the
// row's real prefix and its prior state slice.
row1Slice := func(a *mlx.Array, axisLens ...int32) *mlx.Array {
dims := a.Dims()
start := make([]int32, len(dims))
stop := make([]int32, len(dims))
start[0], stop[0] = 1, 2
for i := 1; i < len(dims); i++ {
stop[i] = int32(dims[i])
}
// Optionally truncate axis 1 (sequence axis) to qLenShort.
if len(axisLens) >= 1 && len(dims) >= 2 {
stop[1] = axisLens[0]
}
return mlx.SliceStartStop(a, start, stop)
}
q1 := row1Slice(q, int32(qLenShort))
k1 := row1Slice(k, int32(qLenShort))
v1 := row1Slice(v, int32(qLenShort))
gDecay1 := row1Slice(gDecay, int32(qLenShort))
beta1 := row1Slice(beta, int32(qLenShort))
priorRow1 := row1Slice(priorState)
_, refState := mlx.FastGatedDelta(q1, k1, v1, gDecay1, beta1, priorRow1, nil)
mlx.Eval(refState)
gotState := state.Floats()
wantState := refState.Floats()
row1Stride := nV * dV * dK
for i := range row1Stride {
gotV := gotState[row1Stride+i]
wantV := wantState[i]
if math.Abs(float64(gotV-wantV)) > 1e-4 {
t.Fatalf("row 1 final state[%d]: got %v, want %v", i, gotV, wantV)
}
}
}

129
x/models/nn/rope.go Normal file
View File

@@ -0,0 +1,129 @@
package nn
import (
"math"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
// RopeParameters carries common RoPE metadata embedded in model configs.
type RopeParameters struct {
RopeTheta float32 `json:"rope_theta"`
RopeType string `json:"rope_type"`
Type string `json:"type"`
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
Factor float32 `json:"factor"`
OriginalMaxPositionEmbeddings int32 `json:"original_max_position_embeddings"`
BetaFast float32 `json:"beta_fast"`
BetaSlow float32 `json:"beta_slow"`
AttentionFactor float32 `json:"attention_factor"`
}
// TypeName returns rope_type when present, falling back to type.
func (rp *RopeParameters) TypeName() string {
if rp == nil {
return ""
}
if rp.RopeType != "" {
return rp.RopeType
}
return rp.Type
}
// BuildYarnRopeFreqs returns YaRN rotary frequencies and the mscale value.
func BuildYarnRopeFreqs(dim int, base float32, rp *RopeParameters) (*mlx.Array, float32) {
if rp == nil || dim <= 0 {
return nil, 1
}
factor := rp.Factor
if factor <= 0 {
factor = 1
}
attentionFactor := rp.AttentionFactor
if attentionFactor == 0 && factor > 1 {
attentionFactor = float32(0.1*math.Log(float64(factor)) + 1.0)
} else if attentionFactor == 0 {
attentionFactor = 1
}
if factor <= 1 {
return nil, attentionFactor
}
originalMax := rp.OriginalMaxPositionEmbeddings
if originalMax <= 0 {
originalMax = 4096
}
betaFast := rp.BetaFast
if betaFast == 0 {
betaFast = 32
}
betaSlow := rp.BetaSlow
if betaSlow == 0 {
betaSlow = 1
}
half := dim / 2
low, high := yarnCorrectionRange(betaFast, betaSlow, dim, base, originalMax)
freqs := make([]float32, half)
for i := range half {
posFreq := math.Pow(float64(base), float64(2*i)/float64(dim))
invExtrapolation := 1.0 / posFreq
invInterpolation := 1.0 / (float64(factor) * posFreq)
ramp := yarnRamp(float64(i), low, high)
mask := 1 - ramp
inv := invInterpolation*(1-mask) + invExtrapolation*mask
freqs[i] = float32(1.0 / inv)
}
arr := mlx.FromValues(freqs, half)
mlx.Eval(arr)
return arr, attentionFactor
}
func yarnCorrectionRange(betaFast, betaSlow float32, dim int, base float32, maxPosition int32) (float64, float64) {
findDim := func(rot float32) float64 {
return float64(dim) * math.Log(float64(maxPosition)/(float64(rot)*2*math.Pi)) / (2 * math.Log(float64(base)))
}
low := math.Floor(findDim(betaFast))
high := math.Ceil(findDim(betaSlow))
low = math.Max(low, 0)
high = math.Min(high, float64(dim-1))
if low == high {
high += 0.001
}
return low, high
}
func yarnRamp(i, low, high float64) float64 {
v := (i - low) / (high - low)
if v < 0 {
return 0
}
if v > 1 {
return 1
}
return v
}
// ScaleRotaryPart applies YaRN's mscale to only the rotated dimensions.
func ScaleRotaryPart(x *mlx.Array, ropeDim int, scale float32) *mlx.Array {
if scale == 1 {
return x
}
dims := x.Dims()
last := dims[len(dims)-1]
if ropeDim >= last {
return mlx.MulScalar(x, scale)
}
start := make([]int32, len(dims))
stopRot := make([]int32, len(dims))
stopPass := make([]int32, len(dims))
startPass := make([]int32, len(dims))
for i, dim := range dims {
stopRot[i] = int32(dim)
stopPass[i] = int32(dim)
}
stopRot[len(dims)-1] = int32(ropeDim)
startPass[len(dims)-1] = int32(ropeDim)
rot := mlx.MulScalar(mlx.SliceStartStop(x, start, stopRot), scale)
pass := mlx.SliceStartStop(x, startPass, stopPass)
return mlx.Concatenate([]*mlx.Array{rot, pass}, -1)
}

578
x/models/nn/sdpa.go Normal file
View File

@@ -0,0 +1,578 @@
package nn
import (
"encoding/binary"
"math"
"github.com/ollama/ollama/x/mlxrunner/batch"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
// SDPAOption configures a call to ScaledDotProductAttention.
type SDPAOption func(*sdpaConfig)
type sdpaConfig struct {
// Exactly one of (k,v,kLens) or history supplies keys/values.
k, v *mlx.Array
kLens []int32
history *KVHistory
// Optional model-supplied logical mask.
mask AttentionMask
}
// WithKVHistory supplies a cache's per-layer view of K and V. The
// cache hides any storage layout (sliding window, ring buffer,
// k-padding) behind the history.
func WithKVHistory(h *KVHistory) SDPAOption {
return func(c *sdpaConfig) { c.history = h }
}
// WithMLAHistory supplies a cache's per-layer view for absorbed MLA
// attention, where V is the first valueDim positions of K.
func WithMLAHistory(h *KVHistory, valueDim int) SDPAOption {
v := h.K().Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(), mlx.Slice(0, valueDim))
return WithKVHistory(&KVHistory{k: h.K(), v: v, applier: h.applier})
}
// WithKV supplies explicit K/V tensors for the no-cache path. kLens
// gives per-row real key extents — pass b.SeqQueryLens for self-
// attention, or the caller's own extents for cross-attention.
func WithKV(k, v *mlx.Array, kLens []int32) SDPAOption {
return func(c *sdpaConfig) { c.k = k; c.v = v; c.kLens = kLens }
}
// WithMask supplies the model's logical-coordinate mask.
func WithMask(m AttentionMask) SDPAOption {
return func(c *sdpaConfig) { c.mask = m }
}
// ScaledDotProductAttention runs the fast SDPA kernel against q and
// the keys/values supplied via exactly one of WithKV or
// WithKVHistory. Automatically applies any Q/K padding masking required
// for padded batches.
func ScaledDotProductAttention(b *batch.Batch, q *mlx.Array, scale float32, opts ...SDPAOption) *mlx.Array {
var cfg sdpaConfig
for _, opt := range opts {
opt(&cfg)
}
haveKV := cfg.k != nil || cfg.v != nil
haveHistory := cfg.history != nil
if haveKV && haveHistory {
panic("nn.ScaledDotProductAttention: WithKV and WithKVHistory are mutually exclusive")
}
if !haveKV && !haveHistory {
panic("nn.ScaledDotProductAttention: no keys/values supplied (use WithKV or WithKVHistory)")
}
k, v := cfg.k, cfg.v
var applier MaskApplier
if cfg.history != nil {
k = cfg.history.K()
v = cfg.history.V()
applier = cfg.history.applier
}
inputs := dispatchInputs{
batch: b,
mask: cfg.mask,
applier: applier,
K: k.Dim(2),
dtype: k.DType(),
kLens: newKLensKey(cfg.kLens),
}
if cached, ok := b.Memo.Get(inputs); ok {
d := cached.(sdpaDispatch)
return mlx.FastScaledDotProductAttention(q, k, v, scale, d.mode, d.arr)
}
d := inputs.resolve()
b.Memo.Put(inputs, d)
return mlx.FastScaledDotProductAttention(q, k, v, scale, d.mode, d.arr)
}
// sdpaDispatch is the resolved kernel call for a given SDPA key —
// either a flag-mode fast path (mode "" or "causal", arr nil) or an
// array-mode call with a materialized tensor. Memoized on b.Memo so
// sibling layers skip applier composition, padding build, and AsArray.
type sdpaDispatch struct {
mode string
arr *mlx.Array
}
// dispatchInputs bundles every value resolve reads and doubles as
// the Memo map key. All fields are comparable: batch is a
// *batch.Batch pointer, the applier interface is comparable when
// its concrete type is, and kLens is a kLensKey string that hashes
// by content.
//
// Making resolve a method on this struct is the enforcement — any
// new dependency must be added as a field, which automatically
// participates in the map key.
//
// applier and kLens are mutually exclusive by construction:
// WithKVHistory sets applier (which owns any K-padding in its output
// space) and leaves kLens ""; WithKV sets kLens and leaves applier nil.
type dispatchInputs struct {
batch *batch.Batch
mask AttentionMask
applier MaskApplier
K int
dtype mlx.DType
kLens kLensKey
}
// kLensKey is a comparable encoding of an int32 slice (four bytes
// per element, native endian) so it can live in a struct used as a
// map key. Decode back via Int32s.
type kLensKey string
func newKLensKey(vs []int32) kLensKey {
if len(vs) == 0 {
return ""
}
buf := make([]byte, len(vs)*4)
for i, v := range vs {
binary.NativeEndian.PutUint32(buf[i*4:], uint32(v))
}
return kLensKey(buf)
}
// Int32s decodes the key back to a fresh []int32.
func (k kLensKey) Int32s() []int32 {
if len(k) == 0 {
return nil
}
b := []byte(k)
out := make([]int32, len(b)/4)
for i := range out {
out[i] = int32(binary.NativeEndian.Uint32(b[i*4:]))
}
return out
}
// resolve composes model + padding + storage contributions and
// returns the kernel dispatch decision. Reads only from inputs; any
// new input must be added to dispatchInputs.
//
// Order matters: QPaddingMask is added in logical Q-space before the
// applier runs, so an applier that remaps coordinates receives the
// full logical mask. The applier and KPaddingMask branches are
// mutually exclusive — on the applier path the output may be in a
// remapped K space, so the applier owns any K-padding; on the
// WithKV path kLens describes the direct K tensor, which shares
// logical K space with QPaddingMask.
func (inputs dispatchInputs) resolve() sdpaDispatch {
mask := inputs.mask.Intersect(QPaddingMask(inputs.batch, inputs.dtype))
if inputs.applier != nil {
mask = inputs.applier.ApplyMask(mask)
} else if inputs.kLens != "" {
mask = mask.Intersect(KPaddingMask(inputs.batch, inputs.K, inputs.kLens.Int32s(), inputs.dtype))
}
switch {
case mask.IsZero():
return sdpaDispatch{mode: ""}
case mask.IsCausal():
if inputs.batch.InputIDs.Dim(1) == 1 {
// At L=1 the causal "k > q" constraint is redundant -
// drop it so the kernel dispatches to the no-mask fast path.
return sdpaDispatch{mode: ""}
} else {
return sdpaDispatch{mode: "causal"}
}
default:
return sdpaDispatch{mode: "array", arr: mask.AsArray(inputs.batch, inputs.K, inputs.dtype)}
}
}
// MaskApplier composes a cache's storage-mask contribution onto a
// fully-composed logical mask. The returned mask may live in the
// applier's own coordinate system (e.g. a rotated or compacted K layout),
// so any addition in logical K space must happen before the applier runs.
// SDPA does not add KPaddingMask on this path — the applier owns any
// K-padding its output needs.
//
// Implementations must be comparable struct values whose fields
// capture everything the composition depends on (no slice, map, or
// func fields); the value doubles as the applier's identity in
// SDPA's dispatch-cache key, where a non-comparable concrete type
// would panic at map insertion. A nil MaskApplier means "no storage
// contribution".
type MaskApplier interface {
ApplyMask(logical AttentionMask) AttentionMask
}
// KVHistory is the per-forward view a KV cache hands to SDPA:
// post-Update K and V plus an optional MaskApplier that composes
// the cache's storage mask onto the caller's model mask.
type KVHistory struct {
k, v *mlx.Array
applier MaskApplier
}
// NewKVHistory constructs a KVHistory. Intended for
// cache implementations across packages; model code uses
// WithKVHistory / WithKV instead.
func NewKVHistory(k, v *mlx.Array, applier MaskApplier) *KVHistory {
return &KVHistory{k: k, v: v, applier: applier}
}
// K returns the post-Update keys tensor.
//
// Last-resort escape hatch for custom attention paths — may force a
// slow materialization to canonical form depending on the cache's
// internal storage. Prefer ScaledDotProductAttention via
// WithKVHistory.
func (h *KVHistory) K() *mlx.Array { return h.k }
// V returns the post-Update values tensor.
//
// Last-resort escape hatch for custom attention paths — may force a
// slow materialization to canonical form depending on the cache's
// internal storage. Prefer ScaledDotProductAttention via
// WithKVHistory.
func (h *KVHistory) V() *mlx.Array { return h.v }
// Mask returns the final AttentionMask for this layer's SDPA —
// cache storage restrictions composed onto the caller's fully-
// composed logical mask.
//
// Last-resort escape hatch for custom attention paths — may force a
// slow materialization to canonical form depending on the cache's
// internal storage. Prefer ScaledDotProductAttention via
// WithKVHistory.
func (h *KVHistory) Mask(logical AttentionMask) AttentionMask {
if h.applier == nil {
return logical
}
return h.applier.ApplyMask(logical)
}
// AttentionMask describes an attention mask in four states:
// - zero value: no mask.
// - flag-form causal (causal=true only): dispatches to the MLX
// kernel's mask_mode="causal" fast path.
// - causal with relaxation rectangles: a causal mask with
// bidirectional attention rectangles, such as for images.
// - additive tensor (array!=nil): broadcast-compatible with
// [B, 1, L, K]; contributed by a custom mask, helpers such as
// QPaddingMask, KPaddingMask, or cache appliers and accumulated
// via Intersect.
//
// The mask is a pure logical description — it carries no batch and
// exists independent of cache storage layout.
//
// All fields are comparable, so AttentionMask values compare with ==
// by full identity — SDPA uses this directly as a dispatch-cache key.
type AttentionMask struct {
causal bool
relaxations *relaxNode
array *mlx.Array
}
type relaxRect struct {
seq, qLo, qHi, kLo, kHi int
}
// relaxNode is a singly-linked list node holding relaxation
// rectangles. Each AttentionMask must have a fresh set of
// nodes to avoid false sharing between masks.
type relaxNode struct {
rect relaxRect
next *relaxNode
}
// CausalMask returns a flag-form causal mask. The mask stays
// tensor-free — hitting the kernel's mask_mode="causal" fast path —
// until something composes a relaxation, padding, or applier tensor
// onto it; then SDPA materializes via AsArray.
func CausalMask() AttentionMask {
return AttentionMask{causal: true}
}
// ArrayMask wraps an explicit additive tensor broadcast-compatible
// with [B, 1, L, K].
func ArrayMask(a *mlx.Array) AttentionMask {
return AttentionMask{array: a}
}
// IsZero reports whether the mask is the zero value (no mask at all).
func (m AttentionMask) IsZero() bool {
return !m.causal && m.array == nil && m.relaxations == nil
}
// IsCausal reports whether the mask is pure flag-form causal — no
// relaxations and no accumulated array. SDPA dispatches to the
// kernel's "causal" fast path on this; any padding, applier
// contribution, or relaxation falls to the array path.
func (m AttentionMask) IsCausal() bool {
return m.causal && m.relaxations == nil && m.array == nil
}
// Relax records a relaxation rectangle for batch sequence seq —
// positions (q, k) with q in [qLo, qHi) and k in [kLo, kHi) become
// freely attendable regardless of the causal base. Coordinates are
// absolute sequence positions on both axes, matching how causal is
// defined (k <= q). Multiple calls compose as a union per sequence.
//
// Rectangles that cannot change any cell — empty or already fully
// inside causal (kHi-1 <= qLo) — are dropped so IsCausal stays true
// and the mask remains on the kernel's fast path.
//
// Panics on pure ArrayMask (the caller owns the tensor and should
// modify it directly) or on the zero mask (nothing to relax).
func (m AttentionMask) Relax(seq, qLo, qHi, kLo, kHi int) AttentionMask {
if !m.causal {
if m.array != nil {
panic("AttentionMask.Relax: cannot relax a pure ArrayMask; modify the tensor directly")
}
panic("AttentionMask.Relax: cannot relax a zero mask")
}
if qLo >= qHi || kLo >= kHi {
return m
}
if kHi-1 <= qLo {
return m
}
m.relaxations = &relaxNode{
rect: relaxRect{seq: seq, qLo: qLo, qHi: qHi, kLo: kLo, kHi: kHi},
next: m.relaxations,
}
return m
}
// Intersect returns the element-wise sum of this mask and other. Masks are
// additive and apply before softmax, so this is intersection
// semantics — a position is valid only if both sides have 0 there.
//
// At AsArray time a causal+Relax+array mask materializes as: causal
// writes -inf into the upper triangle, Relax overwrites its
// rectangles back to 0, then array is added on top — restricting 0
// cells further or no-op'ing on -inf cells.
func (m AttentionMask) Intersect(other AttentionMask) AttentionMask {
if m.IsZero() {
return other
}
if other.IsZero() {
return m
}
result := AttentionMask{
causal: m.causal || other.causal,
}
// Relax requires causal, so relaxations != nil implies causal.
switch {
case m.relaxations != nil && other.relaxations != nil:
// Both sides causal+Relax: pairwise rect intersection per sequence.
var list *relaxNode
for a := m.relaxations; a != nil; a = a.next {
for b := other.relaxations; b != nil; b = b.next {
if a.rect.seq != b.rect.seq {
continue
}
qLo := max(a.rect.qLo, b.rect.qLo)
qHi := min(a.rect.qHi, b.rect.qHi)
kLo := max(a.rect.kLo, b.rect.kLo)
kHi := min(a.rect.kHi, b.rect.kHi)
if qHi <= qLo || kHi <= kLo || kHi-1 <= qLo {
continue
}
list = &relaxNode{
rect: relaxRect{seq: a.rect.seq, qLo: qLo, qHi: qHi, kLo: kLo, kHi: kHi},
next: list,
}
}
}
result.relaxations = list
case m.relaxations != nil && !other.causal:
result.relaxations = m.relaxations
case other.relaxations != nil && !m.causal:
result.relaxations = other.relaxations
default:
// Implicit: one side causal+Relax, the other plain causal
// (no relaxations). Plain causal blocks every cell Relax
// tried to release, so intersection with its empty release
// set leaves nothing — result.relaxations stays nil and
// collapses to pure causal.
}
switch {
case m.array != nil && other.array != nil:
result.array = mlx.Add(m.array, other.array)
case m.array != nil:
result.array = m.array
case other.array != nil:
result.array = other.array
}
return result
}
// AsArray materializes the mask as a [B, 1, L, K] additive tensor
// (0 where valid, -inf where blocked). B and L come from b; K and
// dtype come from the caller.
//
// Composition order:
// 1. Start from zero.
// 2. If m.causal: -inf where oldestPos+k > SeqOffsets[b] + q per row.
// 3. Apply m.relaxations (qLo/qHi and kLo/kHi are absolute positions).
// 4. Add m.array if present.
func (m AttentionMask) AsArray(b *batch.Batch, K int, dtype mlx.DType) *mlx.Array {
// Pure ArrayMask: caller owns the tensor, nothing to compose.
if !m.causal && m.relaxations == nil && m.array != nil {
if m.array.DType() == dtype {
return m.array
}
return m.array.AsType(dtype)
}
B := len(b.SeqOffsets)
L := b.InputIDs.Dim(1)
negInf := float32(math.Inf(-1))
vals := make([]float32, B*L*K)
if m.causal {
for i := range B {
off := int(b.SeqOffsets[i])
oldestPos := max(0, off+L-K)
base := i * L * K
for q := range L {
absQ := off + q
row := base + q*K
for k := range K {
if oldestPos+k > absQ {
vals[row+k] = negInf
}
}
}
}
}
for n := m.relaxations; n != nil; n = n.next {
r := n.rect
if r.seq < 0 || r.seq >= B {
continue
}
off := int(b.SeqOffsets[r.seq])
oldestPos := max(0, off+L-K)
qLo := min(max(r.qLo-off, 0), L)
qHi := min(max(r.qHi-off, 0), L)
kLo := min(max(r.kLo-oldestPos, 0), K)
kHi := min(max(r.kHi-oldestPos, 0), K)
base := r.seq * L * K
for q := qLo; q < qHi; q++ {
row := base + q*K
for k := kLo; k < kHi; k++ {
vals[row+k] = 0
}
}
}
out := mlx.FromValues(vals, B, 1, L, K)
if m.array != nil {
out = mlx.Add(out, m.array)
}
if dtype != mlx.DTypeFloat32 {
out = out.AsType(dtype)
}
return out
}
// QPaddingMask returns an additive [B, 1, L, 1] mask that blocks
// padded query rows (q >= b.SeqQueryLens[i]) across all keys. It is
// logical — independent of whatever layout the cache uses for K.
// Returns the zero mask when every row is full.
func QPaddingMask(b *batch.Batch, dtype mlx.DType) AttentionMask {
return padTailMask(len(b.SeqOffsets), b.InputIDs.Dim(1), 2, b.SeqQueryLens, dtype)
}
// KPaddingMask returns an additive [B, 1, 1, K] mask that blocks
// padded key columns (k >= kLens[i]) across all queries. Storage-
// dependent: kLens describes where real content ends in physical K,
// so this is typically used without a cache where the caller knows
// the actual layout. Returns the zero mask when every row is full.
func KPaddingMask(b *batch.Batch, K int, kLens []int32, dtype mlx.DType) AttentionMask {
return padTailMask(len(b.SeqOffsets), K, 3, kLens, dtype)
}
// SlidingWindowMask returns an additive [B, 1, L, K] mask blocking
// keys outside a per-row window of size `window`: any key whose
// absolute position p < absQ - window + 1 is blocked. Returns the
// zero mask when window <= 0 or no row needs blocking.
//
// Defined in logical position space — the K axis is position-ordered
// with column 0 at oldestPos = max(0, b.SeqOffsets[i]+L-K).
func SlidingWindowMask(b *batch.Batch, K, window int, dtype mlx.DType) AttentionMask {
if window <= 0 {
return AttentionMask{}
}
B := len(b.SeqOffsets)
L := b.InputIDs.Dim(1)
negInf := float32(math.Inf(-1))
vals := make([]float32, B*L*K)
needed := false
for i := range B {
off := int(b.SeqOffsets[i])
oldestPos := max(0, off+L-K)
base := i * L * K
for q := range L {
absQ := off + q
lo := absQ - window + 1
maskCount := lo - oldestPos
if maskCount <= 0 {
continue
}
if maskCount > K {
maskCount = K
}
row := base + q*K
for k := range maskCount {
vals[row+k] = negInf
needed = true
}
}
}
if !needed {
return AttentionMask{}
}
out := mlx.FromValues(vals, B, 1, L, K)
if dtype != mlx.DTypeFloat32 {
out = out.AsType(dtype)
}
return ArrayMask(out)
}
func padTailMask(B, total, axis int, lens []int32, dtype mlx.DType) AttentionMask {
needed := false
for i := range B {
if int(lens[i]) < total {
needed = true
break
}
}
if !needed {
return AttentionMask{}
}
negInf := float32(math.Inf(-1))
vals := make([]float32, B*total)
for i := range B {
n := int(lens[i])
base := i * total
for j := n; j < total; j++ {
vals[base+j] = negInf
}
}
shape := [4]int{B, 1, 1, 1}
shape[axis] = total
out := mlx.FromValues(vals, shape[0], shape[1], shape[2], shape[3])
if dtype != mlx.DTypeFloat32 {
out = out.AsType(dtype)
}
return ArrayMask(out)
}

680
x/models/nn/sdpa_test.go Normal file
View File

@@ -0,0 +1,680 @@
package nn
import (
"math"
"testing"
"github.com/ollama/ollama/x/mlxrunner/batch"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
// newBatch constructs a synthetic batch for mask/SDPA tests.
// seqOffsets defines B (length of slice) and each row's absolute start;
// L is the padded query length along InputIDs's second axis;
// qLens is per-row real query length (defaults to all L if nil).
func newBatch(seqOffsets []int32, L int, qLens []int32) *batch.Batch {
B := len(seqOffsets)
if qLens == nil {
qLens = make([]int32, B)
for i := range qLens {
qLens[i] = int32(L)
}
}
// InputIDs values don't matter for masking, only the shape.
ids := mlx.FromValues(make([]int32, B*L), B, L)
return &batch.Batch{
InputIDs: ids,
SeqOffsets: seqOffsets,
SeqQueryLens: qLens,
}
}
func TestAttentionMaskZero(t *testing.T) {
skipIfNoMLX(t)
var m AttentionMask
if !m.IsZero() {
t.Fatal("zero value should report IsZero")
}
if m.IsCausal() {
t.Fatal("zero value should not report IsCausal")
}
b := newBatch([]int32{0}, 2, nil)
arr := m.AsArray(b, 3, mlx.DTypeFloat32)
if arr == nil {
t.Fatal("zero value AsArray should return a zeros tensor, not nil")
}
mlx.Eval(arr)
got := arr.Floats()
for i, v := range got {
if v != 0 {
t.Fatalf("zero mask should materialize all zeros; got[%d] = %v", i, v)
}
}
}
func TestAttentionMaskAsArrayCausal(t *testing.T) {
skipIfNoMLX(t)
L, K := 4, 6
b := newBatch([]int32{2}, L, nil)
arr := CausalMask().AsArray(b, K, mlx.DTypeFloat32)
if arr == nil {
t.Fatal("CausalMask AsArray should return a tensor")
}
dims := arr.Dims()
if len(dims) != 4 || dims[0] != 1 || dims[1] != 1 || dims[2] != L || dims[3] != K {
t.Fatalf("want shape [1,1,%d,%d], got %v", L, K, dims)
}
mlx.Eval(arr)
got := arr.Floats()
negInf := float32(math.Inf(-1))
want := make([]float32, L*K)
for q := range L {
absQ := int(b.SeqOffsets[0]) + q
for k := range K {
if k > absQ {
want[q*K+k] = negInf
}
}
}
for i := range want {
if !sameF(got[i], want[i]) {
t.Fatalf("index %d: want %v, got %v", i, want[i], got[i])
}
}
}
func TestAttentionMaskRelaxLazy(t *testing.T) {
skipIfNoMLX(t)
// Relax must not materialize a tensor — the perf invariant the
// causal-flag fast path relies on. Everything else (predicates,
// AsArray contents) is exercised by the materialization tests.
m := CausalMask().
Relax(0, 1, 3, 2, 5).
Relax(0, 0, 2, 1, 4)
if m.array != nil {
t.Fatal("Relax should not materialize a tensor")
}
}
// TestAttentionMaskRelaxNoopRectsMatchCausal pins the contract that
// rectangles which can't change any cell — empty in q or k, or fully
// inside the causal triangle — must produce the same materialized
// tensor as plain causal.
func TestAttentionMaskRelaxNoopRectsMatchCausal(t *testing.T) {
skipIfNoMLX(t)
L, K := 4, 6
b := newBatch([]int32{0}, L, nil)
want := CausalMask().AsArray(b, K, mlx.DTypeFloat32)
mlx.Eval(want)
wantF := want.Floats()
cases := []struct {
name string
qLo, qHi, kLo, kHi int
}{
{"empty Q rect", 2, 2, 0, 3},
{"empty K rect", 0, 3, 2, 2},
{"fully under causal", 5, 7, 0, 3},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
m := CausalMask().Relax(0, tc.qLo, tc.qHi, tc.kLo, tc.kHi)
arr := m.AsArray(b, K, mlx.DTypeFloat32)
mlx.Eval(arr)
got := arr.Floats()
for i := range wantF {
if !sameF(got[i], wantF[i]) {
t.Fatalf("index %d: want %v, got %v", i, wantF[i], got[i])
}
}
})
}
}
func TestAttentionMaskAsArrayWithRelax(t *testing.T) {
skipIfNoMLX(t)
L, K := 4, 6
b := newBatch([]int32{0}, L, nil)
arr := CausalMask().Relax(0, 1, 3, 2, 5).AsArray(b, K, mlx.DTypeFloat32)
if arr == nil {
t.Fatal("expected tensor")
}
mlx.Eval(arr)
got := arr.Floats()
negInf := float32(math.Inf(-1))
want := make([]float32, L*K)
for q := range L {
for k := range K {
if k > q {
want[q*K+k] = negInf
}
}
}
for q := 1; q < 3; q++ {
for k := 2; k < 5; k++ {
want[q*K+k] = 0
}
}
for i := range want {
if !sameF(got[i], want[i]) {
t.Fatalf("index %d: want %v, got %v", i, want[i], got[i])
}
}
}
func TestAttentionMaskAsArrayPerRow(t *testing.T) {
skipIfNoMLX(t)
L, K := 3, 5
b := newBatch([]int32{0, 2}, L, nil)
m := CausalMask().
Relax(0, 0, 2, 0, 3).
Relax(1, 3, 5, 2, 5)
arr := m.AsArray(b, K, mlx.DTypeFloat32)
if arr == nil {
t.Fatal("expected tensor")
}
dims := arr.Dims()
if dims[0] != 2 {
t.Fatalf("expected batch dim 2, got %v", dims)
}
mlx.Eval(arr)
got := arr.Floats()
negInf := float32(math.Inf(-1))
want := make([]float32, 2*L*K)
for bi, off := range b.SeqOffsets {
for q := range L {
absQ := int(off) + q
for k := range K {
if k > absQ {
want[bi*L*K+q*K+k] = negInf
}
}
}
}
for q := range 2 {
for k := range 3 {
want[0*L*K+q*K+k] = 0
}
}
for q := 1; q < 3; q++ {
for k := 2; k < 5; k++ {
want[1*L*K+q*K+k] = 0
}
}
for i := range want {
if !sameF(got[i], want[i]) {
t.Fatalf("index %d: want %v, got %v", i, want[i], got[i])
}
}
}
func TestQPaddingMask(t *testing.T) {
skipIfNoMLX(t)
L := 4
// Row 0 fully real; row 1 has 2 real queries.
b := newBatch([]int32{0, 0}, L, []int32{int32(L), 2})
m := QPaddingMask(b, mlx.DTypeFloat32)
if m.array == nil {
t.Fatal("expected q-padding tensor")
}
mlx.Eval(m.array)
got := m.array.Floats()
negInf := float32(math.Inf(-1))
want := make([]float32, 2*L)
// Row 0: no blocking; row 1: q >= 2 blocked.
for q := 2; q < L; q++ {
want[1*L+q] = negInf
}
for i := range want {
if !sameF(got[i], want[i]) {
t.Fatalf("index %d: want %v, got %v", i, want[i], got[i])
}
}
}
func TestKPaddingMask(t *testing.T) {
skipIfNoMLX(t)
K := 5
// Row 0 full keys; row 1 has 3 real keys.
b := newBatch([]int32{0, 0}, 4, nil)
kLens := []int32{int32(K), 3}
m := KPaddingMask(b, K, kLens, mlx.DTypeFloat32)
if m.array == nil {
t.Fatal("expected k-padding tensor")
}
mlx.Eval(m.array)
got := m.array.Floats()
negInf := float32(math.Inf(-1))
want := make([]float32, 2*K)
for k := 3; k < K; k++ {
want[1*K+k] = negInf
}
for i := range want {
if !sameF(got[i], want[i]) {
t.Fatalf("index %d: want %v, got %v", i, want[i], got[i])
}
}
}
func TestQPaddingMaskZeroWhenFull(t *testing.T) {
skipIfNoMLX(t)
b := newBatch([]int32{0}, 4, nil)
m := QPaddingMask(b, mlx.DTypeFloat32)
if !m.IsZero() {
t.Fatal("QPaddingMask at full queries should be zero")
}
}
func TestKPaddingMaskZeroWhenFull(t *testing.T) {
skipIfNoMLX(t)
K := 4
b := newBatch([]int32{0}, 4, nil)
kLens := []int32{int32(K)}
m := KPaddingMask(b, K, kLens, mlx.DTypeFloat32)
if !m.IsZero() {
t.Fatal("KPaddingMask at full keys should be zero")
}
}
func TestAttentionMaskCombineCausal(t *testing.T) {
skipIfNoMLX(t)
var z AttentionMask
got := z.Intersect(CausalMask())
if !got.IsCausal() {
t.Fatal("zero + CausalMask should be pure causal")
}
got = CausalMask().Intersect(z)
if !got.IsCausal() {
t.Fatal("CausalMask + zero should be pure causal")
}
got = CausalMask().Intersect(CausalMask())
if !got.IsCausal() {
t.Fatal("causal + causal should stay pure causal")
}
}
func TestAttentionMaskCombineRelaxDroppedAgainstCausal(t *testing.T) {
skipIfNoMLX(t)
relaxed := CausalMask().Relax(0, 1, 3, 2, 5)
got := relaxed.Intersect(CausalMask())
if !got.IsCausal() {
t.Fatal("causal-with-Relax + causal should drop relaxations and stay pure causal")
}
got = CausalMask().Intersect(relaxed)
if !got.IsCausal() {
t.Fatal("causal + causal-with-Relax should drop relaxations and stay pure causal")
}
// Disjoint relaxations on two causals also drop — neither side
// agrees to release the cells the other side relaxed.
got = CausalMask().Relax(0, 1, 3, 2, 5).Intersect(CausalMask().Relax(0, 5, 7, 6, 9))
if !got.IsCausal() {
t.Fatal("disjoint relaxations on two causals should drop and stay pure causal")
}
}
func TestAttentionMaskCombineRelaxIntersect(t *testing.T) {
skipIfNoMLX(t)
L, K := 6, 6
b := newBatch([]int32{0}, L, nil)
// Overlapping rects on two causals: the surviving relaxation is
// the geometric intersection — q in [1,3) ∩ [2,5) = [2,3),
// k in [2,5) ∩ [3,6) = [3,5).
m := CausalMask().Relax(0, 1, 3, 2, 5).Intersect(CausalMask().Relax(0, 2, 5, 3, 6))
if m.IsCausal() {
t.Fatal("overlapping relaxations should survive as their intersection, not collapse to pure causal")
}
arr := m.AsArray(b, K, mlx.DTypeFloat32)
if arr == nil {
t.Fatal("expected tensor")
}
mlx.Eval(arr)
vals := arr.Floats()
negInf := float32(math.Inf(-1))
want := make([]float32, L*K)
for q := range L {
for k := range K {
if k > q {
want[q*K+k] = negInf
}
}
}
// Intersection rect: q ∈ [2,3), k ∈ [3,5).
for q := 2; q < 3; q++ {
for k := 3; k < 5; k++ {
want[q*K+k] = 0
}
}
for i := range want {
if !sameF(vals[i], want[i]) {
t.Fatalf("index %d: want %v, got %v", i, want[i], vals[i])
}
}
}
func TestAttentionMaskCombineRelaxKeptAgainstNonCausal(t *testing.T) {
skipIfNoMLX(t)
L, K := 4, 6
b := newBatch([]int32{0}, L, nil)
// Pad q=3 — non-causal additive contribution that should leave
// the relaxation intact (the rect releases above-diagonal cells
// q in [1,3), k in [2,5) where k > q).
pad := QPaddingMask(newBatch([]int32{0}, L, []int32{3}), mlx.DTypeFloat32)
if pad.IsZero() {
t.Fatal("padding mask should be non-zero")
}
got := CausalMask().Relax(0, 1, 3, 2, 5).Intersect(pad)
arr := got.AsArray(b, K, mlx.DTypeFloat32)
if arr == nil {
t.Fatal("expected tensor")
}
mlx.Eval(arr)
vals := arr.Floats()
negInf := float32(math.Inf(-1))
want := make([]float32, L*K)
for q := range L {
for k := range K {
if k > q {
want[q*K+k] = negInf
}
}
}
for q := 1; q < 3; q++ {
for k := 2; k < 5; k++ {
want[q*K+k] = 0
}
}
for q := 3; q < L; q++ {
for k := range K {
want[q*K+k] = negInf
}
}
for i := range want {
if !sameF(vals[i], want[i]) {
t.Fatalf("index %d: want %v, got %v", i, want[i], vals[i])
}
}
}
func TestAttentionMaskCombineArrays(t *testing.T) {
skipIfNoMLX(t)
a := mlx.FromValues([]float32{0, 0, 0, 0}, 1, 1, 2, 2)
bb := mlx.FromValues([]float32{1, 2, 3, 4}, 1, 1, 2, 2)
sum := ArrayMask(a).Intersect(ArrayMask(bb))
if sum.array == nil {
t.Fatal("array + array should produce array")
}
mlx.Eval(sum.array)
got := sum.array.Floats()
want := []float32{1, 2, 3, 4}
for i := range want {
if got[i] != want[i] {
t.Fatalf("index %d: want %v, got %v", i, want[i], got[i])
}
}
}
func TestAttentionMaskRelaxPanicOnArray(t *testing.T) {
skipIfNoMLX(t)
a := mlx.FromValues([]float32{0}, 1, 1, 1, 1)
defer func() {
if r := recover(); r == nil {
t.Fatal("Relax on ArrayMask should panic")
}
}()
ArrayMask(a).Relax(0, 0, 1, 0, 1)
}
func TestAttentionMaskRelaxPanicOnZero(t *testing.T) {
skipIfNoMLX(t)
defer func() {
if r := recover(); r == nil {
t.Fatal("Relax on zero mask should panic")
}
}()
var z AttentionMask
z.Relax(0, 0, 1, 0, 1)
}
func sameF(a, b float32) bool {
if math.IsInf(float64(a), -1) && math.IsInf(float64(b), -1) {
return true
}
return a == b
}
// sdpaInputs builds non-trivial Q/K/V so masking actually changes the
// kernel output. With zero K/V, SDPA returns zero regardless of mask
// and "parity" tests pass even when the mask path is broken.
func sdpaInputs(L, K int) (q, k, v *mlx.Array) {
const D = 4
qVals := make([]float32, L*D)
for i := range qVals {
qVals[i] = 0.1 * float32(i+1)
}
kVals := make([]float32, K*D)
for i := range kVals {
kVals[i] = 0.07 * float32(i+1)
}
vVals := make([]float32, K*D)
for i := range vVals {
vVals[i] = float32(i+1) - 0.5*float32(K*D)
}
q = mlx.FromValues(qVals, 1, 1, L, D)
k = mlx.FromValues(kVals, 1, 1, K, D)
v = mlx.FromValues(vVals, 1, 1, K, D)
return
}
func TestSDPACausalParity(t *testing.T) {
skipIfNoMLX(t)
L, K := 4, 4
q, k, v := sdpaInputs(L, K)
b := newBatch([]int32{int32(K - L)}, L, nil)
got := ScaledDotProductAttention(b, q, 1.0,
WithKV(k, v, []int32{int32(K)}),
WithMask(CausalMask()),
)
want := mlx.FastScaledDotProductAttention(q, k, v, 1.0, "causal", nil)
mlx.Eval(got, want)
gs, ws := got.Floats(), want.Floats()
for i := range ws {
if gs[i] != ws[i] {
t.Fatalf("index %d: want %v, got %v", i, ws[i], gs[i])
}
}
}
func TestSDPAZeroMaskParity(t *testing.T) {
skipIfNoMLX(t)
L, K := 4, 4
q, k, v := sdpaInputs(L, K)
b := newBatch([]int32{0}, L, nil)
got := ScaledDotProductAttention(b, q, 1.0, WithKV(k, v, []int32{int32(K)}))
want := mlx.FastScaledDotProductAttention(q, k, v, 1.0, "", nil)
mlx.Eval(got, want)
gs, ws := got.Floats(), want.Floats()
for i := range ws {
if gs[i] != ws[i] {
t.Fatalf("index %d: want %v, got %v", i, ws[i], gs[i])
}
}
}
func TestSDPAArrayMaskParity(t *testing.T) {
skipIfNoMLX(t)
L, K := 3, 3
q, k, v := sdpaInputs(L, K)
b := newBatch([]int32{0}, L, nil)
mask := mlx.FromValues([]float32{
0, -1, -1,
0, 0, -1,
0, 0, 0,
}, 1, 1, 3, 3)
got := ScaledDotProductAttention(b, q, 1.0,
WithKV(k, v, []int32{int32(K)}),
WithMask(ArrayMask(mask)),
)
want := mlx.FastScaledDotProductAttention(q, k, v, 1.0, "array", mask)
mlx.Eval(got, want)
gs, ws := got.Floats(), want.Floats()
for i := range ws {
if gs[i] != ws[i] {
t.Fatalf("index %d: want %v, got %v", i, ws[i], gs[i])
}
}
}
func TestSDPARelaxMaskMaterializes(t *testing.T) {
skipIfNoMLX(t)
L, K := 3, 5
q, k, v := sdpaInputs(L, K)
b := newBatch([]int32{int32(K - L)}, L, nil)
got := ScaledDotProductAttention(b, q, 1.0,
WithKV(k, v, []int32{int32(K)}),
WithMask(CausalMask().Relax(0, 3, 5, 2, 5)),
)
ref := CausalMask().Relax(0, 3, 5, 2, 5).AsArray(b, K, k.DType())
want := mlx.FastScaledDotProductAttention(q, k, v, 1.0, "array", ref)
mlx.Eval(got, want)
gs, ws := got.Floats(), want.Floats()
for i := range ws {
if gs[i] != ws[i] {
t.Fatalf("index %d: want %v, got %v", i, ws[i], gs[i])
}
}
}
func TestSDPAPanicsWithBothKVAndHistory(t *testing.T) {
skipIfNoMLX(t)
L := 3
q, k, v := sdpaInputs(L, L)
b := newBatch([]int32{0}, L, nil)
history := NewKVHistory(k, v, nil)
defer func() {
if r := recover(); r == nil {
t.Fatal("expected panic when both WithKV and WithKVHistory are supplied")
}
}()
ScaledDotProductAttention(b, q, 1.0, WithKV(k, v, []int32{int32(L)}), WithKVHistory(history))
}
func TestSDPAMLAHistorySlicesVFromK(t *testing.T) {
skipIfNoMLX(t)
L, D, valueDim := 2, 5, 3
kBuf := make([]float32, 1*1*L*D)
for i := range kBuf {
kBuf[i] = float32(i) + 1
}
k := mlx.FromValues(kBuf, 1, 1, L, D)
v := mlx.Zeros(mlx.DTypeFloat32, 1, 1, L, valueDim)
history := NewKVHistory(k, v, nil)
q := mlx.Zeros(mlx.DTypeFloat32, 1, 1, L, D)
b := newBatch([]int32{0}, L, nil)
got := ScaledDotProductAttention(b, q, 1.0,
WithMLAHistory(history, valueDim),
)
vRef := k.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(), mlx.Slice(0, valueDim))
want := mlx.FastScaledDotProductAttention(q, k, vRef, 1.0, "", nil)
mlx.Eval(got, want)
gs, ws := got.Floats(), want.Floats()
for i := range ws {
if gs[i] != ws[i] {
t.Fatalf("index %d: want %v, got %v", i, ws[i], gs[i])
}
}
}
func TestSDPAPanicsWithoutKV(t *testing.T) {
skipIfNoMLX(t)
q := mlx.FromValues(make([]float32, 4), 1, 1, 1, 4)
b := newBatch([]int32{0}, 1, nil)
defer func() {
if r := recover(); r == nil {
t.Fatal("expected panic when no K/V supplied")
}
}()
ScaledDotProductAttention(b, q, 1.0)
}
// fillTensor builds a [B, H, T, D] float32 tensor whose entries are
// distinct, non-zero, and predictable so per-row slices stay distinct.
func fillTensor(seed float32, B, H, T, D int) *mlx.Array {
vals := make([]float32, B*H*T*D)
for i := range vals {
vals[i] = seed + 0.05*float32(i)
}
return mlx.FromValues(vals, B, H, T, D)
}
// TestSDPAMultiSequenceParity drives a B=2 batch with mixed
// SeqOffsets and SeqQueryLens through ScaledDotProductAttention via
// the no-cache (WithKV) path, then compares each row's real
// positions against a B=1 reference at that row's offset and length.
// Padded-tail outputs are unconstrained and not checked. Pins the
// central multi-sequence contract: right-padded rows must produce
// per-row outputs that don't depend on the padded tails.
func TestSDPAMultiSequenceParity(t *testing.T) {
skipIfNoMLX(t)
const H, D = 1, 4
const L, K = 4, 6
const qShort, kShort = 2, 2
const scale = 1.0
q := fillTensor(0.5, 2, H, L, D)
k := fillTensor(-0.3, 2, H, K, D)
v := fillTensor(0.7, 2, H, K, D)
b := newBatch([]int32{2, 0}, L, []int32{int32(L), int32(qShort)})
got := ScaledDotProductAttention(b, q, scale,
WithKV(k, v, []int32{int32(K), int32(kShort)}),
WithMask(CausalMask()))
mlx.Eval(got)
gotF := got.Floats()
// Row 0: full Q at offset 2, full K. B=1 reference.
q0 := mlx.SliceStartStop(q, []int32{0, 0, 0, 0}, []int32{1, H, L, D})
k0 := mlx.SliceStartStop(k, []int32{0, 0, 0, 0}, []int32{1, H, K, D})
v0 := mlx.SliceStartStop(v, []int32{0, 0, 0, 0}, []int32{1, H, K, D})
b0 := newBatch([]int32{2}, L, nil)
ref0 := ScaledDotProductAttention(b0, q0, scale,
WithKV(k0, v0, []int32{int32(K)}),
WithMask(CausalMask()))
mlx.Eval(ref0)
ref0F := ref0.Floats()
// Row 1: real Q at offset 0, length qShort, with kShort real keys.
q1 := mlx.SliceStartStop(q, []int32{1, 0, 0, 0}, []int32{2, H, int32(qShort), D})
k1 := mlx.SliceStartStop(k, []int32{1, 0, 0, 0}, []int32{2, H, int32(kShort), D})
v1 := mlx.SliceStartStop(v, []int32{1, 0, 0, 0}, []int32{2, H, int32(kShort), D})
b1 := newBatch([]int32{0}, qShort, nil)
ref1 := ScaledDotProductAttention(b1, q1, scale,
WithKV(k1, v1, []int32{int32(kShort)}),
WithMask(CausalMask()))
mlx.Eval(ref1)
ref1F := ref1.Floats()
// got is [2, H, L, D] = [B=2, 1, 4, 4]. Row 0 is got[0,...] and
// must match ref0 over the full [L, D]. Row 1 is got[1,...] and
// must match ref1 over [qShort, D] only — padded positions are
// unconstrained.
rowStride := H * L * D
for i := range rowStride {
if !approxEqual(gotF[i], ref0F[i], 1e-5) {
t.Fatalf("row 0 [%d]: got %v, want %v", i, gotF[i], ref0F[i])
}
}
for q := range qShort {
for d := range D {
gotI := rowStride + q*D + d
refI := q*D + d
if !approxEqual(gotF[gotI], ref1F[refI], 1e-5) {
t.Fatalf("row 1 [q=%d,d=%d]: got %v, want %v", q, d, gotF[gotI], ref1F[refI])
}
}
}
}

338
x/models/qwen3/qwen3.go Normal file
View File

@@ -0,0 +1,338 @@
// Package qwen3 provides the Qwen3 text model implementation for MLX.
package qwen3
import (
"encoding/json"
"fmt"
"math"
"github.com/ollama/ollama/x/mlxrunner/batch"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
"github.com/ollama/ollama/x/mlxrunner/model/base"
"github.com/ollama/ollama/x/models/nn"
"github.com/ollama/ollama/x/tokenizer"
)
func init() {
base.Register("Qwen3ForCausalLM", newModel)
}
// Config holds Qwen3 model configuration.
type Config struct {
HiddenSize int32 `json:"hidden_size"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
IntermediateSize int32 `json:"intermediate_size"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumKeyValueHeads int32 `json:"num_key_value_heads"`
VocabSize int32 `json:"vocab_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
HeadDim int32 `json:"head_dim"`
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
TieWordEmbeddings bool `json:"tie_word_embeddings"`
// Quantization parameters (set during load based on model quantization).
QuantGroupSize int `json:"-"`
QuantBits int `json:"-"`
QuantMode string `json:"-"`
TensorQuant map[string]*model.TensorQuantInfo `json:"-"`
// Computed fields.
Scale float32 `json:"-"`
QKNormEps float32 `json:"-"`
}
// Model is the Qwen3 text-only model.
type Model struct {
EmbedTokens nn.EmbeddingLayer
Layers []*Layer
Norm *nn.RMSNorm
LMHead nn.LinearLayer
tok *tokenizer.Tokenizer
*Config
weightPrefix string
}
// Layer is a single Qwen3 decoder block.
type Layer struct {
Attention *Attention
MLP *MLP
AttentionNorm *nn.RMSNorm
MLPNorm *nn.RMSNorm
}
// Attention implements Qwen3 attention with Q/K norms.
type Attention struct {
QProj nn.LinearLayer
KProj nn.LinearLayer
VProj nn.LinearLayer
OProj nn.LinearLayer
QNorm *nn.RMSNorm
KNorm *nn.RMSNorm
}
// MLP is the feed-forward network with SwiGLU activation.
type MLP struct {
GateProj nn.LinearLayer
UpProj nn.LinearLayer
DownProj nn.LinearLayer
}
func resolveWeightPrefix(tensors map[string]*mlx.Array) string {
for _, prefix := range []string{"", "language_model."} {
if tensors[prefix+"model.embed_tokens.weight"] != nil {
return prefix
}
}
return ""
}
func newModel(root *model.Root) (base.Model, error) {
configData, err := root.Manifest.ReadConfig("config.json")
if err != nil {
return nil, fmt.Errorf("load config: %w", err)
}
var cfg Config
if err := json.Unmarshal(configData, &cfg); err != nil {
return nil, fmt.Errorf("parse config: %w", err)
}
if cfg.HiddenSize <= 0 {
return nil, fmt.Errorf("invalid hidden_size: %d", cfg.HiddenSize)
}
if cfg.NumAttentionHeads <= 0 {
return nil, fmt.Errorf("invalid num_attention_heads: %d", cfg.NumAttentionHeads)
}
if cfg.NumKeyValueHeads <= 0 {
cfg.NumKeyValueHeads = cfg.NumAttentionHeads
}
if cfg.HeadDim == 0 {
if cfg.HiddenSize%cfg.NumAttentionHeads != 0 {
return nil, fmt.Errorf("hidden_size (%d) must be divisible by num_attention_heads (%d)", cfg.HiddenSize, cfg.NumAttentionHeads)
}
cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
}
if cfg.HeadDim <= 0 {
return nil, fmt.Errorf("invalid head_dim: %d", cfg.HeadDim)
}
if cfg.NumAttentionHeads%cfg.NumKeyValueHeads != 0 {
return nil, fmt.Errorf("num_attention_heads (%d) must be divisible by num_key_value_heads (%d)", cfg.NumAttentionHeads, cfg.NumKeyValueHeads)
}
if cfg.RMSNormEps == 0 {
cfg.RMSNormEps = 1e-6
}
if cfg.RopeTheta == 0 {
cfg.RopeTheta = 1000000
}
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
cfg.QKNormEps = 1e-6
if qt := root.QuantType(); qt != "" {
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams(qt)
if gs := root.GroupSize(); gs > 0 {
cfg.QuantGroupSize = gs
}
} else {
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams("")
}
cfg.TensorQuant = root.AllTensorQuant()
tokData, err := root.Manifest.ReadConfig("tokenizer.json")
if err != nil {
return nil, fmt.Errorf("load tokenizer config: %w", err)
}
tokConfig := &tokenizer.TokenizerConfig{
ConfigJSON: configData,
}
if genConfigData, err := root.Manifest.ReadConfig("generation_config.json"); err == nil {
tokConfig.GenerationConfigJSON = genConfigData
}
if tokConfigData, err := root.Manifest.ReadConfig("tokenizer_config.json"); err == nil {
tokConfig.TokenizerConfigJSON = tokConfigData
}
tok, err := tokenizer.LoadFromBytesWithConfig(tokData, tokConfig)
if err != nil {
return nil, fmt.Errorf("parse tokenizer: %w", err)
}
m := &Model{
Layers: make([]*Layer, cfg.NumHiddenLayers),
Config: &cfg,
tok: tok,
}
return m, nil
}
// LoadWeights receives all tensors loaded from the manifest and assigns them
// to model fields.
func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
m.weightPrefix = resolveWeightPrefix(tensors)
prefix := m.weightPrefix
linears := model.NewLinearFactory(tensors, m.QuantGroupSize, m.QuantBits, m.QuantMode, m.TensorQuant)
embedTokens := model.MakeEmbeddingLayer(tensors, prefix+"model.embed_tokens", m.QuantGroupSize, m.QuantBits, m.QuantMode, m.TensorQuant)
if embedTokens == nil {
return fmt.Errorf("missing embedding weight: %smodel.embed_tokens.weight", prefix)
}
m.EmbedTokens = embedTokens
normWeight := tensors[prefix+"model.norm.weight"]
if normWeight == nil {
return fmt.Errorf("missing final norm weight: %smodel.norm.weight", prefix)
}
m.Norm = nn.NewRMSNorm(normWeight, m.RMSNormEps)
if m.TieWordEmbeddings {
m.LMHead = m.EmbedTokens.AsLinear()
} else if lmHead := linears.Make(prefix + "lm_head"); lmHead != nil {
m.LMHead = lmHead
} else if lmHead := linears.Make("lm_head"); lmHead != nil {
m.LMHead = lmHead
} else {
// Qwen3 checkpoints commonly tie output projection to embeddings.
m.LMHead = m.EmbedTokens.AsLinear()
}
for i := int32(0); i < m.NumHiddenLayers; i++ {
layerPrefix := fmt.Sprintf("%smodel.layers.%d", prefix, i)
layer := &Layer{
Attention: &Attention{},
MLP: &MLP{},
}
if w := tensors[layerPrefix+".input_layernorm.weight"]; w != nil {
layer.AttentionNorm = nn.NewRMSNorm(w, m.RMSNormEps)
}
if w := tensors[layerPrefix+".post_attention_layernorm.weight"]; w != nil {
layer.MLPNorm = nn.NewRMSNorm(w, m.RMSNormEps)
}
layer.Attention.QProj = linears.Make(layerPrefix + ".self_attn.q_proj")
layer.Attention.KProj = linears.Make(layerPrefix + ".self_attn.k_proj")
layer.Attention.VProj = linears.Make(layerPrefix + ".self_attn.v_proj")
layer.Attention.OProj = linears.Make(layerPrefix + ".self_attn.o_proj")
if w := tensors[layerPrefix+".self_attn.q_norm.weight"]; w != nil {
layer.Attention.QNorm = nn.NewRMSNorm(w, m.QKNormEps)
}
if w := tensors[layerPrefix+".self_attn.k_norm.weight"]; w != nil {
layer.Attention.KNorm = nn.NewRMSNorm(w, m.QKNormEps)
}
layer.MLP.GateProj = linears.Make(layerPrefix + ".mlp.gate_proj")
layer.MLP.UpProj = linears.Make(layerPrefix + ".mlp.up_proj")
layer.MLP.DownProj = linears.Make(layerPrefix + ".mlp.down_proj")
if layer.AttentionNorm == nil {
return fmt.Errorf("layer %d: missing input_layernorm", i)
}
if layer.MLPNorm == nil {
return fmt.Errorf("layer %d: missing post_attention_layernorm", i)
}
if layer.Attention.QProj == nil || layer.Attention.KProj == nil || layer.Attention.VProj == nil || layer.Attention.OProj == nil {
return fmt.Errorf("layer %d: missing attention projections", i)
}
if layer.Attention.QNorm == nil || layer.Attention.KNorm == nil {
return fmt.Errorf("layer %d: missing attention q/k norms", i)
}
if layer.MLP.GateProj == nil || layer.MLP.UpProj == nil || layer.MLP.DownProj == nil {
return fmt.Errorf("layer %d: missing mlp projections", i)
}
m.Layers[i] = layer
}
return nil
}
func (m *Model) Forward(b *batch.Batch, caches []cache.Cache) *mlx.Array {
dims := b.InputIDs.Dims()
B, L := int32(dims[0]), int32(dims[1])
positions := mlx.FromValues(b.SeqOffsets, len(b.SeqOffsets))
h := m.EmbedTokens.Forward(b.InputIDs)
for i, layer := range m.Layers {
var c cache.Cache
if caches != nil && i < len(caches) {
c = caches[i]
}
h = layer.Forward(h, b, c, positions, B, L, m.Config)
}
return m.Norm.Forward(h, m.RMSNormEps)
}
func (m *Model) Unembed(x *mlx.Array) *mlx.Array {
return m.LMHead.Forward(x)
}
func (m *Model) NumLayers() int {
return len(m.Layers)
}
func (m *Model) MaxContextLength() int {
return int(m.MaxPositionEmbeddings)
}
func (m *Model) Tokenizer() *tokenizer.Tokenizer {
return m.tok
}
func (m *Model) NewCaches() []cache.Cache {
caches := make([]cache.Cache, len(m.Layers))
for i := range caches {
caches[i] = cache.NewKVCache()
}
return caches
}
func (l *Layer) Forward(x *mlx.Array, b *batch.Batch, c cache.Cache, positions *mlx.Array, B, L int32, cfg *Config) *mlx.Array {
h := mlx.Add(x, l.Attention.Forward(l.AttentionNorm.Forward(x, cfg.RMSNormEps), b, c, positions, B, L, cfg))
return mlx.Add(h, l.MLP.Forward(l.MLPNorm.Forward(h, cfg.RMSNormEps)))
}
func (a *Attention) Forward(x *mlx.Array, b *batch.Batch, c cache.Cache, positions *mlx.Array, B, L int32, cfg *Config) *mlx.Array {
q := a.QProj.Forward(x)
k := a.KProj.Forward(x)
v := a.VProj.Forward(x)
q = mlx.Reshape(q, B, L, cfg.NumAttentionHeads, cfg.HeadDim)
k = mlx.Reshape(k, B, L, cfg.NumKeyValueHeads, cfg.HeadDim)
v = mlx.Reshape(v, B, L, cfg.NumKeyValueHeads, cfg.HeadDim)
q = a.QNorm.Forward(q, cfg.QKNormEps)
k = a.KNorm.Forward(k, cfg.QKNormEps)
q = mlx.Transpose(q, 0, 2, 1, 3)
k = mlx.Transpose(k, 0, 2, 1, 3)
v = mlx.Transpose(v, 0, 2, 1, 3)
q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, positions)
k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, positions)
// MLX SDPA supports grouped-query attention directly (Q heads can be a
// multiple of K/V heads), so avoid materializing repeated K/V tensors.
var kv nn.SDPAOption
if c != nil {
history := c.(cache.Attention).Update(b, k, v)
kv = nn.WithKVHistory(history)
} else {
kv = nn.WithKV(k, v, b.SeqQueryLens)
}
out := nn.ScaledDotProductAttention(b, q, cfg.Scale, kv, nn.WithMask(nn.CausalMask()))
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
return a.OProj.Forward(out)
}
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
return m.DownProj.Forward(mlx.SwiGLU(m.GateProj.Forward(x), m.UpProj.Forward(x)))
}

1372
x/models/qwen3_5/qwen3_5.go Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,367 @@
package qwen3_5
import (
"testing"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
func skipIfNoMLX(t *testing.T) {
t.Helper()
if err := mlx.CheckInit(); err != nil {
t.Skipf("MLX not available: %v", err)
}
}
func TestParseConfigNestedDefaults(t *testing.T) {
data := []byte(`{
"model_type": "Qwen3_5MoeForConditionalGeneration",
"text_config": {
"hidden_size": 4096,
"intermediate_size": 14336,
"num_hidden_layers": 8,
"num_attention_heads": 32,
"num_key_value_heads": 8,
"head_dim": 128,
"linear_num_value_heads": 64,
"linear_num_key_heads": 16,
"linear_key_head_dim": 128,
"linear_value_head_dim": 128,
"linear_conv_kernel_dim": 4,
"num_experts": 16,
"num_experts_per_tok": 4,
"moe_intermediate_size": 2048,
"shared_expert_intermediate_size": 4096,
"rope_parameters": {
"rope_theta": 500000,
"partial_rotary_factor": 0.5
}
}
}`)
cfg, err := parseConfig(data)
if err != nil {
t.Fatalf("parseConfig failed: %v", err)
}
if cfg.RopeTheta != 500000 {
t.Fatalf("rope theta mismatch: got %v", cfg.RopeTheta)
}
if cfg.RopeDim != 64 {
t.Fatalf("rope dim mismatch: got %d want 64", cfg.RopeDim)
}
if cfg.FullAttentionInterval != 4 {
t.Fatalf("full_attention_interval default mismatch: got %d want 4", cfg.FullAttentionInterval)
}
if !cfg.NormTopKProb {
t.Fatalf("norm_topk_prob should default to true for MoE")
}
}
func TestLayerSelectionHelpers(t *testing.T) {
cfg := &Config{
NumHiddenLayers: 6,
FullAttentionInterval: 3,
NumExperts: 8,
DecoderSparseStep: 2,
MLPOnlyLayers: []int32{1},
}
if !layerIsLinear(cfg, 0) {
t.Fatalf("layer 0 should be linear")
}
if layerIsLinear(cfg, 2) {
t.Fatalf("layer 2 should be full attention")
}
if layerUsesMoE(cfg, 1) {
t.Fatalf("layer 1 should be forced dense by mlp_only_layers")
}
if !layerUsesMoE(cfg, 3) {
t.Fatalf("layer 3 should use moe with decoder_sparse_step=2")
}
}
func TestSupportsGatherQMM(t *testing.T) {
tests := []struct {
mode string
bits int
want bool
}{
{mode: "affine", bits: 4, want: true},
{mode: "affine", bits: 8, want: true},
{mode: "mxfp8", bits: 8, want: true},
{mode: "nvfp4", bits: 4, want: true},
{mode: "mxfp4", bits: 4, want: true},
{mode: "mxfp8", bits: 4, want: false},
{mode: "affine", bits: 3, want: false},
}
for _, tt := range tests {
if got := supportsGatherQMM(tt.mode, tt.bits); got != tt.want {
t.Fatalf("supportsGatherQMM(%q, %d) = %v, want %v", tt.mode, tt.bits, got, tt.want)
}
}
}
func TestResolveTensorPathLayout(t *testing.T) {
dummy := mlx.New("dummy")
tests := []struct {
name string
key string
wantContainer string
wantModel string
}{
{
name: "standard",
key: "model.embed_tokens.weight",
wantContainer: "",
wantModel: "model.",
},
{
name: "nested language model with inner model",
key: "model.language_model.model.embed_tokens.weight",
wantContainer: "model.language_model.",
wantModel: "model.",
},
{
name: "nested language model without inner model",
key: "model.language_model.embed_tokens.weight",
wantContainer: "model.language_model.",
wantModel: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
layout := resolveTensorPathLayout(map[string]*mlx.Array{
tt.key: dummy,
})
if layout.containerPrefix != tt.wantContainer || layout.modelPrefix != tt.wantModel {
t.Fatalf(
"resolveTensorPathLayout() = {%q %q}, want {%q %q}",
layout.containerPrefix,
layout.modelPrefix,
tt.wantContainer,
tt.wantModel,
)
}
})
}
}
func TestNewCachesLayout(t *testing.T) {
m := &Model{
Config: &Config{
LinearConvKernelDim: 4,
LinearNumKeyHeads: 2,
LinearKeyHeadDim: 8,
LinearNumValueHeads: 4,
LinearValueHeadDim: 16,
},
Layers: []*Layer{
{IsLinear: true},
{IsLinear: false},
{IsLinear: true},
},
}
caches := m.NewCaches()
if len(caches) != len(m.Layers) {
t.Fatalf("len(caches) = %d, want %d", len(caches), len(m.Layers))
}
if _, ok := caches[0].(*cache.RecurrentCache); !ok {
t.Fatalf("cache[0] = %T, want *cache.RecurrentCache", caches[0])
}
if _, ok := caches[1].(*cache.KVCache); !ok {
t.Fatalf("cache[1] = %T, want *cache.KVCache", caches[1])
}
if _, ok := caches[2].(*cache.RecurrentCache); !ok {
t.Fatalf("cache[2] = %T, want *cache.RecurrentCache", caches[2])
}
}
func TestLoadWeightsPreservesLinearAttentionNormWeightDType(t *testing.T) {
skipIfNoMLX(t)
cfg := &Config{
HiddenSize: 4,
IntermediateSize: 8,
NumHiddenLayers: 2,
NumAttentionHeads: 1,
NumKeyValueHeads: 1,
HeadDim: 4,
RMSNormEps: 1e-6,
TieWordEmbeddings: true,
LayerTypes: []string{"linear", "full"},
LinearNumValueHeads: 1,
LinearNumKeyHeads: 1,
LinearKeyHeadDim: 2,
LinearValueHeadDim: 2,
LinearConvKernelDim: 4,
FullAttentionInterval: 2,
}
m := &Model{
Config: cfg,
Layers: make([]*Layer, cfg.NumHiddenLayers),
}
bf16 := mlx.DTypeBFloat16
f32 := mlx.DTypeFloat32
tensors := map[string]*mlx.Array{
"model.embed_tokens.weight": mlx.FromValues([]float32{1, 2, 3, 4, 5, 6, 7, 8}, 2, 4).AsType(bf16),
"model.norm.weight": mlx.FromValues([]float32{1, 1, 1, 1}, 4),
"model.layers.0.input_layernorm.weight": mlx.FromValues([]float32{1, 1, 1, 1}, 4),
"model.layers.0.post_attention_layernorm.weight": mlx.FromValues([]float32{1, 1, 1, 1}, 4),
"model.layers.0.linear_attn.in_proj_qkv.weight": mlx.FromValues([]float32{
1, 0, 0, 0,
0, 1, 0, 0,
0, 0, 1, 0,
0, 0, 0, 1,
1, 1, 0, 0,
0, 1, 1, 0,
}, 6, 4),
"model.layers.0.linear_attn.in_proj_z.weight": mlx.FromValues([]float32{
1, 0, 0, 0,
0, 1, 0, 0,
}, 2, 4),
"model.layers.0.linear_attn.in_proj_b.weight": mlx.FromValues([]float32{1, 0, 0, 0}, 1, 4),
"model.layers.0.linear_attn.in_proj_a.weight": mlx.FromValues([]float32{0, 1, 0, 0}, 1, 4),
"model.layers.0.linear_attn.out_proj.weight": mlx.FromValues([]float32{
1, 0,
0, 1,
1, 1,
0, 0,
}, 4, 2),
"model.layers.0.linear_attn.conv1d.weight": mlx.FromValues([]float32{
1, 0, 0, 0,
0, 1, 0, 0,
0, 0, 1, 0,
0, 0, 0, 1,
1, 1, 0, 0,
0, 1, 1, 0,
}, 6, 4),
"model.layers.0.linear_attn.norm.weight": mlx.FromValues([]float32{1, 1}, 2),
"model.layers.0.linear_attn.dt_bias": mlx.FromValues([]float32{0}, 1),
"model.layers.0.linear_attn.A_log": mlx.FromValues([]float32{0}, 1),
"model.layers.0.mlp.gate_proj.weight": mlx.FromValues([]float32{
1, 0, 0, 0,
0, 1, 0, 0,
0, 0, 1, 0,
0, 0, 0, 1,
1, 1, 0, 0,
0, 1, 1, 0,
0, 0, 1, 1,
1, 0, 0, 1,
}, 8, 4),
"model.layers.0.mlp.up_proj.weight": mlx.FromValues([]float32{
1, 0, 0, 0,
0, 1, 0, 0,
0, 0, 1, 0,
0, 0, 0, 1,
1, 1, 0, 0,
0, 1, 1, 0,
0, 0, 1, 1,
1, 0, 0, 1,
}, 8, 4),
"model.layers.0.mlp.down_proj.weight": mlx.FromValues([]float32{
1, 0, 0, 0, 0, 0, 0, 0,
0, 1, 0, 0, 0, 0, 0, 0,
0, 0, 1, 0, 0, 0, 0, 0,
0, 0, 0, 1, 0, 0, 0, 0,
}, 4, 8),
"model.layers.1.input_layernorm.weight": mlx.FromValues([]float32{1, 1, 1, 1}, 4),
"model.layers.1.post_attention_layernorm.weight": mlx.FromValues([]float32{1, 1, 1, 1}, 4),
"model.layers.1.self_attn.q_proj.weight": mlx.FromValues([]float32{
1, 0, 0, 0,
0, 1, 0, 0,
0, 0, 1, 0,
0, 0, 0, 1,
1, 1, 0, 0,
0, 1, 1, 0,
0, 0, 1, 1,
1, 0, 0, 1,
}, 8, 4),
"model.layers.1.self_attn.k_proj.weight": mlx.FromValues([]float32{
1, 0, 0, 0,
0, 1, 0, 0,
0, 0, 1, 0,
0, 0, 0, 1,
}, 4, 4),
"model.layers.1.self_attn.v_proj.weight": mlx.FromValues([]float32{
1, 0, 0, 0,
0, 1, 0, 0,
0, 0, 1, 0,
0, 0, 0, 1,
}, 4, 4),
"model.layers.1.self_attn.o_proj.weight": mlx.FromValues([]float32{
1, 0, 0, 0,
0, 1, 0, 0,
0, 0, 1, 0,
0, 0, 0, 1,
}, 4, 4),
"model.layers.1.self_attn.q_norm.weight": mlx.FromValues([]float32{1, 1, 1, 1}, 4),
"model.layers.1.self_attn.k_norm.weight": mlx.FromValues([]float32{1, 1, 1, 1}, 4),
"model.layers.1.mlp.gate_proj.weight": mlx.FromValues([]float32{
1, 0, 0, 0,
0, 1, 0, 0,
0, 0, 1, 0,
0, 0, 0, 1,
1, 1, 0, 0,
0, 1, 1, 0,
0, 0, 1, 1,
1, 0, 0, 1,
}, 8, 4),
"model.layers.1.mlp.up_proj.weight": mlx.FromValues([]float32{
1, 0, 0, 0,
0, 1, 0, 0,
0, 0, 1, 0,
0, 0, 0, 1,
1, 1, 0, 0,
0, 1, 1, 0,
0, 0, 1, 1,
1, 0, 0, 1,
}, 8, 4),
"model.layers.1.mlp.down_proj.weight": mlx.FromValues([]float32{
1, 0, 0, 0, 0, 0, 0, 0,
0, 1, 0, 0, 0, 0, 0, 0,
0, 0, 1, 0, 0, 0, 0, 0,
0, 0, 0, 1, 0, 0, 0, 0,
}, 4, 8),
}
if err := m.LoadWeights(tensors); err != nil {
t.Fatalf("LoadWeights failed: %v", err)
}
if got := m.Layers[0].InputNorm.Weight.DType(); got != f32 {
t.Fatalf("layer 0 input norm dtype = %v, want %v", got, f32)
}
if got := m.Layers[0].PostAttentionNorm.Weight.DType(); got != f32 {
t.Fatalf("layer 0 post-attn norm dtype = %v, want %v", got, f32)
}
if got := m.Layers[1].InputNorm.Weight.DType(); got != f32 {
t.Fatalf("layer 1 input norm dtype = %v, want %v", got, f32)
}
if got := m.Layers[1].PostAttentionNorm.Weight.DType(); got != f32 {
t.Fatalf("layer 1 post-attn norm dtype = %v, want %v", got, f32)
}
if got := m.Norm.Weight.DType(); got != f32 {
t.Fatalf("final norm dtype = %v, want %v", got, f32)
}
if got := m.Layers[0].Linear.NormWeight.DType(); got != f32 {
t.Fatalf("linear-attn norm dtype = %v, want %v", got, f32)
}
if got := m.Layers[1].FullAttn.QNorm.Weight.DType(); got != f32 {
t.Fatalf("q norm dtype = %v, want %v", got, f32)
}
if got := m.Layers[1].FullAttn.KNorm.Weight.DType(); got != f32 {
t.Fatalf("k norm dtype = %v, want %v", got, f32)
}
}

View File

@@ -0,0 +1,14 @@
// Package qwen3_5_moe registers Qwen 3.5 MoE architecture aliases.
package qwen3_5_moe
import (
"github.com/ollama/ollama/x/mlxrunner/model/base"
"github.com/ollama/ollama/x/models/qwen3_5"
)
func init() {
base.Register("Qwen3_5MoeForConditionalGeneration", qwen3_5.NewModel)
base.Register("Qwen3_5MoeForCausalLM", qwen3_5.NewModel)
base.Register("Qwen3NextMoeForConditionalGeneration", qwen3_5.NewModel)
base.Register("Qwen3NextMoeForCausalLM", qwen3_5.NewModel)
}