ollama source for Momentry Core verification

This commit is contained in:
Accusys
2026-05-22 17:19:10 +08:00
commit 0b31ff9135
2020 changed files with 1413145 additions and 0 deletions

3
x/mlxrunner/mlx/.gitignore vendored Normal file
View File

@@ -0,0 +1,3 @@
_deps
build
dist

View File

@@ -0,0 +1,32 @@
cmake_minimum_required(VERSION 3.5)
project(mlx)
if(CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT)
set(CMAKE_INSTALL_PREFIX "${CMAKE_CURRENT_SOURCE_DIR}/dist" CACHE PATH "" FORCE)
endif()
set(MLX_BUILD_GGUF OFF CACHE BOOL "" FORCE)
set(MLX_BUILD_SAFETENSORS ON CACHE BOOL "" FORCE)
set(MLX_C_BUILD_EXAMPLES OFF CACHE BOOL "" FORCE)
set(BUILD_SHARED_LIBS ON CACHE BOOL "" FORCE)
set(CMAKE_INSTALL_RPATH "@loader_path")
include(FetchContent)
# Read MLX-C version from top-level file (shared with imagegen CMakeLists)
file(READ "${CMAKE_SOURCE_DIR}/MLX_C_VERSION" MLX_C_GIT_TAG)
string(STRIP "${MLX_C_GIT_TAG}" MLX_C_GIT_TAG)
FetchContent_Declare(
mlx-c
GIT_REPOSITORY "https://github.com/ml-explore/mlx-c.git"
GIT_TAG ${MLX_C_GIT_TAG}
)
FetchContent_MakeAvailable(mlx-c)
# Sync vendored headers with fetched version
file(GLOB _mlx_c_hdrs "${mlx-c_SOURCE_DIR}/mlx/c/*.h")
file(COPY ${_mlx_c_hdrs} DESTINATION "${CMAKE_CURRENT_SOURCE_DIR}/include/mlx/c/")

99
x/mlxrunner/mlx/act.go Normal file
View File

@@ -0,0 +1,99 @@
package mlx
import "math"
var geluCoeff = float32(math.Sqrt(2 / math.Pi))
// GELUApprox returns 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
// as a fused kernel.
var GELUApprox = Compile1(
"GELUApprox",
func(x *Array) *Array {
// Dtype-matched scalars avoid implicit upcasts on bf16 inputs.
dt := x.DType()
half := FromValue[float32](0.5).AsType(dt)
coeff := FromValue(geluCoeff).AsType(dt)
c := FromValue[float32](0.044715).AsType(dt)
one := FromValue[float32](1.0).AsType(dt)
// x^3 via x*x*x (avoids general Power which is slower).
x3 := x.Multiply(x).Multiply(x)
inner := x.Add(c.Multiply(x3))
tanh := coeff.Multiply(inner).Tanh()
return half.Multiply(x).Multiply(one.Add(tanh))
},
Shapeless(),
)
// SiLU returns a * sigmoid(a) as a fused kernel.
var SiLU = Compile1(
"SiLU",
func(a *Array) *Array {
return a.Multiply(a.Sigmoid())
},
Shapeless(),
)
// SoftplusF32 returns softplus(x) computed in float32 precision and cast back
// to x's original dtype, as a fused kernel. Matches the laguna attention
// output-gate formula: softplus(cast_f32(x)).cast(orig_dtype).
var SoftplusF32 = Compile1(
"SoftplusF32",
func(x *Array) *Array {
dt := x.DType()
zero := FromValue[float32](0)
return Logaddexp(x.AsType(DTypeFloat32), zero).AsType(dt)
},
Shapeless(),
)
// SwiGLU returns silu(gate) * up as a fused kernel.
var SwiGLU = Compile2(
"SwiGLU",
func(gate, up *Array) *Array {
return SiLU(gate).Multiply(up)
},
Shapeless(),
)
// GeGLU returns gelu_approx(gate) * up as a fused kernel. Matches mlx_lm's
// geglu, used by Gemma-family MLP and MoE paths.
var GeGLU = Compile2(
"GeGLU",
func(gate, up *Array) *Array {
return GELUApprox(gate).Multiply(up)
},
Shapeless(),
)
// LogitSoftcap returns tanh(x / cap) * cap as a fused kernel. Matches
// mlx_lm's logit_softcap. cap must have the same dtype as x.
var LogitSoftcap = Compile2(
"LogitSoftcap",
func(x, cap *Array) *Array {
return x.Divide(cap).Tanh().Multiply(cap)
},
Shapeless(),
)
// sigmoidRouterFused traces the DeepSeek-V2 / GLM-MoE aux-loss-free router
// head. Two outputs are returned so the pre-bias sigmoid (used to gather
// per-expert scores after top-k) and the post-bias negation (used as the
// argpartition key for top-k) share a single kernel.
var sigmoidRouterFused = Compile(
"SigmoidRouter",
func(in ...*Array) []*Array {
gates, bias := in[0], in[1]
orig := gates.Sigmoid()
neg := orig.Add(bias).Negative()
return []*Array{orig, neg}
},
Shapeless(),
)
// SigmoidRouter returns (sigmoid(gates), -(sigmoid(gates)+bias)) as a fused
// kernel — the DeepSeek-V2 / GLM-MoE aux-loss-free router head.
func SigmoidRouter(gates, bias *Array) (origScores, negScores *Array) {
out := sigmoidRouterFused(gates, bias)
return out[0], out[1]
}

295
x/mlxrunner/mlx/array.go Normal file
View File

@@ -0,0 +1,295 @@
package mlx
// #include "generated.h"
import "C"
import (
"encoding/binary"
"fmt"
"log/slog"
"reflect"
"sort"
"strings"
"sync"
"sync/atomic"
"unsafe"
"github.com/ollama/ollama/logutil"
)
type Array struct {
ctx C.mlx_array
name string
pinned atomic.Int32
}
var (
arrays []*Array
arraysMu sync.Mutex
)
// constructor utilities
func New(name string) *Array {
t := &Array{name: name}
if tracing {
traceScratch = append(traceScratch, t)
} else {
arraysMu.Lock()
defer arraysMu.Unlock()
arrays = append(arrays, t)
}
return t
}
type scalarTypes interface {
~bool | ~int | ~float32 | ~float64 | ~complex64
}
func FromValue[T scalarTypes](t T) *Array {
tt := New("")
switch v := any(t).(type) {
case bool:
tt.ctx = C.mlx_array_new_bool(C.bool(v))
case int:
tt.ctx = C.mlx_array_new_int(C.int(v))
case float32:
tt.ctx = C.mlx_array_new_float32(C.float(v))
case float64:
tt.ctx = C.mlx_array_new_float64(C.double(v))
case complex64:
tt.ctx = C.mlx_array_new_complex(C.float(real(v)), C.float(imag(v)))
default:
panic("unsupported type")
}
return tt
}
type arrayTypes interface {
~bool | ~uint8 | ~uint16 | ~uint32 | ~uint64 |
~int8 | ~int16 | ~int32 | ~int64 |
~float32 | ~float64 |
~complex64
}
func FromValues[S ~[]E, E arrayTypes](s S, shape ...int) *Array {
if len(shape) == 0 {
panic("shape must be provided for non-scalar tensors")
}
cShape := make([]C.int, len(shape))
for i := range shape {
cShape[i] = C.int(shape[i])
}
var dtype DType
switch reflect.TypeOf(s).Elem().Kind() {
case reflect.Bool:
dtype = DTypeBool
case reflect.Uint8:
dtype = DTypeUint8
case reflect.Uint16:
dtype = DTypeUint16
case reflect.Uint32:
dtype = DTypeUint32
case reflect.Uint64:
dtype = DTypeUint64
case reflect.Int8:
dtype = DTypeInt8
case reflect.Int16:
dtype = DTypeInt16
case reflect.Int32:
dtype = DTypeInt32
case reflect.Int64:
dtype = DTypeInt64
case reflect.Float32:
dtype = DTypeFloat32
case reflect.Float64:
dtype = DTypeFloat64
case reflect.Complex64:
dtype = DTypeComplex64
default:
panic("unsupported type")
}
bts := make([]byte, binary.Size(s))
if _, err := binary.Encode(bts, binary.LittleEndian, s); err != nil {
panic(err)
}
tt := New("")
tt.ctx = C.mlx_array_new_data(unsafe.Pointer(&bts[0]), unsafe.SliceData(cShape), C.int(len(cShape)), C.mlx_dtype(dtype))
return tt
}
func (t *Array) Set(other *Array) {
C.mlx_array_set(&t.ctx, other.ctx)
}
func (t *Array) Clone() *Array {
tt := New(t.name)
C.mlx_array_set(&tt.ctx, t.ctx)
return tt
}
// lifecycle utilities
// Pin marks arrays as in-use so they are retained during Sweep.
func Pin(s ...*Array) {
for _, t := range s {
if t != nil {
t.pinned.Add(1)
}
}
}
// Unpin marks arrays as no longer in-use, allowing Sweep to free them.
func Unpin(s ...*Array) {
for _, t := range s {
if t != nil {
if t.pinned.Add(-1) < 0 {
panic(fmt.Sprintf("mlx.Unpin: negative pin count on array %q", t.name))
}
}
}
}
// Sweep releases all unpinned arrays, primarily intermediate tensors. MLX will truly
// free them when there are no other references, including dependencies in the graph.
func Sweep() {
arraysMu.Lock()
defer arraysMu.Unlock()
n := 0
for _, t := range arrays {
if t.pinned.Load() > 0 && t.Valid() {
arrays[n] = t
n++
} else if t.Valid() {
C.mlx_array_free(t.ctx)
t.ctx.ctx = nil
}
}
arrays = arrays[:n]
}
// misc. utilities
func (t *Array) Valid() bool {
return t.ctx.ctx != nil
}
func (t *Array) String() string {
str := C.mlx_string_new()
defer C.mlx_string_free(str)
C.mlx_array_tostring(&str, t.ctx)
return strings.TrimSpace(C.GoString(C.mlx_string_data(str)))
}
func (t *Array) LogValue() slog.Value {
attrs := []slog.Attr{
slog.String("name", t.name),
slog.Int("pinned", int(t.pinned.Load())),
}
if t.Valid() {
attrs = append(attrs,
slog.Any("dtype", t.DType()),
slog.Any("shape", t.Dims()),
slog.Int("num_bytes", t.NumBytes()),
)
}
return slog.GroupValue(attrs...)
}
// shape utilities
func (t *Array) Size() int {
return int(C.mlx_array_size(t.ctx))
}
func (t *Array) NumBytes() int {
return int(C.mlx_array_nbytes(t.ctx))
}
func (t *Array) NumDims() int {
return int(C.mlx_array_ndim(t.ctx))
}
func (t *Array) Dims() []int {
dims := make([]int, t.NumDims())
for i := range dims {
dims[i] = t.Dim(i)
}
return dims
}
func (t *Array) Dim(dim int) int {
return int(C.mlx_array_dim(t.ctx, C.int(dim)))
}
func (t *Array) DType() DType {
return DType(C.mlx_array_dtype(t.ctx))
}
// data utilities
func (t *Array) Int() int {
var item C.int64_t
C.mlx_array_item_int64(&item, t.ctx)
return int(item)
}
func (t *Array) Float() float64 {
var item C.double
C.mlx_array_item_float64(&item, t.ctx)
return float64(item)
}
func (t *Array) Ints() []int {
if dt := t.DType(); dt != DTypeInt32 {
panic(fmt.Sprintf("mlx: Ints requires DTypeInt32, got %v", dt))
}
ints := make([]int, t.Size())
for i, f := range unsafe.Slice(C.mlx_array_data_int32(t.ctx), len(ints)) {
ints[i] = int(f)
}
return ints
}
func (t *Array) Floats() []float32 {
if dt := t.DType(); dt != DTypeFloat32 {
panic(fmt.Sprintf("mlx: Floats requires DTypeFloat32, got %v", dt))
}
floats := make([]float32, t.Size())
for i, f := range unsafe.Slice(C.mlx_array_data_float32(t.ctx), len(floats)) {
floats[i] = float32(f)
}
return floats
}
func (t *Array) Save(name string) error {
cName := C.CString(name)
defer C.free(unsafe.Pointer(cName))
C.mlx_save(cName, t.ctx)
return nil
}
// LogArrays logs all live arrays, sorted by size
func LogArrays() {
arraysMu.Lock()
defer arraysMu.Unlock()
sort.Slice(arrays, func(i, j int) bool {
return arrays[i].NumBytes() > arrays[j].NumBytes()
})
var total int
for _, t := range arrays {
nb := t.NumBytes()
total += nb
logutil.Trace(fmt.Sprintf("tensor %-60s %5s %5s pinned=%d %v", t.name, t.DType(), PrettyBytes(nb), t.pinned.Load(), t.Dims()))
}
logutil.Trace(fmt.Sprintf("tensors total: %d, size: %s, active: %s", len(arrays), PrettyBytes(total), PrettyBytes(ActiveMemory())))
}

View File

@@ -0,0 +1,76 @@
package mlx
import "testing"
func TestFromValue(t *testing.T) {
withMLXThread(t, func() {
for got, want := range map[*Array]DType{
FromValue(true): DTypeBool,
FromValue(false): DTypeBool,
FromValue(int(7)): DTypeInt32,
FromValue(float32(3.14)): DTypeFloat32,
FromValue(float64(2.71)): DTypeFloat64,
FromValue(complex64(1 + 2i)): DTypeComplex64,
} {
if got.DType() != want {
t.Errorf("%s: want %v, got %v", want, want, got)
}
}
})
}
func TestFromValues(t *testing.T) {
withMLXThread(t, func() {
for got, want := range map[*Array]DType{
FromValues([]bool{true, false, true}, 3): DTypeBool,
FromValues([]uint8{1, 2, 3}, 3): DTypeUint8,
FromValues([]uint16{1, 2, 3}, 3): DTypeUint16,
FromValues([]uint32{1, 2, 3}, 3): DTypeUint32,
FromValues([]uint64{1, 2, 3}, 3): DTypeUint64,
FromValues([]int8{-1, -2, -3}, 3): DTypeInt8,
FromValues([]int16{-1, -2, -3}, 3): DTypeInt16,
FromValues([]int32{-1, -2, -3}, 3): DTypeInt32,
FromValues([]int64{-1, -2, -3}, 3): DTypeInt64,
FromValues([]float32{3.14, 2.71, 1.61}, 3): DTypeFloat32,
FromValues([]float64{3.14, 2.71, 1.61}, 3): DTypeFloat64,
FromValues([]complex64{1 + 2i, 3 + 4i, 5 + 6i}, 3): DTypeComplex64,
} {
if got.DType() != want {
t.Errorf("%s: want %v, got %v", want, want, got)
}
}
})
}
func TestComparisonOpsAndBernoulli(t *testing.T) {
skipIfNoMLX(t)
a := FromValues([]float32{1, 2, 3}, 3)
b := FromValues([]float32{1, 1, 4}, 3)
eq := a.Equal(b).AsType(DTypeInt32)
gt := a.Greater(b).AsType(DTypeInt32)
le := a.LessEqual(b).AsType(DTypeInt32)
bern := Bernoulli(FromValues([]float32{1, 0}, 2)).AsType(DTypeInt32)
Eval(eq, gt, le, bern)
for name, tc := range map[string]struct {
got []int
want []int
}{
"equal": {eq.Ints(), []int{1, 0, 0}},
"greater": {gt.Ints(), []int{0, 1, 0}},
"lessEqual": {le.Ints(), []int{1, 0, 1}},
"bernoulli": {bern.Ints(), []int{1, 0}},
} {
t.Run(name, func(t *testing.T) {
if len(tc.got) != len(tc.want) {
t.Fatalf("got %v, want %v", tc.got, tc.want)
}
for i := range tc.want {
if tc.got[i] != tc.want[i] {
t.Fatalf("got %v, want %v", tc.got, tc.want)
}
}
})
}
}

192
x/mlxrunner/mlx/compile.go Normal file
View File

@@ -0,0 +1,192 @@
package mlx
// #include <stdlib.h>
// #include "generated.h"
//
// extern int closureCallback(mlx_vector_array* res, mlx_vector_array input, void* payload);
// extern void closureDestructor(void* payload);
import "C"
import (
"log/slog"
"runtime/cgo"
"sync"
"unsafe"
)
// CompileFunc is the signature of a function that can be compiled.
type CompileFunc func(inputs ...*Array) []*Array
// CompileOption configures Compile behavior.
type CompileOption func(*compileConfig)
type compileConfig struct {
shapeless bool
}
// Shapeless traces the function once against symbolic shapes so the compiled
// graph accepts any input shape afterwards. Without this option, MLX re-traces
// on each new (shape, dtype) combination and caches each specialization.
func Shapeless() CompileOption {
return func(c *compileConfig) { c.shapeless = true }
}
// Compile returns a compiled version of fn. When called during another
// compile's trace, fn is inlined directly so outer compiles can fuse through
// inner ones.
//
// Compiled functions must not have side effects outside of the function. Do
// not access data other than the arguments passed in (either Go data or MLX
// arrays) unless it is a constant.
func Compile(name string, fn CompileFunc, opts ...CompileOption) CompileFunc {
var cfg compileConfig
for _, o := range opts {
o(&cfg)
}
var closure C.mlx_closure
var once sync.Once
return func(inputs ...*Array) []*Array {
if tracing {
return fn(inputs...)
}
once.Do(func() {
payload := (*cgo.Handle)(C.malloc(C.size_t(unsafe.Sizeof(cgo.Handle(0)))))
*payload = cgo.NewHandle(fn)
src := C.mlx_closure_new_func_payload(
(*[0]byte)(C.closureCallback),
unsafe.Pointer(payload),
(*[0]byte)(C.closureDestructor),
)
defer C.mlx_closure_free(src)
closure = C.mlx_closure_new()
mlxCheck(name+": compile failed", func() C.int {
return C.mlx_compile(&closure, src, C.bool(cfg.shapeless))
})
})
inVec := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(inVec)
for _, in := range inputs {
C.mlx_vector_array_append_value(inVec, in.ctx)
}
outVec := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(outVec)
mlxCheck(name+": closure apply failed", func() C.int {
return C.mlx_closure_apply(&outVec, closure, inVec)
})
n := int(C.mlx_vector_array_size(outVec))
outputs := make([]*Array, n)
for i := range n {
outputs[i] = New(name)
C.mlx_vector_array_get(&outputs[i].ctx, outVec, C.size_t(i))
}
return outputs
}
}
// Compile1 compiles a unary function. See Compile.
func Compile1(name string, fn func(*Array) *Array, opts ...CompileOption) func(*Array) *Array {
cf := Compile(name, func(in ...*Array) []*Array {
return []*Array{fn(in[0])}
}, opts...)
return func(a *Array) *Array {
return cf(a)[0]
}
}
// Compile2 compiles a binary function. See Compile.
func Compile2(name string, fn func(*Array, *Array) *Array, opts ...CompileOption) func(*Array, *Array) *Array {
cf := Compile(name, func(in ...*Array) []*Array {
return []*Array{fn(in[0], in[1])}
}, opts...)
return func(a, b *Array) *Array {
return cf(a, b)[0]
}
}
// Compile3 compiles a ternary function. See Compile.
func Compile3(name string, fn func(*Array, *Array, *Array) *Array, opts ...CompileOption) func(*Array, *Array, *Array) *Array {
cf := Compile(name, func(in ...*Array) []*Array {
return []*Array{fn(in[0], in[1], in[2])}
}, opts...)
return func(a, b, c *Array) *Array {
return cf(a, b, c)[0]
}
}
// tracing is true while a compile callback is running. Since MLX is
// single-threaded at this level a plain Go bool suffices.
var tracing bool
// traceScratch collects arrays created during a compile trace so they can be
// freed as a group when the callback returns.
var traceScratch []*Array
//export closureCallback
func closureCallback(res *C.mlx_vector_array, input C.mlx_vector_array, payload unsafe.Pointer) (rc C.int) {
defer func() {
if r := recover(); r != nil {
slog.Error("mlx closure callback panicked", "panic", r)
rc = 1
}
}()
handle := *(*cgo.Handle)(payload)
fn := handle.Value().(CompileFunc)
// When tracing, we track all of the intermediates that are created and free them separately at the end of
// the process. This will give the effect of a single op - inputs are owned by the original caller (via
// the MLX layer) and outputs are transferred back to MLX to create a new Go side tensor.
if tracing {
panic("mlx: nested compile trace")
}
tracing = true
traceScratch = nil
defer func() {
for _, a := range traceScratch {
if a.pinned.Load() > 0 {
panic("mlx: traced array was pinned during compilation")
}
if a.Valid() {
C.mlx_array_free(a.ctx)
a.ctx.ctx = nil
}
}
tracing = false
traceScratch = nil
}()
n := int(C.mlx_vector_array_size(input))
inputs := make([]*Array, n)
for i := range n {
a := New("")
C.mlx_vector_array_get(&a.ctx, input, C.size_t(i))
inputs[i] = a
}
outputs := fn(inputs...)
var arrPtr *C.mlx_array
if len(outputs) > 0 {
handles := make([]C.mlx_array, len(outputs))
for i, out := range outputs {
handles[i] = out.ctx
}
arrPtr = &handles[0]
}
C.mlx_vector_array_set_data(res, arrPtr, C.size_t(len(outputs)))
return 0
}
//export closureDestructor
func closureDestructor(payload unsafe.Pointer) {
handle := *(*cgo.Handle)(payload)
handle.Delete()
C.free(payload)
}

View File

@@ -0,0 +1,147 @@
package mlx
import (
"testing"
)
func TestCompileFusion(t *testing.T) {
skipIfNoMLX(t)
// Compile fuses the ops inside a function body into a single kernel,
// eliminating intermediate buffers. Use a diamond-shaped graph where
// two branches must be materialized simultaneously without fusion,
// then compare peak memory against the compiled version which fuses
// everything into one kernel with no intermediates.
const n = 1024 * 1024 // 4MB per float32 array
data := make([]float32, n)
for i := range data {
data[i] = float32(i + 1)
}
// Diamond: both a*b and a+b must be live for the final multiply.
// Without fusion: peak includes both intermediates (~8MB extra).
// With fusion: single kernel, no intermediates.
body := func(a, b *Array) *Array {
return a.Multiply(b).Multiply(a.Add(b))
}
a := FromValues(data, n)
b := FromValues(data, n)
Pin(a, b)
defer Unpin(a, b)
// Compiled: ops fused into a single kernel.
EnableCompile()
fn := Compile2("diamond", body, Shapeless())
warm := fn(a, b)
Eval(warm)
Sweep()
ClearCache()
ResetPeakMemory()
y := fn(a, b)
Eval(y)
compiledPeak := PeakMemory()
Sweep()
// Uncompiled: ops evaluated individually, intermediates materialized.
ClearCache()
ResetPeakMemory()
z := body(a, b)
Eval(z)
uncompiledPeak := PeakMemory()
Sweep()
if compiledPeak == 0 && uncompiledPeak == 0 {
t.Skip("peak memory tracking not available")
}
t.Logf("peak memory: compiled=%d uncompiled=%d", compiledPeak, uncompiledPeak)
if compiledPeak >= uncompiledPeak {
t.Fatalf("compilation did not reduce peak memory: compiled=%d uncompiled=%d", compiledPeak, uncompiledPeak)
}
}
func TestCompileNested(t *testing.T) {
skipIfNoMLX(t)
// A compiled function that calls another compiled function should
// produce correct results. The inner function inlines via isTracing()
// during the outer's trace.
inner := Compile1("silu", func(a *Array) *Array {
return a.Multiply(a.Sigmoid())
}, Shapeless())
outer := Compile2("swiglu", func(gate, up *Array) *Array {
return inner(gate).Multiply(up)
}, Shapeless())
gate := FromValues([]float32{0, 1, 2}, 3)
up := FromValues([]float32{1, 1, 1}, 3)
Pin(gate, up)
defer Unpin(gate, up)
y := outer(gate, up)
Eval(y)
// silu(x) = x * sigmoid(x); for x=0 → 0, x=1 → ~0.7311, x=2 → ~1.7616
got := y.Floats()
want := []float32{0, 0.7310586, 1.7615942}
for i, v := range got {
if v-want[i] > 1e-4 || want[i]-v > 1e-4 {
t.Fatalf("got[%d]=%v want %v", i, v, want[i])
}
}
}
func TestCompileCallbackPanicRecovers(t *testing.T) {
skipIfNoMLX(t)
boom := Compile1("boom", func(a *Array) *Array {
panic("intentional test panic")
})
x := FromValues([]float32{1}, 1)
Pin(x)
defer Unpin(x)
defer func() {
r := recover()
if r == nil {
t.Fatal("expected panic from Call, got none")
}
if _, ok := r.(string); !ok {
t.Fatalf("expected string panic, got %T: %v", r, r)
}
}()
boom(x)
}
func TestCompileNoTrackingGrowth(t *testing.T) {
skipIfNoMLX(t)
// Repeated invocations of a compiled kernel should not grow the
// tracked-arrays list — the callback's traceScratch collects
// intermediates during tracing and frees them when the callback returns.
fn := Compile2("mul_add", func(a, b *Array) *Array {
return a.Multiply(b).Add(b)
})
a := FromValues([]float32{1, 2}, 2)
b := FromValues([]float32{3, 4}, 2)
Pin(a, b)
defer Unpin(a, b)
Sweep()
before := len(arrays)
for range 100 {
_ = fn(a, b)
Sweep()
}
after := len(arrays)
if after > before+2 {
t.Fatalf("tracked arrays grew from %d to %d across 100 calls (includes initial trace)", before, after)
}
}

