package cache import ( "math" "testing" "github.com/ollama/ollama/x/mlxrunner/batch" "github.com/ollama/ollama/x/mlxrunner/mlx" "github.com/ollama/ollama/x/models/nn" ) // TestRecurrentCacheRestoreExactOffset verifies that RecurrentCache restore // only succeeds when target exactly matches the snapshot's offset. Recurrent // state is cumulative, so it can't be rewound or fast-forwarded. func TestRecurrentCacheRestoreExactOffset(t *testing.T) { skipIfNoMLX(t) c := NewRecurrentCache(3, 12, 4, 8, 8) b1 := &batch.Batch{InputIDs: mlx.Zeros(mlx.DTypeInt32, 1, 1)} c.Get(b1, mlx.DTypeFloat16) // lazy-init b10 := &batch.Batch{InputIDs: mlx.Zeros(mlx.DTypeInt32, 1, 10), SeqQueryLens: []int32{10}} c.Put(b10, nil, nil) // advance to 10 snap := c.Snapshot(0) // snap.offset == 10 b5 := &batch.Batch{InputIDs: mlx.Zeros(mlx.DTypeInt32, 1, 5), SeqQueryLens: []int32{5}} c.Put(b5, nil, nil) // cache now at 15 // target < snap.offset: fails (can't rewind past snapshot) if c.Restore(snap, 5) { t.Fatal("Restore(snap, 5) should fail — target != snap.offset") } // target > snap.offset: fails (can't advance without feeding tokens) if c.Restore(snap, 15) { t.Fatal("Restore(snap, 15) should fail — target != snap.offset") } // target == snap.offset: succeeds if !c.Restore(snap, 10) { t.Fatal("Restore(snap, 10) should succeed — target == snap.offset") } if c.Offset() != 10 { t.Fatalf("offset = %d, want 10", c.Offset()) } } func TestRecurrentCacheGetLazyInit(t *testing.T) { skipIfNoMLX(t) c := NewRecurrentCache(3, 4, 2, 4, 4) b := &batch.Batch{ InputIDs: mlx.Zeros(mlx.DTypeInt32, 1, 1), SeqOffsets: []int32{0}, SeqQueryLens: []int32{1}, } h := c.Get(b, mlx.DTypeBFloat16) if c.Offset() != 0 { t.Fatalf("Get should not advance; got offset %d", c.Offset()) } if h.ConvState() == nil || h.DeltaState() == nil { t.Fatal("history should expose conv/delta tensors") } if got := h.ConvState().DType(); got != mlx.DTypeBFloat16 { t.Fatalf("conv state dtype = %v, want %v", got, mlx.DTypeBFloat16) } if got := h.DeltaState().DType(); got != mlx.DTypeFloat32 { t.Fatalf("delta state dtype = %v, want %v", got, mlx.DTypeFloat32) } } func TestSpeculativeRecurrentCacheUsesStagedState(t *testing.T) { skipIfNoMLX(t) target := NewRecurrentCache(2, 3, 1, 2, 3) caches, ok := BeginIsolatedSpeculation([]Cache{target}) if !ok { t.Fatal("BeginIsolatedSpeculation failed") } c := caches[0].(*speculativeRecurrentCache) b := &batch.Batch{ InputIDs: mlx.Zeros(mlx.DTypeInt32, 1, 1), SeqOffsets: []int32{0}, SeqQueryLens: []int32{1}, } c.Get(b, mlx.DTypeFloat32) convVals := []float32{1, 2, 3, 4, 5, 6} deltaVals := []float32{7, 8, 9, 10, 11, 12} nextConv := mlx.FromValues(convVals, 1, 2, 3) nextDelta := mlx.FromValues(deltaVals, 1, 1, 2, 3) c.Put(b, nextConv, nextDelta) h := c.Get(b, mlx.DTypeFloat32) state := c.State() if len(state) != 2 { t.Fatalf("State() returned %d arrays, want 2", len(state)) } assertArray := func(name string, got, want *mlx.Array) { t.Helper() if got != want { t.Fatalf("%s = %p, want %p", name, got, want) } } assertArray("history conv", h.ConvState(), nextConv) assertArray("history delta", h.DeltaState(), nextDelta) assertArray("state conv", state[0], nextConv) assertArray("state delta", state[1], nextDelta) if got := c.Offset(); got != 1 { t.Fatalf("speculative offset = %d, want 1", got) } if got := target.Offset(); got != 0 { t.Fatalf("target offset = %d, want 0", got) } } // TestRecurrentCachePaddedRoundTrip runs Get → CausalConv1D → // GatedDelta → Put on a B=1 batch with qLen 1e-4 { t.Fatalf("nextConv[%d]: padded=%v unpadded=%v (padding leaked into conv state)", i, gp[i], gr[i]) } } dp := deltaPad.Floats() dr := deltaRef.Floats() if len(dp) != len(dr) { t.Fatalf("delta state shape mismatch: padded %d vs unpadded %d", len(dp), len(dr)) } for i := range dp { if math.Abs(float64(dp[i]-dr[i])) > 1e-3 { t.Fatalf("delta state[%d]: padded=%v unpadded=%v (padding leaked into recurrent state)", i, dp[i], dr[i]) } } } func TestRecurrentCachePutAdvances(t *testing.T) { skipIfNoMLX(t) c := NewRecurrentCache(3, 4, 2, 4, 4) b := &batch.Batch{InputIDs: mlx.Zeros(mlx.DTypeInt32, 1, 2), SeqQueryLens: []int32{2}} newConv := mlx.Zeros(mlx.DTypeFloat16, 1, 3, 4) newDelta := mlx.Zeros(mlx.DTypeFloat16, 1, 2, 4, 4) c.Put(b, newConv, newDelta) if c.Offset() != 2 { t.Fatalf("cache offset not advanced: %d", c.Offset()) } }