ollama source for Momentry Core verification
This commit is contained in:
147
x/mlxrunner/mlx/compile_test.go
Normal file
147
x/mlxrunner/mlx/compile_test.go
Normal file
@@ -0,0 +1,147 @@
|
||||
package mlx
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCompileFusion(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
// Compile fuses the ops inside a function body into a single kernel,
|
||||
// eliminating intermediate buffers. Use a diamond-shaped graph where
|
||||
// two branches must be materialized simultaneously without fusion,
|
||||
// then compare peak memory against the compiled version which fuses
|
||||
// everything into one kernel with no intermediates.
|
||||
const n = 1024 * 1024 // 4MB per float32 array
|
||||
data := make([]float32, n)
|
||||
for i := range data {
|
||||
data[i] = float32(i + 1)
|
||||
}
|
||||
|
||||
// Diamond: both a*b and a+b must be live for the final multiply.
|
||||
// Without fusion: peak includes both intermediates (~8MB extra).
|
||||
// With fusion: single kernel, no intermediates.
|
||||
body := func(a, b *Array) *Array {
|
||||
return a.Multiply(b).Multiply(a.Add(b))
|
||||
}
|
||||
|
||||
a := FromValues(data, n)
|
||||
b := FromValues(data, n)
|
||||
Pin(a, b)
|
||||
defer Unpin(a, b)
|
||||
|
||||
// Compiled: ops fused into a single kernel.
|
||||
EnableCompile()
|
||||
fn := Compile2("diamond", body, Shapeless())
|
||||
warm := fn(a, b)
|
||||
Eval(warm)
|
||||
Sweep()
|
||||
ClearCache()
|
||||
ResetPeakMemory()
|
||||
y := fn(a, b)
|
||||
Eval(y)
|
||||
compiledPeak := PeakMemory()
|
||||
Sweep()
|
||||
|
||||
// Uncompiled: ops evaluated individually, intermediates materialized.
|
||||
ClearCache()
|
||||
ResetPeakMemory()
|
||||
z := body(a, b)
|
||||
Eval(z)
|
||||
uncompiledPeak := PeakMemory()
|
||||
Sweep()
|
||||
|
||||
if compiledPeak == 0 && uncompiledPeak == 0 {
|
||||
t.Skip("peak memory tracking not available")
|
||||
}
|
||||
|
||||
t.Logf("peak memory: compiled=%d uncompiled=%d", compiledPeak, uncompiledPeak)
|
||||
|
||||
if compiledPeak >= uncompiledPeak {
|
||||
t.Fatalf("compilation did not reduce peak memory: compiled=%d uncompiled=%d", compiledPeak, uncompiledPeak)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompileNested(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
// A compiled function that calls another compiled function should
|
||||
// produce correct results. The inner function inlines via isTracing()
|
||||
// during the outer's trace.
|
||||
inner := Compile1("silu", func(a *Array) *Array {
|
||||
return a.Multiply(a.Sigmoid())
|
||||
}, Shapeless())
|
||||
|
||||
outer := Compile2("swiglu", func(gate, up *Array) *Array {
|
||||
return inner(gate).Multiply(up)
|
||||
}, Shapeless())
|
||||
|
||||
gate := FromValues([]float32{0, 1, 2}, 3)
|
||||
up := FromValues([]float32{1, 1, 1}, 3)
|
||||
Pin(gate, up)
|
||||
defer Unpin(gate, up)
|
||||
|
||||
y := outer(gate, up)
|
||||
Eval(y)
|
||||
|
||||
// silu(x) = x * sigmoid(x); for x=0 → 0, x=1 → ~0.7311, x=2 → ~1.7616
|
||||
got := y.Floats()
|
||||
want := []float32{0, 0.7310586, 1.7615942}
|
||||
for i, v := range got {
|
||||
if v-want[i] > 1e-4 || want[i]-v > 1e-4 {
|
||||
t.Fatalf("got[%d]=%v want %v", i, v, want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompileCallbackPanicRecovers(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
boom := Compile1("boom", func(a *Array) *Array {
|
||||
panic("intentional test panic")
|
||||
})
|
||||
|
||||
x := FromValues([]float32{1}, 1)
|
||||
Pin(x)
|
||||
defer Unpin(x)
|
||||
|
||||
defer func() {
|
||||
r := recover()
|
||||
if r == nil {
|
||||
t.Fatal("expected panic from Call, got none")
|
||||
}
|
||||
if _, ok := r.(string); !ok {
|
||||
t.Fatalf("expected string panic, got %T: %v", r, r)
|
||||
}
|
||||
}()
|
||||
boom(x)
|
||||
}
|
||||
|
||||
func TestCompileNoTrackingGrowth(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
// Repeated invocations of a compiled kernel should not grow the
|
||||
// tracked-arrays list — the callback's traceScratch collects
|
||||
// intermediates during tracing and frees them when the callback returns.
|
||||
fn := Compile2("mul_add", func(a, b *Array) *Array {
|
||||
return a.Multiply(b).Add(b)
|
||||
})
|
||||
|
||||
a := FromValues([]float32{1, 2}, 2)
|
||||
b := FromValues([]float32{3, 4}, 2)
|
||||
Pin(a, b)
|
||||
defer Unpin(a, b)
|
||||
|
||||
Sweep()
|
||||
before := len(arrays)
|
||||
|
||||
for range 100 {
|
||||
_ = fn(a, b)
|
||||
Sweep()
|
||||
}
|
||||
|
||||
after := len(arrays)
|
||||
if after > before+2 {
|
||||
t.Fatalf("tracked arrays grew from %d to %d across 100 calls (includes initial trace)", before, after)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user