94
x/mlxrunner/mlx/dtype.go Normal file
View File

@@ -0,0 +1,94 @@
package mlx
// #include "generated.h"
import "C"
type DType int
func (t DType) String() string {
switch t {
case DTypeBool:
return "BOOL"
case DTypeUint8:
return "U8"
case DTypeUint16:
return "U16"
case DTypeUint32:
return "U32"
case DTypeUint64:
return "U64"
case DTypeInt8:
return "I8"
case DTypeInt16:
return "I16"
case DTypeInt32:
return "I32"
case DTypeInt64:
return "I64"
case DTypeFloat16:
return "F16"
case DTypeFloat32:
return "F32"
case DTypeFloat64:
return "F64"
case DTypeBFloat16:
return "BF16"
case DTypeComplex64:
return "C64"
default:
return "Unknown"
}
}
func (t *DType) UnmarshalJSON(b []byte) error {
switch string(b) {
case `"BOOL"`:
*t = DTypeBool
case `"U8"`:
*t = DTypeUint8
case `"U16"`:
*t = DTypeUint16
case `"U32"`:
*t = DTypeUint32
case `"U64"`:
*t = DTypeUint64
case `"I8"`:
*t = DTypeInt8
case `"I16"`:
*t = DTypeInt16
case `"I32"`:
*t = DTypeInt32
case `"I64"`:
*t = DTypeInt64
case `"F16"`:
*t = DTypeFloat16
case `"F64"`:
*t = DTypeFloat64
case `"F32"`:
*t = DTypeFloat32
case `"BF16"`:
*t = DTypeBFloat16
case `"C64"`:
*t = DTypeComplex64
default:
return nil
}
return nil
}
const (
DTypeBool DType = C.MLX_BOOL
DTypeUint8 DType = C.MLX_UINT8
DTypeUint16 DType = C.MLX_UINT16
DTypeUint32 DType = C.MLX_UINT32
DTypeUint64 DType = C.MLX_UINT64
DTypeInt8 DType = C.MLX_INT8
DTypeInt16 DType = C.MLX_INT16
DTypeInt32 DType = C.MLX_INT32
DTypeInt64 DType = C.MLX_INT64
DTypeFloat16 DType = C.MLX_FLOAT16
DTypeFloat32 DType = C.MLX_FLOAT32
DTypeFloat64 DType = C.MLX_FLOAT64
DTypeBFloat16 DType = C.MLX_BFLOAT16
DTypeComplex64 DType = C.MLX_COMPLEX64
)

36
x/mlxrunner/mlx/dynamic.c Normal file
View File

@@ -0,0 +1,36 @@
#include "dynamic.h"
#include <stdio.h>
#ifdef _WIN32
#include <windows.h>
#define DLOPEN(path) LoadLibraryA(path)
#define DLCLOSE(handle) FreeLibrary((HMODULE)(handle))
#else
#ifdef __APPLE__
#include <mach-o/dyld.h>
#include <libgen.h>
#endif
#include <dlfcn.h>
#define DLOPEN(path) dlopen(path, RTLD_LAZY | RTLD_GLOBAL)
#define DLCLOSE(handle) dlclose(handle)
#endif
static int mlx_dynamic_open(mlx_dynamic_handle* handle, const char* path) {
handle->ctx = (void*) DLOPEN(path);
if (handle->ctx == NULL) {
return 1;
}
return 0;
}
int mlx_dynamic_load(mlx_dynamic_handle* handle, const char *path) {
return mlx_dynamic_open(handle, path);
}
void mlx_dynamic_unload(mlx_dynamic_handle* handle) {
if (handle->ctx) {
DLCLOSE(handle->ctx);
handle->ctx = NULL;
}
}

253
x/mlxrunner/mlx/dynamic.go Normal file
View File

@@ -0,0 +1,253 @@
package mlx
// #include "dynamic.h"
// #include "generated.h"
// #include <stdlib.h>
import "C"
import (
"fmt"
"io/fs"
"log/slog"
"os"
"path/filepath"
"runtime"
"sort"
"strconv"
"strings"
"unsafe"
)
var initError error
var initLoadError string
var initLoadedPath string
// CheckInit returns any error that occurred during MLX dynamic library initialization.
func CheckInit() error {
if initLoadedPath != "" {
slog.Debug("MLX dynamic library loaded", "path", initLoadedPath)
}
if initError != nil && initLoadError != "" {
slog.Error(initLoadError)
}
return initError
}
// tryLoadFromDir searches a directory for the mlxc shared library and loads it.
func tryLoadFromDir(dir string) bool {
// On Windows, MSVC produces mlxc.dll (no lib prefix)
// On Unix, it's libmlxc.so or libmlxc.dylib
pattern := "libmlxc.*"
if runtime.GOOS == "windows" {
pattern = "mlxc.*"
}
matches, err := fs.Glob(os.DirFS(dir), pattern)
if err != nil || len(matches) == 0 {
return false
}
for _, match := range matches {
path := filepath.Join(dir, match)
cPath := C.CString(path)
defer C.free(unsafe.Pointer(cPath))
var handle C.mlx_dynamic_handle
if C.mlx_dynamic_load(&handle, cPath) != 0 {
initLoadError = fmt.Sprintf("failed to load MLX dynamic library: path=%s", path)
continue
}
if C.mlx_dynamic_load_symbols(handle) != 0 {
initLoadError = fmt.Sprintf("failed to load MLX dynamic library symbols: path=%s", path)
C.mlx_dynamic_unload(&handle)
continue
}
initLoadedPath = path
return true
}
return false
}
// libOllamaRoots returns candidate directories for MLX dynamic libraries.
// Production: exe_dir/lib/ollama (dist tarball) and exe_dir (app bundle).
// Development: build/lib/ollama and build/*/lib/ollama.
func libOllamaRoots() []string {
var roots []string
// Production paths relative to executable
if exe, err := os.Executable(); err == nil {
if eval, err := filepath.EvalSymlinks(exe); err == nil {
exe = eval
}
exeDir := filepath.Dir(exe)
switch runtime.GOOS {
case "darwin":
roots = append(roots, filepath.Join(exeDir, "lib", "ollama"))
roots = append(roots, exeDir) // app bundle: Contents/Resources/
case "linux":
roots = append(roots, filepath.Join(exeDir, "..", "lib", "ollama"))
case "windows":
roots = append(roots, filepath.Join(exeDir, "lib", "ollama"))
}
}
// Development paths: build/lib/ollama and build/*/lib/ollama.
// Reverse-sort and filter the glob results so higher-versioned Metal
// builds (e.g., metal-v4) are tried before lower ones (metal-v3),
// and incompatible variants are skipped. Without this, alphabetical
// order would always pick v3 over v4 in dev builds.
for _, base := range repoBuildDirs() {
roots = append(roots, filepath.Join(base, "lib", "ollama"))
if matches, err := filepath.Glob(filepath.Join(base, "*", "lib", "ollama")); err == nil {
sort.Sort(sort.Reverse(sort.StringSlice(matches)))
for _, m := range matches {
// Extract the build dir name (e.g., "metal-v4" from "build/metal-v4/lib/ollama")
rel, _ := filepath.Rel(base, m)
variant := strings.SplitN(rel, string(filepath.Separator), 2)[0]
if isCompatibleMLXVariant(variant) {
roots = append(roots, m)
}
}
}
}
return roots
}
// repoBuildDirs returns candidate build/ directories relative to cwd and repo root.
func repoBuildDirs() []string {
var dirs []string
if cwd, err := os.Getwd(); err == nil {
dirs = append(dirs, filepath.Join(cwd, "build"))
for dir := cwd; ; {
if _, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil {
if dir != cwd {
dirs = append(dirs, filepath.Join(dir, "build"))
}
break
}
parent := filepath.Dir(dir)
if parent == dir {
break
}
dir = parent
}
}
return dirs
}
// prependLibraryPath prepends dir to the platform's dynamic library search
// path so the linker finds colocated libmlx before any stale copies.
// Called once after successful library load.
func prependLibraryPath(dir string) {
var envVar string
switch runtime.GOOS {
case "darwin":
envVar = "DYLD_LIBRARY_PATH"
case "linux":
envVar = "LD_LIBRARY_PATH"
default:
return
}
if existing := os.Getenv(envVar); existing != "" {
os.Setenv(envVar, dir+string(filepath.ListSeparator)+existing)
} else {
os.Setenv(envVar, dir)
}
}
func init() {
switch runtime.GOOS {
case "darwin", "linux", "windows":
default:
return
}
// OLLAMA_LLM_LIBRARY overrides variant selection (e.g., "mlx_metal_v3").
// When set to an mlx_* value, only that specific subdir is tried.
// The GGML runner ignores mlx_* values (see discover/runner.go).
forcedVariant, _ := os.LookupEnv("OLLAMA_LLM_LIBRARY")
if forcedVariant != "" && !strings.HasPrefix(forcedVariant, "mlx_") {
forcedVariant = "" // not an MLX variant, ignore
}
found := findMLXLibrary(forcedVariant)
if !found {
initError = fmt.Errorf("failed to load MLX dynamic library (searched: %v)", libOllamaRoots())
return
}
prependLibraryPath(filepath.Dir(initLoadedPath))
}
func findMLXLibrary(forcedVariant string) bool {
for _, root := range libOllamaRoots() {
if forcedVariant != "" {
if tryLoadFromDir(filepath.Join(root, forcedVariant)) {
return true
}
} else {
if tryLoadFromMLXSubdirs(root) {
return true
}
if tryLoadFromDir(root) {
return true
}
}
}
return false
}
// tryLoadFromMLXSubdirs globs for mlx_* subdirs within dir, filters out
// incompatible variants, tries the remainder in reverse sorted order (so
// higher-versioned variants are preferred), and returns true on first
// successful load.
func tryLoadFromMLXSubdirs(dir string) bool {
mlxDirs, err := filepath.Glob(filepath.Join(dir, "mlx_*"))
if err != nil || len(mlxDirs) == 0 {
return false
}
// Reverse sort: mlx_metal_v4 before mlx_metal_v3, mlx_cuda_v13 before v12
sort.Sort(sort.Reverse(sort.StringSlice(mlxDirs)))
for _, mlxDir := range mlxDirs {
if !isCompatibleMLXVariant(filepath.Base(mlxDir)) {
slog.Debug("skipping incompatible MLX variant", "dir", mlxDir)
continue
}
if tryLoadFromDir(mlxDir) {
return true
}
}
return false
}
// isCompatibleMLXVariant checks whether an MLX variant directory is
// compatible with the current OS. On macOS, dlopen does NOT enforce
// the deployment target for dynamically loaded libraries, so we must
// check compatibility ourselves to avoid loading Metal 4.x shaders
// on a Metal 3.x driver.
func isCompatibleMLXVariant(name string) bool {
if runtime.GOOS != "darwin" {
return true // non-macOS variants use dlopen failure for filtering
}
// Metal variant naming:
// Production: mlx_metal_v3, mlx_metal_v4
// Dev build: metal-v3, metal-v4
var verStr string
switch {
case strings.HasPrefix(name, "mlx_metal_v"):
verStr = strings.TrimPrefix(name, "mlx_metal_v")
case strings.HasPrefix(name, "metal-v"):
verStr = strings.TrimPrefix(name, "metal-v")
}
if verStr != "" {
metalVer, err := strconv.Atoi(verStr)
if err != nil {
return true // unknown format, try it
}
// Metal 4.x requires macOS 26+
if metalVer >= 4 && macOSMajorVersion() < 26 {
return false
}
}
return true
}

47
x/mlxrunner/mlx/dynamic.h Normal file
View File

@@ -0,0 +1,47 @@
#ifndef MLX_DYNAMIC_H
#define MLX_DYNAMIC_H
#ifdef _WIN32
#include <windows.h>
#define DLSYM(handle, symbol) (void*)GetProcAddress((HMODULE)(handle.ctx), symbol)
#else
#include <dlfcn.h>
#define DLSYM(handle, symbol) dlsym(handle.ctx, symbol)
#endif
#include <stdint.h>
// Provide fallback typedefs for float16_t and bfloat16_t on non-ARM64
// platforms where arm_fp16.h and arm_bf16.h are not available. These are
// only used as function pointer signature placeholders since MLX requires
// Apple Silicon at runtime.
#if !defined(__aarch64__) && !defined(__ARM_FEATURE_FP16_SCALAR_ARITHMETIC)
typedef uint16_t float16_t;
#endif
#if !defined(__aarch64__) && !defined(__ARM_FEATURE_BF16)
typedef uint16_t bfloat16_t;
#endif
// Undef ERROR to avoid conflict with wingdi.h on Windows
#ifdef ERROR
#undef ERROR
#endif
#define MLX_ERROR(fmt, ...) fprintf(stderr, "%s %s - ERROR - %s:%d - " fmt "\n", __DATE__, __TIME__, __FILE__, __LINE__, ##__VA_ARGS__); return 1
#define CHECK(x) if (!(x)) { MLX_ERROR("CHECK failed: " #x); }
#define CHECK_LOAD(handle, x) *(void**)(&x##_) = DLSYM(handle, #x); CHECK(x##_)
// OPTIONAL_LOAD: load symbol if available, leave function pointer NULL otherwise
#define OPTIONAL_LOAD(handle, x) *(void**)(&x##_) = DLSYM(handle, #x)
typedef struct {
void* ctx;
} mlx_dynamic_handle;
int mlx_dynamic_load(
mlx_dynamic_handle* handle,
const char *path);
void mlx_dynamic_unload(
mlx_dynamic_handle* handle);
#endif // MLX_DYNAMIC_H

View File

@@ -0,0 +1,17 @@
package mlx
import (
"strconv"
"strings"
"syscall"
)
func macOSMajorVersion() int {
ver, err := syscall.Sysctl("kern.osproductversion")
if err != nil {
return 0
}
parts := strings.SplitN(ver, ".", 2)
major, _ := strconv.Atoi(parts[0])
return major
}

View File

@@ -0,0 +1,5 @@
//go:build !darwin
package mlx
func macOSMajorVersion() int { return 0 }

47
x/mlxrunner/mlx/fast.go Normal file
View File

@@ -0,0 +1,47 @@
package mlx
// #include "generated.h"
import "C"
import (
"unsafe"
)
func FastScaledDotProductAttention(q, k, v *Array, scale float32, mode string, mask *Array) *Array {
sinks := New("")
cMode := C.CString(mode)
defer C.free(unsafe.Pointer(cMode))
var maskCtx C.mlx_array
if mask != nil {
maskCtx = mask.ctx
} else {
empty := New("")
maskCtx = empty.ctx
}
out := New("FAST_SDPA")
C.mlx_fast_scaled_dot_product_attention(&out.ctx, q.ctx, k.ctx, v.ctx, C.float(scale), cMode, maskCtx, sinks.ctx, DefaultStream().ctx)
return out
}
type LayerNorm struct {
Weight *Array `weight:"weight"`
Bias *Array `weight:"bias"`
}
func (r *LayerNorm) Forward(x *Array, eps float32) *Array {
out := New("FAST_LAYERNORM")
C.mlx_fast_layer_norm(&out.ctx, x.ctx, r.Weight.ctx, r.Bias.ctx, C.float(eps), DefaultStream().ctx)
return out
}
type RMSNorm struct {
Weight *Array `weight:"weight"`
}
func (r *RMSNorm) Forward(x *Array, eps float32) *Array {
out := New("FAST_RMSNORM")
C.mlx_fast_rms_norm(&out.ctx, x.ctx, r.Weight.ctx, C.float(eps), DefaultStream().ctx)
return out
}

View File

