package nn import ( "math" "testing" "github.com/ollama/ollama/x/mlxrunner/batch" "github.com/ollama/ollama/x/mlxrunner/mlx" ) func ones(dtype mlx.DType, shape ...int) *mlx.Array { return mlx.AddScalar(mlx.Zeros(dtype, shape...), 1) } // fromValues builds a tensor with sequentially-numbered float32 // values so element-by-element parity actually exercises the kernel. func fromValues(seed float32, shape ...int) *mlx.Array { n := 1 for _, d := range shape { n *= d } vals := make([]float32, n) for i := range vals { vals[i] = seed + 0.1*float32(i) } return mlx.FromValues(vals, shape...) } // depthwiseCausalRef is a Go-side reference for the depthwise causal // 1D conv fallback. concat is [B, total, C], weight is [C, K], output // is [B, total-K+1, C]. Used to anchor the wrapper's parity tests. func depthwiseCausalRef(concat, weight *mlx.Array) []float32 { mlx.Eval(concat, weight) cVals := concat.Floats() wVals := weight.Floats() B := concat.Dim(0) total := concat.Dim(1) C := concat.Dim(2) K := weight.Dim(1) outLen := total - K + 1 out := make([]float32, B*outLen*C) for bi := range B { for q := range outLen { for c := range C { var sum float32 for k := range K { x := cVals[bi*total*C+(q+k)*C+c] w := wVals[c*K+k] sum += x * w } out[bi*outLen*C+q*C+c] = sum } } } return out } // TestCausalConv1DParity drives the wrapper with non-trivial prior, // input, and weight values, then compares against a direct depthwise- // causal-conv reference. func TestCausalConv1DParity(t *testing.T) { skipIfNoMLX(t) B, L, D, convTail := 1, 4, 3, 2 K := convTail + 1 input := fromValues(0.5, B, L, D) prior := fromValues(-0.3, B, convTail, D) weight := fromValues(0.2, D, K) out, nextConv := CausalConv1D(&batch.Batch{}, input, nil, weight, convTail, WithRecurrentState(prior, nil)) mlx.Eval(out, nextConv) concat := mlx.Concatenate([]*mlx.Array{prior, input}, 1) want := depthwiseCausalRef(concat, weight) got := out.Floats() if len(got) != len(want) { t.Fatalf("out len = %d, want %d", len(got), len(want)) } for i := range want { if math.Abs(float64(got[i]-want[i])) > 1e-5 { t.Fatalf("out[%d]: got %v, want %v", i, got[i], want[i]) } } // nextConv (no padding) is the trailing convTail rows of concat. mlx.Eval(concat) cVals := concat.Floats() total := concat.Dim(1) wantTail := make([]float32, B*convTail*D) for bi := range B { for k := range convTail { for d := range D { wantTail[bi*convTail*D+k*D+d] = cVals[bi*total*D+(total-convTail+k)*D+d] } } } tail := nextConv.Floats() if len(tail) != len(wantTail) { t.Fatalf("nextConv len = %d, want %d", len(tail), len(wantTail)) } for i := range wantTail { if tail[i] != wantTail[i] { t.Fatalf("nextConv[%d]: got %v, want %v", i, tail[i], wantTail[i]) } } } // TestCausalConv1DPaddedRowParity drives a B=2 batch with one short // row (qLen 1e-5 { t.Fatalf("row 1 real out[q=%d,d=%d]: got %v, want %v", q, d, gotOut[gotI], wantOut1[refI]) } } } // nextConv: row 0 unaffected, row 1 must be the row's real tail // (positions [qLenShort - convTail, qLenShort) of the per-row // concat, i.e. the last two real input rows in this setup). gotTail := nextConv.Floats() wantTail0 := refNextConv0.Floats() wantTail1 := refNextConv1.Floats() for k := range convTail { for d := range D { i := k*D + d if gotTail[i] != wantTail0[i] { t.Fatalf("row 0 nextConv[k=%d,d=%d]: got %v, want %v", k, d, gotTail[i], wantTail0[i]) } } } for k := range convTail { for d := range D { gotI := convTail*D + k*D + d refI := k*D + d if gotTail[gotI] != wantTail1[refI] { t.Fatalf("row 1 nextConv[k=%d,d=%d]: got %v, want %v (must come from real positions, not the padded tail)", k, d, gotTail[gotI], wantTail1[refI]) } } } } func TestGatedDeltaZeroFallback(t *testing.T) { skipIfNoMLX(t) B, L, nK, nV, dK, dV := 1, 2, 1, 1, 4, 4 q := ones(mlx.DTypeFloat32, B, L, nK, dK) k := ones(mlx.DTypeFloat32, B, L, nK, dK) v := ones(mlx.DTypeFloat32, B, L, nV, dV) gDecay := ones(mlx.DTypeFloat32, B, L, nV) beta := ones(mlx.DTypeFloat32, B, L, nV) zero := mlx.Zeros(mlx.DTypeFloat32, B, nV, dV, dK) outA, stateA := GatedDelta(&batch.Batch{}, q, k, v, gDecay, beta, WithRecurrentState(nil, zero)) outB, stateB := mlx.FastGatedDelta(q, k, v, gDecay, beta, zero, nil) mlx.Eval(outA, stateA, outB, stateB) gotOut, wantOut := outA.Floats(), outB.Floats() for i := range wantOut { if gotOut[i] != wantOut[i] { t.Fatalf("output[%d]: wrapper=%v direct=%v", i, gotOut[i], wantOut[i]) } } gotState, wantState := stateA.Floats(), stateB.Floats() for i := range wantState { if gotState[i] != wantState[i] { t.Fatalf("state[%d]: wrapper=%v direct=%v", i, gotState[i], wantState[i]) } } } func TestGatedDeltaUsesPriorState(t *testing.T) { skipIfNoMLX(t) B, L, nK, nV, dK, dV := 1, 2, 1, 1, 4, 4 q := ones(mlx.DTypeFloat32, B, L, nK, dK) k := ones(mlx.DTypeFloat32, B, L, nK, dK) v := ones(mlx.DTypeFloat32, B, L, nV, dV) gDecay := ones(mlx.DTypeFloat32, B, L, nV) beta := ones(mlx.DTypeFloat32, B, L, nV) priorState := mlx.MulScalar(ones(mlx.DTypeFloat32, B, nV, dV, dK), 3) outA, _ := GatedDelta(&batch.Batch{}, q, k, v, gDecay, beta, WithRecurrentState(nil, priorState)) outB, _ := mlx.FastGatedDelta(q, k, v, gDecay, beta, priorState, nil) mlx.Eval(outA, outB) gotOut, wantOut := outA.Floats(), outB.Floats() for i := range wantOut { if gotOut[i] != wantOut[i] { t.Fatalf("output[%d]: wrapper=%v direct=%v", i, gotOut[i], wantOut[i]) } } } // TestGatedDeltaPaddedRowParity drives a B=2 batch where row 1 is // short (qLen < L). The wrapper must substitute neutral values // (q=k=v=beta=0, g=1) at row 1's padded positions so the recurrence // is a no-op there — and row 1's final state must equal the state // after its last real token. Pinned via parity against a B=1 length- // qLen call on the same row. func TestGatedDeltaPaddedRowParity(t *testing.T) { skipIfNoMLX(t) L, nK, nV, dK, dV := 4, 1, 1, 4, 4 qLenShort := 2 makeRows := func(seedA, seedB float32, shape ...int) *mlx.Array { // Build a rank-(len(shape)+1) tensor with B=2 rows from two // distinct seeds so the rows are not accidentally identical. n := 1 for _, d := range shape { n *= d } vals := make([]float32, 2*n) for i := range n { vals[i] = seedA + 0.1*float32(i) } for i := range n { vals[n+i] = seedB + 0.1*float32(i) } full := append([]int{2}, shape...) return mlx.FromValues(vals, full...) } q := makeRows(0.5, -0.5, L, nK, dK) k := makeRows(0.7, -0.7, L, nK, dK) v := makeRows(0.3, -0.3, L, nV, dV) gDecay := makeRows(0.1, -0.1, L, nV) beta := makeRows(0.4, -0.4, L, nV) priorState := makeRows(0.2, -0.2, nV, dV, dK) b := &batch.Batch{ InputIDs: mlx.Zeros(mlx.DTypeInt32, 2, L), SeqOffsets: []int32{0, 0}, SeqQueryLens: []int32{int32(L), int32(qLenShort)}, } _, state := GatedDelta(b, q, k, v, gDecay, beta, WithRecurrentState(nil, priorState)) mlx.Eval(state) // Reference for row 1: B=1 length-qLenShort call against the // row's real prefix and its prior state slice. row1Slice := func(a *mlx.Array, axisLens ...int32) *mlx.Array { dims := a.Dims() start := make([]int32, len(dims)) stop := make([]int32, len(dims)) start[0], stop[0] = 1, 2 for i := 1; i < len(dims); i++ { stop[i] = int32(dims[i]) } // Optionally truncate axis 1 (sequence axis) to qLenShort. if len(axisLens) >= 1 && len(dims) >= 2 { stop[1] = axisLens[0] } return mlx.SliceStartStop(a, start, stop) } q1 := row1Slice(q, int32(qLenShort)) k1 := row1Slice(k, int32(qLenShort)) v1 := row1Slice(v, int32(qLenShort)) gDecay1 := row1Slice(gDecay, int32(qLenShort)) beta1 := row1Slice(beta, int32(qLenShort)) priorRow1 := row1Slice(priorState) _, refState := mlx.FastGatedDelta(q1, k1, v1, gDecay1, beta1, priorRow1, nil) mlx.Eval(refState) gotState := state.Floats() wantState := refState.Floats() row1Stride := nV * dV * dK for i := range row1Stride { gotV := gotState[row1Stride+i] wantV := wantState[i] if math.Abs(float64(gotV-wantV)) > 1e-4 { t.Fatalf("row 1 final state[%d]: got %v, want %v", i, gotV, wantV) } } }