ollama source for Momentry Core verification

This commit is contained in:
Accusys
2026-05-22 17:19:10 +08:00
commit 0b31ff9135
2020 changed files with 1413145 additions and 0 deletions

266
x/models/nn/nn.go Normal file
View File

@@ -0,0 +1,266 @@
package nn
import "github.com/ollama/ollama/x/mlxrunner/mlx"
// Layer is the interface for neural network layers with a Forward method.
type Layer interface {
Forward(x *mlx.Array) *mlx.Array
}
// LinearLayer is an interface for linear layers (both regular and quantized).
type LinearLayer interface {
Forward(x *mlx.Array) *mlx.Array
OutputDim() int32
}
// EmbeddingLayer is an interface for embedding layers that can also expose a
// tied-output projection when the model reuses embedding weights as the LM head.
type EmbeddingLayer interface {
Forward(indices *mlx.Array) *mlx.Array
AsLinear() LinearLayer
}
// Conv1d applies 1D convolution over NLC input.
type Conv1d struct {
Weight *mlx.Array
Bias *mlx.Array
Stride int32
Padding int32
Dilation int32
Groups int32
}
func NewConv1d(weight, bias *mlx.Array, stride, padding, dilation, groups int32) *Conv1d {
if stride <= 0 {
stride = 1
}
if dilation <= 0 {
dilation = 1
}
if groups <= 0 {
groups = 1
}
return &Conv1d{
Weight: weight,
Bias: bias,
Stride: stride,
Padding: padding,
Dilation: dilation,
Groups: groups,
}
}
func (c *Conv1d) Forward(x *mlx.Array) *mlx.Array {
return mlx.Conv1d(x, c.Weight, c.Bias, c.Stride, c.Padding, c.Dilation, c.Groups)
}
// Linear applies an affine transformation: y = x @ W.T + b
type Linear struct {
Weight *mlx.Array
Bias *mlx.Array
}
func NewLinear(weight *mlx.Array, bias *mlx.Array) *Linear {
return &Linear{Weight: weight, Bias: bias}
}
func (l *Linear) Forward(x *mlx.Array) *mlx.Array {
w := l.Weight.Transpose(1, 0)
if l.Bias != nil && l.Bias.Valid() {
return l.Bias.Addmm(x, w, 1.0, 1.0)
}
return x.Matmul(w)
}
func (l *Linear) OutputDim() int32 {
return int32(l.Weight.Dim(0))
}
// QuantizedLinear applies an affine transformation using quantized weights.
type QuantizedLinear struct {
Weight *mlx.Array // Quantized weight data
Scales *mlx.Array // Scale factors for dequantization
QBiases *mlx.Array // Quantization biases (nil for nvfp4)
Bias *mlx.Array // Layer bias [output_dims] or nil
GlobalScale *mlx.Array // Per-tensor global scale for double-scale nvfp4 (nil for standard)
GroupSize int
Bits int
Mode string
}
func NewQuantizedLinear(weight *mlx.Array, bias *mlx.Array, groupSize, bits int, mode string) *QuantizedLinear {
qw, scales, qbiases := mlx.Quantize(weight, groupSize, bits, mode)
if qbiases != nil {
mlx.Eval(qw, scales, qbiases)
} else {
mlx.Eval(qw, scales)
}
return &QuantizedLinear{
Weight: qw,
Scales: scales,
QBiases: qbiases,
Bias: bias,
GroupSize: groupSize,
Bits: bits,
Mode: mode,
}
}
func (ql *QuantizedLinear) Forward(x *mlx.Array) *mlx.Array {
var out *mlx.Array
if ql.GlobalScale != nil {
// Double-scale nvfp4 (e.g., NVIDIA ModelOpt): standard quantized_matmul
// followed by global_scale multiply. The global_scale is a per-tensor
// F32 scalar (weight_scale_2 in NVIDIA's format).
// TODO: switch to a fused double-scale matmul once MLX has kernel
// coverage for this path.
out = mlx.QuantizedMatmul(x, ql.Weight, ql.Scales, ql.QBiases, true, ql.GroupSize, ql.Bits, ql.Mode)
out = mlx.Mul(out, ql.GlobalScale)
} else {
out = mlx.QuantizedMatmul(x, ql.Weight, ql.Scales, ql.QBiases, true, ql.GroupSize, ql.Bits, ql.Mode)
}
if ql.Bias != nil && ql.Bias.Valid() {
out = out.Add(ql.Bias)
}
return out
}
func (ql *QuantizedLinear) OutputDim() int32 {
return int32(ql.Weight.Dim(0))
}
// RMSNorm represents an RMS normalization layer.
type RMSNorm struct {
Weight *mlx.Array
Eps float32
}
func NewRMSNorm(weight *mlx.Array, eps float32) *RMSNorm {
return &RMSNorm{Weight: weight, Eps: eps}
}
func (rn *RMSNorm) Forward(x *mlx.Array, eps float32) *mlx.Array {
if eps == 0 {
eps = rn.Eps
}
return mlx.RMSNormFn(x, rn.Weight, eps)
}
// Embedding represents an embedding layer.
type Embedding struct {
Weight *mlx.Array
}
func NewEmbedding(weight *mlx.Array) *Embedding {
return &Embedding{Weight: weight}
}
func (e *Embedding) Forward(indices *mlx.Array) *mlx.Array {
return e.Weight.TakeAxis(indices, 0)
}
func (e *Embedding) AsLinear() LinearLayer {
return NewLinear(e.Weight, nil)
}
// QuantizedEmbedding performs row-wise embedding lookup from affine/nvfp4/etc.
// packed weights and dequantizes only the selected rows.
type QuantizedEmbedding struct {
Weight *mlx.Array
Scales *mlx.Array
QBiases *mlx.Array
GroupSize int
Bits int
Mode string
}
func NewQuantizedEmbedding(weight, scales, qbiases *mlx.Array, groupSize, bits int, mode string) *QuantizedEmbedding {
return &QuantizedEmbedding{
Weight: weight,
Scales: scales,
QBiases: qbiases,
GroupSize: groupSize,
Bits: bits,
Mode: mode,
}
}
func (qe *QuantizedEmbedding) Forward(indices *mlx.Array) *mlx.Array {
weight := qe.Weight.TakeAxis(indices, 0)
scales := qe.Scales.TakeAxis(indices, 0)
var qbiases *mlx.Array
if qe.QBiases != nil && qe.QBiases.Valid() {
qbiases = qe.QBiases.TakeAxis(indices, 0)
}
return mlx.Dequantize(weight, scales, qbiases, qe.GroupSize, qe.Bits, qe.Mode)
}
func (qe *QuantizedEmbedding) AsLinear() LinearLayer {
return &QuantizedLinear{
Weight: qe.Weight,
Scales: qe.Scales,
QBiases: qe.QBiases,
GroupSize: qe.GroupSize,
Bits: qe.Bits,
Mode: qe.Mode,
}
}
// LayerNorm represents a standard layer normalization layer (with bias).
type LayerNorm struct {
Weight *mlx.Array
Bias *mlx.Array
Eps float32
}
func (ln *LayerNorm) Forward(x *mlx.Array) *mlx.Array {
eps := ln.Eps
if eps == 0 {
eps = 1e-5
}
return mlx.LayerNormFn(x, ln.Weight, ln.Bias, eps)
}
// MultiLinearLayer is an interface for per-head linear layers.
type MultiLinearLayer interface {
Forward(x *mlx.Array) *mlx.Array
}
// MultiLinear performs per-head linear projections.
// Weight shape: [num_heads, output_dims, input_dims]
type MultiLinear struct {
Weight *mlx.Array
}
func NewMultiLinear(weight *mlx.Array) *MultiLinear {
return &MultiLinear{Weight: weight}
}
func (ml *MultiLinear) Forward(x *mlx.Array) *mlx.Array {
wT := ml.Weight.Transpose(0, 2, 1)
return x.Matmul(wT)
}
// ApplyCausalMask applies causal (lower triangular) mask to attention scores.
func ApplyCausalMask(scores *mlx.Array) *mlx.Array {
shape := scores.Dims()
seqLen := int32(shape[2])
mask := mlx.Tri(seqLen, seqLen, 0)
negInf := mlx.NewScalarArray(float32(-1e9))
mask = mask.ExpandDims(0).ExpandDims(0)
return mlx.Where(mask, scores, negInf)
}
// ApplyCausalMaskWithOffset applies causal mask for cached attention.
func ApplyCausalMaskWithOffset(scores *mlx.Array, offset int32) *mlx.Array {
if offset == 0 {
return ApplyCausalMask(scores)
}
shape := scores.Dims()
queryLen := int32(shape[2])
keyLen := int32(shape[3])
mask := mlx.Tri(queryLen, keyLen, int(offset))
negInf := mlx.NewScalarArray(float32(-1e9))
mask = mask.ExpandDims(0).ExpandDims(0)
return mlx.Where(mask, scores, negInf)
}

187
x/models/nn/nn_test.go Normal file
View File