@@ -0,0 +1,663 @@
package mlx
// #include <stdlib.h>
// #include "generated.h"
import "C"
import (
"sync"
"unsafe"
)
var (
gatedDeltaMetalKernelOnce sync.Once
gatedDeltaMetalKernel C.mlx_fast_metal_kernel
gatedDeltaMetalDisabled bool
gatedDeltaCUDAKernelOnce sync.Once
gatedDeltaCUDAKernel C.mlx_fast_cuda_kernel
gatedDeltaCUDADisabled bool
)
const gatedDeltaMetalKernelSource = `
auto n = thread_position_in_grid.z;
auto b_idx = n / Hv;
auto hv_idx = n % Hv;
auto hk_idx = hv_idx / (Hv / Hk);
constexpr int n_per_t = Dk / 32;
// q, k: [B, T, Hk, Dk]
auto q_ = q + b_idx * T * Hk * Dk + hk_idx * Dk;
auto k_ = k + b_idx * T * Hk * Dk + hk_idx * Dk;
// v, y: [B, T, Hv, Dv]
auto v_ = v + b_idx * T * Hv * Dv + hv_idx * Dv;
y += b_idx * T * Hv * Dv + hv_idx * Dv;
auto dk_idx = thread_position_in_threadgroup.x;
auto dv_idx = thread_position_in_grid.y;
// state_in, state_out: [B, Hv, Dv, Dk]
auto i_state = state_in + (n * Dv + dv_idx) * Dk;
auto o_state = state_out + (n * Dv + dv_idx) * Dk;
float state[n_per_t];
for (int i = 0; i < n_per_t; ++i) {
auto s_idx = n_per_t * dk_idx + i;
state[i] = static_cast<float>(i_state[s_idx]);
}
// g: [B, T, Hv]
auto g_ = g + b_idx * T * Hv;
auto beta_ = beta + b_idx * T * Hv;
for (int t = 0; t < T; ++t) {
float kv_mem = 0.0f;
for (int i = 0; i < n_per_t; ++i) {
auto s_idx = n_per_t * dk_idx + i;
state[i] = state[i] * g_[hv_idx];
kv_mem += state[i] * k_[s_idx];
}
kv_mem = simd_sum(kv_mem);
auto delta = (v_[dv_idx] - kv_mem) * beta_[hv_idx];
float out = 0.0f;
for (int i = 0; i < n_per_t; ++i) {
auto s_idx = n_per_t * dk_idx + i;
state[i] = state[i] + k_[s_idx] * delta;
out += state[i] * q_[s_idx];
}
out = simd_sum(out);
if (thread_index_in_simdgroup == 0) {
y[dv_idx] = static_cast<InT>(out);
}
q_ += Hk * Dk;
k_ += Hk * Dk;
v_ += Hv * Dv;
y += Hv * Dv;
g_ += Hv;
beta_ += Hv;
}
for (int i = 0; i < n_per_t; ++i) {
auto s_idx = n_per_t * dk_idx + i;
o_state[s_idx] = static_cast<StT>(state[i]);
}
`
const gatedDeltaCUDAKernelSource = `
auto tid_x = threadIdx.x;
auto tid_y = threadIdx.y;
auto grid_y = blockIdx.y * blockDim.y + tid_y;
auto grid_z = blockIdx.z;
int T_val = static_cast<int>(*T);
auto n = grid_z;
auto b_idx = n / Hv;
auto hv_idx = n % Hv;
auto hk_idx = hv_idx / (Hv / Hk);
constexpr int n_per_t = Dk / 32;
// q, k: [B, T, Hk, Dk]
auto q_ = q + b_idx * T_val * Hk * Dk + hk_idx * Dk;
auto k_ = k + b_idx * T_val * Hk * Dk + hk_idx * Dk;
// v, y: [B, T, Hv, Dv]
auto dv_idx = grid_y;
auto v_ = v + b_idx * T_val * Hv * Dv + hv_idx * Dv;
y += b_idx * T_val * Hv * Dv + hv_idx * Dv;
auto dk_idx = tid_x;
// state_in, state_out: [B, Hv, Dv, Dk]
auto i_state = state_in + (n * Dv + dv_idx) * Dk;
auto o_state = state_out + (n * Dv + dv_idx) * Dk;
float state[n_per_t];
for (int i = 0; i < n_per_t; ++i) {
auto s_idx = n_per_t * dk_idx + i;
state[i] = static_cast<float>(i_state[s_idx]);
}
// g: [B, T, Hv]
auto g_ = g + b_idx * T_val * Hv;
auto beta_ = beta + b_idx * T_val * Hv;
for (int t = 0; t < T_val; ++t) {
float kv_mem = 0.0f;
for (int i = 0; i < n_per_t; ++i) {
auto s_idx = n_per_t * dk_idx + i;
state[i] = state[i] * static_cast<float>(g_[hv_idx]);
kv_mem += state[i] * static_cast<float>(k_[s_idx]);
}
// Warp reduction (full warp, 32 threads in x)
for (int offset = 16; offset > 0; offset >>= 1)
kv_mem += __shfl_down_sync(0xffffffff, kv_mem, offset);
kv_mem = __shfl_sync(0xffffffff, kv_mem, 0);
auto delta = (static_cast<float>(v_[dv_idx]) - kv_mem) * static_cast<float>(beta_[hv_idx]);
float out = 0.0f;
for (int i = 0; i < n_per_t; ++i) {
auto s_idx = n_per_t * dk_idx + i;
state[i] = state[i] + static_cast<float>(k_[s_idx]) * delta;
out += state[i] * static_cast<float>(q_[s_idx]);
}
// Warp reduction
for (int offset = 16; offset > 0; offset >>= 1)
out += __shfl_down_sync(0xffffffff, out, offset);
if (tid_x == 0) {
y[dv_idx] = static_cast<InT>(out);
}
q_ += Hk * Dk;
k_ += Hk * Dk;
v_ += Hv * Dv;
y += Hv * Dv;
g_ += Hv;
beta_ += Hv;
}
for (int i = 0; i < n_per_t; ++i) {
auto s_idx = n_per_t * dk_idx + i;
o_state[s_idx] = static_cast<StT>(state[i]);
}
`
func cStringVector(values []string) (C.mlx_vector_string, func(), bool) {
vec := C.mlx_vector_string_new()
ok := true
for _, s := range values {
cs := C.CString(s)
if C.mlx_vector_string_append_value(vec, cs) != 0 {
ok = false
}
C.free(unsafe.Pointer(cs))
if !ok {
break
}
}
cleanup := func() {
C.mlx_vector_string_free(vec)
}
return vec, cleanup, ok
}
func initGatedDeltaMetalKernel() {
inputs, freeInputs, ok := cStringVector([]string{"q", "k", "v", "g", "beta", "state_in", "T"})
if !ok {
gatedDeltaMetalDisabled = true
freeInputs()
return
}
defer freeInputs()
outputs, freeOutputs, ok := cStringVector([]string{"y", "state_out"})
if !ok {
gatedDeltaMetalDisabled = true
freeOutputs()
return
}
defer freeOutputs()
cName := C.CString("gated_delta_step")
defer C.free(unsafe.Pointer(cName))
cSource := C.CString(gatedDeltaMetalKernelSource)
defer C.free(unsafe.Pointer(cSource))
cHeader := C.CString("")
defer C.free(unsafe.Pointer(cHeader))
gatedDeltaMetalKernel = C.mlx_fast_metal_kernel_new(
cName,
inputs,
outputs,
cSource,
cHeader,
C.bool(true),
C.bool(false),
)
}
// gatedDeltaKernel runs a fused Metal kernel for the qwen3.5 recurrent update.
// It returns ok=false on unsupported shapes/devices or kernel setup/apply failure.
func gatedDeltaKernel(q, k, v, g, beta, state *Array) (y, nextState *Array, ok bool) {
if gatedDeltaMetalDisabled {
return nil, nil, false
}
if q == nil || k == nil || v == nil || g == nil || beta == nil || state == nil {
return nil, nil, false
}
qd := q.Dims()
kd := k.Dims()
vd := v.Dims()
gd := g.Dims()
bd := beta.Dims()
sd := state.Dims()
if len(qd) != 4 || len(kd) != 4 || len(vd) != 4 || len(gd) != 3 || len(bd) != 3 || len(sd) != 4 {
return nil, nil, false
}
B, T, Hk, Dk := qd[0], qd[1], qd[2], qd[3]
if T <= 0 || Hk <= 0 || Dk <= 0 || Dk%32 != 0 {
return nil, nil, false
}
if kd[0] != B || kd[1] != T || kd[2] != Hk || kd[3] != Dk {
return nil, nil, false
}
Hv, Dv := vd[2], vd[3]
if vd[0] != B || vd[1] != T || Hv <= 0 || Dv <= 0 || Hv%Hk != 0 {
return nil, nil, false
}
if gd[0] != B || gd[1] != T || gd[2] != Hv {
return nil, nil, false
}
if bd[0] != B || bd[1] != T || bd[2] != Hv {
return nil, nil, false
}
if sd[0] != B || sd[1] != Hv || sd[2] != Dv || sd[3] != Dk {
return nil, nil, false
}
inputDType := q.DType()
stateDType := state.DType()
if k.DType() != inputDType || v.DType() != inputDType || g.DType() != inputDType || beta.DType() != inputDType {
return nil, nil, false
}
gatedDeltaMetalKernelOnce.Do(initGatedDeltaMetalKernel)
if gatedDeltaMetalDisabled {
return nil, nil, false
}
cfg := C.mlx_fast_metal_kernel_config_new()
defer C.mlx_fast_metal_kernel_config_free(cfg)
cInT := C.CString("InT")
defer C.free(unsafe.Pointer(cInT))
if C.mlx_fast_metal_kernel_config_add_template_arg_dtype(cfg, cInT, C.mlx_dtype(inputDType)) != 0 {
gatedDeltaMetalDisabled = true
return nil, nil, false
}
cStT := C.CString("StT")
defer C.free(unsafe.Pointer(cStT))
if C.mlx_fast_metal_kernel_config_add_template_arg_dtype(cfg, cStT, C.mlx_dtype(stateDType)) != 0 {
gatedDeltaMetalDisabled = true
return nil, nil, false
}
for _, tpl := range []struct {
name string
value int
}{
{name: "Dk", value: Dk},
{name: "Dv", value: Dv},
{name: "Hk", value: Hk},
{name: "Hv", value: Hv},
} {
cn := C.CString(tpl.name)
rc := C.mlx_fast_metal_kernel_config_add_template_arg_int(cfg, cn, C.int(tpl.value))
C.free(unsafe.Pointer(cn))
if rc != 0 {
gatedDeltaMetalDisabled = true
return nil, nil, false
}
}
yShape := []C.int{C.int(B), C.int(T), C.int(Hv), C.int(Dv)}
stateShape := []C.int{C.int(B), C.int(Hv), C.int(Dv), C.int(Dk)}
if C.mlx_fast_metal_kernel_config_add_output_arg(cfg, unsafe.SliceData(yShape), C.size_t(len(yShape)), C.mlx_dtype(inputDType)) != 0 {
gatedDeltaMetalDisabled = true
return nil, nil, false
}
if C.mlx_fast_metal_kernel_config_add_output_arg(cfg, unsafe.SliceData(stateShape), C.size_t(len(stateShape)), C.mlx_dtype(stateDType)) != 0 {
gatedDeltaMetalDisabled = true
return nil, nil, false
}
if C.mlx_fast_metal_kernel_config_set_grid(cfg, 32, C.int(Dv), C.int(B*Hv)) != 0 {
gatedDeltaMetalDisabled = true
return nil, nil, false
}
threadY := Dv
if threadY > 4 {
threadY = 4
}
if C.mlx_fast_metal_kernel_config_set_thread_group(cfg, 32, C.int(threadY), 1) != 0 {
gatedDeltaMetalDisabled = true
return nil, nil, false
}
tScalar := FromValue(T)
inputs := []C.mlx_array{
q.ctx,
k.ctx,
v.ctx,
g.ctx,
beta.ctx,
state.ctx,
tScalar.ctx,
}
inVec := C.mlx_vector_array_new_data(unsafe.SliceData(inputs), C.size_t(len(inputs)))
defer C.mlx_vector_array_free(inVec)
outVec := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(outVec)
if C.mlx_fast_metal_kernel_apply(&outVec, gatedDeltaMetalKernel, inVec, cfg, DefaultStream().ctx) != 0 {
gatedDeltaMetalDisabled = true
return nil, nil, false
}
if int(C.mlx_vector_array_size(outVec)) < 2 {
return nil, nil, false
}
y = New("GATED_DELTA_METAL_Y")
nextState = New("GATED_DELTA_METAL_STATE")
C.mlx_vector_array_get(&y.ctx, outVec, 0)
C.mlx_vector_array_get(&nextState.ctx, outVec, 1)
return y, nextState, true
}
func repeatHeadsForGatedDelta(x *Array, repeatFactor int) *Array {
if repeatFactor <= 1 {
return x
}
shape := x.Dims()
x = ExpandDims(x, 3)
x = Tile(x, []int32{1, 1, 1, int32(repeatFactor), 1})
return Reshape(x, int32(shape[0]), int32(shape[1]), int32(shape[2]*repeatFactor), int32(shape[3]))
}
func gatedDeltaFallback(q, k, v, g, beta, state *Array) (y, nextState *Array) {
if q == nil || k == nil || v == nil || g == nil || beta == nil || state == nil {
return nil, nil
}
qd := q.Dims()
kd := k.Dims()
vd := v.Dims()
gd := g.Dims()
bd := beta.Dims()
sd := state.Dims()
if len(qd) != 4 || len(kd) != 4 || len(vd) != 4 || len(gd) != 3 || len(bd) != 3 || len(sd) != 4 {
return nil, nil
}
B, T, Hk, Dk := int32(qd[0]), int32(qd[1]), int32(qd[2]), int32(qd[3])
Hv, Dv := int32(vd[2]), int32(vd[3])
if T <= 0 || Hk <= 0 || Dk <= 0 || Hv <= 0 || Dv <= 0 || Hv%Hk != 0 {
return nil, nil
}
if kd[0] != int(B) || kd[1] != int(T) || kd[2] != int(Hk) || kd[3] != int(Dk) {
return nil, nil
}
if vd[0] != int(B) || vd[1] != int(T) {
return nil, nil
}
if gd[0] != int(B) || gd[1] != int(T) || gd[2] != int(Hv) {
return nil, nil
}
if bd[0] != int(B) || bd[1] != int(T) || bd[2] != int(Hv) {
return nil, nil
}
if sd[0] != int(B) || sd[1] != int(Hv) || sd[2] != int(Dv) || sd[3] != int(Dk) {
return nil, nil
}
repeatFactor := int(Hv / Hk)
q = repeatHeadsForGatedDelta(q, repeatFactor)
k = repeatHeadsForGatedDelta(k, repeatFactor)
nextState = state
if T == 1 {
qt := Squeeze(q, 1)
kt := Squeeze(k, 1)
vt := Squeeze(v, 1)
gt := Squeeze(g, 1)
bt := Squeeze(beta, 1)
nextState = Mul(nextState, ExpandDims(ExpandDims(gt, -1), -1))
kvMem := Sum(Mul(nextState, ExpandDims(kt, 2)), -1, false)
delta := Mul(Sub(vt, kvMem), ExpandDims(bt, -1))
nextState = Add(nextState, Mul(ExpandDims(kt, 2), ExpandDims(delta, -1)))
yt := Sum(Mul(nextState, ExpandDims(qt, 2)), -1, false)
return ExpandDims(yt, 1), nextState
}
outs := make([]*Array, 0, T)
for t := int32(0); t < T; t++ {
qt := Squeeze(SliceStartStop(q, []int32{0, t, 0, 0}, []int32{B, t + 1, Hv, Dk}), 1)
kt := Squeeze(SliceStartStop(k, []int32{0, t, 0, 0}, []int32{B, t + 1, Hv, Dk}), 1)
vt := Squeeze(SliceStartStop(v, []int32{0, t, 0, 0}, []int32{B, t + 1, Hv, Dv}), 1)
gt := Squeeze(SliceStartStop(g, []int32{0, t, 0}, []int32{B, t + 1, Hv}), 1)
bt := Squeeze(SliceStartStop(beta, []int32{0, t, 0}, []int32{B, t + 1, Hv}), 1)
nextState = Mul(nextState, ExpandDims(ExpandDims(gt, -1), -1))
kvMem := Sum(Mul(nextState, ExpandDims(kt, 2)), -1, false)
delta := Mul(Sub(vt, kvMem), ExpandDims(bt, -1))
nextState = Add(nextState, Mul(ExpandDims(kt, 2), ExpandDims(delta, -1)))
yt := Sum(Mul(nextState, ExpandDims(qt, 2)), -1, false)
outs = append(outs, ExpandDims(yt, 1))
}
return Concatenate(outs, 1), nextState
}
func initGatedDeltaCUDAKernel() {
var cudaAvail C.bool
if C.mlx_cuda_is_available(&cudaAvail) != 0 || !bool(cudaAvail) {
gatedDeltaCUDADisabled = true
return
}
inputs, freeInputs, ok := cStringVector([]string{"q", "k", "v", "g", "beta", "state_in", "T"})
if !ok {
gatedDeltaCUDADisabled = true
freeInputs()
return
}
defer freeInputs()
outputs, freeOutputs, ok := cStringVector([]string{"y", "state_out"})
if !ok {
gatedDeltaCUDADisabled = true
freeOutputs()
return
}
defer freeOutputs()
cName := C.CString("gated_delta_step")
defer C.free(unsafe.Pointer(cName))
cSource := C.CString(gatedDeltaCUDAKernelSource)
defer C.free(unsafe.Pointer(cSource))
cHeader := C.CString("")
defer C.free(unsafe.Pointer(cHeader))
gatedDeltaCUDAKernel = C.mlx_fast_cuda_kernel_new(
cName,
inputs,
outputs,
cSource,
cHeader,
C.bool(true),
C.int(0),
)
}
func gatedDeltaCUDAKernelApply(q, k, v, g, beta, state *Array) (y, nextState *Array, ok bool) {
if gatedDeltaCUDADisabled {
return nil, nil, false
}
if q == nil || k == nil || v == nil || g == nil || beta == nil || state == nil {
return nil, nil, false
}
qd := q.Dims()
kd := k.Dims()
vd := v.Dims()
gd := g.Dims()
bd := beta.Dims()
sd := state.Dims()
if len(qd) != 4 || len(kd) != 4 || len(vd) != 4 || len(gd) != 3 || len(bd) != 3 || len(sd) != 4 {
return nil, nil, false
}
B, T, Hk, Dk := qd[0], qd[1], qd[2], qd[3]
if T <= 0 || Hk <= 0 || Dk <= 0 || Dk%32 != 0 {
return nil, nil, false
}
if kd[0] != B || kd[1] != T || kd[2] != Hk || kd[3] != Dk {
return nil, nil, false
}
Hv, Dv := vd[2], vd[3]
if vd[0] != B || vd[1] != T || Hv <= 0 || Dv <= 0 || Hv%Hk != 0 {
return nil, nil, false
}
if gd[0] != B || gd[1] != T || gd[2] != Hv {
return nil, nil, false
}
if bd[0] != B || bd[1] != T || bd[2] != Hv {
return nil, nil, false
}
if sd[0] != B || sd[1] != Hv || sd[2] != Dv || sd[3] != Dk {
return nil, nil, false
}
inputDType := q.DType()
stateDType := state.DType()
if k.DType() != inputDType || v.DType() != inputDType || g.DType() != inputDType || beta.DType() != inputDType {
return nil, nil, false
}
gatedDeltaCUDAKernelOnce.Do(initGatedDeltaCUDAKernel)
if gatedDeltaCUDADisabled {
return nil, nil, false
}
cfg := C.mlx_fast_cuda_kernel_config_new()
defer C.mlx_fast_cuda_kernel_config_free(cfg)
cInT := C.CString("InT")
defer C.free(unsafe.Pointer(cInT))
if C.mlx_fast_cuda_kernel_config_add_template_arg_dtype(cfg, cInT, C.mlx_dtype(inputDType)) != 0 {
gatedDeltaCUDADisabled = true
return nil, nil, false
}
cStT := C.CString("StT")
defer C.free(unsafe.Pointer(cStT))
if C.mlx_fast_cuda_kernel_config_add_template_arg_dtype(cfg, cStT, C.mlx_dtype(stateDType)) != 0 {
gatedDeltaCUDADisabled = true
return nil, nil, false
}
for _, tpl := range []struct {
name string
value int
}{
{name: "Dk", value: Dk},
{name: "Dv", value: Dv},
{name: "Hk", value: Hk},
{name: "Hv", value: Hv},
} {
cn := C.CString(tpl.name)
rc := C.mlx_fast_cuda_kernel_config_add_template_arg_int(cfg, cn, C.int(tpl.value))
C.free(unsafe.Pointer(cn))
if rc != 0 {
gatedDeltaCUDADisabled = true
return nil, nil, false
}
}
yShape := []C.int{C.int(B), C.int(T), C.int(Hv), C.int(Dv)}
stateShape := []C.int{C.int(B), C.int(Hv), C.int(Dv), C.int(Dk)}
if C.mlx_fast_cuda_kernel_config_add_output_arg(cfg, unsafe.SliceData(yShape), C.size_t(len(yShape)), C.mlx_dtype(inputDType)) != 0 {
gatedDeltaCUDADisabled = true
return nil, nil, false
}
if C.mlx_fast_cuda_kernel_config_add_output_arg(cfg, unsafe.SliceData(stateShape), C.size_t(len(stateShape)), C.mlx_dtype(stateDType)) != 0 {
gatedDeltaCUDADisabled = true
return nil, nil, false
}
if C.mlx_fast_cuda_kernel_config_set_grid(cfg, 32, C.int(Dv), C.int(B*Hv)) != 0 {
gatedDeltaCUDADisabled = true
return nil, nil, false
}
threadY := Dv
if threadY > 4 {
threadY = 4
}
if C.mlx_fast_cuda_kernel_config_set_thread_group(cfg, 32, C.int(threadY), 1) != 0 {
gatedDeltaCUDADisabled = true
return nil, nil, false
}
tScalar := FromValue(T)
inputs := []C.mlx_array{
q.ctx,
k.ctx,
v.ctx,
g.ctx,
beta.ctx,
state.ctx,
tScalar.ctx,
}
inVec := C.mlx_vector_array_new_data(unsafe.SliceData(inputs), C.size_t(len(inputs)))
defer C.mlx_vector_array_free(inVec)
outVec := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(outVec)
if C.mlx_fast_cuda_kernel_apply(&outVec, gatedDeltaCUDAKernel, inVec, cfg, DefaultStream().ctx) != 0 {
gatedDeltaCUDADisabled = true
return nil, nil, false
}
if int(C.mlx_vector_array_size(outVec)) < 2 {
return nil, nil, false
}
y = New("GATED_DELTA_CUDA_Y")
nextState = New("GATED_DELTA_CUDA_STATE")
C.mlx_vector_array_get(&y.ctx, outVec, 0)
C.mlx_vector_array_get(&nextState.ctx, outVec, 1)
return y, nextState, true
}
// FastGatedDelta runs the recurrent update operation.
//
// When mask is non-nil, it must be a [B, T] bool tensor identifying real
// (true) vs. padded (false) positions in q/k/v/g/beta. Padded positions
// are substituted with neutral values (q=k=v=beta=0, g=1) so each padded
// kernel iteration is a no-op — state passes through unchanged and the
// final state equals the state after the last real token of each row.
//
// It tries the fused CUDA kernel first, then Metal, then falls back to a
// backend-agnostic MLX implementation with identical inputs/outputs.
func FastGatedDelta(q, k, v, g, beta, state, mask *Array) (y, nextState *Array) {
// TODO: handle this more efficiently with a masked kernel (MLX-LM has one).
if mask != nil {
B := int32(mask.Dim(0))
T := int32(mask.Dim(1))
m4 := Reshape(mask, B, T, 1, 1)
m3 := Reshape(mask, B, T, 1)
zeroQ := FromValue(float32(0)).AsType(q.DType())
zeroK := FromValue(float32(0)).AsType(k.DType())
zeroV := FromValue(float32(0)).AsType(v.DType())
zeroBeta := FromValue(float32(0)).AsType(beta.DType())
oneG := FromValue(float32(1)).AsType(g.DType())
q = Where(m4, q, zeroQ)
k = Where(m4, k, zeroK)
v = Where(m4, v, zeroV)
beta = Where(m3, beta, zeroBeta)
g = Where(m3, g, oneG)
}
if y, nextState, ok := gatedDeltaCUDAKernelApply(q, k, v, g, beta, state); ok {
return y, nextState
}
if y, nextState, ok := gatedDeltaKernel(q, k, v, g, beta, state); ok {
return y, nextState
}
y, nextState = gatedDeltaFallback(q, k, v, g, beta, state)
if y == nil || nextState == nil {
panic("mlx.FastGatedDelta: fallback failed (invalid inputs or unsupported shapes)")
}
return y, nextState
}

3012
x/mlxrunner/mlx/generated.c Normal file

File diff suppressed because it is too large Load Diff

7256
x/mlxrunner/mlx/generated.h Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,17 @@
// This code is auto-generated; DO NOT EDIT.
#include "generated.h"
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
{{ range .Functions }}
{{ .Type }} (*{{ .Name }}_){{ .Parameters }} = NULL;
{{- end }}
int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
{{- range .Functions }}
{{ if .Optional }}OPTIONAL_LOAD{{ else }}CHECK_LOAD{{ end }}(handle, {{ .Name }});
{{- end }}
return 0;
}

View File

@@ -0,0 +1,26 @@
// This code is auto-generated; DO NOT EDIT.
#ifndef MLX_GENERATED_H
#define MLX_GENERATED_H
#include "dynamic.h"
{{ range .Functions }}
#define {{ .Name }} {{ .Name }}_mlx_gen_orig_
{{- end }}
#include "mlx/c/mlx.h"
{{ range .Functions }}
#undef {{ .Name }}
{{- end }}
{{ range .Functions }}
extern {{ .Type }} (*{{ .Name }}_){{ .Parameters }};
{{- end }}
int mlx_dynamic_load_symbols(mlx_dynamic_handle handle);
{{ range .Functions }}
static inline {{ .Type }} {{ .Name }}{{ .Parameters }} {{ "{" }}
return {{ .Name }}_({{ .Args }});
{{ "}" }}
{{- end }}
#endif // MLX_GENERATED_H

View File

@@ -0,0 +1,157 @@
package main
import (
"embed"
"flag"
"fmt"
"os"
"path/filepath"
"slices"
"strings"
"text/template"
tree_sitter "github.com/tree-sitter/go-tree-sitter"
tree_sitter_cpp "github.com/tree-sitter/tree-sitter-cpp/bindings/go"
)
//go:embed *.gotmpl
var fsys embed.FS
// optionalSymbols lists symbols that may not be present in all builds
// (e.g., float16/bfloat16 are unavailable in CUDA builds of MLX).
var optionalSymbols = map[string]bool{
"mlx_array_item_float16": true,
"mlx_array_item_bfloat16": true,
"mlx_array_data_float16": true,
"mlx_array_data_bfloat16": true,
}
type Function struct {
Type,
Name,
Parameters,
Args string
Optional bool
}
func ParseFunction(node *tree_sitter.Node, tc *tree_sitter.TreeCursor, source []byte) Function {
var fn Function
fn.Name = node.ChildByFieldName("declarator").Utf8Text(source)
if params := node.ChildByFieldName("parameters"); params != nil {
fn.Parameters = params.Utf8Text(source)
fn.Args = ParseParameters(params, tc, source)
}
var types []string
for node.Parent() != nil && node.Parent().Kind() != "declaration" {
if node.Parent().Kind() == "pointer_declarator" {
types = append(types, "*")
}
node = node.Parent()
}
for sibling := node.PrevSibling(); sibling != nil; sibling = sibling.PrevSibling() {
types = append(types, sibling.Utf8Text(source))
}
slices.Reverse(types)
fn.Type = strings.Join(types, " ")
return fn
}
func ParseParameters(node *tree_sitter.Node, tc *tree_sitter.TreeCursor, source []byte) string {
var s []string
for _, child := range node.Children(tc) {
if child.IsNamed() {
child := child.ChildByFieldName("declarator")
for child != nil && child.Kind() != "identifier" {
if child.Kind() == "parenthesized_declarator" {
child = child.Child(1)
} else {
child = child.ChildByFieldName("declarator")
}
}
if child != nil {
s = append(s, child.Utf8Text(source))
}
}
}
return strings.Join(s, ", ")
}
func main() {
var output string
flag.StringVar(&output, "output", ".", "Output directory for generated files")
flag.Parse()
parser := tree_sitter.NewParser()
defer parser.Close()
language := tree_sitter.NewLanguage(tree_sitter_cpp.Language())
parser.SetLanguage(language)
query, _ := tree_sitter.NewQuery(language, `(function_declarator declarator: (identifier)) @func`)
defer query.Close()
qc := tree_sitter.NewQueryCursor()
defer qc.Close()
var files []string
for _, arg := range flag.Args() {
matches, err := filepath.Glob(arg)
if err != nil {
fmt.Fprintf(os.Stderr, "Error expanding glob %s: %v\n", arg, err)
continue
}
files = append(files, matches...)
}
var funs []Function
for _, arg := range files {
bts, err := os.ReadFile(arg)
if err != nil {
fmt.Fprintf(os.Stderr, "Error reading file %s: %v\n", arg, err)
continue
}
tree := parser.Parse(bts, nil)
defer tree.Close()
tc := tree.Walk()
defer tc.Close()
matches := qc.Matches(query, tree.RootNode(), bts)
for match := matches.Next(); match != nil; match = matches.Next() {
for _, capture := range match.Captures {
fn := ParseFunction(&capture.Node, tc, bts)
fn.Optional = optionalSymbols[fn.Name]
funs = append(funs, fn)
}
}
}
tmpl, err := template.New("").ParseFS(fsys, "*.gotmpl")
if err != nil {
fmt.Fprintf(os.Stderr, "Error parsing template: %v\n", err)
return
}
for _, tmpl := range tmpl.Templates() {
name := filepath.Join(output, strings.TrimSuffix(tmpl.Name(), ".gotmpl"))
fmt.Println("Generating", name)
f, err := os.Create(name)
if err != nil {
fmt.Fprintf(os.Stderr, "Error creating file %s: %v\n", name, err)
continue
}
defer f.Close()
if err := tmpl.Execute(f, map[string]any{
"Functions": funs,
}); err != nil {
fmt.Fprintf(os.Stderr, "Error executing template %s: %v\n", tmpl.Name(), err)
}
}
}

View File

@@ -0,0 +1,12 @@
# Vendored MLX-C Headers
These header files are vendored from [mlx-c](https://github.com/ml-explore/mlx-c).
The pinned version is in `MLX_C_VERSION` at the repo root.
Headers are automatically refreshed when you run a CMake build:
```shell
cmake --preset 'MLX CUDA 13'
```
See the [MLX Engine](../../../../../../../docs/development.md#mlx-engine-optional) section of the development docs for full build instructions.

View File

@@ -0,0 +1,420 @@
/* Copyright © 2023-2024 Apple Inc. */
#ifndef MLX_ARRAY_H
#define MLX_ARRAY_H
#include "mlx/c/string.h"
#include <float.h>
#include <stdbool.h>
#include <stdint.h>
#include <stdlib.h>
// Complex number support
#ifdef _MSC_VER
#define _CRT_USE_C_COMPLEX_H
#include <complex.h>
typedef _Fcomplex mlx_complex64_t;
#else
#include <complex.h>
typedef float _Complex mlx_complex64_t;
#endif
#include "half.h"
#ifdef __cplusplus
extern "C" {
#endif
/**
* \defgroup mlx_array Array
* MLX N-dimensional array object.
*/
/**@{*/
/**
* A N-dimensional array object.
*/
typedef struct mlx_array_ {
void* ctx;
} mlx_array;
static mlx_array mlx_array_empty;
/**
* Array element type.
*/
typedef enum mlx_dtype_ {
MLX_BOOL,
MLX_UINT8,
MLX_UINT16,
MLX_UINT32,
MLX_UINT64,
MLX_INT8,
MLX_INT16,
MLX_INT32,
MLX_INT64,
MLX_FLOAT16,
MLX_FLOAT32,
MLX_FLOAT64,
MLX_BFLOAT16,
MLX_COMPLEX64,
} mlx_dtype;
/**
* Size of given mlx_dtype datatype in bytes.
*/
size_t mlx_dtype_size(mlx_dtype dtype);
/**
* Get array description.
*/
int mlx_array_tostring(mlx_string* str, const mlx_array arr);
/**
* New empty array.
*/
mlx_array mlx_array_new(void);
/**
* Free an array.
*/
int mlx_array_free(mlx_array arr);
/**
* New array from a bool scalar.
*/
mlx_array mlx_array_new_bool(bool val);
/**
* New array from a int scalar.
*/
mlx_array mlx_array_new_int(int val);
/**
* New array from a float32 scalar.
*/
mlx_array mlx_array_new_float32(float val);
/**
* New array from a float scalar.
* Same as float32.
*/
mlx_array mlx_array_new_float(float val);
/**
* New array from a float64 scalar.
*/
mlx_array mlx_array_new_float64(double val);
/**
* New array from a double scalar.
* Same as float64.
*/
mlx_array mlx_array_new_double(double val);
/**
* New array from a complex scalar.
*/
mlx_array mlx_array_new_complex(float real_val, float imag_val);
/**
* New array from existing buffer.
* @param data A buffer which will be copied.
* @param shape Shape of the array.
* @param dim Number of dimensions (size of `shape`).
* @param dtype Type of array elements.
*/
mlx_array mlx_array_new_data(
const void* data,
const int* shape,
int dim,
mlx_dtype dtype);
/**
* New array from existing buffer.
* @param data A buffer which will be copied.
* @param shape Shape of the array.
* @param dim Number of dimensions (size of `shape`).
* @param dtype Type of array elements.
* @param dtor Callback for when the buffer is no longer needed.
*/
mlx_array mlx_array_new_data_managed(
void* data,
const int* shape,
int dim,
mlx_dtype dtype,
void (*dtor)(void*));
/**
* New array from existing buffer.
* @param data A buffer which will be copied.
* @param shape Shape of the array.
* @param dim Number of dimensions (size of `shape`).
* @param dtype Type of array elements.
* @param payload Payload pointer passed to the `dtor` callback instead of
* `data`.
* @param dtor Callback for when the buffer is no longer needed.
*/
mlx_array mlx_array_new_data_managed_payload(
void* data,
const int* shape,
int dim,
mlx_dtype dtype,
void* payload,
void (*dtor)(void*));
/**
* Set array to provided src array.
*/
int mlx_array_set(mlx_array* arr, const mlx_array src);
/**
* Set array to a bool scalar.
*/
int mlx_array_set_bool(mlx_array* arr, bool val);
/**
* Set array to a int scalar.
*/
int mlx_array_set_int(mlx_array* arr, int val);
/**
* Set array to a float32 scalar.
*/
int mlx_array_set_float32(mlx_array* arr, float val);
/**
* Set array to a float scalar.
*/
int mlx_array_set_float(mlx_array* arr, float val);
/**
* Set array to a float64 scalar.
*/
int mlx_array_set_float64(mlx_array* arr, double val);
/**
* Set array to a double scalar.
*/
int mlx_array_set_double(mlx_array* arr, double val);
/**
* Set array to a complex scalar.
*/
int mlx_array_set_complex(mlx_array* arr, float real_val, float imag_val);
/**
* Set array to specified data and shape.
* @param arr Destination array.
* @param data A buffer which will be copied.
* @param shape Shape of the array.
* @param dim Number of dimensions (size of `shape`).
* @param dtype Type of array elements.
*/
int mlx_array_set_data(
mlx_array* arr,
const void* data,
const int* shape,
int dim,
mlx_dtype dtype);
/**
* The size of the array's datatype in bytes.
*/
size_t mlx_array_itemsize(const mlx_array arr);
/**
* Number of elements in the array.
*/
size_t mlx_array_size(const mlx_array arr);
/**
* The number of bytes in the array.
*/
size_t mlx_array_nbytes(const mlx_array arr);
/**
* The array's dimension.
*/
size_t mlx_array_ndim(const mlx_array arr);
/**
* The shape of the array.
* Returns: a pointer to the sizes of each dimension.
*/
const int* mlx_array_shape(const mlx_array arr);
/**
* The strides of the array.
* Returns: a pointer to the sizes of each dimension.
*/
const size_t* mlx_array_strides(const mlx_array arr);
/**
* The shape of the array in a particular dimension.
*/
int mlx_array_dim(const mlx_array arr, int dim);
/**
* The array element type.
*/
mlx_dtype mlx_array_dtype(const mlx_array arr);
/**
* Evaluate the array.
*/
int mlx_array_eval(mlx_array arr);
/**
* Access the value of a scalar array.
*/
int mlx_array_item_bool(bool* res, const mlx_array arr);
/**
* Access the value of a scalar array.
*/
int mlx_array_item_uint8(uint8_t* res, const mlx_array arr);
/**
* Access the value of a scalar array.
*/
int mlx_array_item_uint16(uint16_t* res, const mlx_array arr);
/**
* Access the value of a scalar array.
*/
int mlx_array_item_uint32(uint32_t* res, const mlx_array arr);
/**
* Access the value of a scalar array.
*/
int mlx_array_item_uint64(uint64_t* res, const mlx_array arr);
/**
* Access the value of a scalar array.
*/
int mlx_array_item_int8(int8_t* res, const mlx_array arr);
/**
* Access the value of a scalar array.
*/
int mlx_array_item_int16(int16_t* res, const mlx_array arr);
/**
* Access the value of a scalar array.
*/
int mlx_array_item_int32(int32_t* res, const mlx_array arr);
/**
* Access the value of a scalar array.
*/
int mlx_array_item_int64(int64_t* res, const mlx_array arr);
/**
* Access the value of a scalar array.
*/
int mlx_array_item_float32(float* res, const mlx_array arr);
/**
* Access the value of a scalar array.
*/
int mlx_array_item_float64(double* res, const mlx_array arr);
/**
* Access the value of a scalar array.
*/
int mlx_array_item_complex64(mlx_complex64_t* res, const mlx_array arr);
#ifdef HAS_FLOAT16
/**
* Access the value of a scalar array.
*/
int mlx_array_item_float16(float16_t* res, const mlx_array arr);
#endif
#ifdef HAS_BFLOAT16
/**
* Access the value of a scalar array.
*/
int mlx_array_item_bfloat16(bfloat16_t* res, const mlx_array arr);
#endif
/**
* Returns a pointer to the array data, cast to `bool*`.
* Array must be evaluated, otherwise returns NULL.
*/
const bool* mlx_array_data_bool(const mlx_array arr);
/**
* Returns a pointer to the array data, cast to `uint8_t*`.
* Array must be evaluated, otherwise returns NULL.
*/
const uint8_t* mlx_array_data_uint8(const mlx_array arr);
/**
* Returns a pointer to the array data, cast to `uint16_t*`.
* Array must be evaluated, otherwise returns NULL.
*/
const uint16_t* mlx_array_data_uint16(const mlx_array arr);
/**
* Returns a pointer to the array data, cast to `uint32_t*`.
* Array must be evaluated, otherwise returns NULL.
*/
const uint32_t* mlx_array_data_uint32(const mlx_array arr);
/**
* Returns a pointer to the array data, cast to `uint64_t*`.
* Array must be evaluated, otherwise returns NULL.
*/
const uint64_t* mlx_array_data_uint64(const mlx_array arr);
/**
* Returns a pointer to the array data, cast to `int8_t*`.
* Array must be evaluated, otherwise returns NULL.
*/
const int8_t* mlx_array_data_int8(const mlx_array arr);
/**
* Returns a pointer to the array data, cast to `int16_t*`.
* Array must be evaluated, otherwise returns NULL.
*/
const int16_t* mlx_array_data_int16(const mlx_array arr);
/**
* Returns a pointer to the array data, cast to `int32_t*`.
* Array must be evaluated, otherwise returns NULL.
*/
const int32_t* mlx_array_data_int32(const mlx_array arr);
/**
* Returns a pointer to the array data, cast to `int64_t*`.
* Array must be evaluated, otherwise returns NULL.
*/
const int64_t* mlx_array_data_int64(const mlx_array arr);
/**
* Returns a pointer to the array data, cast to `float32*`.
* Array must be evaluated, otherwise returns NULL.
*/
const float* mlx_array_data_float32(const mlx_array arr);
/**
* Returns a pointer to the array data, cast to `float64*`.
* Array must be evaluated, otherwise returns NULL.
*/
const double* mlx_array_data_float64(const mlx_array arr);
/**
* Returns a pointer to the array data, cast to `_Complex*`.
* Array must be evaluated, otherwise returns NULL.
*/
const mlx_complex64_t* mlx_array_data_complex64(const mlx_array arr);
#ifdef HAS_FLOAT16
/**
* Returns a pointer to the array data, cast to `float16_t*`.
* Array must be evaluated, otherwise returns NULL.
*/
const float16_t* mlx_array_data_float16(const mlx_array arr);
#endif
#ifdef HAS_BFLOAT16
/**
* Returns a pointer to the array data, cast to `bfloat16_t*`.
* Array must be evaluated, otherwise returns NULL.
*/
const bfloat16_t* mlx_array_data_bfloat16(const mlx_array arr);
#endif
/**
* Check if the array is available.
* Internal function: use at your own risk.
*/
int _mlx_array_is_available(bool* res, const mlx_array arr);
/**
* Wait on the array to be available. After this `_mlx_array_is_available`
* returns `true`. Internal function: use at your own risk.
*/
int _mlx_array_wait(const mlx_array arr);
/**
* Whether the array is contiguous in memory.
* Internal function: use at your own risk.
*/
int _mlx_array_is_contiguous(bool* res, const mlx_array arr);
/**
* Whether the array's rows are contiguous in memory.
* Internal function: use at your own risk.
*/
int _mlx_array_is_row_contiguous(bool* res, const mlx_array arr);
/**
* Whether the array's columns are contiguous in memory.
* Internal function: use at your own risk.
*/
int _mlx_array_is_col_contiguous(bool* res, const mlx_array arr);
/**@}*/
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -0,0 +1,197 @@
/* Copyright © 2023-2024 Apple Inc. */
/* */
/* This file is auto-generated. Do not edit manually. */
/* */
#ifndef MLX_CLOSURE_H
#define MLX_CLOSURE_H
#include "mlx/c/array.h"
#include "mlx/c/map.h"
#include "mlx/c/optional.h"
#include "mlx/c/stream.h"
#include "mlx/c/vector.h"
#ifdef __cplusplus
extern "C" {
#endif
/**
* \defgroup mlx_closure Closures
* MLX closure objects.
*/
/**@{*/
typedef struct mlx_closure_ {
void* ctx;
} mlx_closure;
mlx_closure mlx_closure_new(void);
int mlx_closure_free(mlx_closure cls);
mlx_closure mlx_closure_new_func(
int (*fun)(mlx_vector_array*, const mlx_vector_array));
mlx_closure mlx_closure_new_func_payload(
int (*fun)(mlx_vector_array*, const mlx_vector_array, void*),
void* payload,
void (*dtor)(void*));
int mlx_closure_set(mlx_closure* cls, const mlx_closure src);
int mlx_closure_apply(
mlx_vector_array* res,
mlx_closure cls,
const mlx_vector_array input);
mlx_closure mlx_closure_new_unary(int (*fun)(mlx_array*, const mlx_array));
typedef struct mlx_closure_kwargs_ {
void* ctx;
} mlx_closure_kwargs;
mlx_closure_kwargs mlx_closure_kwargs_new(void);
int mlx_closure_kwargs_free(mlx_closure_kwargs cls);
mlx_closure_kwargs mlx_closure_kwargs_new_func(
int (*fun)(
mlx_vector_array*,
const mlx_vector_array,
const mlx_map_string_to_array));
mlx_closure_kwargs mlx_closure_kwargs_new_func_payload(
int (*fun)(
mlx_vector_array*,
const mlx_vector_array,
const mlx_map_string_to_array,
void*),
void* payload,
void (*dtor)(void*));
int mlx_closure_kwargs_set(
mlx_closure_kwargs* cls,
const mlx_closure_kwargs src);
int mlx_closure_kwargs_apply(
mlx_vector_array* res,
mlx_closure_kwargs cls,
const mlx_vector_array input_0,
const mlx_map_string_to_array input_1);
typedef struct mlx_closure_value_and_grad_ {
void* ctx;
} mlx_closure_value_and_grad;
mlx_closure_value_and_grad mlx_closure_value_and_grad_new(void);
int mlx_closure_value_and_grad_free(mlx_closure_value_and_grad cls);
mlx_closure_value_and_grad mlx_closure_value_and_grad_new_func(
int (*fun)(mlx_vector_array*, mlx_vector_array*, const mlx_vector_array));
mlx_closure_value_and_grad mlx_closure_value_and_grad_new_func_payload(
int (*fun)(
mlx_vector_array*,
mlx_vector_array*,
const mlx_vector_array,
void*),
void* payload,
void (*dtor)(void*));
int mlx_closure_value_and_grad_set(
mlx_closure_value_and_grad* cls,
const mlx_closure_value_and_grad src);
int mlx_closure_value_and_grad_apply(
mlx_vector_array* res_0,
mlx_vector_array* res_1,
mlx_closure_value_and_grad cls,
const mlx_vector_array input);
typedef struct mlx_closure_custom_ {
void* ctx;
} mlx_closure_custom;
mlx_closure_custom mlx_closure_custom_new(void);
int mlx_closure_custom_free(mlx_closure_custom cls);
mlx_closure_custom mlx_closure_custom_new_func(
int (*fun)(
mlx_vector_array*,
const mlx_vector_array,
const mlx_vector_array,
const mlx_vector_array));
mlx_closure_custom mlx_closure_custom_new_func_payload(
int (*fun)(
mlx_vector_array*,
const mlx_vector_array,
const mlx_vector_array,
const mlx_vector_array,
void*),
void* payload,
void (*dtor)(void*));
int mlx_closure_custom_set(
mlx_closure_custom* cls,
const mlx_closure_custom src);
int mlx_closure_custom_apply(
mlx_vector_array* res,
mlx_closure_custom cls,
const mlx_vector_array input_0,
const mlx_vector_array input_1,
const mlx_vector_array input_2);
typedef struct mlx_closure_custom_jvp_ {
void* ctx;
} mlx_closure_custom_jvp;
mlx_closure_custom_jvp mlx_closure_custom_jvp_new(void);
int mlx_closure_custom_jvp_free(mlx_closure_custom_jvp cls);
mlx_closure_custom_jvp mlx_closure_custom_jvp_new_func(
int (*fun)(
mlx_vector_array*,
const mlx_vector_array,
const mlx_vector_array,
const int*,
size_t _num));
mlx_closure_custom_jvp mlx_closure_custom_jvp_new_func_payload(
int (*fun)(
mlx_vector_array*,
const mlx_vector_array,
const mlx_vector_array,
const int*,
size_t _num,
void*),
void* payload,
void (*dtor)(void*));
int mlx_closure_custom_jvp_set(
mlx_closure_custom_jvp* cls,
const mlx_closure_custom_jvp src);
int mlx_closure_custom_jvp_apply(
mlx_vector_array* res,
mlx_closure_custom_jvp cls,
const mlx_vector_array input_0,
const mlx_vector_array input_1,
const int* input_2,
size_t input_2_num);
typedef struct mlx_closure_custom_vmap_ {
void* ctx;
} mlx_closure_custom_vmap;
mlx_closure_custom_vmap mlx_closure_custom_vmap_new(void);
int mlx_closure_custom_vmap_free(mlx_closure_custom_vmap cls);
mlx_closure_custom_vmap mlx_closure_custom_vmap_new_func(
int (*fun)(
mlx_vector_array*,
mlx_vector_int*,
const mlx_vector_array,
const int*,
size_t _num));
mlx_closure_custom_vmap mlx_closure_custom_vmap_new_func_payload(
int (*fun)(
mlx_vector_array*,
mlx_vector_int*,
const mlx_vector_array,
const int*,
size_t _num,
void*),
void* payload,
void (*dtor)(void*));
int mlx_closure_custom_vmap_set(
mlx_closure_custom_vmap* cls,
const mlx_closure_custom_vmap src);
int mlx_closure_custom_vmap_apply(
mlx_vector_array* res_0,
mlx_vector_int* res_1,
mlx_closure_custom_vmap cls,
const mlx_vector_array input_0,
const int* input_1,
size_t input_1_num);
/**@}*/
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -0,0 +1,58 @@
/* Copyright © 2023-2024 Apple Inc. */
/* */
/* This file is auto-generated. Do not edit manually. */
/* */
#ifndef MLX_COMPILE_H
#define MLX_COMPILE_H
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include "mlx/c/array.h"
#include "mlx/c/closure.h"
#include "mlx/c/distributed_group.h"
#include "mlx/c/io_types.h"
#include "mlx/c/map.h"
#include "mlx/c/stream.h"
#include "mlx/c/string.h"
#include "mlx/c/vector.h"
#ifdef __cplusplus
extern "C" {
#endif
/**
* \defgroup compile Compilation operations
*/
/**@{*/
typedef enum mlx_compile_mode_ {
MLX_COMPILE_MODE_DISABLED,
MLX_COMPILE_MODE_NO_SIMPLIFY,
MLX_COMPILE_MODE_NO_FUSE,
MLX_COMPILE_MODE_ENABLED
} mlx_compile_mode;
int mlx_compile(mlx_closure* res, const mlx_closure fun, bool shapeless);
int mlx_detail_compile(
mlx_closure* res,
const mlx_closure fun,
uintptr_t fun_id,
bool shapeless,
const uint64_t* constants,
size_t constants_num);
int mlx_detail_compile_clear_cache(void);
int mlx_detail_compile_erase(uintptr_t fun_id);
int mlx_disable_compile(void);
int mlx_enable_compile(void);
int mlx_set_compile_mode(mlx_compile_mode mode);
/**@}*/
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -0,0 +1,39 @@
/* Copyright © 2023-2024 Apple Inc. */
/* */
/* This file is auto-generated. Do not edit manually. */
/* */
#ifndef MLX_CUDA_H
#define MLX_CUDA_H
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include "mlx/c/array.h"
#include "mlx/c/closure.h"
#include "mlx/c/distributed_group.h"
#include "mlx/c/io_types.h"
#include "mlx/c/map.h"
#include "mlx/c/stream.h"
#include "mlx/c/string.h"
#include "mlx/c/vector.h"
#ifdef __cplusplus
extern "C" {
#endif
/**
* \defgroup cuda Cuda specific operations
*/
/**@{*/
int mlx_cuda_is_available(bool* res);
/**@}*/
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -0,0 +1,154 @@
/* Copyright © 2023-2024 Apple Inc. */
#ifndef MLX_DEVICE_H
#define MLX_DEVICE_H
#include <stdbool.h>
#include <stddef.h>
#include "mlx/c/string.h"
#include "mlx/c/vector.h"
#ifdef __cplusplus
extern "C" {
#endif
/**
* \defgroup mlx_device Device
* MLX device object.
*/
/**@{*/
/**
* A MLX device object.
*/
typedef struct mlx_device_ {
void* ctx;
} mlx_device;
/**
* Device type.
*/
typedef enum mlx_device_type_ { MLX_CPU, MLX_GPU } mlx_device_type;
/**
* Returns a new empty device.
*/
mlx_device mlx_device_new(void);
/**
* Returns a new device of specified `type`, with specified `index`.
*/
mlx_device mlx_device_new_type(mlx_device_type type, int index);
/**
* Free a device.
*/
int mlx_device_free(mlx_device dev);
/**
* Set device to provided src device.
*/
int mlx_device_set(mlx_device* dev, const mlx_device src);
/**
* Get device description.
*/
int mlx_device_tostring(mlx_string* str, mlx_device dev);
/**
* Check if devices are the same.
*/
bool mlx_device_equal(mlx_device lhs, mlx_device rhs);
/**
* Returns the index of the device.
*/
int mlx_device_get_index(int* index, mlx_device dev);
/**
* Returns the type of the device.
*/
int mlx_device_get_type(mlx_device_type* type, mlx_device dev);
/**
* Returns the default MLX device.
*/
int mlx_get_default_device(mlx_device* dev);
/**
* Set the default MLX device.
*/
int mlx_set_default_device(mlx_device dev);
/**
* Check if device is available.
*/
int mlx_device_is_available(bool* avail, mlx_device dev);
/**
* Get the number of available devices for a device type.
*/
int mlx_device_count(int* count, mlx_device_type type);
/**
* A MLX device info object.
* Contains key-value pairs with device properties.
* Keys vary by backend but common keys include:
* - device_name (string): Device name
* - architecture (string): Architecture identifier
* Additional keys may be present depending on the backend.
*/
typedef struct mlx_device_info_ {
void* ctx;
} mlx_device_info;
/**
* Returns a new empty device info object.
*/
mlx_device_info mlx_device_info_new(void);
/**
* Get device information for a device.
*/
int mlx_device_info_get(mlx_device_info* info, mlx_device dev);
/**
* Free a device info object.
*/
int mlx_device_info_free(mlx_device_info info);
/**
* Check if a key exists in the device info.
* Returns 0 on success, 1 on error.
* Sets *exists to true if the key exists, false otherwise.
*/
int mlx_device_info_has_key(
bool* exists,
mlx_device_info info,
const char* key);
/**
* Check if a value is a string type.
* Returns 0 on success, 1 on error.
* Sets *is_string to true if the value is a string, false if it's a size_t.
*/
int mlx_device_info_is_string(
bool* is_string,
mlx_device_info info,
const char* key);
/**
* Get a string value from device info.
* Returns 0 on success, 1 on error, 2 if key not found or wrong type.
*/
int mlx_device_info_get_string(
const char** value,
mlx_device_info info,
const char* key);
/**
* Get a size_t value from device info.
* Returns 0 on success, 1 on error, 2 if key not found or wrong type.
*/
int mlx_device_info_get_size(
size_t* value,
mlx_device_info info,
const char* key);
/**
* Get all keys from device info.
* Returns 0 on success, 1 on error.
*/
int mlx_device_info_get_keys(mlx_vector_string* keys, mlx_device_info info);
/**@}*/
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -0,0 +1,83 @@
/* Copyright © 2023-2024 Apple Inc. */
/* */
/* This file is auto-generated. Do not edit manually. */
/* */
#ifndef MLX_DISTRIBUTED_H
#define MLX_DISTRIBUTED_H
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include "mlx/c/array.h"
#include "mlx/c/closure.h"
#include "mlx/c/distributed_group.h"
#include "mlx/c/io_types.h"
#include "mlx/c/map.h"
#include "mlx/c/stream.h"
#include "mlx/c/string.h"
#include "mlx/c/vector.h"
#ifdef __cplusplus
extern "C" {
#endif
/**
* \defgroup distributed Distributed collectives
*/
/**@{*/
int mlx_distributed_all_gather(
mlx_array* res,
const mlx_array x,
const mlx_distributed_group group /* may be null */,
const mlx_stream S);
int mlx_distributed_all_max(
mlx_array* res,
const mlx_array x,
const mlx_distributed_group group /* may be null */,
const mlx_stream s);
int mlx_distributed_all_min(
mlx_array* res,
const mlx_array x,
const mlx_distributed_group group /* may be null */,
const mlx_stream s);
int mlx_distributed_all_sum(
mlx_array* res,
const mlx_array x,
const mlx_distributed_group group /* may be null */,
const mlx_stream s);
int mlx_distributed_recv(
mlx_array* res,
const int* shape,
size_t shape_num,
mlx_dtype dtype,
int src,
const mlx_distributed_group group /* may be null */,
const mlx_stream s);
int mlx_distributed_recv_like(
mlx_array* res,
const mlx_array x,
int src,
const mlx_distributed_group group /* may be null */,
const mlx_stream s);
int mlx_distributed_send(
mlx_array* res,
const mlx_array x,
int dst,
const mlx_distributed_group group /* may be null */,
const mlx_stream s);
int mlx_distributed_sum_scatter(
mlx_array* res,
const mlx_array x,
const mlx_distributed_group group /* may be null */,
const mlx_stream s);
/**@}*/
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -0,0 +1,74 @@
/* Copyright © 2023-2024 Apple Inc. */
#ifndef MLX_DISTRIBUTED_GROUP_H
#define MLX_DISTRIBUTED_GROUP_H
#include <stdbool.h>
#include "mlx/c/stream.h"
#ifdef __cplusplus
extern "C" {
#endif
/**
* \defgroup mlx_distributed_group MLX distributed
*/
/**@{*/
/**
* A MLX distributed group object.
*/
typedef struct mlx_distributed_group_ {
void* ctx;
} mlx_distributed_group;
/**
* Create an empty group.
*/
mlx_distributed_group mlx_distributed_group_new(void);
/**
* Free the group.
*/
int mlx_distributed_group_free(mlx_distributed_group group);
/**
* Initialize distributed.
*/
int mlx_distributed_init(
mlx_distributed_group* res,
bool strict,
const char* bk /* may be null */);
/**
* Get the rank.
*/
int mlx_distributed_group_rank(mlx_distributed_group group);
/**
* Get the group size.
*/
int mlx_distributed_group_size(mlx_distributed_group group);
/**
* Split the group.
*/
int mlx_distributed_group_split(
mlx_distributed_group* res,
mlx_distributed_group group,
int color,
int key);
/**
* Check if distributed is available.
*/
bool mlx_distributed_is_available(const char* bk /* may be null */);
/**@}*/
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -0,0 +1,41 @@
/* Copyright © 2023-2024 Apple Inc. */
#ifndef MLX_ERROR_H
#define MLX_ERROR_H
#ifdef __cplusplus
extern "C" {
#endif
/**
* \defgroup mlx_error Error management
*/
/**@{*/
typedef void (*mlx_error_handler_func)(const char* msg, void* data);
/**
* Set the error handler.
*/
void mlx_set_error_handler(
mlx_error_handler_func handler,
void* data,
void (*dtor)(void*));
/**
* Throw an error.
*/
void _mlx_error(const char* file, const int line, const char* fmt, ...);
/**
* Throw an error. Macro which passes file name and line number to _mlx_error().
*/
#define mlx_error(...) _mlx_error(__FILE__, __LINE__, __VA_ARGS__)
/**@}*/
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -0,0 +1,75 @@
/* Copyright © 2023-2025 Apple Inc. */
#ifndef MLX_EXPORT_H
#define MLX_EXPORT_H
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include "mlx/c/array.h"
#include "mlx/c/closure.h"
#include "mlx/c/distributed_group.h"
#include "mlx/c/io_types.h"
#include "mlx/c/map.h"
#include "mlx/c/stream.h"
#include "mlx/c/string.h"
#include "mlx/c/vector.h"
#ifdef __cplusplus
extern "C" {
#endif
/**
* \defgroup export Function serialization
*/
/**@{*/
int mlx_export_function(
const char* file,
const mlx_closure fun,
const mlx_vector_array args,
bool shapeless);
int mlx_export_function_kwargs(
const char* file,
const mlx_closure_kwargs fun,
const mlx_vector_array args,
const mlx_map_string_to_array kwargs,
bool shapeless);
typedef struct mlx_function_exporter_ {
void* ctx;
} mlx_function_exporter;
mlx_function_exporter mlx_function_exporter_new(
const char* file,
const mlx_closure fun,
bool shapeless);
int mlx_function_exporter_free(mlx_function_exporter xfunc);
int mlx_function_exporter_apply(
const mlx_function_exporter xfunc,
const mlx_vector_array args);
int mlx_function_exporter_apply_kwargs(
const mlx_function_exporter xfunc,
const mlx_vector_array args,
const mlx_map_string_to_array kwargs);
typedef struct mlx_imported_function_ {
void* ctx;
} mlx_imported_function;
mlx_imported_function mlx_imported_function_new(const char* file);
int mlx_imported_function_free(mlx_imported_function xfunc);
int mlx_imported_function_apply(
mlx_vector_array* res,
const mlx_imported_function xfunc,
const mlx_vector_array args);
int mlx_imported_function_apply_kwargs(
mlx_vector_array* res,
const mlx_imported_function xfunc,
const mlx_vector_array args,
const mlx_map_string_to_array kwargs);
/**@}*/
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -0,0 +1,206 @@
/* Copyright © 2023-2024 Apple Inc. */
/* */
/* This file is auto-generated. Do not edit manually. */
/* */
#ifndef MLX_FAST_H
#define MLX_FAST_H
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include "mlx/c/array.h"
#include "mlx/c/closure.h"
#include "mlx/c/distributed_group.h"
#include "mlx/c/io_types.h"
#include "mlx/c/map.h"
#include "mlx/c/stream.h"
#include "mlx/c/string.h"
#include "mlx/c/vector.h"
#ifdef __cplusplus
extern "C" {
#endif
/**
* \defgroup fast Fast custom operations
*/
/**@{*/
typedef struct mlx_fast_cuda_kernel_config_ {
void* ctx;
} mlx_fast_cuda_kernel_config;
mlx_fast_cuda_kernel_config mlx_fast_cuda_kernel_config_new(void);
void mlx_fast_cuda_kernel_config_free(mlx_fast_cuda_kernel_config cls);
int mlx_fast_cuda_kernel_config_add_output_arg(
mlx_fast_cuda_kernel_config cls,
const int* shape,
size_t size,
mlx_dtype dtype);
int mlx_fast_cuda_kernel_config_set_grid(
mlx_fast_cuda_kernel_config cls,
int grid1,
int grid2,
int grid3);
int mlx_fast_cuda_kernel_config_set_thread_group(
mlx_fast_cuda_kernel_config cls,
int thread1,
int thread2,
int thread3);
int mlx_fast_cuda_kernel_config_set_init_value(
mlx_fast_cuda_kernel_config cls,
float value);
int mlx_fast_cuda_kernel_config_set_verbose(
mlx_fast_cuda_kernel_config cls,
bool verbose);
int mlx_fast_cuda_kernel_config_add_template_arg_dtype(
mlx_fast_cuda_kernel_config cls,
const char* name,
mlx_dtype dtype);
int mlx_fast_cuda_kernel_config_add_template_arg_int(
mlx_fast_cuda_kernel_config cls,
const char* name,
int value);
int mlx_fast_cuda_kernel_config_add_template_arg_bool(
mlx_fast_cuda_kernel_config cls,
const char* name,
bool value);
typedef struct mlx_fast_cuda_kernel_ {
void* ctx;
} mlx_fast_cuda_kernel;
mlx_fast_cuda_kernel mlx_fast_cuda_kernel_new(
const char* name,
const mlx_vector_string input_names,
const mlx_vector_string output_names,
const char* source,
const char* header,
bool ensure_row_contiguous,
int shared_memory);
void mlx_fast_cuda_kernel_free(mlx_fast_cuda_kernel cls);
int mlx_fast_cuda_kernel_apply(
mlx_vector_array* outputs,
mlx_fast_cuda_kernel cls,
const mlx_vector_array inputs,
const mlx_fast_cuda_kernel_config config,
const mlx_stream stream);
int mlx_fast_layer_norm(
mlx_array* res,
const mlx_array x,
const mlx_array weight /* may be null */,
const mlx_array bias /* may be null */,
float eps,
const mlx_stream s);
typedef struct mlx_fast_metal_kernel_config_ {
void* ctx;
} mlx_fast_metal_kernel_config;
mlx_fast_metal_kernel_config mlx_fast_metal_kernel_config_new(void);
void mlx_fast_metal_kernel_config_free(mlx_fast_metal_kernel_config cls);
int mlx_fast_metal_kernel_config_add_output_arg(
mlx_fast_metal_kernel_config cls,
const int* shape,
size_t size,
mlx_dtype dtype);
int mlx_fast_metal_kernel_config_set_grid(
mlx_fast_metal_kernel_config cls,
int grid1,
int grid2,
int grid3);
int mlx_fast_metal_kernel_config_set_thread_group(
mlx_fast_metal_kernel_config cls,
int thread1,
int thread2,
int thread3);
int mlx_fast_metal_kernel_config_set_init_value(
mlx_fast_metal_kernel_config cls,
float value);
int mlx_fast_metal_kernel_config_set_verbose(
mlx_fast_metal_kernel_config cls,
bool verbose);
int mlx_fast_metal_kernel_config_add_template_arg_dtype(
mlx_fast_metal_kernel_config cls,
const char* name,
mlx_dtype dtype);
int mlx_fast_metal_kernel_config_add_template_arg_int(
mlx_fast_metal_kernel_config cls,
const char* name,
int value);
int mlx_fast_metal_kernel_config_add_template_arg_bool(
mlx_fast_metal_kernel_config cls,
const char* name,
bool value);
typedef struct mlx_fast_metal_kernel_ {
void* ctx;
} mlx_fast_metal_kernel;
mlx_fast_metal_kernel mlx_fast_metal_kernel_new(
const char* name,
const mlx_vector_string input_names,
const mlx_vector_string output_names,
const char* source,
const char* header,
bool ensure_row_contiguous,
bool atomic_outputs);
void mlx_fast_metal_kernel_free(mlx_fast_metal_kernel cls);
int mlx_fast_metal_kernel_apply(
mlx_vector_array* outputs,
mlx_fast_metal_kernel cls,
const mlx_vector_array inputs,
const mlx_fast_metal_kernel_config config,
const mlx_stream stream);
int mlx_fast_rms_norm(
mlx_array* res,
const mlx_array x,
const mlx_array weight /* may be null */,
float eps,
const mlx_stream s);
int mlx_fast_rope(
mlx_array* res,
const mlx_array x,
int dims,
bool traditional,
mlx_optional_float base,
float scale,
int offset,
const mlx_array freqs /* may be null */,
const mlx_stream s);
int mlx_fast_rope_dynamic(
mlx_array* res,
const mlx_array x,
int dims,
bool traditional,
mlx_optional_float base,
float scale,
const mlx_array offset,
const mlx_array freqs /* may be null */,
const mlx_stream s);
int mlx_fast_scaled_dot_product_attention(
mlx_array* res,
const mlx_array queries,
const mlx_array keys,
const mlx_array values,
float scale,
const char* mask_mode,
const mlx_array mask_arr /* may be null */,
const mlx_array sinks /* may be null */,
const mlx_stream s);
/**@}*/
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -0,0 +1,158 @@
/* Copyright © 2023-2024 Apple Inc. */
/* */
/* This file is auto-generated. Do not edit manually. */
/* */
#ifndef MLX_FFT_H
#define MLX_FFT_H
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include "mlx/c/array.h"
#include "mlx/c/closure.h"
#include "mlx/c/distributed_group.h"
#include "mlx/c/io_types.h"
#include "mlx/c/map.h"
#include "mlx/c/stream.h"
#include "mlx/c/string.h"
#include "mlx/c/vector.h"
#ifdef __cplusplus
extern "C" {
#endif
/**
* \defgroup fft FFT operations
*/
/**@{*/
typedef enum mlx_fft_norm_ {
MLX_FFT_NORM_BACKWARD,
MLX_FFT_NORM_ORTHO,
MLX_FFT_NORM_FORWARD
} mlx_fft_norm;
int mlx_fft_fft(
mlx_array* res,
const mlx_array a,
int n,
int axis,
mlx_fft_norm norm,
const mlx_stream s);
int mlx_fft_fft2(
mlx_array* res,
const mlx_array a,
const int* n,
size_t n_num,
const int* axes,
size_t axes_num,
mlx_fft_norm norm,
const mlx_stream s);
int mlx_fft_fftfreq(mlx_array* res, int n, double d, const mlx_stream s);
int mlx_fft_fftn(
mlx_array* res,
const mlx_array a,
const int* n,
size_t n_num,
const int* axes,
size_t axes_num,
mlx_fft_norm norm,
const mlx_stream s);
int mlx_fft_fftshift(
mlx_array* res,
const mlx_array a,
const int* axes,
size_t axes_num,
const mlx_stream s);
int mlx_fft_ifft(
mlx_array* res,
const mlx_array a,
int n,
int axis,
mlx_fft_norm norm,
const mlx_stream s);
int mlx_fft_ifft2(
mlx_array* res,
const mlx_array a,
const int* n,
size_t n_num,
const int* axes,
size_t axes_num,
mlx_fft_norm norm,
const mlx_stream s);
int mlx_fft_ifftn(
mlx_array* res,
const mlx_array a,
const int* n,
size_t n_num,
const int* axes,
size_t axes_num,
mlx_fft_norm norm,
const mlx_stream s);
int mlx_fft_ifftshift(
mlx_array* res,
const mlx_array a,
const int* axes,
size_t axes_num,
const mlx_stream s);
int mlx_fft_irfft(
mlx_array* res,
const mlx_array a,
int n,
int axis,
mlx_fft_norm norm,
const mlx_stream s);
int mlx_fft_irfft2(
mlx_array* res,
const mlx_array a,
const int* n,
size_t n_num,
const int* axes,
size_t axes_num,
mlx_fft_norm norm,
const mlx_stream s);
int mlx_fft_irfftn(
mlx_array* res,
const mlx_array a,
const int* n,
size_t n_num,
const int* axes,
size_t axes_num,
mlx_fft_norm norm,
const mlx_stream s);
int mlx_fft_rfft(
mlx_array* res,
const mlx_array a,
int n,
int axis,
mlx_fft_norm norm,
const mlx_stream s);
int mlx_fft_rfft2(
mlx_array* res,
const mlx_array a,
const int* n,
size_t n_num,
const int* axes,
size_t axes_num,
mlx_fft_norm norm,
const mlx_stream s);
int mlx_fft_rfftfreq(mlx_array* res, int n, double d, const mlx_stream s);
int mlx_fft_rfftn(
mlx_array* res,
const mlx_array a,
const int* n,
size_t n_num,
const int* axes,
size_t axes_num,
mlx_fft_norm norm,
const mlx_stream s);
/**@}*/
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -0,0 +1,61 @@
/* Copyright © 2023-2024 Apple Inc. */
/* */
/* This file is auto-generated. Do not edit manually. */
/* */
#ifndef MLX_GRAPH_UTILS_H
#define MLX_GRAPH_UTILS_H
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include "mlx/c/array.h"
#include "mlx/c/closure.h"
#include "mlx/c/distributed_group.h"
#include "mlx/c/io_types.h"
#include "mlx/c/map.h"
#include "mlx/c/stream.h"
#include "mlx/c/string.h"
#include "mlx/c/vector.h"
#ifdef __cplusplus
extern "C" {
#endif
/**
* \defgroup graph_utils Graph Utils
*/
/**@{*/
typedef struct mlx_node_namer_ {
void* ctx;
} mlx_node_namer;
mlx_node_namer mlx_node_namer_new();
int mlx_node_namer_free(mlx_node_namer namer);
int mlx_node_namer_set_name(
mlx_node_namer namer,
const mlx_array arr,
const char* name);
int mlx_node_namer_get_name(
const char** name,
mlx_node_namer namer,
const mlx_array arr);
int mlx_export_to_dot(
FILE* os,
const mlx_node_namer namer,
const mlx_vector_array outputs);
int mlx_print_graph(
FILE* os,
const mlx_node_namer namer,
const mlx_vector_array outputs);
/**@}*/
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -0,0 +1,26 @@
/* Copyright © 2023-2024 Apple Inc. */
#ifndef MLX_HALF_H
#define MLX_HALF_H
#ifdef __cplusplus
extern "C" {
#endif
#if defined(__ARM_FEATURE_FP16_SCALAR_ARITHMETIC) || defined(__aarch64__)
#define HAS_FLOAT16
#include <arm_fp16.h>
typedef __fp16 float16_t;
#endif
#if defined(__ARM_FEATURE_BF16) || defined(__aarch64__)
#define HAS_BFLOAT16
#include <arm_bf16.h>
typedef __bf16 bfloat16_t;
#endif
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -0,0 +1,68 @@
/* Copyright © 2023-2024 Apple Inc. */
/* */
/* This file is auto-generated. Do not edit manually. */
/* */
#ifndef MLX_IO_H
#define MLX_IO_H
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include "mlx/c/array.h"
#include "mlx/c/closure.h"
#include "mlx/c/distributed_group.h"
#include "mlx/c/io_types.h"
#include "mlx/c/map.h"
#include "mlx/c/stream.h"
#include "mlx/c/string.h"
#include "mlx/c/vector.h"
#ifdef __cplusplus
extern "C" {
#endif
/**
* \defgroup io IO operations
*/
/**@{*/
int mlx_load_reader(
mlx_array* res,
mlx_io_reader in_stream,
const mlx_stream s);
int mlx_load(mlx_array* res, const char* file, const mlx_stream s);
int mlx_load_gguf(mlx_io_gguf* gguf, const char* file, const mlx_stream s);
int mlx_load_safetensors_reader(
mlx_map_string_to_array* res_0,
mlx_map_string_to_string* res_1,
mlx_io_reader in_stream,
const mlx_stream s);
int mlx_load_safetensors(
mlx_map_string_to_array* res_0,
mlx_map_string_to_string* res_1,
const char* file,
const mlx_stream s);
int mlx_save_writer(mlx_io_writer out_stream, const mlx_array a);
int mlx_save(const char* file, const mlx_array a);
int mlx_save_gguf(const char* file, mlx_io_gguf gguf);
int mlx_save_safetensors_writer(
mlx_io_writer in_stream,
const mlx_map_string_to_array param,
const mlx_map_string_to_string metadata);
int mlx_save_safetensors(
const char* file,
const mlx_map_string_to_array param,
const mlx_map_string_to_string metadata);
/**@}*/
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -0,0 +1,150 @@
/* Copyright © 2023-2024 Apple Inc. */
#ifndef MLX_IO_TYPES_H
#define MLX_IO_TYPES_H
#include <stdbool.h>
#include "mlx/c/string.h"
#ifdef __cplusplus
extern "C" {
#endif
/**
* \defgroup mlx_io_types IO Types
* MLX IO type objects.
*/
/**@{*/
/**
* A MLX IO reader object.
*/
typedef struct mlx_io_reader_ {
void* ctx;
} mlx_io_reader;
/**
* A MLX IO writer object.
*/
typedef struct mlx_io_writer_ {
void* ctx;
} mlx_io_writer;
/**
* Virtual table for custom IO reader and writer objects.
*/
typedef struct mlx_io_vtable_ {
bool (*is_open)(void*);
bool (*good)(void*);
size_t (*tell)(void*);
void (*seek)(void*, int64_t off, int whence);
void (*read)(void*, char* data, size_t n);
void (*read_at_offset)(void*, char* data, size_t n, size_t off);
void (*write)(void*, const char* data, size_t n);
const char* (*label)(void*);
void (*free)(void*);
} mlx_io_vtable;
/**
* Returns a new custom IO reader.
* `vtable` operates on user descriptor `desc`.
*/
mlx_io_reader mlx_io_reader_new(void* desc, mlx_io_vtable vtable);
/**
* Get IO reader user descriptor.
*/
int mlx_io_reader_descriptor(void** desc_, mlx_io_reader io);
/**
* Get IO reader description.
*/
int mlx_io_reader_tostring(mlx_string* str_, mlx_io_reader io);
/**
* Free IO reader.
*
* Note that MLX arrays are lazily evaluated, so the underlying object may
* be not freed right away. The ``free()`` callback from ``mlx_io_vtable``
* will be called when the underlying object is actually freed.
*/
int mlx_io_reader_free(mlx_io_reader io);
/**
* Returns a new custom IO writer.
* `vtable` operates on user descriptor `desc`.
*/
mlx_io_writer mlx_io_writer_new(void* desc, mlx_io_vtable vtable);
/**
* Get IO writer user descriptor.
*/
int mlx_io_writer_descriptor(void** desc_, mlx_io_writer io);
/**
* Get IO writer description.
*/
int mlx_io_writer_tostring(mlx_string* str_, mlx_io_writer io);
/**
* Free IO writer.
*
* Note that MLX arrays are lazily evaluated, so the underlying object may
* be not freed right away. The ``free()`` callback from ``mlx_io_vtable``
* will be called when the underlying object is actually freed.
*/
int mlx_io_writer_free(mlx_io_writer io);
/**
* A MLX GGUF object.
*/
typedef struct mlx_io_gguf_ {
void* ctx;
} mlx_io_gguf;
mlx_io_gguf mlx_io_gguf_new(void);
int mlx_io_gguf_free(mlx_io_gguf io);
int mlx_io_gguf_get_keys(mlx_vector_string* keys, mlx_io_gguf io);
int mlx_io_gguf_get_array(mlx_array* arr, mlx_io_gguf io, const char* key);
int mlx_io_gguf_get_metadata_array(
mlx_array* arr,
mlx_io_gguf io,
const char* key);
int mlx_io_gguf_get_metadata_string(
mlx_string* str,
mlx_io_gguf io,
const char* key);
int mlx_io_gguf_get_metadata_vector_string(
mlx_vector_string* vstr,
mlx_io_gguf io,
const char* key);
int mlx_io_gguf_has_metadata_array(bool* flag, mlx_io_gguf io, const char* key);
int mlx_io_gguf_has_metadata_string(
bool* flag,
mlx_io_gguf io,
const char* key);
int mlx_io_gguf_has_metadata_vector_string(
bool* flag,
mlx_io_gguf io,
const char* key);
int mlx_io_gguf_set_array(mlx_io_gguf io, const char* key, const mlx_array arr);
int mlx_io_gguf_set_metadata_array(
mlx_io_gguf io,
const char* key,
const mlx_array marr);
int mlx_io_gguf_set_metadata_string(
mlx_io_gguf io,
const char* key,
const char* mstr);
int mlx_io_gguf_set_metadata_vector_string(
mlx_io_gguf io,
const char* key,
const mlx_vector_string mvstr);
/**@}*/
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -0,0 +1,128 @@
/* Copyright © 2023-2024 Apple Inc. */
/* */
/* This file is auto-generated. Do not edit manually. */
/* */
#ifndef MLX_LINALG_H
#define MLX_LINALG_H
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include "mlx/c/array.h"
#include "mlx/c/closure.h"
#include "mlx/c/distributed_group.h"
#include "mlx/c/io_types.h"
#include "mlx/c/map.h"
#include "mlx/c/stream.h"
#include "mlx/c/string.h"
#include "mlx/c/vector.h"
#ifdef __cplusplus
extern "C" {
#endif
/**
* \defgroup linalg Linear algebra operations
*/
/**@{*/
int mlx_linalg_cholesky(
mlx_array* res,
const mlx_array a,
bool upper,
const mlx_stream s);
int mlx_linalg_cholesky_inv(
mlx_array* res,
const mlx_array a,
bool upper,
const mlx_stream s);
int mlx_linalg_cross(
mlx_array* res,
const mlx_array a,
const mlx_array b,
int axis,
const mlx_stream s);
int mlx_linalg_eig(
mlx_array* res_0,
mlx_array* res_1,
const mlx_array a,
const mlx_stream s);
int mlx_linalg_eigh(
mlx_array* res_0,
mlx_array* res_1,
const mlx_array a,
const char* UPLO,
const mlx_stream s);
int mlx_linalg_eigvals(mlx_array* res, const mlx_array a, const mlx_stream s);
int mlx_linalg_eigvalsh(
mlx_array* res,
const mlx_array a,
const char* UPLO,
const mlx_stream s);
int mlx_linalg_inv(mlx_array* res, const mlx_array a, const mlx_stream s);
int mlx_linalg_lu(mlx_vector_array* res, const mlx_array a, const mlx_stream s);
int mlx_linalg_lu_factor(
mlx_array* res_0,
mlx_array* res_1,
const mlx_array a,
const mlx_stream s);
int mlx_linalg_norm(
mlx_array* res,
const mlx_array a,
double ord,
const int* axis /* may be null */,
size_t axis_num,
bool keepdims,
const mlx_stream s);
int mlx_linalg_norm_matrix(
mlx_array* res,
const mlx_array a,
const char* ord,
const int* axis /* may be null */,
size_t axis_num,
bool keepdims,
const mlx_stream s);
int mlx_linalg_norm_l2(
mlx_array* res,
const mlx_array a,
const int* axis /* may be null */,
size_t axis_num,
bool keepdims,
const mlx_stream s);
int mlx_linalg_pinv(mlx_array* res, const mlx_array a, const mlx_stream s);
int mlx_linalg_qr(
mlx_array* res_0,
mlx_array* res_1,
const mlx_array a,
const mlx_stream s);
int mlx_linalg_solve(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const mlx_stream s);
int mlx_linalg_solve_triangular(
mlx_array* res,
const mlx_array a,
const mlx_array b,
bool upper,
const mlx_stream s);
int mlx_linalg_svd(
mlx_vector_array* res,
const mlx_array a,
bool compute_uv,
const mlx_stream s);
int mlx_linalg_tri_inv(
mlx_array* res,
const mlx_array a,
bool upper,
const mlx_stream s);
/**@}*/
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -0,0 +1,149 @@
/* Copyright © 2023-2024 Apple Inc. */
/* */
/* This file is auto-generated. Do not edit manually. */
/* */
#ifndef MLX_MAP_H
#define MLX_MAP_H
#include "mlx/c/array.h"
#include "mlx/c/string.h"
#ifdef __cplusplus
extern "C" {
#endif
/**
* \defgroup mlx_map Maps
* MLX map objects.
*/
/**@{*/
/**
* A string-to-array map
*/
typedef struct mlx_map_string_to_array_ {
void* ctx;
} mlx_map_string_to_array;
/**
* Returns a new empty string-to-array map.
*/
mlx_map_string_to_array mlx_map_string_to_array_new(void);
/**
* Set map to provided src map.
*/
int mlx_map_string_to_array_set(
mlx_map_string_to_array* map,
const mlx_map_string_to_array src);
/**
* Free a string-to-array map.
*/
int mlx_map_string_to_array_free(mlx_map_string_to_array map);
/**
* Insert a new `value` at the specified `key` in the map.
*/
int mlx_map_string_to_array_insert(
mlx_map_string_to_array map,
const char* key,
const mlx_array value);
/**
* Returns the value indexed at the specified `key` in the map.
*/
int mlx_map_string_to_array_get(
mlx_array* value,
const mlx_map_string_to_array map,
const char* key);
/**
* An iterator over a string-to-array map.
*/
typedef struct mlx_map_string_to_array_iterator_ {
void* ctx;
void* map_ctx;
} mlx_map_string_to_array_iterator;
/**
* Returns a new iterator over the given map.
*/
mlx_map_string_to_array_iterator mlx_map_string_to_array_iterator_new(
mlx_map_string_to_array map);
/**
* Free iterator.
*/
int mlx_map_string_to_array_iterator_free(mlx_map_string_to_array_iterator it);
/**
* Increment iterator.
*/
int mlx_map_string_to_array_iterator_next(
const char** key,
mlx_array* value,
mlx_map_string_to_array_iterator it);
/**
* A string-to-string map
*/
typedef struct mlx_map_string_to_string_ {
void* ctx;
} mlx_map_string_to_string;
/**
* Returns a new empty string-to-string map.
*/
mlx_map_string_to_string mlx_map_string_to_string_new(void);
/**
* Set map to provided src map.
*/
int mlx_map_string_to_string_set(
mlx_map_string_to_string* map,
const mlx_map_string_to_string src);
/**
* Free a string-to-string map.
*/
int mlx_map_string_to_string_free(mlx_map_string_to_string map);
/**
* Insert a new `value` at the specified `key` in the map.
*/
int mlx_map_string_to_string_insert(
mlx_map_string_to_string map,
const char* key,
const char* value);
/**
* Returns the value indexed at the specified `key` in the map.
*/
int mlx_map_string_to_string_get(
const char** value,
const mlx_map_string_to_string map,
const char* key);
/**
* An iterator over a string-to-string map.
*/
typedef struct mlx_map_string_to_string_iterator_ {
void* ctx;
void* map_ctx;
} mlx_map_string_to_string_iterator;
/**
* Returns a new iterator over the given map.
*/
mlx_map_string_to_string_iterator mlx_map_string_to_string_iterator_new(
mlx_map_string_to_string map);
/**
* Free iterator.
*/
int mlx_map_string_to_string_iterator_free(
mlx_map_string_to_string_iterator it);
/**
* Increment iterator.
*/
int mlx_map_string_to_string_iterator_next(
const char** key,
const char** value,
mlx_map_string_to_string_iterator it);
/**@}*/
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -0,0 +1,47 @@
/* Copyright © 2023-2024 Apple Inc. */
/* */
/* This file is auto-generated. Do not edit manually. */
/* */
#ifndef MLX_MEMORY_H
#define MLX_MEMORY_H
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include "mlx/c/array.h"
#include "mlx/c/closure.h"
#include "mlx/c/distributed_group.h"
#include "mlx/c/io_types.h"
#include "mlx/c/map.h"
#include "mlx/c/stream.h"
#include "mlx/c/string.h"
#include "mlx/c/vector.h"
#ifdef __cplusplus
extern "C" {
#endif
/**
* \defgroup memory Memory operations
*/
/**@{*/
int mlx_clear_cache(void);
int mlx_get_active_memory(size_t* res);
int mlx_get_cache_memory(size_t* res);
int mlx_get_memory_limit(size_t* res);
int mlx_get_peak_memory(size_t* res);
int mlx_reset_peak_memory(void);
int mlx_set_cache_limit(size_t* res, size_t limit);
int mlx_set_memory_limit(size_t* res, size_t limit);
int mlx_set_wired_limit(size_t* res, size_t limit);
/**@}*/
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -0,0 +1,41 @@
/* Copyright © 2023-2024 Apple Inc. */
/* */
/* This file is auto-generated. Do not edit manually. */
/* */
#ifndef MLX_METAL_H
#define MLX_METAL_H
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include "mlx/c/array.h"
#include "mlx/c/closure.h"
#include "mlx/c/distributed_group.h"
#include "mlx/c/io_types.h"
#include "mlx/c/map.h"
#include "mlx/c/stream.h"
#include "mlx/c/string.h"
#include "mlx/c/vector.h"
#ifdef __cplusplus
extern "C" {
#endif
/**
* \defgroup metal Metal specific operations
*/
/**@{*/
int mlx_metal_is_available(bool* res);
int mlx_metal_start_capture(const char* path);
int mlx_metal_stop_capture(void);
/**@}*/
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -0,0 +1,35 @@
/* Copyright © 2023-2024 Apple Inc. */
#ifndef MLX_ALL_H
#define MLX_ALL_H
#include "mlx/c/array.h"
#include "mlx/c/closure.h"
#include "mlx/c/compile.h"
#include "mlx/c/cuda.h"
#include "mlx/c/device.h"
#include "mlx/c/distributed.h"
#include "mlx/c/distributed_group.h"
#include "mlx/c/error.h"
#include "mlx/c/export.h"
#include "mlx/c/fast.h"
#include "mlx/c/fft.h"
#include "mlx/c/graph_utils.h"
#include "mlx/c/half.h"
#include "mlx/c/io.h"
#include "mlx/c/io_types.h"
#include "mlx/c/linalg.h"
#include "mlx/c/map.h"
#include "mlx/c/memory.h"
#include "mlx/c/metal.h"
#include "mlx/c/ops.h"
#include "mlx/c/optional.h"
#include "mlx/c/random.h"
#include "mlx/c/stream.h"
#include "mlx/c/string.h"
#include "mlx/c/transforms.h"
#include "mlx/c/transforms_impl.h"
#include "mlx/c/vector.h"
#include "mlx/c/version.h"
#endif

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,51 @@
/* Copyright © 2023-2024 Apple Inc. */
#ifndef MLX_OPTIONAL_H
#define MLX_OPTIONAL_H
#include <stdbool.h>
#include "mlx/c/array.h"
#include "mlx/c/string.h"
#ifdef __cplusplus
extern "C" {
#endif
/**
* \defgroup mlx_optional Optionals
* MLX optional scalars.
*/
/**@{*/
/**
* A int optional.
*/
typedef struct mlx_optional_int_ {
int value;
bool has_value;
} mlx_optional_int;
/**
* A float optional.
*/
typedef struct mlx_optional_float_ {
float value;
bool has_value;
} mlx_optional_float;
/**
* A dtype optional.
*/
typedef struct mlx_optional_dtype_ {
mlx_dtype value;
bool has_value;
} mlx_optional_dtype;
/**@}*/
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -0,0 +1,166 @@
/* Copyright © 2023-2024 Apple Inc. */
/* */
/* This file is auto-generated. Do not edit manually. */
/* */
#ifndef MLX_RANDOM_H
#define MLX_RANDOM_H
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include "mlx/c/array.h"
#include "mlx/c/closure.h"
#include "mlx/c/distributed_group.h"
#include "mlx/c/io_types.h"
#include "mlx/c/map.h"
#include "mlx/c/stream.h"
#include "mlx/c/string.h"
#include "mlx/c/vector.h"
#ifdef __cplusplus
extern "C" {
#endif
/**
* \defgroup random Random number operations
*/
/**@{*/
int mlx_random_bernoulli(
mlx_array* res,
const mlx_array p,
const int* shape,
size_t shape_num,
const mlx_array key /* may be null */,
const mlx_stream s);
int mlx_random_bits(
mlx_array* res,
const int* shape,
size_t shape_num,
int width,
const mlx_array key /* may be null */,
const mlx_stream s);
int mlx_random_categorical_shape(
mlx_array* res,
const mlx_array logits,
int axis,
const int* shape,
size_t shape_num,
const mlx_array key /* may be null */,
const mlx_stream s);
int mlx_random_categorical_num_samples(
mlx_array* res,
const mlx_array logits_,
int axis,
int num_samples,
const mlx_array key /* may be null */,
const mlx_stream s);
int mlx_random_categorical(
mlx_array* res,
const mlx_array logits,
int axis,
const mlx_array key /* may be null */,
const mlx_stream s);
int mlx_random_gumbel(
mlx_array* res,
const int* shape,
size_t shape_num,
mlx_dtype dtype,
const mlx_array key /* may be null */,
const mlx_stream s);
int mlx_random_key(mlx_array* res, uint64_t seed);
int mlx_random_laplace(
mlx_array* res,
const int* shape,
size_t shape_num,
mlx_dtype dtype,
float loc,
float scale,
const mlx_array key /* may be null */,
const mlx_stream s);
int mlx_random_multivariate_normal(
mlx_array* res,
const mlx_array mean,
const mlx_array cov,
const int* shape,
size_t shape_num,
mlx_dtype dtype,
const mlx_array key /* may be null */,
const mlx_stream s);
int mlx_random_normal_broadcast(
mlx_array* res,
const int* shape,
size_t shape_num,
mlx_dtype dtype,
const mlx_array loc /* may be null */,
const mlx_array scale /* may be null */,
const mlx_array key /* may be null */,
const mlx_stream s);
int mlx_random_normal(
mlx_array* res,
const int* shape,
size_t shape_num,
mlx_dtype dtype,
float loc,
float scale,
const mlx_array key /* may be null */,
const mlx_stream s);
int mlx_random_permutation(
mlx_array* res,
const mlx_array x,
int axis,
const mlx_array key /* may be null */,
const mlx_stream s);
int mlx_random_permutation_arange(
mlx_array* res,
int x,
const mlx_array key /* may be null */,
const mlx_stream s);
int mlx_random_randint(
mlx_array* res,
const mlx_array low,
const mlx_array high,
const int* shape,
size_t shape_num,
mlx_dtype dtype,
const mlx_array key /* may be null */,
const mlx_stream s);
int mlx_random_seed(uint64_t seed);
int mlx_random_split_num(
mlx_array* res,
const mlx_array key,
int num,
const mlx_stream s);
int mlx_random_split(
mlx_array* res_0,
mlx_array* res_1,
const mlx_array key,
const mlx_stream s);
int mlx_random_truncated_normal(
mlx_array* res,
const mlx_array lower,
const mlx_array upper,
const int* shape,
size_t shape_num,
mlx_dtype dtype,
const mlx_array key /* may be null */,
const mlx_stream s);
int mlx_random_uniform(
mlx_array* res,
const mlx_array low,
const mlx_array high,
const int* shape,
size_t shape_num,
mlx_dtype dtype,
const mlx_array key /* may be null */,
const mlx_stream s);
/**@}*/
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -0,0 +1,88 @@
/* Copyright © 2023-2024 Apple Inc. */
#ifndef MLX_STREAM_H
#define MLX_STREAM_H
#include <stdbool.h>
#include "mlx/c/device.h"
#ifdef __cplusplus
extern "C" {
#endif
/**
* \defgroup mlx_stream Stream
* MLX stream object.
*/
/**@{*/
/**
* A MLX stream object.
*/
typedef struct mlx_stream_ {
void* ctx;
} mlx_stream;
/**
* Returns a new empty stream.
*/
mlx_stream mlx_stream_new(void);
/**
* Returns a new stream on a device.
*/
mlx_stream mlx_stream_new_device(mlx_device dev);
/**
* Set stream to provided src stream.
*/
int mlx_stream_set(mlx_stream* stream, const mlx_stream src);
/**
* Free a stream.
*/
int mlx_stream_free(mlx_stream stream);
/**
* Get stream description.
*/
int mlx_stream_tostring(mlx_string* str, mlx_stream stream);
/**
* Check if streams are the same.
*/
bool mlx_stream_equal(mlx_stream lhs, mlx_stream rhs);
/**
* Return the device of the stream.
*/
int mlx_stream_get_device(mlx_device* dev, mlx_stream stream);
/**
* Return the index of the stream.
*/
int mlx_stream_get_index(int* index, mlx_stream stream);
/**
* Synchronize with the provided stream.
*/
int mlx_synchronize(mlx_stream stream);
/**
* Returns the default stream on the given device.
*/
int mlx_get_default_stream(mlx_stream* stream, mlx_device dev);
/**
* Set default stream.
*/
int mlx_set_default_stream(mlx_stream stream);
/**
* Returns the current default CPU stream.
*/
mlx_stream mlx_default_cpu_stream_new(void);
/**
* Returns the current default GPU stream.
*/
mlx_stream mlx_default_gpu_stream_new(void);
/**@}*/
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -0,0 +1,55 @@
/* Copyright © 2023-2024 Apple Inc. */
#ifndef MLX_STRING_H
#define MLX_STRING_H
#ifdef __cplusplus
extern "C" {
#endif
/**
* \defgroup mlx_string String
* MLX string object.
*/
/**@{*/
/**
* A MLX string object.
*/
typedef struct mlx_string_ {
void* ctx;
} mlx_string;
/**
* Returns a new empty string.
*/
mlx_string mlx_string_new(void);
/**
* Returns a new string, copying contents from `str`, which must end with `\0`.
*/
mlx_string mlx_string_new_data(const char* str);
/**
* Set string to src string.
*/
int mlx_string_set(mlx_string* str, const mlx_string src);
/**
* Returns a pointer to the string contents.
* The pointer is valid for the life duration of the string.
*/
const char* mlx_string_data(mlx_string str);
/**
* Free string.
*/
int mlx_string_free(mlx_string str);
/**@}*/
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -0,0 +1,68 @@
/* Copyright © 2023-2024 Apple Inc. */
/* */
/* This file is auto-generated. Do not edit manually. */
/* */
#ifndef MLX_TRANSFORMS_H
#define MLX_TRANSFORMS_H
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include "mlx/c/array.h"
#include "mlx/c/closure.h"
#include "mlx/c/distributed_group.h"
#include "mlx/c/io_types.h"
#include "mlx/c/map.h"
#include "mlx/c/stream.h"
#include "mlx/c/string.h"
#include "mlx/c/vector.h"
#ifdef __cplusplus
extern "C" {
#endif
/**
* \defgroup transforms Transform operations
*/
/**@{*/
int mlx_async_eval(const mlx_vector_array outputs);
int mlx_checkpoint(mlx_closure* res, const mlx_closure fun);
int mlx_custom_function(
mlx_closure* res,
const mlx_closure fun,
const mlx_closure_custom fun_vjp /* may be null */,
const mlx_closure_custom_jvp fun_jvp /* may be null */,
const mlx_closure_custom_vmap fun_vmap /* may be null */);
int mlx_custom_vjp(
mlx_closure* res,
const mlx_closure fun,
const mlx_closure_custom fun_vjp);
int mlx_eval(const mlx_vector_array outputs);
int mlx_jvp(
mlx_vector_array* res_0,
mlx_vector_array* res_1,
const mlx_closure fun,
const mlx_vector_array primals,
const mlx_vector_array tangents);
int mlx_value_and_grad(
mlx_closure_value_and_grad* res,
const mlx_closure fun,
const int* argnums,
size_t argnums_num);
int mlx_vjp(
mlx_vector_array* res_0,
mlx_vector_array* res_1,
const mlx_closure fun,
const mlx_vector_array primals,
const mlx_vector_array cotangents);
/**@}*/
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -0,0 +1,54 @@
/* Copyright © 2023-2024 Apple Inc. */
/* */
/* This file is auto-generated. Do not edit manually. */
/* */
#ifndef MLX_TRANSFORMS_IMPL_H
#define MLX_TRANSFORMS_IMPL_H
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include "mlx/c/array.h"
#include "mlx/c/closure.h"
#include "mlx/c/distributed_group.h"
#include "mlx/c/io_types.h"
#include "mlx/c/map.h"
#include "mlx/c/stream.h"
#include "mlx/c/string.h"
#include "mlx/c/vector.h"
#ifdef __cplusplus
extern "C" {
#endif
/**
* \defgroup transforms_impl Implementation detail operations
*/
/**@{*/
int mlx_detail_vmap_replace(
mlx_vector_array* res,
const mlx_vector_array inputs,
const mlx_vector_array s_inputs,
const mlx_vector_array s_outputs,
const int* in_axes,
size_t in_axes_num,
const int* out_axes,
size_t out_axes_num);
int mlx_detail_vmap_trace(
mlx_vector_array* res_0,
mlx_vector_array* res_1,
const mlx_closure fun,
const mlx_vector_array inputs,
const int* in_axes,
size_t in_axes_num);
/**@}*/
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -0,0 +1,133 @@
/* Copyright © 2023-2024 Apple Inc. */
/* */
/* This file is auto-generated. Do not edit manually. */
/* */
#ifndef MLX_VECTOR_H
#define MLX_VECTOR_H
#include "mlx/c/array.h"
#include "mlx/c/string.h"
#ifdef __cplusplus
extern "C" {
#endif
/**
* \defgroup mlx_vector Vectors
* MLX vector objects.
*/
/**@{*/
/**
* A vector of array.
*/
typedef struct mlx_vector_array_ {
void* ctx;
} mlx_vector_array;
mlx_vector_array mlx_vector_array_new(void);
int mlx_vector_array_set(mlx_vector_array* vec, const mlx_vector_array src);
int mlx_vector_array_free(mlx_vector_array vec);
mlx_vector_array mlx_vector_array_new_data(const mlx_array* data, size_t size);
mlx_vector_array mlx_vector_array_new_value(const mlx_array val);
int mlx_vector_array_set_data(
mlx_vector_array* vec,
const mlx_array* data,
size_t size);
int mlx_vector_array_set_value(mlx_vector_array* vec, const mlx_array val);
int mlx_vector_array_append_data(
mlx_vector_array vec,
const mlx_array* data,
size_t size);
int mlx_vector_array_append_value(mlx_vector_array vec, const mlx_array val);
size_t mlx_vector_array_size(mlx_vector_array vec);
int mlx_vector_array_get(
mlx_array* res,
const mlx_vector_array vec,
size_t idx);
/**
* A vector of vector_array.
*/
typedef struct mlx_vector_vector_array_ {
void* ctx;
} mlx_vector_vector_array;
mlx_vector_vector_array mlx_vector_vector_array_new(void);
int mlx_vector_vector_array_set(
mlx_vector_vector_array* vec,
const mlx_vector_vector_array src);
int mlx_vector_vector_array_free(mlx_vector_vector_array vec);
mlx_vector_vector_array mlx_vector_vector_array_new_data(
const mlx_vector_array* data,
size_t size);
mlx_vector_vector_array mlx_vector_vector_array_new_value(
const mlx_vector_array val);
int mlx_vector_vector_array_set_data(
mlx_vector_vector_array* vec,
const mlx_vector_array* data,
size_t size);
int mlx_vector_vector_array_set_value(
mlx_vector_vector_array* vec,
const mlx_vector_array val);
int mlx_vector_vector_array_append_data(
mlx_vector_vector_array vec,
const mlx_vector_array* data,
size_t size);
int mlx_vector_vector_array_append_value(
mlx_vector_vector_array vec,
const mlx_vector_array val);
size_t mlx_vector_vector_array_size(mlx_vector_vector_array vec);
int mlx_vector_vector_array_get(
mlx_vector_array* res,
const mlx_vector_vector_array vec,
size_t idx);
/**
* A vector of int.
*/
typedef struct mlx_vector_int_ {
void* ctx;
} mlx_vector_int;
mlx_vector_int mlx_vector_int_new(void);
int mlx_vector_int_set(mlx_vector_int* vec, const mlx_vector_int src);
int mlx_vector_int_free(mlx_vector_int vec);
mlx_vector_int mlx_vector_int_new_data(int* data, size_t size);
mlx_vector_int mlx_vector_int_new_value(int val);
int mlx_vector_int_set_data(mlx_vector_int* vec, int* data, size_t size);
int mlx_vector_int_set_value(mlx_vector_int* vec, int val);
int mlx_vector_int_append_data(mlx_vector_int vec, int* data, size_t size);
int mlx_vector_int_append_value(mlx_vector_int vec, int val);
size_t mlx_vector_int_size(mlx_vector_int vec);
int mlx_vector_int_get(int* res, const mlx_vector_int vec, size_t idx);
/**
* A vector of string.
*/
typedef struct mlx_vector_string_ {
void* ctx;
} mlx_vector_string;
mlx_vector_string mlx_vector_string_new(void);
int mlx_vector_string_set(mlx_vector_string* vec, const mlx_vector_string src);
int mlx_vector_string_free(mlx_vector_string vec);
mlx_vector_string mlx_vector_string_new_data(const char** data, size_t size);
mlx_vector_string mlx_vector_string_new_value(const char* val);
int mlx_vector_string_set_data(
mlx_vector_string* vec,
const char** data,
size_t size);
int mlx_vector_string_set_value(mlx_vector_string* vec, const char* val);
int mlx_vector_string_append_data(
mlx_vector_string vec,
const char** data,
size_t size);
int mlx_vector_string_append_value(mlx_vector_string vec, const char* val);
size_t mlx_vector_string_size(mlx_vector_string vec);
int mlx_vector_string_get(char** res, const mlx_vector_string vec, size_t idx);
/**@}*/
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -0,0 +1,18 @@
/* Copyright © 2023-2024 Apple Inc. */
#ifndef MLX_VERSION_H
#define MLX_VERSION_H
#include "mlx/c/string.h"
#ifdef __cplusplus
extern "C" {
#endif
int mlx_version(mlx_string* str_);
#ifdef __cplusplus
}
#endif
#endif

164
x/mlxrunner/mlx/io.go Normal file
View File

@@ -0,0 +1,164 @@
package mlx
// #include "generated.h"
import "C"
import (
"fmt"
"iter"
"runtime"
"sort"
"unsafe"
)
// SafetensorsFile represents a loaded safetensors file.
type SafetensorsFile struct {
arrays C.mlx_map_string_to_array
metadata C.mlx_map_string_to_string
}
func loadSafetensorsStream() C.mlx_stream {
if runtime.GOOS == "darwin" {
return C.mlx_default_cpu_stream_new()
}
return C.mlx_default_gpu_stream_new()
}
// LoadSafetensorsNative loads a safetensors file using MLX's native loader.
func LoadSafetensorsNative(path string) (*SafetensorsFile, error) {
var arrays C.mlx_map_string_to_array
var metadata C.mlx_map_string_to_string
cPath := C.CString(path)
defer C.free(unsafe.Pointer(cPath))
stream := loadSafetensorsStream()
defer C.mlx_stream_free(stream)
if C.mlx_load_safetensors(&arrays, &metadata, cPath, stream) != 0 {
return nil, fmt.Errorf("failed to load safetensors: %s", path)
}
return &SafetensorsFile{arrays: arrays, metadata: metadata}, nil
}
// Get retrieves a tensor by name.
func (s *SafetensorsFile) Get(name string) *Array {
cName := C.CString(name)
defer C.free(unsafe.Pointer(cName))
value := C.mlx_array_new()
if C.mlx_map_string_to_array_get(&value, s.arrays, cName) != 0 {
return nil
}
if value.ctx == nil {
return nil
}
arr := New(name)
arr.ctx = value
return arr
}
// GetMetadata retrieves a metadata value by key.
func (s *SafetensorsFile) GetMetadata(key string) string {
cKey := C.CString(key)
defer C.free(unsafe.Pointer(cKey))
var cValue *C.char
if C.mlx_map_string_to_string_get(&cValue, s.metadata, cKey) != 0 {
return ""
}
return C.GoString(cValue)
}
// Free releases the loaded safetensors maps.
func (s *SafetensorsFile) Free() {
if s == nil {
return
}
C.mlx_map_string_to_array_free(s.arrays)
C.mlx_map_string_to_string_free(s.metadata)
}
func Load(path string) iter.Seq2[string, *Array] {
return func(yield func(string, *Array) bool) {
sf, err := LoadSafetensorsNative(path)
if err != nil {
return
}
defer sf.Free()
it := C.mlx_map_string_to_array_iterator_new(sf.arrays)
defer C.mlx_map_string_to_array_iterator_free(it)
for {
var key *C.char
value := C.mlx_array_new()
if C.mlx_map_string_to_array_iterator_next(&key, &value, it) != 0 {
break
}
name := C.GoString(key)
arr := New(name)
arr.ctx = value
if !yield(name, arr) {
break
}
}
}
}
// SaveSafetensors saves arrays to a safetensors file without metadata.
func SaveSafetensors(path string, arrays map[string]*Array) error {
return SaveSafetensorsWithMetadata(path, arrays, nil)
}
// SaveSafetensorsWithMetadata saves arrays to a safetensors file with metadata.
func SaveSafetensorsWithMetadata(path string, arrays map[string]*Array, metadata map[string]string) error {
cPath := C.CString(path)
defer C.free(unsafe.Pointer(cPath))
cArrays := C.mlx_map_string_to_array_new()
defer C.mlx_map_string_to_array_free(cArrays)
arrayNames := make([]string, 0, len(arrays))
for name, arr := range arrays {
if arr == nil {
continue
}
arrayNames = append(arrayNames, name)
}
sort.Strings(arrayNames)
for _, name := range arrayNames {
arr := arrays[name]
cName := C.CString(name)
C.mlx_map_string_to_array_insert(cArrays, cName, arr.ctx)
C.free(unsafe.Pointer(cName))
}
cMetadata := C.mlx_map_string_to_string_new()
defer C.mlx_map_string_to_string_free(cMetadata)
metadataKeys := make([]string, 0, len(metadata))
for key := range metadata {
metadataKeys = append(metadataKeys, key)
}
sort.Strings(metadataKeys)
for _, key := range metadataKeys {
value := metadata[key]
cKey := C.CString(key)
cValue := C.CString(value)
C.mlx_map_string_to_string_insert(cMetadata, cKey, cValue)
C.free(unsafe.Pointer(cKey))
C.free(unsafe.Pointer(cValue))
}
if C.mlx_save_safetensors(cPath, cArrays, cMetadata) != 0 {
return fmt.Errorf("failed to save safetensors: %s", path)
}
return nil
}

89
x/mlxrunner/mlx/memory.go Normal file
View File

@@ -0,0 +1,89 @@
package mlx
// #include "generated.h"
import "C"
import (
"fmt"
"log/slog"
"strconv"
)
func (b Byte) String() string {
return strconv.FormatInt(int64(b), 10) + " B"
}
func (b KibiByte) String() string {
return strconv.FormatFloat(float64(b)/(1<<10), 'f', 2, 64) + " KiB"
}
func (b MebiByte) String() string {
return strconv.FormatFloat(float64(b)/(1<<(2*10)), 'f', 2, 64) + " MiB"
}
func (b GibiByte) String() string {
return strconv.FormatFloat(float64(b)/(1<<(3*10)), 'f', 2, 64) + " GiB"
}
func (b TebiByte) String() string {
return strconv.FormatFloat(float64(b)/(1<<(4*10)), 'f', 2, 64) + " TiB"
}
func PrettyBytes(n int) fmt.Stringer {
switch {
case n < 1<<10:
return Byte(n)
case n < 1<<(2*10):
return KibiByte(n)
case n < 1<<(3*10):
return MebiByte(n)
case n < 1<<(4*10):
return GibiByte(n)
default:
return TebiByte(n)
}
}
func ActiveMemory() int {
var active C.size_t
C.mlx_get_active_memory(&active)
return int(active)
}
func CacheMemory() int {
var cache C.size_t
C.mlx_get_cache_memory(&cache)
return int(cache)
}
func PeakMemory() int {
var peak C.size_t
C.mlx_get_peak_memory(&peak)
return int(peak)
}
func ResetPeakMemory() {
C.mlx_reset_peak_memory()
}
type Memory struct{}
func (Memory) LogValue() slog.Value {
return slog.GroupValue(
slog.Any("active", PrettyBytes(ActiveMemory())),
slog.Any("cache", PrettyBytes(CacheMemory())),
slog.Any("peak", PrettyBytes(PeakMemory())),
)
}
type (
Byte int
KibiByte int
MebiByte int
GibiByte int
TebiByte int
)
func ClearCache() {
C.mlx_clear_cache()
}

107
x/mlxrunner/mlx/mlx.go Normal file
View File

@@ -0,0 +1,107 @@
package mlx
//go:generate go run generator/main.go -output=. ./include/mlx/c/*.h
// #cgo CXXFLAGS: -std=c++17
// #cgo CPPFLAGS: -I${SRCDIR}/include
// #cgo LDFLAGS: -lstdc++
// #cgo darwin LDFLAGS: -framework Foundation -framework Metal -framework Accelerate
// #include "generated.h"
// #include <string.h>
//
// static __thread char _mlx_last_error_msg[1024] = {0};
// static __thread int _mlx_last_error_flag = 0;
//
// static void _mlx_capture_error_handler(const char* msg, void* data) {
// (void)data;
// strncpy(_mlx_last_error_msg, msg, sizeof(_mlx_last_error_msg) - 1);
// _mlx_last_error_msg[sizeof(_mlx_last_error_msg) - 1] = '\0';
// _mlx_last_error_flag = 1;
// }
//
// static void mlx_install_capture_handler(void) {
// if (mlx_set_error_handler_) {
// mlx_set_error_handler_(_mlx_capture_error_handler, NULL, NULL);
// }
// }
//
// static void mlx_clear_last_error(void) {
// _mlx_last_error_flag = 0;
// _mlx_last_error_msg[0] = '\0';
// }
//
// static const char* mlx_get_last_error(void) {
// return _mlx_last_error_flag ? _mlx_last_error_msg : "";
// }
import "C"
import "runtime"
func init() {
// Replace the default exit(-1) error handler with one that captures
// the error message so we can surface it in Go.
C.mlx_install_capture_handler()
}
// Version returns the MLX core library version string.
func Version() string {
str := C.mlx_string_new()
defer C.mlx_string_free(str)
C.mlx_version(&str)
return C.GoString(C.mlx_string_data(str))
}
// mlxCheck locks the goroutine to its OS thread, clears the captured error
// state, calls fn, and panics with the captured message if fn returns non-zero.
// The thread lock ensures the thread-local error state is read from the same
// thread that executed the call.
func mlxCheck(fallback string, fn func() C.int) {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
C.mlx_clear_last_error()
if fn() != 0 {
msg := C.GoString(C.mlx_get_last_error())
if msg == "" {
msg = fallback
}
panic("mlx: " + msg)
}
}
func doEval(outputs []*Array, async bool) {
if len(outputs) == 0 {
return
}
vector := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(vector)
for _, output := range outputs {
if output != nil && output.Valid() {
C.mlx_vector_array_append_value(vector, output.ctx)
}
}
mlxCheck("eval failed", func() C.int {
if async {
return C.mlx_async_eval(vector)
}
return C.mlx_eval(vector)
})
}
func AsyncEval(outputs ...*Array) {
doEval(outputs, true)
}
func Eval(outputs ...*Array) {
doEval(outputs, false)
}
// MetalIsAvailable returns true if a Metal GPU is available.
func MetalIsAvailable() bool {
var available C._Bool
C.mlx_metal_is_available(&available)
return bool(available)
}

36
x/mlxrunner/mlx/nn.go Normal file
View File

@@ -0,0 +1,36 @@
package mlx
type Linear struct {
Weight *Array `weight:"weight"`
Bias *Array `weight:"bias"`
}
// Forward computes the linear transformation: x @ Weight.T + Bias
func (m *Linear) Forward(x *Array) *Array {
w := m.Weight.Transpose(1, 0)
if m.Bias.Valid() {
return m.Bias.Addmm(x, w, 1.0, 1.0)
}
return x.Matmul(w)
}
func (m *Linear) Gather(x, lhs, rhs *Array, sorted bool) *Array {
w := m.Weight.Transpose(0, 2, 1)
// TODO: bias
return x.GatherMM(w, lhs, rhs, sorted)
}
type Embedding struct {
Weight *Array `weight:"weight"`
}
func (e *Embedding) Forward(indices *Array) *Array {
return e.Weight.TakeAxis(indices, 0)
}
func (e *Embedding) AsLinear() Linear {
return Linear{
Weight: e.Weight,
}
}

300
x/mlxrunner/mlx/ops.go Normal file
View File

@@ -0,0 +1,300 @@
package mlx
// #include "generated.h"
import "C"
import (
"unsafe"
)
func (t *Array) Abs() *Array {
out := New("ABS")
C.mlx_abs(&out.ctx, t.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Add(other *Array) *Array {
out := New("ADD")
C.mlx_add(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Addmm(a, b *Array, alpha, beta float32) *Array {
out := New("ADDMM")
C.mlx_addmm(&out.ctx, t.ctx, a.ctx, b.ctx, C.float(alpha), C.float(beta), DefaultStream().ctx)
return out
}
func (t *Array) Argmax(axis int, keepDims bool) *Array {
out := New("ARGMAX")
C.mlx_argmax_axis(&out.ctx, t.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
return out
}
func (t *Array) ArgpartitionAxis(kth int, axis int) *Array {
out := New("ARGPARTITION")
C.mlx_argpartition_axis(&out.ctx, t.ctx, C.int(kth), C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) ArgsortAxis(axis int) *Array {
out := New("ARGSORT_AXIS")
C.mlx_argsort_axis(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) AsType(dtype DType) *Array {
out := New("AS_TYPE")
C.mlx_astype(&out.ctx, t.ctx, C.mlx_dtype(dtype), DefaultStream().ctx)
return out
}
func (t *Array) AsStrided(shape []int, strides []int, offset int) *Array {
cShape := make([]C.int, len(shape))
for i, s := range shape {
cShape[i] = C.int(s)
}
cStrides := make([]C.int64_t, len(strides))
for i, s := range strides {
cStrides[i] = C.int64_t(s)
}
out := New("AS_STRIDED")
C.mlx_as_strided(
&out.ctx, t.ctx,
unsafe.SliceData(cShape), C.size_t(len(shape)),
unsafe.SliceData(cStrides), C.size_t(len(strides)),
C.size_t(offset),
DefaultStream().ctx,
)
return out
}
func (t *Array) Concatenate(axis int, others ...*Array) *Array {
if len(others) == 0 {
return t.Clone()
}
vector := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(vector)
s := append([]*Array{t}, others...)
for _, other := range s {
C.mlx_vector_array_append_value(vector, other.ctx)
}
out := New("CONCATENATE")
C.mlx_concatenate_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) Cumsum(axis int, reverse, inclusive bool) *Array {
out := New("CUMSUM")
C.mlx_cumsum(&out.ctx, t.ctx, C.int(axis), C.bool(reverse), C.bool(inclusive), DefaultStream().ctx)
return out
}
func (t *Array) Divide(other *Array) *Array {
out := New("DIVIDE")
C.mlx_divide(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) ExpandDims(axis int) *Array {
out := New("EXPAND_DIMS")
C.mlx_expand_dims(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) Flatten(startAxis, endAxis int) *Array {
out := New("FLATTEN")
C.mlx_flatten(&out.ctx, t.ctx, C.int(startAxis), C.int(endAxis), DefaultStream().ctx)
return out
}
func (t *Array) FloorDivide(other *Array) *Array {
out := New("FLOOR_DIVIDE")
C.mlx_floor_divide(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) GatherMM(other, lhs, rhs *Array, sorted bool) *Array {
if lhs == nil {
lhs = New("")
}
if rhs == nil {
rhs = New("")
}
out := New("GATHER_MM")
C.mlx_gather_mm(&out.ctx, t.ctx, other.ctx, lhs.ctx, rhs.ctx, C.bool(sorted), DefaultStream().ctx)
return out
}
func (t *Array) LogsumexpAxis(axis int, keepDims bool) *Array {
out := New("LOGSUMEXP_AXIS")
C.mlx_logsumexp_axis(&out.ctx, t.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
return out
}
func (t *Array) Equal(other *Array) *Array {
out := New("EQUAL")
C.mlx_equal(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Greater(other *Array) *Array {
out := New("GREATER")
C.mlx_greater(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Less(other *Array) *Array {
out := New("LESS")
C.mlx_less(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) LessEqual(other *Array) *Array {
out := New("LESS_EQUAL")
C.mlx_less_equal(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) MaxAxis(axis int, keepDims bool) *Array {
out := New("MAX_AXIS")
C.mlx_max_axis(&out.ctx, t.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
return out
}
func (t *Array) Matmul(other *Array) *Array {
out := New("MATMUL")
C.mlx_matmul(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Multiply(other *Array) *Array {
out := New("MULTIPLY")
C.mlx_multiply(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Negative() *Array {
out := New("NEGATIVE")
C.mlx_negative(&out.ctx, t.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Power(exponent *Array) *Array {
out := New("POWER")
C.mlx_power(&out.ctx, t.ctx, exponent.ctx, DefaultStream().ctx)
return out
}
func (t *Array) PutAlongAxis(indices, values *Array, axis int) *Array {
out := New("PUT_ALONG_AXIS")
C.mlx_put_along_axis(&out.ctx, t.ctx, indices.ctx, values.ctx, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) ScatterAddAxis(indices, values *Array, axis int) *Array {
out := New("SCATTER_ADD_AXIS")
C.mlx_scatter_add_axis(&out.ctx, t.ctx, indices.ctx, values.ctx, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) Reshape(axes ...int) *Array {
cAxes := make([]C.int, len(axes))
for i := range axes {
cAxes[i] = C.int(axes[i])
}
out := New("RESHAPE")
C.mlx_reshape(&out.ctx, t.ctx, unsafe.SliceData(cAxes), C.size_t(len(cAxes)), DefaultStream().ctx)
return out
}
func (t *Array) Sigmoid() *Array {
out := New("SIGMOID")
C.mlx_sigmoid(&out.ctx, t.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Sqrt() *Array {
out := New("SQRT")
C.mlx_sqrt(&out.ctx, t.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Squeeze(axis int) *Array {
out := New("SQUEEZE")
C.mlx_squeeze_axis(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) StackAxis(axis int, others ...*Array) *Array {
vectorData := make([]C.mlx_array, len(others)+1)
vectorData[0] = t.ctx
for i := range others {
vectorData[i+1] = others[i].ctx
}
vector := C.mlx_vector_array_new_data(unsafe.SliceData(vectorData), C.size_t(len(vectorData)))
defer C.mlx_vector_array_free(vector)
out := New("STACK_AXIS")
C.mlx_stack_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) Subtract(other *Array) *Array {
out := New("SUBTRACT")
C.mlx_subtract(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) SumAxis(axis int, keepDims bool) *Array {
out := New("SUM_AXIS")
C.mlx_sum_axis(&out.ctx, t.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
return out
}
func (t *Array) TakeAxis(indices *Array, axis int) *Array {
out := New("TAKE_AXIS")
C.mlx_take_axis(&out.ctx, t.ctx, indices.ctx, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) TakeAlongAxis(indices *Array, axis int) *Array {
out := New("TAKE_ALONG_AXIS")
C.mlx_take_along_axis(&out.ctx, t.ctx, indices.ctx, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) Tanh() *Array {
out := New("TANH")
C.mlx_tanh(&out.ctx, t.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Transpose(axes ...int) *Array {
cAxes := make([]C.int, len(axes))
for i, axis := range axes {
cAxes[i] = C.int(axis)
}
out := New("TRANSPOSE")
C.mlx_transpose_axes(&out.ctx, t.ctx, unsafe.SliceData(cAxes), C.size_t(len(cAxes)), DefaultStream().ctx)
return out
}
func Zeros(dtype DType, shape ...int) *Array {
cAxes := make([]C.int, len(shape))
for i := range shape {
cAxes[i] = C.int(shape[i])
}
t := New("ZEROS")
C.mlx_zeros(&t.ctx, unsafe.SliceData(cAxes), C.size_t(len(cAxes)), C.mlx_dtype(dtype), DefaultStream().ctx)
return t
}

View File

@@ -0,0 +1,666 @@
package mlx
// #include "generated.h"
import "C"
import (
"reflect"
"unsafe"
)
// Quantization operations
func Quantize(w *Array, groupSize, bits int, mode string) (weights, scales, biases *Array) {
cMode := C.CString(mode)
defer C.free(unsafe.Pointer(cMode))
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
res := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(res)
var globalScale C.mlx_array
C.mlx_quantize(&res, w.ctx, optGroupSize, optBits, cMode, globalScale, DefaultStream().ctx)
vecSize := int(C.mlx_vector_array_size(res))
w0 := New("QUANTIZE_W")
C.mlx_vector_array_get(&w0.ctx, res, 0)
w1 := New("QUANTIZE_S")
C.mlx_vector_array_get(&w1.ctx, res, 1)
if vecSize >= 3 {
w2 := New("QUANTIZE_B")
C.mlx_vector_array_get(&w2.ctx, res, 2)
return w0, w1, w2
}
return w0, w1, nil
}
func FromFP8(x *Array, dtype DType) *Array {
out := New("FROM_FP8")
C.mlx_from_fp8(&out.ctx, x.ctx, C.mlx_dtype(dtype), DefaultStream().ctx)
return out
}
func ToFP8(x *Array) *Array {
out := New("TO_FP8")
C.mlx_to_fp8(&out.ctx, x.ctx, DefaultStream().ctx)
return out
}
func Dequantize(w, scales, biases *Array, groupSize, bits int, mode string) *Array {
cMode := C.CString(mode)
defer C.free(unsafe.Pointer(cMode))
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
optDtype := C.mlx_optional_dtype{has_value: false}
var b C.mlx_array
if biases != nil {
b = biases.ctx
}
out := New("DEQUANTIZE")
var globalScale C.mlx_array
C.mlx_dequantize(&out.ctx, w.ctx, scales.ctx, b, optGroupSize, optBits, cMode, globalScale, optDtype, DefaultStream().ctx)
return out
}
func QuantizedMatmul(x, w, scales, biases *Array, transpose bool, groupSize, bits int, mode string) *Array {
cMode := C.CString(mode)
defer C.free(unsafe.Pointer(cMode))
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
var b C.mlx_array
if biases != nil {
b = biases.ctx
}
out := New("QUANTIZED_MATMUL")
C.mlx_quantized_matmul(&out.ctx, x.ctx, w.ctx, scales.ctx, b, C.bool(transpose), optGroupSize, optBits, cMode, DefaultStream().ctx)
return out
}
func GatherQMM(x, w, scales *Array, biases, lhsIndices, rhsIndices *Array, transpose bool, groupSize, bits int, mode string, sortedIndices bool) *Array {
cMode := C.CString(mode)
defer C.free(unsafe.Pointer(cMode))
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
var b, lhs, rhs C.mlx_array
if biases != nil {
b = biases.ctx
}
if lhsIndices != nil {
lhs = lhsIndices.ctx
}
if rhsIndices != nil {
rhs = rhsIndices.ctx
}
out := New("GATHER_QMM")
C.mlx_gather_qmm(&out.ctx, x.ctx, w.ctx, scales.ctx, b, lhs, rhs, C.bool(transpose), optGroupSize, optBits, cMode, C.bool(sortedIndices), DefaultStream().ctx)
return out
}
// Missing tensor ops
func Tile(a *Array, reps []int32) *Array {
cReps := make([]C.int, len(reps))
for i, r := range reps {
cReps[i] = C.int(r)
}
out := New("TILE")
C.mlx_tile(&out.ctx, a.ctx, unsafe.SliceData(cReps), C.size_t(len(reps)), DefaultStream().ctx)
return out
}
func Tri(n, m int32, k int) *Array {
out := New("TRI")
C.mlx_tri(&out.ctx, C.int(n), C.int(m), C.int(k), C.mlx_dtype(DTypeFloat32), DefaultStream().ctx)
return out
}
func Where(condition, a, b *Array) *Array {
out := New("WHERE")
C.mlx_where(&out.ctx, condition.ctx, a.ctx, b.ctx, DefaultStream().ctx)
return out
}
func Conv1d(x, weight *Array, bias *Array, stride, padding, dilation, groups int32) *Array {
out := New("CONV1D")
C.mlx_conv1d(
&out.ctx,
x.ctx,
weight.ctx,
C.int(stride),
C.int(padding),
C.int(dilation),
C.int(groups),
DefaultStream().ctx,
)
if bias != nil && bias.Valid() {
out = Add(out, bias)
}
return out
}
func Contiguous(a *Array, allowColMajor bool) *Array {
out := New("CONTIGUOUS")
C.mlx_contiguous(&out.ctx, a.ctx, C.bool(allowColMajor), DefaultStream().ctx)
return out
}
// Conv2d performs 2D convolution: x [N,H,W,C_in], weight [C_out,kH,kW,C_in].
// MLX uses NHWC layout.
func Conv2d(x, weight *Array, strideH, strideW, padH, padW, dilationH, dilationW, groups int32) *Array {
out := New("CONV2D")
C.mlx_conv2d(
&out.ctx,
x.ctx,
weight.ctx,
C.int(strideH), C.int(strideW),
C.int(padH), C.int(padW),
C.int(dilationH), C.int(dilationW),
C.int(groups),
DefaultStream().ctx,
)
return out
}
// Pad pads array a along the given axes with specified low/high pad sizes.
// mode should be "constant", "edge", or "reflect".
func Pad(a *Array, axes []int, lowPad, highPad []int, padValue *Array, mode string) *Array {
cAxes := make([]C.int, len(axes))
cLow := make([]C.int, len(lowPad))
cHigh := make([]C.int, len(highPad))
for i := range axes {
cAxes[i] = C.int(axes[i])
cLow[i] = C.int(lowPad[i])
cHigh[i] = C.int(highPad[i])
}
cMode := C.CString(mode)
defer C.free(unsafe.Pointer(cMode))
out := New("PAD")
C.mlx_pad(
&out.ctx,
a.ctx,
unsafe.SliceData(cAxes), C.size_t(len(cAxes)),
unsafe.SliceData(cLow), C.size_t(len(cLow)),
unsafe.SliceData(cHigh), C.size_t(len(cHigh)),
padValue.ctx,
cMode,
DefaultStream().ctx,
)
return out
}
// PadConstant pads with zeros along the given axes.
func PadConstant(a *Array, axes []int, lowPad, highPad []int) *Array {
zero := NewScalarArray(float32(0))
return Pad(a, axes, lowPad, highPad, zero, "constant")
}
func DepthwiseConv1d(x, weight *Array, bias *Array) *Array {
groups := int32(x.Dim(x.NumDims() - 1))
return Conv1d(x, weight, bias, 1, 0, 1, groups)
}
// Maximum returns element-wise maximum of two arrays.
func Maximum(a, b *Array) *Array {
out := New("MAXIMUM")
C.mlx_maximum(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
return out
}
// Minimum returns element-wise minimum of two arrays.
func Minimum(a, b *Array) *Array {
out := New("MINIMUM")
C.mlx_minimum(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
return out
}
// Softplus computes log(1 + exp(x)) using logaddexp for numerical stability.
func Softplus(a *Array) *Array {
return Logaddexp(a, Zeros(a.DType(), a.Dims()...))
}
// ReLU computes max(0, x).
func ReLU(a *Array) *Array {
return Maximum(a, NewScalarArray(float32(0)))
}
// GLU applies Gated Linear Unit: splits x along last dim into two halves,
// returns first * sigmoid(second).
func GLU(a *Array) *Array {
lastDim := a.NumDims() - 1
halfSize := a.Dim(lastDim) / 2
first := SliceStartStop(a,
make([]int32, lastDim+1), // all zeros for start
appendDims(a, lastDim, int32(halfSize)),
)
second := SliceStartStop(a,
appendDimsStart(a, lastDim, int32(halfSize)),
appendDims(a, lastDim, int32(a.Dim(lastDim))),
)
return first.Multiply(second.Sigmoid())
}
// helper: builds stop array for SliceStartStop where the target axis = val
func appendDims(a *Array, targetAxis int, val int32) []int32 {
n := a.NumDims()
out := make([]int32, n)
for i := range n {
if i == targetAxis {
out[i] = val
} else {
out[i] = int32(a.Dim(i))
}
}
return out
}
// helper: builds start array for SliceStartStop where the target axis = val
func appendDimsStart(a *Array, targetAxis int, val int32) []int32 {
n := a.NumDims()
out := make([]int32, n)
for i := range n {
if i == targetAxis {
out[i] = val
}
}
return out
}
// Clamp clamps array values to [min, max].
func Clamp(a *Array, minVal, maxVal float32) *Array {
return Minimum(Maximum(a, NewScalarArray(minVal)), NewScalarArray(maxVal))
}
// Convenience wrappers (function-style for the model code)
func Stack(arrays []*Array, axis int) *Array {
vectorData := make([]C.mlx_array, len(arrays))
for i := range arrays {
vectorData[i] = arrays[i].ctx
}
vector := C.mlx_vector_array_new_data(unsafe.SliceData(vectorData), C.size_t(len(vectorData)))
defer C.mlx_vector_array_free(vector)
out := New("STACK")
C.mlx_stack_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx)
return out
}
func Neg(a *Array) *Array {
return a.Negative()
}
func Sum(a *Array, axis int, keepDims bool) *Array {
return a.SumAxis(axis, keepDims)
}
func Argsort(a *Array, axis int) *Array {
return a.ArgsortAxis(axis)
}
func Take(a *Array, indices *Array, axis int) *Array {
return a.TakeAxis(indices, axis)
}
func RSqrt(a *Array) *Array {
out := New("RSQRT")
C.mlx_rsqrt(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
func Mean(a *Array, axis int, keepDims bool) *Array {
out := New("MEAN_AXIS")
C.mlx_mean_axis(&out.ctx, a.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
return out
}
func Argpartition(a *Array, kth int, axis int) *Array {
return a.ArgpartitionAxis(kth, axis)
}
func TakeAlongAxis(a, indices *Array, axis int) *Array {
return a.TakeAlongAxis(indices, axis)
}
// Function-style wrappers matching imagegen API
func Add(a, b *Array) *Array {
return a.Add(b)
}
func Sub(a, b *Array) *Array {
return a.Subtract(b)
}
func Mul(a, b *Array) *Array {
return a.Multiply(b)
}
func Div(a, b *Array) *Array {
return a.Divide(b)
}
func Matmul(a, b *Array) *Array {
return a.Matmul(b)
}
func Reshape(a *Array, shape ...int32) *Array {
axes := make([]int, len(shape))
for i, s := range shape {
axes[i] = int(s)
}
return a.Reshape(axes...)
}
func Transpose(a *Array, axes ...int) *Array {
return a.Transpose(axes...)
}
func ExpandDims(a *Array, axis int) *Array {
return a.ExpandDims(axis)
}
func Squeeze(a *Array, axis int) *Array {
return a.Squeeze(axis)
}
func Flatten(a *Array) *Array {
return a.Flatten(0, -1)
}
func Concatenate(arrays []*Array, axis int) *Array {
if len(arrays) == 0 {
return nil
}
if len(arrays) == 1 {
return arrays[0].Clone()
}
return arrays[0].Concatenate(axis, arrays[1:]...)
}
func SliceStartStop(a *Array, start, stop []int32) *Array {
n := len(start)
cStart := make([]C.int, n)
cStop := make([]C.int, n)
cStrides := make([]C.int, n)
for i := 0; i < n; i++ {
cStart[i] = C.int(start[i])
cStop[i] = C.int(stop[i])
cStrides[i] = 1
}
out := New("SLICE")
C.mlx_slice(&out.ctx, a.ctx, unsafe.SliceData(cStart), C.size_t(n), unsafe.SliceData(cStop), C.size_t(n), unsafe.SliceData(cStrides), C.size_t(n), DefaultStream().ctx)
return out
}
func GatherMM(a, b *Array, lhsIndices, rhsIndices *Array, sortedIndices bool) *Array {
if lhsIndices == nil {
lhsIndices = New("")
}
if rhsIndices == nil {
rhsIndices = New("")
}
return a.GatherMM(b, lhsIndices, rhsIndices, sortedIndices)
}
// RoPEWithBase applies rotary position embeddings to x. offsets is an
// int32 array of shape [B] giving each batch row's starting position;
// the kernel applies positions offsets[b] + 0..T-1 per row.
func RoPEWithBase(x *Array, dims int, traditional bool, base, scale float32, offsets *Array) *Array {
return RoPEWithFreqs(x, dims, traditional, base, scale, offsets, nil)
}
// RoPEWithFreqs applies RoPE with optional custom frequencies.
// When freqs is non-nil, it is used instead of computing from base.
// Note: MLX takes reciprocal(freqs) internally to get inv_freq, so pass
// the actual frequencies (base^(2i/dim)), not the inverse frequencies.
func RoPEWithFreqs(x *Array, dims int, traditional bool, base, scale float32, offsets *Array, freqs *Array) *Array {
var freqsCtx C.mlx_array
var optBase C.mlx_optional_float
if freqs != nil {
freqsCtx = freqs.ctx
optBase = C.mlx_optional_float{has_value: C.bool(false)}
} else {
empty := New("")
freqsCtx = empty.ctx
optBase = C.mlx_optional_float{
value: C.float(base),
has_value: C.bool(func() bool { return base != 0 }()),
}
}
out := New("FAST_ROPE")
C.mlx_fast_rope_dynamic(
&out.ctx,
x.ctx,
C.int(dims),
C.bool(traditional),
optBase,
C.float(scale),
offsets.ctx,
freqsCtx,
DefaultStream().ctx,
)
return out
}
func Sigmoid(a *Array) *Array {
return a.Sigmoid()
}
func Exp(a *Array) *Array {
out := New("EXP")
C.mlx_exp(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
func Log(a *Array) *Array {
out := New("LOG")
C.mlx_log(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
func Sin(a *Array) *Array {
out := New("SIN")
C.mlx_sin(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
func Cos(a *Array) *Array {
out := New("COS")
C.mlx_cos(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
func Clip(a, aMin, aMax *Array) *Array {
out := New("CLIP")
C.mlx_clip(&out.ctx, a.ctx, aMin.ctx, aMax.ctx, DefaultStream().ctx)
return out
}
func Logaddexp(a, b *Array) *Array {
out := New("LOGADDEXP")
C.mlx_logaddexp(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
return out
}
func SoftmaxAxis(a *Array, axis int, precise bool) *Array {
out := New("SOFTMAX_AXIS")
C.mlx_softmax_axis(&out.ctx, a.ctx, C.int(axis), C.bool(precise), DefaultStream().ctx)
return out
}
func LayerNormFn(x, weight, bias *Array, eps float32) *Array {
out := New("FAST_LAYERNORM")
var w, b C.mlx_array
if weight != nil {
w = weight.ctx
}
if bias != nil {
b = bias.ctx
}
C.mlx_fast_layer_norm(&out.ctx, x.ctx, w, b, C.float(eps), DefaultStream().ctx)
return out
}
func RMSNormFn(x, weight *Array, eps float32) *Array {
out := New("FAST_RMSNORM")
var w C.mlx_array
if weight != nil {
w = weight.ctx
}
C.mlx_fast_rms_norm(&out.ctx, x.ctx, w, C.float(eps), DefaultStream().ctx)
return out
}
func AddMM(c, a, b *Array, alpha, beta float32) *Array {
return c.Addmm(a, b, alpha, beta)
}
// Scalar helpers
// scalarWithDtype creates a scalar array matching the dtype of a.
// Matching dtype is important for graph fusion and avoiding implicit casts.
func scalarWithDtype(s float32, a *Array) C.mlx_array {
f32 := C.mlx_array_new_float(C.float(s))
dtype := a.DType()
if dtype == DTypeFloat32 {
return f32
}
casted := C.mlx_array_new()
C.mlx_astype(&casted, f32, C.mlx_dtype(dtype), DefaultStream().ctx)
C.mlx_array_free(f32)
return casted
}
func AddScalar(a *Array, s float32) *Array {
scalar := scalarWithDtype(s, a)
out := New("ADD_SCALAR")
C.mlx_add(&out.ctx, a.ctx, scalar, DefaultStream().ctx)
C.mlx_array_free(scalar)
return out
}
func MulScalar(a *Array, s float32) *Array {
scalar := scalarWithDtype(s, a)
out := New("MUL_SCALAR")
C.mlx_multiply(&out.ctx, a.ctx, scalar, DefaultStream().ctx)
C.mlx_array_free(scalar)
return out
}
func DivScalar(a *Array, s float32) *Array {
scalar := scalarWithDtype(s, a)
out := New("DIV_SCALAR")
C.mlx_divide(&out.ctx, a.ctx, scalar, DefaultStream().ctx)
C.mlx_array_free(scalar)
return out
}
func FloorDivideScalar(a *Array, s int32) *Array {
scalar := FromValue(int(s))
return a.FloorDivide(scalar)
}
// Array constructors
func NewArrayInt32(data []int32, shape []int32) *Array {
cShape := make([]C.int, len(shape))
for i, s := range shape {
cShape[i] = C.int(s)
}
out := New("NEW_ARRAY_INT32")
out.ctx = C.mlx_array_new_data(unsafe.Pointer(&data[0]), unsafe.SliceData(cShape), C.int(len(shape)), C.mlx_dtype(DTypeInt32))
return out
}
func NewScalarArray(value float32) *Array {
out := New("SCALAR")
out.ctx = C.mlx_array_new_float32(C.float(value))
return out
}
func ZerosF32(shape []int32) *Array {
return Zeros(DTypeFloat32, func() []int {
ints := make([]int, len(shape))
for i, s := range shape {
ints[i] = int(s)
}
return ints
}()...)
}
// Utility
func Collect(v any) []*Array {
var arrays []*Array
seen := make(map[uintptr]bool)
collect(reflect.ValueOf(v), &arrays, seen)
return arrays
}
func collect(v reflect.Value, arrays *[]*Array, seen map[uintptr]bool) {
if !v.IsValid() {
return
}
if v.Kind() == reflect.Ptr {
if v.IsNil() {
return
}
ptr := v.Pointer()
if seen[ptr] {
return
}
seen[ptr] = true
if arr, ok := v.Interface().(*Array); ok {
if arr != nil && arr.Valid() {
*arrays = append(*arrays, arr)
}
return
}
collect(v.Elem(), arrays, seen)
return
}
switch v.Kind() {
case reflect.Struct:
// Check if this struct IS an Array (not a pointer to one)
if arr, ok := v.Addr().Interface().(*Array); ok {
if arr != nil && arr.Valid() {
*arrays = append(*arrays, arr)
}
return
}
for i := 0; i < v.NumField(); i++ {
field := v.Field(i)
if field.CanInterface() {
collect(field, arrays, seen)
}
}
case reflect.Slice:
for i := 0; i < v.Len(); i++ {
collect(v.Index(i), arrays, seen)
}
case reflect.Map:
for _, key := range v.MapKeys() {
collect(v.MapIndex(key), arrays, seen)
}
case reflect.Interface:
if !v.IsNil() {
collect(v.Elem(), arrays, seen)
}
}
}
func EnableCompile() {
C.mlx_enable_compile()
}
func DisableCompile() {
C.mlx_disable_compile()
}

44
x/mlxrunner/mlx/random.go Normal file
View File

@@ -0,0 +1,44 @@
package mlx
// #include "generated.h"
import "C"
import "unsafe"
func RandomKey(seed uint64) *Array {
out := New("RANDOM_KEY")
C.mlx_random_key(&out.ctx, C.uint64_t(seed))
return out
}
func (t *Array) Categorical(axis int) *Array {
return t.CategoricalWithKey(axis, nil)
}
func (t *Array) CategoricalWithKey(axis int, key *Array) *Array {
if key == nil {
key = New("")
}
out := New("")
C.mlx_random_categorical(&out.ctx, t.ctx, C.int(axis), key.ctx, DefaultStream().ctx)
return out
}
func Bernoulli(p *Array) *Array {
return BernoulliWithKey(p, nil)
}
func BernoulliWithKey(p *Array, key *Array) *Array {
dims := p.Dims()
shape := make([]C.int, len(dims))
for i, d := range dims {
shape[i] = C.int(d)
}
if key == nil {
key = New("")
}
out := New("BERNOULLI")
C.mlx_random_bernoulli(&out.ctx, p.ctx, unsafe.SliceData(shape), C.size_t(len(shape)), key.ctx, DefaultStream().ctx)
return out
}

100
x/mlxrunner/mlx/slice.go Normal file
View File

@@ -0,0 +1,100 @@
package mlx
// #include "generated.h"
import "C"
import (
"math"
"unsafe"
)
// End is a sentinel value meaning "to the end of the dimension",
// equivalent to an omitted stop in Python (e.g. a[i:]).
const End = math.MaxInt32
type slice struct {
args []int
}
func Slice(args ...int) slice {
return slice{args: args}
}
func resolve(val, dim int) C.int {
if val == End {
return C.int(dim)
}
if val < 0 {
return C.int(dim + val)
}
return C.int(val)
}
func makeSlices(dims []int, slices ...slice) (starts, stops, strides []C.int) {
if len(slices) != len(dims) {
panic("number of slice arguments must match number of tensor dimensions")
}
args := [3][]C.int{
make([]C.int, len(slices)),
make([]C.int, len(slices)),
make([]C.int, len(slices)),
}
for i, s := range slices {
dim := dims[i]
switch len(s.args) {
case 0:
// slice[:]
args[0][i] = C.int(0)
args[1][i] = C.int(dim)
args[2][i] = C.int(1)
case 1:
// slice[i]
start := resolve(s.args[0], dim)
args[0][i] = start
args[1][i] = start + 1
args[2][i] = C.int(1)
case 2:
// slice[i:j]
args[0][i] = resolve(s.args[0], dim)
args[1][i] = resolve(s.args[1], dim)
args[2][i] = C.int(1)
case 3:
// slice[i:j:k]
args[0][i] = resolve(s.args[0], dim)
args[1][i] = resolve(s.args[1], dim)
args[2][i] = C.int(s.args[2])
default:
panic("invalid slice arguments")
}
}
return args[0], args[1], args[2]
}
func (t *Array) Slice(slices ...slice) *Array {
starts, stops, strides := makeSlices(t.Dims(), slices...)
out := New("SLICE")
C.mlx_slice(
&out.ctx, t.ctx,
unsafe.SliceData(starts), C.size_t(len(starts)),
unsafe.SliceData(stops), C.size_t(len(stops)),
unsafe.SliceData(strides), C.size_t(len(strides)),
DefaultStream().ctx,
)
return out
}
func (t *Array) SliceUpdate(other *Array, slices ...slice) *Array {
starts, stops, strides := makeSlices(t.Dims(), slices...)
out := New("SLICE_UPDATE")
C.mlx_slice_update(
&out.ctx, t.ctx, other.ctx,
unsafe.SliceData(starts), C.size_t(len(starts)),
unsafe.SliceData(stops), C.size_t(len(stops)),
unsafe.SliceData(strides), C.size_t(len(strides)),
DefaultStream().ctx,
)
return out
}

79
x/mlxrunner/mlx/stream.go Normal file
View File

@@ -0,0 +1,79 @@
package mlx
// #include "generated.h"
import "C"
import "log/slog"
type Device struct {
ctx C.mlx_device
}
func (d Device) LogValue() slog.Value {
str := C.mlx_string_new()
defer C.mlx_string_free(str)
C.mlx_device_tostring(&str, d.ctx)
return slog.StringValue(C.GoString(C.mlx_string_data(str)))
}
var (
defaultDevice Device
defaultDeviceSet bool
defaultStream Stream
defaultStreamSet bool
)
func resetDefaultStreamCache() {
defaultDeviceSet = false
defaultStreamSet = false
}
func DefaultDevice() Device {
if !defaultDeviceSet {
d := C.mlx_device_new()
C.mlx_get_default_device(&d)
defaultDevice = Device{d}
defaultDeviceSet = true
}
return defaultDevice
}
// GPUIsAvailable returns true if a GPU device is available.
func GPUIsAvailable() bool {
dev := C.mlx_device_new_type(C.MLX_GPU, 0)
defer C.mlx_device_free(dev)
var avail C.bool
C.mlx_device_is_available(&avail, dev)
return bool(avail)
}
// SetDefaultDeviceGPU sets the default MLX device to GPU.
func SetDefaultDeviceGPU() {
dev := C.mlx_device_new_type(C.MLX_GPU, 0)
C.mlx_set_default_device(dev)
C.mlx_device_free(dev)
resetDefaultStreamCache()
}
type Stream struct {
ctx C.mlx_stream
}
func (s Stream) LogValue() slog.Value {
str := C.mlx_string_new()
defer C.mlx_string_free(str)
C.mlx_stream_tostring(&str, s.ctx)
return slog.StringValue(C.GoString(C.mlx_string_data(str)))
}
func DefaultStream() Stream {
if !defaultStreamSet {
s := C.mlx_stream_new()
C.mlx_get_default_stream(&s, DefaultDevice().ctx)
defaultStream = Stream{s}
defaultStreamSet = true
}
return defaultStream
}

View File

@@ -0,0 +1,104 @@
package mlx
import (
"context"
"runtime"
"sync"
"testing"
"github.com/ollama/ollama/x/internal/mlxthread"
)
func skipIfNoMLX(t *testing.T) {
t.Helper()
if err := CheckInit(); err != nil {
t.Skipf("MLX not available: %v", err)
}
}
func startMLXThread(t *testing.T) *mlxthread.Thread {
t.Helper()
thread, err := mlxthread.Start("mlx-test", func() error {
if err := CheckInit(); err != nil {
return err
}
if GPUIsAvailable() {
SetDefaultDeviceGPU()
}
return nil
})
if err != nil {
t.Skipf("MLX not available: %v", err)
}
return thread
}
func stopMLXThread(t *testing.T, thread *mlxthread.Thread) {
t.Helper()
if err := thread.Stop(context.Background(), func() {
Sweep()
ClearCache()
resetDefaultStreamCache()
}); err != nil {
t.Fatal(err)
}
}
func withMLXThread(t *testing.T, fn func()) {
t.Helper()
thread := startMLXThread(t)
defer stopMLXThread(t, thread)
if err := thread.Do(context.Background(), func() error {
fn()
return nil
}); err != nil {
t.Fatal(err)
}
}
func TestThreadedMLXOperations(t *testing.T) {
thread := startMLXThread(t)
defer stopMLXThread(t, thread)
oldProcs := runtime.GOMAXPROCS(8)
defer runtime.GOMAXPROCS(oldProcs)
const goroutines = 8
const iterations = 8
var wg sync.WaitGroup
errCh := make(chan error, goroutines)
for range goroutines {
wg.Add(1)
go func() {
defer wg.Done()
for range iterations {
if err := thread.Do(context.Background(), func() error {
a := FromValues([]float32{1, 2, 3, 4}, 2, 2)
b := Matmul(a, a)
AsyncEval(b)
Eval(b)
Sweep()
ClearCache()
return nil
}); err != nil {
errCh <- err
return
}
}
}()
}
wg.Wait()
close(errCh)
for err := range errCh {
t.Fatal(err)
}
}