@@ -0,0 +1,187 @@
package nn
import (
"math"
"testing"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
func skipIfNoMLX(t *testing.T) {
t.Helper()
if err := mlx.CheckInit(); err != nil {
t.Skipf("MLX not available: %v", err)
}
}
func approxEqual(a, b, tol float32) bool {
return float32(math.Abs(float64(a-b))) < tol
}
// TestLayerNormNoBias verifies LayerNorm without bias against manual computation.
func TestLayerNormNoBias(t *testing.T) {
skipIfNoMLX(t)
// Input: [1, 4] — single row, 4 features
x := mlx.FromValues([]float32{1, 2, 3, 4}, 1, 4)
weight := mlx.FromValues([]float32{1, 1, 1, 1}, 4)
mlx.Eval(x, weight)
ln := &LayerNorm{Weight: weight, Eps: 1e-5}
out := ln.Forward(x)
mlx.Eval(out)
data := out.Floats()
if len(data) != 4 {
t.Fatalf("expected 4 values, got %d", len(data))
}
// Manual LayerNorm: mean=2.5, var=1.25, std=sqrt(1.25+1e-5)
// normalized = (x - mean) / std
mean := float32(2.5)
variance := float32(1.25)
std := float32(math.Sqrt(float64(variance + 1e-5)))
for i, v := range []float32{1, 2, 3, 4} {
expected := (v - mean) / std
if !approxEqual(data[i], expected, 1e-4) {
t.Errorf("index %d: expected %.6f, got %.6f", i, expected, data[i])
}
}
}
// TestLayerNormWithBias verifies LayerNorm with weight and bias.
func TestLayerNormWithBias(t *testing.T) {
skipIfNoMLX(t)
x := mlx.FromValues([]float32{1, 2, 3, 4}, 1, 4)
weight := mlx.FromValues([]float32{2, 2, 2, 2}, 4)
bias := mlx.FromValues([]float32{10, 20, 30, 40}, 4)
mlx.Eval(x, weight, bias)
ln := &LayerNorm{Weight: weight, Bias: bias, Eps: 1e-5}
out := ln.Forward(x)
mlx.Eval(out)
data := out.Floats()
if len(data) != 4 {
t.Fatalf("expected 4 values, got %d", len(data))
}
mean := float32(2.5)
variance := float32(1.25)
std := float32(math.Sqrt(float64(variance + 1e-5)))
biases := []float32{10, 20, 30, 40}
for i, v := range []float32{1, 2, 3, 4} {
expected := ((v-mean)/std)*2 + biases[i]
if !approxEqual(data[i], expected, 1e-4) {
t.Errorf("index %d: expected %.6f, got %.6f", i, expected, data[i])
}
}
}
// TestLayerNormBatched verifies LayerNorm normalizes each row independently.
func TestLayerNormBatched(t *testing.T) {
skipIfNoMLX(t)
// Input: [2, 3] — two rows
x := mlx.FromValues([]float32{
1, 2, 3,
10, 20, 30,
}, 2, 3)
weight := mlx.FromValues([]float32{1, 1, 1}, 3)
mlx.Eval(x, weight)
ln := &LayerNorm{Weight: weight, Eps: 1e-5}
out := ln.Forward(x)
mlx.Eval(out)
data := out.Floats()
if len(data) != 6 {
t.Fatalf("expected 6 values, got %d", len(data))
}
// Each row should be independently normalized.
// Row 0: [1,2,3] mean=2, var=2/3
// Row 1: [10,20,30] mean=20, var=200/3
// After normalization both rows should have the same pattern
// since [10,20,30] = 10*[1,2,3], the normalized values are identical.
for i := range 3 {
if !approxEqual(data[i], data[i+3], 1e-4) {
t.Errorf("row 0 elem %d (%.6f) != row 1 elem %d (%.6f); expected identical normalized values",
i, data[i], i, data[i+3])
}
}
// Verify the normalized values sum to ~0 (mean-centered)
sum := data[0] + data[1] + data[2]
if !approxEqual(sum, 0, 1e-4) {
t.Errorf("normalized row sum should be ~0, got %.6f", sum)
}
}
// TestLayerNormDefaultEps verifies the default epsilon of 1e-5 is used when Eps is 0.
func TestLayerNormDefaultEps(t *testing.T) {
skipIfNoMLX(t)
x := mlx.FromValues([]float32{1, 2, 3, 4}, 1, 4)
weight := mlx.FromValues([]float32{1, 1, 1, 1}, 4)
mlx.Eval(x, weight)
// Eps=0 should use default 1e-5
ln0 := &LayerNorm{Weight: weight, Eps: 0}
out0 := ln0.Forward(x)
mlx.Eval(out0)
lnExplicit := &LayerNorm{Weight: weight, Eps: 1e-5}
outExplicit := lnExplicit.Forward(x)
mlx.Eval(outExplicit)
d0 := out0.Floats()
dE := outExplicit.Floats()
for i := range d0 {
if !approxEqual(d0[i], dE[i], 1e-6) {
t.Errorf("index %d: Eps=0 gave %.6f, Eps=1e-5 gave %.6f", i, d0[i], dE[i])
}
}
}
func TestQuantizedLinearMXFP4MatchesDequantizedWeight(t *testing.T) {
skipIfNoMLX(t)
weightVals := make([]float32, 3*32)
for i := range weightVals {
weightVals[i] = float32((i%11)-5) / 7
}
inputVals := make([]float32, 2*32)
for i := range inputVals {
inputVals[i] = float32((i%7)-3) / 5
}
weight := mlx.FromValues(weightVals, 3, 32).AsType(mlx.DTypeBFloat16)
input := mlx.FromValues(inputVals, 2, 32).AsType(mlx.DTypeBFloat16)
mlx.Eval(weight, input)
ql := NewQuantizedLinear(weight, nil, 32, 4, "mxfp4")
if ql.QBiases != nil {
t.Fatalf("mxfp4 qbiases = %v, want nil", ql.QBiases)
}
dequantizedWeight := mlx.Dequantize(ql.Weight, ql.Scales, ql.QBiases, 32, 4, "mxfp4")
mlx.Eval(dequantizedWeight)
qOut := ql.Forward(input).AsType(mlx.DTypeFloat32)
dOut := NewLinear(dequantizedWeight, nil).Forward(input).AsType(mlx.DTypeFloat32)
mlx.Eval(qOut, dOut)
got := qOut.Floats()
want := dOut.Floats()
if len(got) != len(want) {
t.Fatalf("output length = %d, want %d", len(got), len(want))
}
for i := range got {
if !approxEqual(got[i], want[i], 1e-3) {
t.Fatalf("output[%d] = %.6f, want %.6f", i, got[i], want[i])
}
}
}

259
x/models/nn/recurrent.go Normal file
View File

@@ -0,0 +1,259 @@
package nn
import (
"github.com/ollama/ollama/x/mlxrunner/batch"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
// RecurrentOption configures a call to CausalConv1D or GatedDelta.
type RecurrentOption func(*recurrentConfig)
// recurrentConfig is the resolved set of inputs supplied via
// RecurrentOption. Exactly one of history or (convState/deltaState)
// must be supplied per call.
type recurrentConfig struct {
history *RecurrentHistory
convState *mlx.Array
deltaState *mlx.Array
}
// WithRecurrentHistory supplies a cache's per-layer view of conv and
// delta state. The cache hides any storage layout (per-row, paged,
// gather/scatter) behind the history.
func WithRecurrentHistory(h *RecurrentHistory) RecurrentOption {
return func(c *recurrentConfig) { c.history = h }
}
// WithRecurrentState supplies explicit conv and delta state tensors
// for the no-cache path. Each wrapper consumes one of the two — pass
// nil for the unused slot when calling only one wrapper.
func WithRecurrentState(convState, deltaState *mlx.Array) RecurrentOption {
return func(c *recurrentConfig) {
c.convState = convState
c.deltaState = deltaState
}
}
// resolve applies opts and panics if WithRecurrentHistory and
// WithRecurrentState were combined or neither was supplied.
func resolveRecurrentConfig(opts []RecurrentOption) recurrentConfig {
var cfg recurrentConfig
for _, opt := range opts {
opt(&cfg)
}
haveHistory := cfg.history != nil
haveState := cfg.convState != nil || cfg.deltaState != nil
if haveHistory && haveState {
panic("WithRecurrentHistory and WithRecurrentState are mutually exclusive")
}
if !haveHistory && !haveState {
panic("no recurrent state supplied (use WithRecurrentHistory or WithRecurrentState)")
}
return cfg
}
// CausalConv1D runs a depthwise causal 1D convolution with recurrent
// state management. Prepends the prior conv state along axis 1, runs
// the conv, and returns (output, nextConv). nextConv is the trailing
// convTail positions of the concat — write it back to the cache via
// Put alongside the scan's new delta state.
//
// Conv selection: when conv is non-nil (a full nn.Conv1d layer), it
// runs through conv.Forward. Otherwise weight is treated as the bare
// depthwise kernel [C, K] and the fallback manual implementation runs.
// Exactly one of conv or weight should be non-nil.
//
// Shapes: input [B, L, D]; prior state [B, convTail, D]; output
// [B, L, D] (the causal conv strips the prepended state).
//
// Prior state comes from exactly one of WithRecurrentHistory (cache
// path) or WithRecurrentState (no-cache path).
func CausalConv1D(b *batch.Batch, input *mlx.Array, conv *Conv1d, weight *mlx.Array, convTail int, opts ...RecurrentOption) (out, nextConv *mlx.Array) {
cfg := resolveRecurrentConfig(opts)
var prior *mlx.Array
if cfg.history != nil {
prior = cfg.history.ConvState()
} else {
prior = cfg.convState
}
mask := paddingMask(b, int32(input.Dim(1)))
if mask != nil {
zero := mlx.FromValue(float32(0)).AsType(input.DType())
input = mlx.Where(mlx.ExpandDims(mask, 2), input, zero)
}
concat := mlx.Concatenate([]*mlx.Array{prior, input}, 1)
if conv != nil {
out = conv.Forward(concat)
} else {
out = depthwiseCausalConv1d(concat, weight, int32(input.Dim(1)))
}
B := int32(concat.Dim(0))
total := int32(concat.Dim(1))
D := int32(concat.Dim(2))
// Gather the tail from each of the non-padded sequence ends
if mask != nil && convTail > 0 {
offsets := make([]int32, int(B)*convTail)
for i := range int(B) {
end := b.SeqQueryLens[i]
for k := range convTail {
offsets[i*convTail+k] = end + int32(k)
}
}
positions := mlx.NewArrayInt32(offsets, []int32{B, int32(convTail), 1})
nextConv = mlx.TakeAlongAxis(concat, positions, 1)
} else {
nextConv = mlx.SliceStartStop(concat,
[]int32{0, total - int32(convTail), 0},
[]int32{B, total, D})
}
return out, nextConv
}
// depthwiseCausalConv1d implements a depthwise 1D causal convolution
// manually as a sum of kernel-offset multiplies. x has shape
// [B, inLen, C], weight has shape [C, K]; output has shape [B, outLen, C]
// where outLen = inLen - K + 1 (the caller passes outLen to avoid the
// subtraction). Used as the fallback path in CausalConv1D when no
// full Conv1d layer is configured.
func depthwiseCausalConv1d(x, w *mlx.Array, outLen int32) *mlx.Array {
if x == nil || w == nil {
return nil
}
if w.NumDims() != 2 {
return nil
}
B := int32(x.Dim(0))
C := int32(w.Dim(0))
K := int32(w.Dim(1))
var out *mlx.Array
for i := range K {
seg := mlx.SliceStartStop(x, []int32{0, i, 0}, []int32{B, i + outLen, C})
wi := mlx.SliceStartStop(w, []int32{0, i}, []int32{C, i + 1})
wi = mlx.Reshape(wi, 1, 1, C)
term := mlx.Mul(seg, wi)
if out == nil {
out = term
} else {
out = mlx.Add(out, term)
}
}
return out
}
// GatedDelta wraps mlx.FastGatedDelta with recurrent state management.
// Reads prior delta state from the supplied option and returns
// (output, newDelta). Write newDelta back via the cache's Put
// alongside the conv wrapper's nextConv.
//
// Shape conventions:
//
// q: [B, L, numKeyHeads, headKDim]
// k: [B, L, numKeyHeads, headKDim]
// v: [B, L, numValueHeads, headVDim]
// state: [B, numValueHeads, headVDim, headKDim]
//
// Prior state comes from exactly one of WithRecurrentHistory (cache
// path) or WithRecurrentState (no-cache path).
func GatedDelta(b *batch.Batch, q, k, v, gDecay, beta *mlx.Array, opts ...RecurrentOption) (out, newDelta *mlx.Array) {
cfg := resolveRecurrentConfig(opts)
var state *mlx.Array
if cfg.history != nil {
state = cfg.history.DeltaState()
} else {
state = cfg.deltaState
}
return mlx.FastGatedDelta(q, k, v, gDecay, beta, state, paddingMask(b, int32(q.Dim(1))))
}
// RecurrentHistory is an opaque per-forward view a recurrent cache
// hands to the SSM kernel wrappers — prior conv and delta state
// tensors. Models do not construct this directly; pass it through
// via WithRecurrentHistory, or use WithRecurrentState on the no-cache
// path.
//
// Opaque structure to model code; accessors ConvState/DeltaState
// provide the escape hatch for custom SSM paths.
type RecurrentHistory struct {
convState, deltaState *mlx.Array
}
// NewRecurrentHistory constructs a RecurrentHistory. Intended for
// cache implementations across packages; model code uses
// WithRecurrentHistory / WithRecurrentState instead.
func NewRecurrentHistory(convState, deltaState *mlx.Array) *RecurrentHistory {
return &RecurrentHistory{convState: convState, deltaState: deltaState}
}
// ConvState returns the current convolution state tensor.
//
// Last-resort escape hatch for custom SSM paths — may force a slow
// materialization to canonical form depending on the cache's
// internal storage. Prefer CausalConv1D via WithRecurrentHistory.
func (h *RecurrentHistory) ConvState() *mlx.Array { return h.convState }
// DeltaState returns the current delta state tensor.
//
// Last-resort escape hatch for custom SSM paths — may force a slow
// materialization to canonical form depending on the cache's
// internal storage. Prefer GatedDelta via WithRecurrentHistory.
func (h *RecurrentHistory) DeltaState() *mlx.Array { return h.deltaState }
type paddingMaskInputs struct {
batch *batch.Batch
L int32
}
func (in paddingMaskInputs) build() *mlx.Array {
B := len(in.batch.SeqQueryLens)
needed := false
for i := range B {
if in.batch.SeqQueryLens[i] < in.L {
needed = true
break
}
}
if !needed {
return nil
}
L := int(in.L)
vals := make([]bool, B*L)
for i := range B {
n := int(in.batch.SeqQueryLens[i])
base := i * L
for j := range n {
vals[base+j] = true
}
}
return mlx.FromValues(vals, B, L)
}
// paddingMask derives a [B, L] bool mask from b.SeqQueryLens for
// right-padded inputs (real tokens at [0, len_i), padding at
// [len_i, L)). Returns nil when b has no rows or every row is full —
// the no-padding fast path that costs nothing extra.
func paddingMask(b *batch.Batch, L int32) *mlx.Array {
inputs := paddingMaskInputs{batch: b, L: L}
if cached, ok := b.Memo.Get(inputs); ok {
return cached.(*mlx.Array)
}
mask := inputs.build()
b.Memo.Put(inputs, mask)
return mask
}

View File

@@ -0,0 +1,340 @@
package nn
import (
"math"
"testing"
"github.com/ollama/ollama/x/mlxrunner/batch"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
func ones(dtype mlx.DType, shape ...int) *mlx.Array {
return mlx.AddScalar(mlx.Zeros(dtype, shape...), 1)
}
// fromValues builds a tensor with sequentially-numbered float32
// values so element-by-element parity actually exercises the kernel.
func fromValues(seed float32, shape ...int) *mlx.Array {
n := 1
for _, d := range shape {
n *= d
}
vals := make([]float32, n)
for i := range vals {
vals[i] = seed + 0.1*float32(i)
}
return mlx.FromValues(vals, shape...)
}
// depthwiseCausalRef is a Go-side reference for the depthwise causal
// 1D conv fallback. concat is [B, total, C], weight is [C, K], output
// is [B, total-K+1, C]. Used to anchor the wrapper's parity tests.
func depthwiseCausalRef(concat, weight *mlx.Array) []float32 {
mlx.Eval(concat, weight)
cVals := concat.Floats()
wVals := weight.Floats()
B := concat.Dim(0)
total := concat.Dim(1)
C := concat.Dim(2)
K := weight.Dim(1)
outLen := total - K + 1
out := make([]float32, B*outLen*C)
for bi := range B {
for q := range outLen {
for c := range C {
var sum float32
for k := range K {
x := cVals[bi*total*C+(q+k)*C+c]
w := wVals[c*K+k]
sum += x * w
}
out[bi*outLen*C+q*C+c] = sum
}
}
}
return out
}
// TestCausalConv1DParity drives the wrapper with non-trivial prior,
// input, and weight values, then compares against a direct depthwise-
// causal-conv reference.
func TestCausalConv1DParity(t *testing.T) {
skipIfNoMLX(t)
B, L, D, convTail := 1, 4, 3, 2
K := convTail + 1
input := fromValues(0.5, B, L, D)
prior := fromValues(-0.3, B, convTail, D)
weight := fromValues(0.2, D, K)
out, nextConv := CausalConv1D(&batch.Batch{}, input, nil, weight, convTail, WithRecurrentState(prior, nil))
mlx.Eval(out, nextConv)
concat := mlx.Concatenate([]*mlx.Array{prior, input}, 1)
want := depthwiseCausalRef(concat, weight)
got := out.Floats()
if len(got) != len(want) {
t.Fatalf("out len = %d, want %d", len(got), len(want))
}
for i := range want {
if math.Abs(float64(got[i]-want[i])) > 1e-5 {
t.Fatalf("out[%d]: got %v, want %v", i, got[i], want[i])
}
}
// nextConv (no padding) is the trailing convTail rows of concat.
mlx.Eval(concat)
cVals := concat.Floats()
total := concat.Dim(1)
wantTail := make([]float32, B*convTail*D)
for bi := range B {
for k := range convTail {
for d := range D {
wantTail[bi*convTail*D+k*D+d] = cVals[bi*total*D+(total-convTail+k)*D+d]
}
}
}
tail := nextConv.Floats()
if len(tail) != len(wantTail) {
t.Fatalf("nextConv len = %d, want %d", len(tail), len(wantTail))
}
for i := range wantTail {
if tail[i] != wantTail[i] {
t.Fatalf("nextConv[%d]: got %v, want %v", i, tail[i], wantTail[i])
}
}
}
// TestCausalConv1DPaddedRowParity drives a B=2 batch with one short
// row (qLen<L). For the short row, (a) `out` positions [0..qLen)
// must equal a B=1 reference at length qLen, (b) `nextConv` for the
// short row must be the row's last convTail real positions (not the
// padded tail), (c) the full row must be unaffected.
func TestCausalConv1DPaddedRowParity(t *testing.T) {
skipIfNoMLX(t)
L, D, convTail := 4, 3, 2
qLenShort := 2
K := convTail + 1
weight := fromValues(0.2, D, K)
priorFull := fromValues(0.5, 2, convTail, D)
priorShort := mlx.SliceStartStop(priorFull,
[]int32{1, 0, 0},
[]int32{2, int32(convTail), int32(D)})
// Pad row 1 with arbitrary values past qLenShort — the wrapper
// must zero them before convolving. Distinct values let us catch
// any leak.
inputFull := fromValues(1.0, 1, L, D)
inputShortReal := mlx.FromValues([]float32{
2.0, 2.1, 2.2,
2.3, 2.4, 2.5,
}, 1, qLenShort, D)
inputShortPad := mlx.FromValues([]float32{
99, 99, 99,
99, 99, 99,
}, 1, L-qLenShort, D)
inputShortFull := mlx.Concatenate([]*mlx.Array{inputShortReal, inputShortPad}, 1)
input := mlx.Concatenate([]*mlx.Array{inputFull, inputShortFull}, 0)
b := &batch.Batch{
InputIDs: mlx.Zeros(mlx.DTypeInt32, 2, L),
SeqOffsets: []int32{0, 0},
SeqQueryLens: []int32{int32(L), int32(qLenShort)},
}
out, nextConv := CausalConv1D(b, input, nil, weight, convTail, WithRecurrentState(priorFull, nil))
mlx.Eval(out, nextConv)
// Reference for row 0: B=1 unpadded length-L call.
refOut0, refNextConv0 := CausalConv1D(&batch.Batch{},
inputFull, nil, weight, convTail,
WithRecurrentState(mlx.SliceStartStop(priorFull,
[]int32{0, 0, 0},
[]int32{1, int32(convTail), int32(D)}), nil))
// Reference for row 1: B=1 unpadded length-qLenShort call.
refOut1, refNextConv1 := CausalConv1D(&batch.Batch{},
inputShortReal, nil, weight, convTail,
WithRecurrentState(priorShort, nil))
mlx.Eval(refOut0, refNextConv0, refOut1, refNextConv1)
gotOut := out.Floats()
wantOut0 := refOut0.Floats()
wantOut1 := refOut1.Floats()
for q := range L {
for d := range D {
i := q*D + d
if gotOut[i] != wantOut0[i] {
t.Fatalf("row 0 out[q=%d,d=%d]: got %v, want %v", q, d, gotOut[i], wantOut0[i])
}
}
}
for q := range qLenShort {
for d := range D {
gotI := L*D + q*D + d
refI := q*D + d
if math.Abs(float64(gotOut[gotI]-wantOut1[refI])) > 1e-5 {
t.Fatalf("row 1 real out[q=%d,d=%d]: got %v, want %v", q, d, gotOut[gotI], wantOut1[refI])
}
}
}
// nextConv: row 0 unaffected, row 1 must be the row's real tail
// (positions [qLenShort - convTail, qLenShort) of the per-row
// concat, i.e. the last two real input rows in this setup).
gotTail := nextConv.Floats()
wantTail0 := refNextConv0.Floats()
wantTail1 := refNextConv1.Floats()
for k := range convTail {
for d := range D {
i := k*D + d
if gotTail[i] != wantTail0[i] {
t.Fatalf("row 0 nextConv[k=%d,d=%d]: got %v, want %v", k, d, gotTail[i], wantTail0[i])
}
}
}
for k := range convTail {
for d := range D {
gotI := convTail*D + k*D + d
refI := k*D + d
if gotTail[gotI] != wantTail1[refI] {
t.Fatalf("row 1 nextConv[k=%d,d=%d]: got %v, want %v (must come from real positions, not the padded tail)",
k, d, gotTail[gotI], wantTail1[refI])
}
}
}
}
func TestGatedDeltaZeroFallback(t *testing.T) {
skipIfNoMLX(t)
B, L, nK, nV, dK, dV := 1, 2, 1, 1, 4, 4
q := ones(mlx.DTypeFloat32, B, L, nK, dK)
k := ones(mlx.DTypeFloat32, B, L, nK, dK)
v := ones(mlx.DTypeFloat32, B, L, nV, dV)
gDecay := ones(mlx.DTypeFloat32, B, L, nV)
beta := ones(mlx.DTypeFloat32, B, L, nV)
zero := mlx.Zeros(mlx.DTypeFloat32, B, nV, dV, dK)
outA, stateA := GatedDelta(&batch.Batch{}, q, k, v, gDecay, beta, WithRecurrentState(nil, zero))
outB, stateB := mlx.FastGatedDelta(q, k, v, gDecay, beta, zero, nil)
mlx.Eval(outA, stateA, outB, stateB)
gotOut, wantOut := outA.Floats(), outB.Floats()
for i := range wantOut {
if gotOut[i] != wantOut[i] {
t.Fatalf("output[%d]: wrapper=%v direct=%v", i, gotOut[i], wantOut[i])
}
}
gotState, wantState := stateA.Floats(), stateB.Floats()
for i := range wantState {
if gotState[i] != wantState[i] {
t.Fatalf("state[%d]: wrapper=%v direct=%v", i, gotState[i], wantState[i])
}
}
}
func TestGatedDeltaUsesPriorState(t *testing.T) {
skipIfNoMLX(t)
B, L, nK, nV, dK, dV := 1, 2, 1, 1, 4, 4
q := ones(mlx.DTypeFloat32, B, L, nK, dK)
k := ones(mlx.DTypeFloat32, B, L, nK, dK)
v := ones(mlx.DTypeFloat32, B, L, nV, dV)
gDecay := ones(mlx.DTypeFloat32, B, L, nV)
beta := ones(mlx.DTypeFloat32, B, L, nV)
priorState := mlx.MulScalar(ones(mlx.DTypeFloat32, B, nV, dV, dK), 3)
outA, _ := GatedDelta(&batch.Batch{}, q, k, v, gDecay, beta, WithRecurrentState(nil, priorState))
outB, _ := mlx.FastGatedDelta(q, k, v, gDecay, beta, priorState, nil)
mlx.Eval(outA, outB)
gotOut, wantOut := outA.Floats(), outB.Floats()
for i := range wantOut {
if gotOut[i] != wantOut[i] {
t.Fatalf("output[%d]: wrapper=%v direct=%v", i, gotOut[i], wantOut[i])
}
}
}
// TestGatedDeltaPaddedRowParity drives a B=2 batch where row 1 is
// short (qLen < L). The wrapper must substitute neutral values
// (q=k=v=beta=0, g=1) at row 1's padded positions so the recurrence
// is a no-op there — and row 1's final state must equal the state
// after its last real token. Pinned via parity against a B=1 length-
// qLen call on the same row.
func TestGatedDeltaPaddedRowParity(t *testing.T) {
skipIfNoMLX(t)
L, nK, nV, dK, dV := 4, 1, 1, 4, 4
qLenShort := 2
makeRows := func(seedA, seedB float32, shape ...int) *mlx.Array {
// Build a rank-(len(shape)+1) tensor with B=2 rows from two
// distinct seeds so the rows are not accidentally identical.
n := 1
for _, d := range shape {
n *= d
}
vals := make([]float32, 2*n)
for i := range n {
vals[i] = seedA + 0.1*float32(i)
}
for i := range n {
vals[n+i] = seedB + 0.1*float32(i)
}
full := append([]int{2}, shape...)
return mlx.FromValues(vals, full...)
}
q := makeRows(0.5, -0.5, L, nK, dK)
k := makeRows(0.7, -0.7, L, nK, dK)
v := makeRows(0.3, -0.3, L, nV, dV)
gDecay := makeRows(0.1, -0.1, L, nV)
beta := makeRows(0.4, -0.4, L, nV)
priorState := makeRows(0.2, -0.2, nV, dV, dK)
b := &batch.Batch{
InputIDs: mlx.Zeros(mlx.DTypeInt32, 2, L),
SeqOffsets: []int32{0, 0},
SeqQueryLens: []int32{int32(L), int32(qLenShort)},
}
_, state := GatedDelta(b, q, k, v, gDecay, beta, WithRecurrentState(nil, priorState))
mlx.Eval(state)
// Reference for row 1: B=1 length-qLenShort call against the
// row's real prefix and its prior state slice.
row1Slice := func(a *mlx.Array, axisLens ...int32) *mlx.Array {
dims := a.Dims()
start := make([]int32, len(dims))
stop := make([]int32, len(dims))
start[0], stop[0] = 1, 2
for i := 1; i < len(dims); i++ {
stop[i] = int32(dims[i])
}
// Optionally truncate axis 1 (sequence axis) to qLenShort.
if len(axisLens) >= 1 && len(dims) >= 2 {
stop[1] = axisLens[0]
}
return mlx.SliceStartStop(a, start, stop)
}
q1 := row1Slice(q, int32(qLenShort))
k1 := row1Slice(k, int32(qLenShort))
v1 := row1Slice(v, int32(qLenShort))
gDecay1 := row1Slice(gDecay, int32(qLenShort))
beta1 := row1Slice(beta, int32(qLenShort))
priorRow1 := row1Slice(priorState)
_, refState := mlx.FastGatedDelta(q1, k1, v1, gDecay1, beta1, priorRow1, nil)
mlx.Eval(refState)
gotState := state.Floats()
wantState := refState.Floats()
row1Stride := nV * dV * dK
for i := range row1Stride {
gotV := gotState[row1Stride+i]
wantV := wantState[i]
if math.Abs(float64(gotV-wantV)) > 1e-4 {
t.Fatalf("row 1 final state[%d]: got %v, want %v", i, gotV, wantV)
}
}
}

129
x/models/nn/rope.go Normal file
View File

@@ -0,0 +1,129 @@
package nn
import (
"math"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
// RopeParameters carries common RoPE metadata embedded in model configs.
type RopeParameters struct {
RopeTheta float32 `json:"rope_theta"`
RopeType string `json:"rope_type"`
Type string `json:"type"`
PartialRotaryFactor float32 `json:"partial_rotary_factor"`
Factor float32 `json:"factor"`
OriginalMaxPositionEmbeddings int32 `json:"original_max_position_embeddings"`
BetaFast float32 `json:"beta_fast"`
BetaSlow float32 `json:"beta_slow"`
AttentionFactor float32 `json:"attention_factor"`
}
// TypeName returns rope_type when present, falling back to type.
func (rp *RopeParameters) TypeName() string {
if rp == nil {
return ""
}
if rp.RopeType != "" {
return rp.RopeType
}
return rp.Type
}
// BuildYarnRopeFreqs returns YaRN rotary frequencies and the mscale value.
func BuildYarnRopeFreqs(dim int, base float32, rp *RopeParameters) (*mlx.Array, float32) {
if rp == nil || dim <= 0 {
return nil, 1
}
factor := rp.Factor
if factor <= 0 {
factor = 1
}
attentionFactor := rp.AttentionFactor
if attentionFactor == 0 && factor > 1 {
attentionFactor = float32(0.1*math.Log(float64(factor)) + 1.0)
} else if attentionFactor == 0 {
attentionFactor = 1
}
if factor <= 1 {
return nil, attentionFactor
}
originalMax := rp.OriginalMaxPositionEmbeddings
if originalMax <= 0 {
originalMax = 4096
}
betaFast := rp.BetaFast
if betaFast == 0 {
betaFast = 32
}
betaSlow := rp.BetaSlow
if betaSlow == 0 {
betaSlow = 1
}
half := dim / 2
low, high := yarnCorrectionRange(betaFast, betaSlow, dim, base, originalMax)
freqs := make([]float32, half)
for i := range half {
posFreq := math.Pow(float64(base), float64(2*i)/float64(dim))
invExtrapolation := 1.0 / posFreq
invInterpolation := 1.0 / (float64(factor) * posFreq)
ramp := yarnRamp(float64(i), low, high)
mask := 1 - ramp
inv := invInterpolation*(1-mask) + invExtrapolation*mask
freqs[i] = float32(1.0 / inv)
}
arr := mlx.FromValues(freqs, half)
mlx.Eval(arr)
return arr, attentionFactor
}
func yarnCorrectionRange(betaFast, betaSlow float32, dim int, base float32, maxPosition int32) (float64, float64) {
findDim := func(rot float32) float64 {
return float64(dim) * math.Log(float64(maxPosition)/(float64(rot)*2*math.Pi)) / (2 * math.Log(float64(base)))
}
low := math.Floor(findDim(betaFast))
high := math.Ceil(findDim(betaSlow))
low = math.Max(low, 0)
high = math.Min(high, float64(dim-1))
if low == high {
high += 0.001
}
return low, high
}
func yarnRamp(i, low, high float64) float64 {
v := (i - low) / (high - low)
if v < 0 {
return 0
}
if v > 1 {
return 1
}
return v
}
// ScaleRotaryPart applies YaRN's mscale to only the rotated dimensions.
func ScaleRotaryPart(x *mlx.Array, ropeDim int, scale float32) *mlx.Array {
if scale == 1 {
return x
}
dims := x.Dims()
last := dims[len(dims)-1]
if ropeDim >= last {
return mlx.MulScalar(x, scale)
}
start := make([]int32, len(dims))
stopRot := make([]int32, len(dims))
stopPass := make([]int32, len(dims))
startPass := make([]int32, len(dims))
for i, dim := range dims {
stopRot[i] = int32(dim)
stopPass[i] = int32(dim)
}
stopRot[len(dims)-1] = int32(ropeDim)
startPass[len(dims)-1] = int32(ropeDim)
rot := mlx.MulScalar(mlx.SliceStartStop(x, start, stopRot), scale)
pass := mlx.SliceStartStop(x, startPass, stopPass)
return mlx.Concatenate([]*mlx.Array{rot, pass}, -1)
}

578
x/models/nn/sdpa.go Normal file
View File

@@ -0,0 +1,578 @@
package nn
import (
"encoding/binary"
"math"
"github.com/ollama/ollama/x/mlxrunner/batch"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
// SDPAOption configures a call to ScaledDotProductAttention.
type SDPAOption func(*sdpaConfig)
type sdpaConfig struct {
// Exactly one of (k,v,kLens) or history supplies keys/values.
k, v *mlx.Array
kLens []int32
history *KVHistory
// Optional model-supplied logical mask.
mask AttentionMask
}
// WithKVHistory supplies a cache's per-layer view of K and V. The
// cache hides any storage layout (sliding window, ring buffer,
// k-padding) behind the history.
func WithKVHistory(h *KVHistory) SDPAOption {
return func(c *sdpaConfig) { c.history = h }
}
// WithMLAHistory supplies a cache's per-layer view for absorbed MLA
// attention, where V is the first valueDim positions of K.
func WithMLAHistory(h *KVHistory, valueDim int) SDPAOption {
v := h.K().Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(), mlx.Slice(0, valueDim))
return WithKVHistory(&KVHistory{k: h.K(), v: v, applier: h.applier})
}
// WithKV supplies explicit K/V tensors for the no-cache path. kLens
// gives per-row real key extents — pass b.SeqQueryLens for self-
// attention, or the caller's own extents for cross-attention.
func WithKV(k, v *mlx.Array, kLens []int32) SDPAOption {
return func(c *sdpaConfig) { c.k = k; c.v = v; c.kLens = kLens }
}
// WithMask supplies the model's logical-coordinate mask.
func WithMask(m AttentionMask) SDPAOption {
return func(c *sdpaConfig) { c.mask = m }
}
// ScaledDotProductAttention runs the fast SDPA kernel against q and
// the keys/values supplied via exactly one of WithKV or
// WithKVHistory. Automatically applies any Q/K padding masking required
// for padded batches.
func ScaledDotProductAttention(b *batch.Batch, q *mlx.Array, scale float32, opts ...SDPAOption) *mlx.Array {
var cfg sdpaConfig
for _, opt := range opts {
opt(&cfg)
}
haveKV := cfg.k != nil || cfg.v != nil
haveHistory := cfg.history != nil
if haveKV && haveHistory {
panic("nn.ScaledDotProductAttention: WithKV and WithKVHistory are mutually exclusive")
}
if !haveKV && !haveHistory {
panic("nn.ScaledDotProductAttention: no keys/values supplied (use WithKV or WithKVHistory)")
}
k, v := cfg.k, cfg.v
var applier MaskApplier
if cfg.history != nil {
k = cfg.history.K()
v = cfg.history.V()
applier = cfg.history.applier
}
inputs := dispatchInputs{
batch: b,
mask: cfg.mask,
applier: applier,
K: k.Dim(2),
dtype: k.DType(),
kLens: newKLensKey(cfg.kLens),
}
if cached, ok := b.Memo.Get(inputs); ok {
d := cached.(sdpaDispatch)
return mlx.FastScaledDotProductAttention(q, k, v, scale, d.mode, d.arr)
}
d := inputs.resolve()
b.Memo.Put(inputs, d)
return mlx.FastScaledDotProductAttention(q, k, v, scale, d.mode, d.arr)
}
// sdpaDispatch is the resolved kernel call for a given SDPA key —
// either a flag-mode fast path (mode "" or "causal", arr nil) or an
// array-mode call with a materialized tensor. Memoized on b.Memo so
// sibling layers skip applier composition, padding build, and AsArray.
type sdpaDispatch struct {
mode string
arr *mlx.Array
}
// dispatchInputs bundles every value resolve reads and doubles as
// the Memo map key. All fields are comparable: batch is a
// *batch.Batch pointer, the applier interface is comparable when
// its concrete type is, and kLens is a kLensKey string that hashes
// by content.
//
// Making resolve a method on this struct is the enforcement — any
// new dependency must be added as a field, which automatically
// participates in the map key.
//
// applier and kLens are mutually exclusive by construction:
// WithKVHistory sets applier (which owns any K-padding in its output
// space) and leaves kLens ""; WithKV sets kLens and leaves applier nil.
type dispatchInputs struct {
batch *batch.Batch
mask AttentionMask
applier MaskApplier
K int
dtype mlx.DType
kLens kLensKey
}
// kLensKey is a comparable encoding of an int32 slice (four bytes
// per element, native endian) so it can live in a struct used as a
// map key. Decode back via Int32s.
type kLensKey string
func newKLensKey(vs []int32) kLensKey {
if len(vs) == 0 {
return ""
}
buf := make([]byte, len(vs)*4)
for i, v := range vs {
binary.NativeEndian.PutUint32(buf[i*4:], uint32(v))
}
return kLensKey(buf)
}
// Int32s decodes the key back to a fresh []int32.
func (k kLensKey) Int32s() []int32 {
if len(k) == 0 {
return nil
}
b := []byte(k)
out := make([]int32, len(b)/4)
for i := range out {
out[i] = int32(binary.NativeEndian.Uint32(b[i*4:]))
}
return out
}
// resolve composes model + padding + storage contributions and
// returns the kernel dispatch decision. Reads only from inputs; any
// new input must be added to dispatchInputs.
//
// Order matters: QPaddingMask is added in logical Q-space before the
// applier runs, so an applier that remaps coordinates receives the
// full logical mask. The applier and KPaddingMask branches are
// mutually exclusive — on the applier path the output may be in a
// remapped K space, so the applier owns any K-padding; on the
// WithKV path kLens describes the direct K tensor, which shares
// logical K space with QPaddingMask.
func (inputs dispatchInputs) resolve() sdpaDispatch {
mask := inputs.mask.Intersect(QPaddingMask(inputs.batch, inputs.dtype))
if inputs.applier != nil {
mask = inputs.applier.ApplyMask(mask)
} else if inputs.kLens != "" {
mask = mask.Intersect(KPaddingMask(inputs.batch, inputs.K, inputs.kLens.Int32s(), inputs.dtype))
}
switch {
case mask.IsZero():
return sdpaDispatch{mode: ""}
case mask.IsCausal():
if inputs.batch.InputIDs.Dim(1) == 1 {
// At L=1 the causal "k > q" constraint is redundant -
// drop it so the kernel dispatches to the no-mask fast path.
return sdpaDispatch{mode: ""}
} else {
return sdpaDispatch{mode: "causal"}
}
default:
return sdpaDispatch{mode: "array", arr: mask.AsArray(inputs.batch, inputs.K, inputs.dtype)}
}
}
// MaskApplier composes a cache's storage-mask contribution onto a
// fully-composed logical mask. The returned mask may live in the
// applier's own coordinate system (e.g. a rotated or compacted K layout),
// so any addition in logical K space must happen before the applier runs.
// SDPA does not add KPaddingMask on this path — the applier owns any
// K-padding its output needs.
//
// Implementations must be comparable struct values whose fields
// capture everything the composition depends on (no slice, map, or
// func fields); the value doubles as the applier's identity in
// SDPA's dispatch-cache key, where a non-comparable concrete type
// would panic at map insertion. A nil MaskApplier means "no storage
// contribution".
type MaskApplier interface {
ApplyMask(logical AttentionMask) AttentionMask
}
// KVHistory is the per-forward view a KV cache hands to SDPA:
// post-Update K and V plus an optional MaskApplier that composes
// the cache's storage mask onto the caller's model mask.
type KVHistory struct {
k, v *mlx.Array
applier MaskApplier
}
// NewKVHistory constructs a KVHistory. Intended for
// cache implementations across packages; model code uses
// WithKVHistory / WithKV instead.
func NewKVHistory(k, v *mlx.Array, applier MaskApplier) *KVHistory {
return &KVHistory{k: k, v: v, applier: applier}
}
// K returns the post-Update keys tensor.
//
// Last-resort escape hatch for custom attention paths — may force a
// slow materialization to canonical form depending on the cache's
// internal storage. Prefer ScaledDotProductAttention via
// WithKVHistory.
func (h *KVHistory) K() *mlx.Array { return h.k }
// V returns the post-Update values tensor.
//
// Last-resort escape hatch for custom attention paths — may force a
// slow materialization to canonical form depending on the cache's
// internal storage. Prefer ScaledDotProductAttention via
// WithKVHistory.
func (h *KVHistory) V() *mlx.Array { return h.v }
// Mask returns the final AttentionMask for this layer's SDPA —
// cache storage restrictions composed onto the caller's fully-
// composed logical mask.
//
// Last-resort escape hatch for custom attention paths — may force a
// slow materialization to canonical form depending on the cache's
// internal storage. Prefer ScaledDotProductAttention via
// WithKVHistory.
func (h *KVHistory) Mask(logical AttentionMask) AttentionMask {
if h.applier == nil {
return logical
}
return h.applier.ApplyMask(logical)
}
// AttentionMask describes an attention mask in four states:
// - zero value: no mask.
// - flag-form causal (causal=true only): dispatches to the MLX
// kernel's mask_mode="causal" fast path.
// - causal with relaxation rectangles: a causal mask with
// bidirectional attention rectangles, such as for images.
// - additive tensor (array!=nil): broadcast-compatible with
// [B, 1, L, K]; contributed by a custom mask, helpers such as
// QPaddingMask, KPaddingMask, or cache appliers and accumulated
// via Intersect.
//
// The mask is a pure logical description — it carries no batch and
// exists independent of cache storage layout.
//
// All fields are comparable, so AttentionMask values compare with ==
// by full identity — SDPA uses this directly as a dispatch-cache key.
type AttentionMask struct {
causal bool
relaxations *relaxNode
array *mlx.Array
}
type relaxRect struct {
seq, qLo, qHi, kLo, kHi int
}
// relaxNode is a singly-linked list node holding relaxation
// rectangles. Each AttentionMask must have a fresh set of
// nodes to avoid false sharing between masks.
type relaxNode struct {
rect relaxRect
next *relaxNode
}
// CausalMask returns a flag-form causal mask. The mask stays
// tensor-free — hitting the kernel's mask_mode="causal" fast path —
// until something composes a relaxation, padding, or applier tensor
// onto it; then SDPA materializes via AsArray.
func CausalMask() AttentionMask {
return AttentionMask{causal: true}
}
// ArrayMask wraps an explicit additive tensor broadcast-compatible
// with [B, 1, L, K].
func ArrayMask(a *mlx.Array) AttentionMask {
return AttentionMask{array: a}
}
// IsZero reports whether the mask is the zero value (no mask at all).
func (m AttentionMask) IsZero() bool {
return !m.causal && m.array == nil && m.relaxations == nil
}
// IsCausal reports whether the mask is pure flag-form causal — no
// relaxations and no accumulated array. SDPA dispatches to the
// kernel's "causal" fast path on this; any padding, applier
// contribution, or relaxation falls to the array path.
func (m AttentionMask) IsCausal() bool {
return m.causal && m.relaxations == nil && m.array == nil
}
// Relax records a relaxation rectangle for batch sequence seq —
// positions (q, k) with q in [qLo, qHi) and k in [kLo, kHi) become
// freely attendable regardless of the causal base. Coordinates are
// absolute sequence positions on both axes, matching how causal is
// defined (k <= q). Multiple calls compose as a union per sequence.
//
// Rectangles that cannot change any cell — empty or already fully
// inside causal (kHi-1 <= qLo) — are dropped so IsCausal stays true
// and the mask remains on the kernel's fast path.
//
// Panics on pure ArrayMask (the caller owns the tensor and should
// modify it directly) or on the zero mask (nothing to relax).
func (m AttentionMask) Relax(seq, qLo, qHi, kLo, kHi int) AttentionMask {
if !m.causal {
if m.array != nil {
panic("AttentionMask.Relax: cannot relax a pure ArrayMask; modify the tensor directly")
}
panic("AttentionMask.Relax: cannot relax a zero mask")
}
if qLo >= qHi || kLo >= kHi {
return m
}
if kHi-1 <= qLo {
return m
}
m.relaxations = &relaxNode{
rect: relaxRect{seq: seq, qLo: qLo, qHi: qHi, kLo: kLo, kHi: kHi},
next: m.relaxations,
}
return m
}
// Intersect returns the element-wise sum of this mask and other. Masks are
// additive and apply before softmax, so this is intersection
// semantics — a position is valid only if both sides have 0 there.
//
// At AsArray time a causal+Relax+array mask materializes as: causal
// writes -inf into the upper triangle, Relax overwrites its
// rectangles back to 0, then array is added on top — restricting 0
// cells further or no-op'ing on -inf cells.
func (m AttentionMask) Intersect(other AttentionMask) AttentionMask {
if m.IsZero() {
return other
}
if other.IsZero() {
return m
}
result := AttentionMask{
causal: m.causal || other.causal,
}
// Relax requires causal, so relaxations != nil implies causal.
switch {
case m.relaxations != nil && other.relaxations != nil:
// Both sides causal+Relax: pairwise rect intersection per sequence.
var list *relaxNode
for a := m.relaxations; a != nil; a = a.next {
for b := other.relaxations; b != nil; b = b.next {
if a.rect.seq != b.rect.seq {
continue
}
qLo := max(a.rect.qLo, b.rect.qLo)
qHi := min(a.rect.qHi, b.rect.qHi)
kLo := max(a.rect.kLo, b.rect.kLo)
kHi := min(a.rect.kHi, b.rect.kHi)
if qHi <= qLo || kHi <= kLo || kHi-1 <= qLo {
continue
}
list = &relaxNode{
rect: relaxRect{seq: a.rect.seq, qLo: qLo, qHi: qHi, kLo: kLo, kHi: kHi},
next: list,
}
}
}
result.relaxations = list
case m.relaxations != nil && !other.causal:
result.relaxations = m.relaxations
case other.relaxations != nil && !m.causal:
result.relaxations = other.relaxations
default:
// Implicit: one side causal+Relax, the other plain causal
// (no relaxations). Plain causal blocks every cell Relax
// tried to release, so intersection with its empty release
// set leaves nothing — result.relaxations stays nil and
// collapses to pure causal.
}
switch {
case m.array != nil && other.array != nil:
result.array = mlx.Add(m.array, other.array)
case m.array != nil:
result.array = m.array
case other.array != nil:
result.array = other.array
}
return result
}
// AsArray materializes the mask as a [B, 1, L, K] additive tensor
// (0 where valid, -inf where blocked). B and L come from b; K and
// dtype come from the caller.
//
// Composition order:
// 1. Start from zero.
// 2. If m.causal: -inf where oldestPos+k > SeqOffsets[b] + q per row.
// 3. Apply m.relaxations (qLo/qHi and kLo/kHi are absolute positions).
// 4. Add m.array if present.
func (m AttentionMask) AsArray(b *batch.Batch, K int, dtype mlx.DType) *mlx.Array {
// Pure ArrayMask: caller owns the tensor, nothing to compose.
if !m.causal && m.relaxations == nil && m.array != nil {
if m.array.DType() == dtype {
return m.array
}
return m.array.AsType(dtype)
}
B := len(b.SeqOffsets)
L := b.InputIDs.Dim(1)
negInf := float32(math.Inf(-1))
vals := make([]float32, B*L*K)
if m.causal {
for i := range B {
off := int(b.SeqOffsets[i])
oldestPos := max(0, off+L-K)
base := i * L * K
for q := range L {
absQ := off + q
row := base + q*K
for k := range K {
if oldestPos+k > absQ {
vals[row+k] = negInf
}
}
}
}
}
for n := m.relaxations; n != nil; n = n.next {
r := n.rect
if r.seq < 0 || r.seq >= B {
continue
}
off := int(b.SeqOffsets[r.seq])
oldestPos := max(0, off+L-K)
qLo := min(max(r.qLo-off, 0), L)
qHi := min(max(r.qHi-off, 0), L)
kLo := min(max(r.kLo-oldestPos, 0), K)
kHi := min(max(r.kHi-oldestPos, 0), K)
base := r.seq * L * K
for q := qLo; q < qHi; q++ {
row := base + q*K
for k := kLo; k < kHi; k++ {
vals[row+k] = 0
}
}
}
out := mlx.FromValues(vals, B, 1, L, K)
if m.array != nil {
out = mlx.Add(out, m.array)
}
if dtype != mlx.DTypeFloat32 {
out = out.AsType(dtype)
}
return out
}
// QPaddingMask returns an additive [B, 1, L, 1] mask that blocks
// padded query rows (q >= b.SeqQueryLens[i]) across all keys. It is
// logical — independent of whatever layout the cache uses for K.
// Returns the zero mask when every row is full.
func QPaddingMask(b *batch.Batch, dtype mlx.DType) AttentionMask {
return padTailMask(len(b.SeqOffsets), b.InputIDs.Dim(1), 2, b.SeqQueryLens, dtype)
}
// KPaddingMask returns an additive [B, 1, 1, K] mask that blocks
// padded key columns (k >= kLens[i]) across all queries. Storage-
// dependent: kLens describes where real content ends in physical K,
// so this is typically used without a cache where the caller knows
// the actual layout. Returns the zero mask when every row is full.
func KPaddingMask(b *batch.Batch, K int, kLens []int32, dtype mlx.DType) AttentionMask {
return padTailMask(len(b.SeqOffsets), K, 3, kLens, dtype)
}
// SlidingWindowMask returns an additive [B, 1, L, K] mask blocking
// keys outside a per-row window of size `window`: any key whose
// absolute position p < absQ - window + 1 is blocked. Returns the
// zero mask when window <= 0 or no row needs blocking.
//
// Defined in logical position space — the K axis is position-ordered
// with column 0 at oldestPos = max(0, b.SeqOffsets[i]+L-K).
func SlidingWindowMask(b *batch.Batch, K, window int, dtype mlx.DType) AttentionMask {
if window <= 0 {
return AttentionMask{}
}
B := len(b.SeqOffsets)
L := b.InputIDs.Dim(1)
negInf := float32(math.Inf(-1))
vals := make([]float32, B*L*K)
needed := false
for i := range B {
off := int(b.SeqOffsets[i])
oldestPos := max(0, off+L-K)
base := i * L * K
for q := range L {
absQ := off + q
lo := absQ - window + 1
maskCount := lo - oldestPos
if maskCount <= 0 {
continue
}
if maskCount > K {
maskCount = K
}
row := base + q*K
for k := range maskCount {
vals[row+k] = negInf
needed = true
}
}
}
if !needed {
return AttentionMask{}
}
out := mlx.FromValues(vals, B, 1, L, K)
if dtype != mlx.DTypeFloat32 {
out = out.AsType(dtype)
}
return ArrayMask(out)
}
func padTailMask(B, total, axis int, lens []int32, dtype mlx.DType) AttentionMask {
needed := false
for i := range B {
if int(lens[i]) < total {
needed = true
break
}
}
if !needed {
return AttentionMask{}
}
negInf := float32(math.Inf(-1))
vals := make([]float32, B*total)
for i := range B {
n := int(lens[i])
base := i * total
for j := n; j < total; j++ {
vals[base+j] = negInf
}
}
shape := [4]int{B, 1, 1, 1}
shape[axis] = total
out := mlx.FromValues(vals, shape[0], shape[1], shape[2], shape[3])
if dtype != mlx.DTypeFloat32 {
out = out.AsType(dtype)
}
return ArrayMask(out)
}

680
x/models/nn/sdpa_test.go Normal file
View File

@@ -0,0 +1,680 @@
package nn
import (
"math"
"testing"
"github.com/ollama/ollama/x/mlxrunner/batch"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
// newBatch constructs a synthetic batch for mask/SDPA tests.
// seqOffsets defines B (length of slice) and each row's absolute start;
// L is the padded query length along InputIDs's second axis;
// qLens is per-row real query length (defaults to all L if nil).
func newBatch(seqOffsets []int32, L int, qLens []int32) *batch.Batch {
B := len(seqOffsets)
if qLens == nil {
qLens = make([]int32, B)
for i := range qLens {
qLens[i] = int32(L)
}
}
// InputIDs values don't matter for masking, only the shape.
ids := mlx.FromValues(make([]int32, B*L), B, L)
return &batch.Batch{
InputIDs: ids,
SeqOffsets: seqOffsets,
SeqQueryLens: qLens,
}
}
func TestAttentionMaskZero(t *testing.T) {
skipIfNoMLX(t)
var m AttentionMask
if !m.IsZero() {
t.Fatal("zero value should report IsZero")
}
if m.IsCausal() {
t.Fatal("zero value should not report IsCausal")
}
b := newBatch([]int32{0}, 2, nil)
arr := m.AsArray(b, 3, mlx.DTypeFloat32)
if arr == nil {
t.Fatal("zero value AsArray should return a zeros tensor, not nil")
}
mlx.Eval(arr)
got := arr.Floats()
for i, v := range got {
if v != 0 {
t.Fatalf("zero mask should materialize all zeros; got[%d] = %v", i, v)
}
}
}
func TestAttentionMaskAsArrayCausal(t *testing.T) {
skipIfNoMLX(t)
L, K := 4, 6
b := newBatch([]int32{2}, L, nil)
arr := CausalMask().AsArray(b, K, mlx.DTypeFloat32)
if arr == nil {
t.Fatal("CausalMask AsArray should return a tensor")
}
dims := arr.Dims()
if len(dims) != 4 || dims[0] != 1 || dims[1] != 1 || dims[2] != L || dims[3] != K {
t.Fatalf("want shape [1,1,%d,%d], got %v", L, K, dims)
}
mlx.Eval(arr)
got := arr.Floats()
negInf := float32(math.Inf(-1))
want := make([]float32, L*K)
for q := range L {
absQ := int(b.SeqOffsets[0]) + q
for k := range K {
if k > absQ {
want[q*K+k] = negInf
}
}
}
for i := range want {
if !sameF(got[i], want[i]) {
t.Fatalf("index %d: want %v, got %v", i, want[i], got[i])
}
}
}
func TestAttentionMaskRelaxLazy(t *testing.T) {
skipIfNoMLX(t)
// Relax must not materialize a tensor — the perf invariant the
// causal-flag fast path relies on. Everything else (predicates,
// AsArray contents) is exercised by the materialization tests.
m := CausalMask().
Relax(0, 1, 3, 2, 5).
Relax(0, 0, 2, 1, 4)
if m.array != nil {
t.Fatal("Relax should not materialize a tensor")
}
}
// TestAttentionMaskRelaxNoopRectsMatchCausal pins the contract that
// rectangles which can't change any cell — empty in q or k, or fully
// inside the causal triangle — must produce the same materialized
// tensor as plain causal.
func TestAttentionMaskRelaxNoopRectsMatchCausal(t *testing.T) {
skipIfNoMLX(t)
L, K := 4, 6
b := newBatch([]int32{0}, L, nil)
want := CausalMask().AsArray(b, K, mlx.DTypeFloat32)
mlx.Eval(want)
wantF := want.Floats()
cases := []struct {
name string
qLo, qHi, kLo, kHi int
}{
{"empty Q rect", 2, 2, 0, 3},
{"empty K rect", 0, 3, 2, 2},
{"fully under causal", 5, 7, 0, 3},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
m := CausalMask().Relax(0, tc.qLo, tc.qHi, tc.kLo, tc.kHi)
arr := m.AsArray(b, K, mlx.DTypeFloat32)
mlx.Eval(arr)
got := arr.Floats()
for i := range wantF {
if !sameF(got[i], wantF[i]) {
t.Fatalf("index %d: want %v, got %v", i, wantF[i], got[i])
}
}
})
}
}
func TestAttentionMaskAsArrayWithRelax(t *testing.T) {
skipIfNoMLX(t)
L, K := 4, 6
b := newBatch([]int32{0}, L, nil)
arr := CausalMask().Relax(0, 1, 3, 2, 5).AsArray(b, K, mlx.DTypeFloat32)
if arr == nil {
t.Fatal("expected tensor")
}
mlx.Eval(arr)
got := arr.Floats()
negInf := float32(math.Inf(-1))
want := make([]float32, L*K)
for q := range L {
for k := range K {
if k > q {
want[q*K+k] = negInf
}
}
}
for q := 1; q < 3; q++ {
for k := 2; k < 5; k++ {
want[q*K+k] = 0
}
}
for i := range want {
if !sameF(got[i], want[i]) {
t.Fatalf("index %d: want %v, got %v", i, want[i], got[i])
}
}
}
func TestAttentionMaskAsArrayPerRow(t *testing.T) {
skipIfNoMLX(t)
L, K := 3, 5
b := newBatch([]int32{0, 2}, L, nil)
m := CausalMask().
Relax(0, 0, 2, 0, 3).
Relax(1, 3, 5, 2, 5)
arr := m.AsArray(b, K, mlx.DTypeFloat32)
if arr == nil {
t.Fatal("expected tensor")
}
dims := arr.Dims()
if dims[0] != 2 {
t.Fatalf("expected batch dim 2, got %v", dims)
}
mlx.Eval(arr)
got := arr.Floats()
negInf := float32(math.Inf(-1))
want := make([]float32, 2*L*K)
for bi, off := range b.SeqOffsets {
for q := range L {
absQ := int(off) + q
for k := range K {
if k > absQ {
want[bi*L*K+q*K+k] = negInf
}
}
}
}
for q := range 2 {
for k := range 3 {
want[0*L*K+q*K+k] = 0
}
}
for q := 1; q < 3; q++ {
for k := 2; k < 5; k++ {
want[1*L*K+q*K+k] = 0
}
}
for i := range want {
if !sameF(got[i], want[i]) {
t.Fatalf("index %d: want %v, got %v", i, want[i], got[i])
}
}
}
func TestQPaddingMask(t *testing.T) {
skipIfNoMLX(t)
L := 4
// Row 0 fully real; row 1 has 2 real queries.
b := newBatch([]int32{0, 0}, L, []int32{int32(L), 2})
m := QPaddingMask(b, mlx.DTypeFloat32)
if m.array == nil {
t.Fatal("expected q-padding tensor")
}
mlx.Eval(m.array)
got := m.array.Floats()
negInf := float32(math.Inf(-1))
want := make([]float32, 2*L)
// Row 0: no blocking; row 1: q >= 2 blocked.
for q := 2; q < L; q++ {
want[1*L+q] = negInf
}
for i := range want {
if !sameF(got[i], want[i]) {
t.Fatalf("index %d: want %v, got %v", i, want[i], got[i])
}
}
}
func TestKPaddingMask(t *testing.T) {
skipIfNoMLX(t)
K := 5
// Row 0 full keys; row 1 has 3 real keys.
b := newBatch([]int32{0, 0}, 4, nil)
kLens := []int32{int32(K), 3}
m := KPaddingMask(b, K, kLens, mlx.DTypeFloat32)
if m.array == nil {
t.Fatal("expected k-padding tensor")
}
mlx.Eval(m.array)
got := m.array.Floats()
negInf := float32(math.Inf(-1))
want := make([]float32, 2*K)
for k := 3; k < K; k++ {
want[1*K+k] = negInf
}
for i := range want {
if !sameF(got[i], want[i]) {
t.Fatalf("index %d: want %v, got %v", i, want[i], got[i])
}
}
}
func TestQPaddingMaskZeroWhenFull(t *testing.T) {
skipIfNoMLX(t)
b := newBatch([]int32{0}, 4, nil)
m := QPaddingMask(b, mlx.DTypeFloat32)
if !m.IsZero() {
t.Fatal("QPaddingMask at full queries should be zero")
}
}
func TestKPaddingMaskZeroWhenFull(t *testing.T) {
skipIfNoMLX(t)
K := 4
b := newBatch([]int32{0}, 4, nil)
kLens := []int32{int32(K)}
m := KPaddingMask(b, K, kLens, mlx.DTypeFloat32)
if !m.IsZero() {
t.Fatal("KPaddingMask at full keys should be zero")
}
}
func TestAttentionMaskCombineCausal(t *testing.T) {
skipIfNoMLX(t)
var z AttentionMask
got := z.Intersect(CausalMask())
if !got.IsCausal() {
t.Fatal("zero + CausalMask should be pure causal")
}
got = CausalMask().Intersect(z)
if !got.IsCausal() {
t.Fatal("CausalMask + zero should be pure causal")
}
got = CausalMask().Intersect(CausalMask())
if !got.IsCausal() {
t.Fatal("causal + causal should stay pure causal")
}
}
func TestAttentionMaskCombineRelaxDroppedAgainstCausal(t *testing.T) {
skipIfNoMLX(t)
relaxed := CausalMask().Relax(0, 1, 3, 2, 5)
got := relaxed.Intersect(CausalMask())
if !got.IsCausal() {
t.Fatal("causal-with-Relax + causal should drop relaxations and stay pure causal")
}
got = CausalMask().Intersect(relaxed)
if !got.IsCausal() {
t.Fatal("causal + causal-with-Relax should drop relaxations and stay pure causal")
}
// Disjoint relaxations on two causals also drop — neither side
// agrees to release the cells the other side relaxed.
got = CausalMask().Relax(0, 1, 3, 2, 5).Intersect(CausalMask().Relax(0, 5, 7, 6, 9))
if !got.IsCausal() {
t.Fatal("disjoint relaxations on two causals should drop and stay pure causal")
}
}
func TestAttentionMaskCombineRelaxIntersect(t *testing.T) {
skipIfNoMLX(t)
L, K := 6, 6
b := newBatch([]int32{0}, L, nil)
// Overlapping rects on two causals: the surviving relaxation is
// the geometric intersection — q in [1,3) ∩ [2,5) = [2,3),
// k in [2,5) ∩ [3,6) = [3,5).
m := CausalMask().Relax(0, 1, 3, 2, 5).Intersect(CausalMask().Relax(0, 2, 5, 3, 6))
if m.IsCausal() {
t.Fatal("overlapping relaxations should survive as their intersection, not collapse to pure causal")
}
arr := m.AsArray(b, K, mlx.DTypeFloat32)
if arr == nil {
t.Fatal("expected tensor")
}
mlx.Eval(arr)
vals := arr.Floats()
negInf := float32(math.Inf(-1))
want := make([]float32, L*K)
for q := range L {
for k := range K {
if k > q {
want[q*K+k] = negInf
}
}
}
// Intersection rect: q ∈ [2,3), k ∈ [3,5).
for q := 2; q < 3; q++ {
for k := 3; k < 5; k++ {
want[q*K+k] = 0
}
}
for i := range want {
if !sameF(vals[i], want[i]) {
t.Fatalf("index %d: want %v, got %v", i, want[i], vals[i])
}
}
}
func TestAttentionMaskCombineRelaxKeptAgainstNonCausal(t *testing.T) {
skipIfNoMLX(t)
L, K := 4, 6
b := newBatch([]int32{0}, L, nil)
// Pad q=3 — non-causal additive contribution that should leave
// the relaxation intact (the rect releases above-diagonal cells
// q in [1,3), k in [2,5) where k > q).
pad := QPaddingMask(newBatch([]int32{0}, L, []int32{3}), mlx.DTypeFloat32)
if pad.IsZero() {
t.Fatal("padding mask should be non-zero")
}
got := CausalMask().Relax(0, 1, 3, 2, 5).Intersect(pad)
arr := got.AsArray(b, K, mlx.DTypeFloat32)
if arr == nil {
t.Fatal("expected tensor")
}
mlx.Eval(arr)
vals := arr.Floats()
negInf := float32(math.Inf(-1))
want := make([]float32, L*K)
for q := range L {
for k := range K {
if k > q {
want[q*K+k] = negInf
}
}
}
for q := 1; q < 3; q++ {
for k := 2; k < 5; k++ {
want[q*K+k] = 0
}
}
for q := 3; q < L; q++ {
for k := range K {
want[q*K+k] = negInf
}
}
for i := range want {
if !sameF(vals[i], want[i]) {
t.Fatalf("index %d: want %v, got %v", i, want[i], vals[i])
}
}
}
func TestAttentionMaskCombineArrays(t *testing.T) {
skipIfNoMLX(t)
a := mlx.FromValues([]float32{0, 0, 0, 0}, 1, 1, 2, 2)
bb := mlx.FromValues([]float32{1, 2, 3, 4}, 1, 1, 2, 2)
sum := ArrayMask(a).Intersect(ArrayMask(bb))
if sum.array == nil {
t.Fatal("array + array should produce array")
}
mlx.Eval(sum.array)
got := sum.array.Floats()
want := []float32{1, 2, 3, 4}
for i := range want {
if got[i] != want[i] {
t.Fatalf("index %d: want %v, got %v", i, want[i], got[i])
}
}
}
func TestAttentionMaskRelaxPanicOnArray(t *testing.T) {
skipIfNoMLX(t)
a := mlx.FromValues([]float32{0}, 1, 1, 1, 1)
defer func() {
if r := recover(); r == nil {
t.Fatal("Relax on ArrayMask should panic")
}
}()
ArrayMask(a).Relax(0, 0, 1, 0, 1)
}
func TestAttentionMaskRelaxPanicOnZero(t *testing.T) {
skipIfNoMLX(t)
defer func() {
if r := recover(); r == nil {
t.Fatal("Relax on zero mask should panic")
}
}()
var z AttentionMask
z.Relax(0, 0, 1, 0, 1)
}
func sameF(a, b float32) bool {
if math.IsInf(float64(a), -1) && math.IsInf(float64(b), -1) {
return true
}
return a == b
}
// sdpaInputs builds non-trivial Q/K/V so masking actually changes the
// kernel output. With zero K/V, SDPA returns zero regardless of mask
// and "parity" tests pass even when the mask path is broken.
func sdpaInputs(L, K int) (q, k, v *mlx.Array) {
const D = 4
qVals := make([]float32, L*D)
for i := range qVals {
qVals[i] = 0.1 * float32(i+1)
}
kVals := make([]float32, K*D)
for i := range kVals {
kVals[i] = 0.07 * float32(i+1)
}
vVals := make([]float32, K*D)
for i := range vVals {
vVals[i] = float32(i+1) - 0.5*float32(K*D)
}
q = mlx.FromValues(qVals, 1, 1, L, D)
k = mlx.FromValues(kVals, 1, 1, K, D)
v = mlx.FromValues(vVals, 1, 1, K, D)
return
}
func TestSDPACausalParity(t *testing.T) {
skipIfNoMLX(t)
L, K := 4, 4
q, k, v := sdpaInputs(L, K)
b := newBatch([]int32{int32(K - L)}, L, nil)
got := ScaledDotProductAttention(b, q, 1.0,
WithKV(k, v, []int32{int32(K)}),
WithMask(CausalMask()),
)
want := mlx.FastScaledDotProductAttention(q, k, v, 1.0, "causal", nil)
mlx.Eval(got, want)
gs, ws := got.Floats(), want.Floats()
for i := range ws {
if gs[i] != ws[i] {
t.Fatalf("index %d: want %v, got %v", i, ws[i], gs[i])
}
}
}
func TestSDPAZeroMaskParity(t *testing.T) {
skipIfNoMLX(t)
L, K := 4, 4
q, k, v := sdpaInputs(L, K)
b := newBatch([]int32{0}, L, nil)
got := ScaledDotProductAttention(b, q, 1.0, WithKV(k, v, []int32{int32(K)}))
want := mlx.FastScaledDotProductAttention(q, k, v, 1.0, "", nil)
mlx.Eval(got, want)
gs, ws := got.Floats(), want.Floats()
for i := range ws {
if gs[i] != ws[i] {
t.Fatalf("index %d: want %v, got %v", i, ws[i], gs[i])
}
}
}
func TestSDPAArrayMaskParity(t *testing.T) {
skipIfNoMLX(t)
L, K := 3, 3
q, k, v := sdpaInputs(L, K)
b := newBatch([]int32{0}, L, nil)
mask := mlx.FromValues([]float32{
0, -1, -1,
0, 0, -1,
0, 0, 0,
}, 1, 1, 3, 3)
got := ScaledDotProductAttention(b, q, 1.0,
WithKV(k, v, []int32{int32(K)}),
WithMask(ArrayMask(mask)),
)
want := mlx.FastScaledDotProductAttention(q, k, v, 1.0, "array", mask)
mlx.Eval(got, want)
gs, ws := got.Floats(), want.Floats()
for i := range ws {
if gs[i] != ws[i] {
t.Fatalf("index %d: want %v, got %v", i, ws[i], gs[i])
}
}
}
func TestSDPARelaxMaskMaterializes(t *testing.T) {
skipIfNoMLX(t)
L, K := 3, 5
q, k, v := sdpaInputs(L, K)
b := newBatch([]int32{int32(K - L)}, L, nil)
got := ScaledDotProductAttention(b, q, 1.0,
WithKV(k, v, []int32{int32(K)}),
WithMask(CausalMask().Relax(0, 3, 5, 2, 5)),
)
ref := CausalMask().Relax(0, 3, 5, 2, 5).AsArray(b, K, k.DType())
want := mlx.FastScaledDotProductAttention(q, k, v, 1.0, "array", ref)
mlx.Eval(got, want)
gs, ws := got.Floats(), want.Floats()
for i := range ws {
if gs[i] != ws[i] {
t.Fatalf("index %d: want %v, got %v", i, ws[i], gs[i])
}
}
}
func TestSDPAPanicsWithBothKVAndHistory(t *testing.T) {
skipIfNoMLX(t)
L := 3
q, k, v := sdpaInputs(L, L)
b := newBatch([]int32{0}, L, nil)
history := NewKVHistory(k, v, nil)
defer func() {
if r := recover(); r == nil {
t.Fatal("expected panic when both WithKV and WithKVHistory are supplied")
}
}()
ScaledDotProductAttention(b, q, 1.0, WithKV(k, v, []int32{int32(L)}), WithKVHistory(history))
}
func TestSDPAMLAHistorySlicesVFromK(t *testing.T) {
skipIfNoMLX(t)
L, D, valueDim := 2, 5, 3
kBuf := make([]float32, 1*1*L*D)
for i := range kBuf {
kBuf[i] = float32(i) + 1
}
k := mlx.FromValues(kBuf, 1, 1, L, D)
v := mlx.Zeros(mlx.DTypeFloat32, 1, 1, L, valueDim)
history := NewKVHistory(k, v, nil)
q := mlx.Zeros(mlx.DTypeFloat32, 1, 1, L, D)
b := newBatch([]int32{0}, L, nil)
got := ScaledDotProductAttention(b, q, 1.0,
WithMLAHistory(history, valueDim),
)
vRef := k.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(), mlx.Slice(0, valueDim))
want := mlx.FastScaledDotProductAttention(q, k, vRef, 1.0, "", nil)
mlx.Eval(got, want)
gs, ws := got.Floats(), want.Floats()
for i := range ws {
if gs[i] != ws[i] {
t.Fatalf("index %d: want %v, got %v", i, ws[i], gs[i])
}
}
}
func TestSDPAPanicsWithoutKV(t *testing.T) {
skipIfNoMLX(t)
q := mlx.FromValues(make([]float32, 4), 1, 1, 1, 4)
b := newBatch([]int32{0}, 1, nil)
defer func() {
if r := recover(); r == nil {
t.Fatal("expected panic when no K/V supplied")
}
}()
ScaledDotProductAttention(b, q, 1.0)
}
// fillTensor builds a [B, H, T, D] float32 tensor whose entries are
// distinct, non-zero, and predictable so per-row slices stay distinct.
func fillTensor(seed float32, B, H, T, D int) *mlx.Array {
vals := make([]float32, B*H*T*D)
for i := range vals {
vals[i] = seed + 0.05*float32(i)
}
return mlx.FromValues(vals, B, H, T, D)
}
// TestSDPAMultiSequenceParity drives a B=2 batch with mixed
// SeqOffsets and SeqQueryLens through ScaledDotProductAttention via
// the no-cache (WithKV) path, then compares each row's real
// positions against a B=1 reference at that row's offset and length.
// Padded-tail outputs are unconstrained and not checked. Pins the
// central multi-sequence contract: right-padded rows must produce
// per-row outputs that don't depend on the padded tails.
func TestSDPAMultiSequenceParity(t *testing.T) {
skipIfNoMLX(t)
const H, D = 1, 4
const L, K = 4, 6
const qShort, kShort = 2, 2
const scale = 1.0
q := fillTensor(0.5, 2, H, L, D)
k := fillTensor(-0.3, 2, H, K, D)
v := fillTensor(0.7, 2, H, K, D)
b := newBatch([]int32{2, 0}, L, []int32{int32(L), int32(qShort)})
got := ScaledDotProductAttention(b, q, scale,
WithKV(k, v, []int32{int32(K), int32(kShort)}),
WithMask(CausalMask()))
mlx.Eval(got)
gotF := got.Floats()
// Row 0: full Q at offset 2, full K. B=1 reference.
q0 := mlx.SliceStartStop(q, []int32{0, 0, 0, 0}, []int32{1, H, L, D})
k0 := mlx.SliceStartStop(k, []int32{0, 0, 0, 0}, []int32{1, H, K, D})
v0 := mlx.SliceStartStop(v, []int32{0, 0, 0, 0}, []int32{1, H, K, D})
b0 := newBatch([]int32{2}, L, nil)
ref0 := ScaledDotProductAttention(b0, q0, scale,
WithKV(k0, v0, []int32{int32(K)}),
WithMask(CausalMask()))
mlx.Eval(ref0)
ref0F := ref0.Floats()
// Row 1: real Q at offset 0, length qShort, with kShort real keys.
q1 := mlx.SliceStartStop(q, []int32{1, 0, 0, 0}, []int32{2, H, int32(qShort), D})
k1 := mlx.SliceStartStop(k, []int32{1, 0, 0, 0}, []int32{2, H, int32(kShort), D})
v1 := mlx.SliceStartStop(v, []int32{1, 0, 0, 0}, []int32{2, H, int32(kShort), D})
b1 := newBatch([]int32{0}, qShort, nil)
ref1 := ScaledDotProductAttention(b1, q1, scale,
WithKV(k1, v1, []int32{int32(kShort)}),
WithMask(CausalMask()))
mlx.Eval(ref1)
ref1F := ref1.Floats()
// got is [2, H, L, D] = [B=2, 1, 4, 4]. Row 0 is got[0,...] and
// must match ref0 over the full [L, D]. Row 1 is got[1,...] and
// must match ref1 over [qShort, D] only — padded positions are
// unconstrained.
rowStride := H * L * D
for i := range rowStride {
if !approxEqual(gotF[i], ref0F[i], 1e-5) {
t.Fatalf("row 0 [%d]: got %v, want %v", i, gotF[i], ref0F[i])
}
}
for q := range qShort {
for d := range D {
gotI := rowStride + q*D + d
refI := q*D + d
if !approxEqual(gotF[gotI], ref1F[refI], 1e-5) {
t.Fatalf("row 1 [q=%d,d=%d]: got %v, want %v", q, d, gotF[gotI], ref1F[refI])
}
}
}
}