ollama source for Momentry Core verification
This commit is contained in:
3
x/mlxrunner/mlx/.gitignore
vendored
Normal file
3
x/mlxrunner/mlx/.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
_deps
|
||||
build
|
||||
dist
|
||||
32
x/mlxrunner/mlx/CMakeLists.txt
Normal file
32
x/mlxrunner/mlx/CMakeLists.txt
Normal file
@@ -0,0 +1,32 @@
|
||||
cmake_minimum_required(VERSION 3.5)
|
||||
|
||||
project(mlx)
|
||||
|
||||
if(CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT)
|
||||
set(CMAKE_INSTALL_PREFIX "${CMAKE_CURRENT_SOURCE_DIR}/dist" CACHE PATH "" FORCE)
|
||||
endif()
|
||||
|
||||
set(MLX_BUILD_GGUF OFF CACHE BOOL "" FORCE)
|
||||
set(MLX_BUILD_SAFETENSORS ON CACHE BOOL "" FORCE)
|
||||
set(MLX_C_BUILD_EXAMPLES OFF CACHE BOOL "" FORCE)
|
||||
set(BUILD_SHARED_LIBS ON CACHE BOOL "" FORCE)
|
||||
|
||||
set(CMAKE_INSTALL_RPATH "@loader_path")
|
||||
|
||||
include(FetchContent)
|
||||
|
||||
# Read MLX-C version from top-level file (shared with imagegen CMakeLists)
|
||||
file(READ "${CMAKE_SOURCE_DIR}/MLX_C_VERSION" MLX_C_GIT_TAG)
|
||||
string(STRIP "${MLX_C_GIT_TAG}" MLX_C_GIT_TAG)
|
||||
|
||||
FetchContent_Declare(
|
||||
mlx-c
|
||||
GIT_REPOSITORY "https://github.com/ml-explore/mlx-c.git"
|
||||
GIT_TAG ${MLX_C_GIT_TAG}
|
||||
)
|
||||
|
||||
FetchContent_MakeAvailable(mlx-c)
|
||||
|
||||
# Sync vendored headers with fetched version
|
||||
file(GLOB _mlx_c_hdrs "${mlx-c_SOURCE_DIR}/mlx/c/*.h")
|
||||
file(COPY ${_mlx_c_hdrs} DESTINATION "${CMAKE_CURRENT_SOURCE_DIR}/include/mlx/c/")
|
||||
99
x/mlxrunner/mlx/act.go
Normal file
99
x/mlxrunner/mlx/act.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package mlx
|
||||
|
||||
import "math"
|
||||
|
||||
var geluCoeff = float32(math.Sqrt(2 / math.Pi))
|
||||
|
||||
// GELUApprox returns 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
|
||||
// as a fused kernel.
|
||||
var GELUApprox = Compile1(
|
||||
"GELUApprox",
|
||||
func(x *Array) *Array {
|
||||
// Dtype-matched scalars avoid implicit upcasts on bf16 inputs.
|
||||
dt := x.DType()
|
||||
half := FromValue[float32](0.5).AsType(dt)
|
||||
coeff := FromValue(geluCoeff).AsType(dt)
|
||||
c := FromValue[float32](0.044715).AsType(dt)
|
||||
one := FromValue[float32](1.0).AsType(dt)
|
||||
|
||||
// x^3 via x*x*x (avoids general Power which is slower).
|
||||
x3 := x.Multiply(x).Multiply(x)
|
||||
inner := x.Add(c.Multiply(x3))
|
||||
tanh := coeff.Multiply(inner).Tanh()
|
||||
return half.Multiply(x).Multiply(one.Add(tanh))
|
||||
},
|
||||
Shapeless(),
|
||||
)
|
||||
|
||||
// SiLU returns a * sigmoid(a) as a fused kernel.
|
||||
var SiLU = Compile1(
|
||||
"SiLU",
|
||||
func(a *Array) *Array {
|
||||
return a.Multiply(a.Sigmoid())
|
||||
},
|
||||
Shapeless(),
|
||||
)
|
||||
|
||||
// SoftplusF32 returns softplus(x) computed in float32 precision and cast back
|
||||
// to x's original dtype, as a fused kernel. Matches the laguna attention
|
||||
// output-gate formula: softplus(cast_f32(x)).cast(orig_dtype).
|
||||
var SoftplusF32 = Compile1(
|
||||
"SoftplusF32",
|
||||
func(x *Array) *Array {
|
||||
dt := x.DType()
|
||||
zero := FromValue[float32](0)
|
||||
return Logaddexp(x.AsType(DTypeFloat32), zero).AsType(dt)
|
||||
},
|
||||
Shapeless(),
|
||||
)
|
||||
|
||||
// SwiGLU returns silu(gate) * up as a fused kernel.
|
||||
var SwiGLU = Compile2(
|
||||
"SwiGLU",
|
||||
func(gate, up *Array) *Array {
|
||||
return SiLU(gate).Multiply(up)
|
||||
},
|
||||
Shapeless(),
|
||||
)
|
||||
|
||||
// GeGLU returns gelu_approx(gate) * up as a fused kernel. Matches mlx_lm's
|
||||
// geglu, used by Gemma-family MLP and MoE paths.
|
||||
var GeGLU = Compile2(
|
||||
"GeGLU",
|
||||
func(gate, up *Array) *Array {
|
||||
return GELUApprox(gate).Multiply(up)
|
||||
},
|
||||
Shapeless(),
|
||||
)
|
||||
|
||||
// LogitSoftcap returns tanh(x / cap) * cap as a fused kernel. Matches
|
||||
// mlx_lm's logit_softcap. cap must have the same dtype as x.
|
||||
var LogitSoftcap = Compile2(
|
||||
"LogitSoftcap",
|
||||
func(x, cap *Array) *Array {
|
||||
return x.Divide(cap).Tanh().Multiply(cap)
|
||||
},
|
||||
Shapeless(),
|
||||
)
|
||||
|
||||
// sigmoidRouterFused traces the DeepSeek-V2 / GLM-MoE aux-loss-free router
|
||||
// head. Two outputs are returned so the pre-bias sigmoid (used to gather
|
||||
// per-expert scores after top-k) and the post-bias negation (used as the
|
||||
// argpartition key for top-k) share a single kernel.
|
||||
var sigmoidRouterFused = Compile(
|
||||
"SigmoidRouter",
|
||||
func(in ...*Array) []*Array {
|
||||
gates, bias := in[0], in[1]
|
||||
orig := gates.Sigmoid()
|
||||
neg := orig.Add(bias).Negative()
|
||||
return []*Array{orig, neg}
|
||||
},
|
||||
Shapeless(),
|
||||
)
|
||||
|
||||
// SigmoidRouter returns (sigmoid(gates), -(sigmoid(gates)+bias)) as a fused
|
||||
// kernel — the DeepSeek-V2 / GLM-MoE aux-loss-free router head.
|
||||
func SigmoidRouter(gates, bias *Array) (origScores, negScores *Array) {
|
||||
out := sigmoidRouterFused(gates, bias)
|
||||
return out[0], out[1]
|
||||
}
|
||||
295
x/mlxrunner/mlx/array.go
Normal file
295
x/mlxrunner/mlx/array.go
Normal file
@@ -0,0 +1,295 @@
|
||||
package mlx
|
||||
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"unsafe"
|
||||
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
type Array struct {
|
||||
ctx C.mlx_array
|
||||
name string
|
||||
pinned atomic.Int32
|
||||
}
|
||||
|
||||
var (
|
||||
arrays []*Array
|
||||
arraysMu sync.Mutex
|
||||
)
|
||||
|
||||
// constructor utilities
|
||||
|
||||
func New(name string) *Array {
|
||||
t := &Array{name: name}
|
||||
|
||||
if tracing {
|
||||
traceScratch = append(traceScratch, t)
|
||||
} else {
|
||||
arraysMu.Lock()
|
||||
defer arraysMu.Unlock()
|
||||
|
||||
arrays = append(arrays, t)
|
||||
}
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
type scalarTypes interface {
|
||||
~bool | ~int | ~float32 | ~float64 | ~complex64
|
||||
}
|
||||
|
||||
func FromValue[T scalarTypes](t T) *Array {
|
||||
tt := New("")
|
||||
switch v := any(t).(type) {
|
||||
case bool:
|
||||
tt.ctx = C.mlx_array_new_bool(C.bool(v))
|
||||
case int:
|
||||
tt.ctx = C.mlx_array_new_int(C.int(v))
|
||||
case float32:
|
||||
tt.ctx = C.mlx_array_new_float32(C.float(v))
|
||||
case float64:
|
||||
tt.ctx = C.mlx_array_new_float64(C.double(v))
|
||||
case complex64:
|
||||
tt.ctx = C.mlx_array_new_complex(C.float(real(v)), C.float(imag(v)))
|
||||
default:
|
||||
panic("unsupported type")
|
||||
}
|
||||
return tt
|
||||
}
|
||||
|
||||
type arrayTypes interface {
|
||||
~bool | ~uint8 | ~uint16 | ~uint32 | ~uint64 |
|
||||
~int8 | ~int16 | ~int32 | ~int64 |
|
||||
~float32 | ~float64 |
|
||||
~complex64
|
||||
}
|
||||
|
||||
func FromValues[S ~[]E, E arrayTypes](s S, shape ...int) *Array {
|
||||
if len(shape) == 0 {
|
||||
panic("shape must be provided for non-scalar tensors")
|
||||
}
|
||||
|
||||
cShape := make([]C.int, len(shape))
|
||||
for i := range shape {
|
||||
cShape[i] = C.int(shape[i])
|
||||
}
|
||||
|
||||
var dtype DType
|
||||
switch reflect.TypeOf(s).Elem().Kind() {
|
||||
case reflect.Bool:
|
||||
dtype = DTypeBool
|
||||
case reflect.Uint8:
|
||||
dtype = DTypeUint8
|
||||
case reflect.Uint16:
|
||||
dtype = DTypeUint16
|
||||
case reflect.Uint32:
|
||||
dtype = DTypeUint32
|
||||
case reflect.Uint64:
|
||||
dtype = DTypeUint64
|
||||
case reflect.Int8:
|
||||
dtype = DTypeInt8
|
||||
case reflect.Int16:
|
||||
dtype = DTypeInt16
|
||||
case reflect.Int32:
|
||||
dtype = DTypeInt32
|
||||
case reflect.Int64:
|
||||
dtype = DTypeInt64
|
||||
case reflect.Float32:
|
||||
dtype = DTypeFloat32
|
||||
case reflect.Float64:
|
||||
dtype = DTypeFloat64
|
||||
case reflect.Complex64:
|
||||
dtype = DTypeComplex64
|
||||
default:
|
||||
panic("unsupported type")
|
||||
}
|
||||
|
||||
bts := make([]byte, binary.Size(s))
|
||||
if _, err := binary.Encode(bts, binary.LittleEndian, s); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
tt := New("")
|
||||
tt.ctx = C.mlx_array_new_data(unsafe.Pointer(&bts[0]), unsafe.SliceData(cShape), C.int(len(cShape)), C.mlx_dtype(dtype))
|
||||
return tt
|
||||
}
|
||||
|
||||
func (t *Array) Set(other *Array) {
|
||||
C.mlx_array_set(&t.ctx, other.ctx)
|
||||
}
|
||||
|
||||
func (t *Array) Clone() *Array {
|
||||
tt := New(t.name)
|
||||
C.mlx_array_set(&tt.ctx, t.ctx)
|
||||
return tt
|
||||
}
|
||||
|
||||
// lifecycle utilities
|
||||
|
||||
// Pin marks arrays as in-use so they are retained during Sweep.
|
||||
func Pin(s ...*Array) {
|
||||
for _, t := range s {
|
||||
if t != nil {
|
||||
t.pinned.Add(1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Unpin marks arrays as no longer in-use, allowing Sweep to free them.
|
||||
func Unpin(s ...*Array) {
|
||||
for _, t := range s {
|
||||
if t != nil {
|
||||
if t.pinned.Add(-1) < 0 {
|
||||
panic(fmt.Sprintf("mlx.Unpin: negative pin count on array %q", t.name))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Sweep releases all unpinned arrays, primarily intermediate tensors. MLX will truly
|
||||
// free them when there are no other references, including dependencies in the graph.
|
||||
func Sweep() {
|
||||
arraysMu.Lock()
|
||||
defer arraysMu.Unlock()
|
||||
n := 0
|
||||
for _, t := range arrays {
|
||||
if t.pinned.Load() > 0 && t.Valid() {
|
||||
arrays[n] = t
|
||||
n++
|
||||
} else if t.Valid() {
|
||||
C.mlx_array_free(t.ctx)
|
||||
t.ctx.ctx = nil
|
||||
}
|
||||
}
|
||||
arrays = arrays[:n]
|
||||
}
|
||||
|
||||
// misc. utilities
|
||||
|
||||
func (t *Array) Valid() bool {
|
||||
return t.ctx.ctx != nil
|
||||
}
|
||||
|
||||
func (t *Array) String() string {
|
||||
str := C.mlx_string_new()
|
||||
defer C.mlx_string_free(str)
|
||||
C.mlx_array_tostring(&str, t.ctx)
|
||||
return strings.TrimSpace(C.GoString(C.mlx_string_data(str)))
|
||||
}
|
||||
|
||||
func (t *Array) LogValue() slog.Value {
|
||||
attrs := []slog.Attr{
|
||||
slog.String("name", t.name),
|
||||
slog.Int("pinned", int(t.pinned.Load())),
|
||||
}
|
||||
if t.Valid() {
|
||||
attrs = append(attrs,
|
||||
slog.Any("dtype", t.DType()),
|
||||
slog.Any("shape", t.Dims()),
|
||||
slog.Int("num_bytes", t.NumBytes()),
|
||||
)
|
||||
}
|
||||
return slog.GroupValue(attrs...)
|
||||
}
|
||||
|
||||
// shape utilities
|
||||
|
||||
func (t *Array) Size() int {
|
||||
return int(C.mlx_array_size(t.ctx))
|
||||
}
|
||||
|
||||
func (t *Array) NumBytes() int {
|
||||
return int(C.mlx_array_nbytes(t.ctx))
|
||||
}
|
||||
|
||||
func (t *Array) NumDims() int {
|
||||
return int(C.mlx_array_ndim(t.ctx))
|
||||
}
|
||||
|
||||
func (t *Array) Dims() []int {
|
||||
dims := make([]int, t.NumDims())
|
||||
for i := range dims {
|
||||
dims[i] = t.Dim(i)
|
||||
}
|
||||
|
||||
return dims
|
||||
}
|
||||
|
||||
func (t *Array) Dim(dim int) int {
|
||||
return int(C.mlx_array_dim(t.ctx, C.int(dim)))
|
||||
}
|
||||
|
||||
func (t *Array) DType() DType {
|
||||
return DType(C.mlx_array_dtype(t.ctx))
|
||||
}
|
||||
|
||||
// data utilities
|
||||
|
||||
func (t *Array) Int() int {
|
||||
var item C.int64_t
|
||||
C.mlx_array_item_int64(&item, t.ctx)
|
||||
return int(item)
|
||||
}
|
||||
|
||||
func (t *Array) Float() float64 {
|
||||
var item C.double
|
||||
C.mlx_array_item_float64(&item, t.ctx)
|
||||
return float64(item)
|
||||
}
|
||||
|
||||
func (t *Array) Ints() []int {
|
||||
if dt := t.DType(); dt != DTypeInt32 {
|
||||
panic(fmt.Sprintf("mlx: Ints requires DTypeInt32, got %v", dt))
|
||||
}
|
||||
ints := make([]int, t.Size())
|
||||
for i, f := range unsafe.Slice(C.mlx_array_data_int32(t.ctx), len(ints)) {
|
||||
ints[i] = int(f)
|
||||
}
|
||||
return ints
|
||||
}
|
||||
|
||||
func (t *Array) Floats() []float32 {
|
||||
if dt := t.DType(); dt != DTypeFloat32 {
|
||||
panic(fmt.Sprintf("mlx: Floats requires DTypeFloat32, got %v", dt))
|
||||
}
|
||||
floats := make([]float32, t.Size())
|
||||
for i, f := range unsafe.Slice(C.mlx_array_data_float32(t.ctx), len(floats)) {
|
||||
floats[i] = float32(f)
|
||||
}
|
||||
return floats
|
||||
}
|
||||
|
||||
func (t *Array) Save(name string) error {
|
||||
cName := C.CString(name)
|
||||
defer C.free(unsafe.Pointer(cName))
|
||||
C.mlx_save(cName, t.ctx)
|
||||
return nil
|
||||
}
|
||||
|
||||
// LogArrays logs all live arrays, sorted by size
|
||||
func LogArrays() {
|
||||
arraysMu.Lock()
|
||||
defer arraysMu.Unlock()
|
||||
sort.Slice(arrays, func(i, j int) bool {
|
||||
return arrays[i].NumBytes() > arrays[j].NumBytes()
|
||||
})
|
||||
|
||||
var total int
|
||||
for _, t := range arrays {
|
||||
nb := t.NumBytes()
|
||||
total += nb
|
||||
logutil.Trace(fmt.Sprintf("tensor %-60s %5s %5s pinned=%d %v", t.name, t.DType(), PrettyBytes(nb), t.pinned.Load(), t.Dims()))
|
||||
}
|
||||
logutil.Trace(fmt.Sprintf("tensors total: %d, size: %s, active: %s", len(arrays), PrettyBytes(total), PrettyBytes(ActiveMemory())))
|
||||
}
|
||||
76
x/mlxrunner/mlx/array_test.go
Normal file
76
x/mlxrunner/mlx/array_test.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package mlx
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestFromValue(t *testing.T) {
|
||||
withMLXThread(t, func() {
|
||||
for got, want := range map[*Array]DType{
|
||||
FromValue(true): DTypeBool,
|
||||
FromValue(false): DTypeBool,
|
||||
FromValue(int(7)): DTypeInt32,
|
||||
FromValue(float32(3.14)): DTypeFloat32,
|
||||
FromValue(float64(2.71)): DTypeFloat64,
|
||||
FromValue(complex64(1 + 2i)): DTypeComplex64,
|
||||
} {
|
||||
if got.DType() != want {
|
||||
t.Errorf("%s: want %v, got %v", want, want, got)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestFromValues(t *testing.T) {
|
||||
withMLXThread(t, func() {
|
||||
for got, want := range map[*Array]DType{
|
||||
FromValues([]bool{true, false, true}, 3): DTypeBool,
|
||||
FromValues([]uint8{1, 2, 3}, 3): DTypeUint8,
|
||||
FromValues([]uint16{1, 2, 3}, 3): DTypeUint16,
|
||||
FromValues([]uint32{1, 2, 3}, 3): DTypeUint32,
|
||||
FromValues([]uint64{1, 2, 3}, 3): DTypeUint64,
|
||||
FromValues([]int8{-1, -2, -3}, 3): DTypeInt8,
|
||||
FromValues([]int16{-1, -2, -3}, 3): DTypeInt16,
|
||||
FromValues([]int32{-1, -2, -3}, 3): DTypeInt32,
|
||||
FromValues([]int64{-1, -2, -3}, 3): DTypeInt64,
|
||||
FromValues([]float32{3.14, 2.71, 1.61}, 3): DTypeFloat32,
|
||||
FromValues([]float64{3.14, 2.71, 1.61}, 3): DTypeFloat64,
|
||||
FromValues([]complex64{1 + 2i, 3 + 4i, 5 + 6i}, 3): DTypeComplex64,
|
||||
} {
|
||||
if got.DType() != want {
|
||||
t.Errorf("%s: want %v, got %v", want, want, got)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestComparisonOpsAndBernoulli(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
a := FromValues([]float32{1, 2, 3}, 3)
|
||||
b := FromValues([]float32{1, 1, 4}, 3)
|
||||
eq := a.Equal(b).AsType(DTypeInt32)
|
||||
gt := a.Greater(b).AsType(DTypeInt32)
|
||||
le := a.LessEqual(b).AsType(DTypeInt32)
|
||||
bern := Bernoulli(FromValues([]float32{1, 0}, 2)).AsType(DTypeInt32)
|
||||
Eval(eq, gt, le, bern)
|
||||
|
||||
for name, tc := range map[string]struct {
|
||||
got []int
|
||||
want []int
|
||||
}{
|
||||
"equal": {eq.Ints(), []int{1, 0, 0}},
|
||||
"greater": {gt.Ints(), []int{0, 1, 0}},
|
||||
"lessEqual": {le.Ints(), []int{1, 0, 1}},
|
||||
"bernoulli": {bern.Ints(), []int{1, 0}},
|
||||
} {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
if len(tc.got) != len(tc.want) {
|
||||
t.Fatalf("got %v, want %v", tc.got, tc.want)
|
||||
}
|
||||
for i := range tc.want {
|
||||
if tc.got[i] != tc.want[i] {
|
||||
t.Fatalf("got %v, want %v", tc.got, tc.want)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
192
x/mlxrunner/mlx/compile.go
Normal file
192
x/mlxrunner/mlx/compile.go
Normal file
@@ -0,0 +1,192 @@
|
||||
package mlx
|
||||
|
||||
// #include <stdlib.h>
|
||||
// #include "generated.h"
|
||||
//
|
||||
// extern int closureCallback(mlx_vector_array* res, mlx_vector_array input, void* payload);
|
||||
// extern void closureDestructor(void* payload);
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"runtime/cgo"
|
||||
"sync"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// CompileFunc is the signature of a function that can be compiled.
|
||||
type CompileFunc func(inputs ...*Array) []*Array
|
||||
|
||||
// CompileOption configures Compile behavior.
|
||||
type CompileOption func(*compileConfig)
|
||||
|
||||
type compileConfig struct {
|
||||
shapeless bool
|
||||
}
|
||||
|
||||
// Shapeless traces the function once against symbolic shapes so the compiled
|
||||
// graph accepts any input shape afterwards. Without this option, MLX re-traces
|
||||
// on each new (shape, dtype) combination and caches each specialization.
|
||||
func Shapeless() CompileOption {
|
||||
return func(c *compileConfig) { c.shapeless = true }
|
||||
}
|
||||
|
||||
// Compile returns a compiled version of fn. When called during another
|
||||
// compile's trace, fn is inlined directly so outer compiles can fuse through
|
||||
// inner ones.
|
||||
//
|
||||
// Compiled functions must not have side effects outside of the function. Do
|
||||
// not access data other than the arguments passed in (either Go data or MLX
|
||||
// arrays) unless it is a constant.
|
||||
func Compile(name string, fn CompileFunc, opts ...CompileOption) CompileFunc {
|
||||
var cfg compileConfig
|
||||
for _, o := range opts {
|
||||
o(&cfg)
|
||||
}
|
||||
|
||||
var closure C.mlx_closure
|
||||
var once sync.Once
|
||||
|
||||
return func(inputs ...*Array) []*Array {
|
||||
if tracing {
|
||||
return fn(inputs...)
|
||||
}
|
||||
|
||||
once.Do(func() {
|
||||
payload := (*cgo.Handle)(C.malloc(C.size_t(unsafe.Sizeof(cgo.Handle(0)))))
|
||||
*payload = cgo.NewHandle(fn)
|
||||
src := C.mlx_closure_new_func_payload(
|
||||
(*[0]byte)(C.closureCallback),
|
||||
unsafe.Pointer(payload),
|
||||
(*[0]byte)(C.closureDestructor),
|
||||
)
|
||||
defer C.mlx_closure_free(src)
|
||||
|
||||
closure = C.mlx_closure_new()
|
||||
mlxCheck(name+": compile failed", func() C.int {
|
||||
return C.mlx_compile(&closure, src, C.bool(cfg.shapeless))
|
||||
})
|
||||
})
|
||||
|
||||
inVec := C.mlx_vector_array_new()
|
||||
defer C.mlx_vector_array_free(inVec)
|
||||
for _, in := range inputs {
|
||||
C.mlx_vector_array_append_value(inVec, in.ctx)
|
||||
}
|
||||
|
||||
outVec := C.mlx_vector_array_new()
|
||||
defer C.mlx_vector_array_free(outVec)
|
||||
mlxCheck(name+": closure apply failed", func() C.int {
|
||||
return C.mlx_closure_apply(&outVec, closure, inVec)
|
||||
})
|
||||
|
||||
n := int(C.mlx_vector_array_size(outVec))
|
||||
outputs := make([]*Array, n)
|
||||
for i := range n {
|
||||
outputs[i] = New(name)
|
||||
C.mlx_vector_array_get(&outputs[i].ctx, outVec, C.size_t(i))
|
||||
}
|
||||
return outputs
|
||||
}
|
||||
}
|
||||
|
||||
// Compile1 compiles a unary function. See Compile.
|
||||
func Compile1(name string, fn func(*Array) *Array, opts ...CompileOption) func(*Array) *Array {
|
||||
cf := Compile(name, func(in ...*Array) []*Array {
|
||||
return []*Array{fn(in[0])}
|
||||
}, opts...)
|
||||
return func(a *Array) *Array {
|
||||
return cf(a)[0]
|
||||
}
|
||||
}
|
||||
|
||||
// Compile2 compiles a binary function. See Compile.
|
||||
func Compile2(name string, fn func(*Array, *Array) *Array, opts ...CompileOption) func(*Array, *Array) *Array {
|
||||
cf := Compile(name, func(in ...*Array) []*Array {
|
||||
return []*Array{fn(in[0], in[1])}
|
||||
}, opts...)
|
||||
return func(a, b *Array) *Array {
|
||||
return cf(a, b)[0]
|
||||
}
|
||||
}
|
||||
|
||||
// Compile3 compiles a ternary function. See Compile.
|
||||
func Compile3(name string, fn func(*Array, *Array, *Array) *Array, opts ...CompileOption) func(*Array, *Array, *Array) *Array {
|
||||
cf := Compile(name, func(in ...*Array) []*Array {
|
||||
return []*Array{fn(in[0], in[1], in[2])}
|
||||
}, opts...)
|
||||
return func(a, b, c *Array) *Array {
|
||||
return cf(a, b, c)[0]
|
||||
}
|
||||
}
|
||||
|
||||
// tracing is true while a compile callback is running. Since MLX is
|
||||
// single-threaded at this level a plain Go bool suffices.
|
||||
var tracing bool
|
||||
|
||||
// traceScratch collects arrays created during a compile trace so they can be
|
||||
// freed as a group when the callback returns.
|
||||
var traceScratch []*Array
|
||||
|
||||
//export closureCallback
|
||||
func closureCallback(res *C.mlx_vector_array, input C.mlx_vector_array, payload unsafe.Pointer) (rc C.int) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
slog.Error("mlx closure callback panicked", "panic", r)
|
||||
rc = 1
|
||||
}
|
||||
}()
|
||||
|
||||
handle := *(*cgo.Handle)(payload)
|
||||
fn := handle.Value().(CompileFunc)
|
||||
|
||||
// When tracing, we track all of the intermediates that are created and free them separately at the end of
|
||||
// the process. This will give the effect of a single op - inputs are owned by the original caller (via
|
||||
// the MLX layer) and outputs are transferred back to MLX to create a new Go side tensor.
|
||||
if tracing {
|
||||
panic("mlx: nested compile trace")
|
||||
}
|
||||
tracing = true
|
||||
traceScratch = nil
|
||||
defer func() {
|
||||
for _, a := range traceScratch {
|
||||
if a.pinned.Load() > 0 {
|
||||
panic("mlx: traced array was pinned during compilation")
|
||||
}
|
||||
if a.Valid() {
|
||||
C.mlx_array_free(a.ctx)
|
||||
a.ctx.ctx = nil
|
||||
}
|
||||
}
|
||||
tracing = false
|
||||
traceScratch = nil
|
||||
}()
|
||||
|
||||
n := int(C.mlx_vector_array_size(input))
|
||||
inputs := make([]*Array, n)
|
||||
for i := range n {
|
||||
a := New("")
|
||||
C.mlx_vector_array_get(&a.ctx, input, C.size_t(i))
|
||||
inputs[i] = a
|
||||
}
|
||||
|
||||
outputs := fn(inputs...)
|
||||
|
||||
var arrPtr *C.mlx_array
|
||||
if len(outputs) > 0 {
|
||||
handles := make([]C.mlx_array, len(outputs))
|
||||
for i, out := range outputs {
|
||||
handles[i] = out.ctx
|
||||
}
|
||||
arrPtr = &handles[0]
|
||||
}
|
||||
C.mlx_vector_array_set_data(res, arrPtr, C.size_t(len(outputs)))
|
||||
return 0
|
||||
}
|
||||
|
||||
//export closureDestructor
|
||||
func closureDestructor(payload unsafe.Pointer) {
|
||||
handle := *(*cgo.Handle)(payload)
|
||||
handle.Delete()
|
||||
C.free(payload)
|
||||
}
|
||||
147
x/mlxrunner/mlx/compile_test.go
Normal file
147
x/mlxrunner/mlx/compile_test.go
Normal file
@@ -0,0 +1,147 @@
|
||||
package mlx
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCompileFusion(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
// Compile fuses the ops inside a function body into a single kernel,
|
||||
// eliminating intermediate buffers. Use a diamond-shaped graph where
|
||||
// two branches must be materialized simultaneously without fusion,
|
||||
// then compare peak memory against the compiled version which fuses
|
||||
// everything into one kernel with no intermediates.
|
||||
const n = 1024 * 1024 // 4MB per float32 array
|
||||
data := make([]float32, n)
|
||||
for i := range data {
|
||||
data[i] = float32(i + 1)
|
||||
}
|
||||
|
||||
// Diamond: both a*b and a+b must be live for the final multiply.
|
||||
// Without fusion: peak includes both intermediates (~8MB extra).
|
||||
// With fusion: single kernel, no intermediates.
|
||||
body := func(a, b *Array) *Array {
|
||||
return a.Multiply(b).Multiply(a.Add(b))
|
||||
}
|
||||
|
||||
a := FromValues(data, n)
|
||||
b := FromValues(data, n)
|
||||
Pin(a, b)
|
||||
defer Unpin(a, b)
|
||||
|
||||
// Compiled: ops fused into a single kernel.
|
||||
EnableCompile()
|
||||
fn := Compile2("diamond", body, Shapeless())
|
||||
warm := fn(a, b)
|
||||
Eval(warm)
|
||||
Sweep()
|
||||
ClearCache()
|
||||
ResetPeakMemory()
|
||||
y := fn(a, b)
|
||||
Eval(y)
|
||||
compiledPeak := PeakMemory()
|
||||
Sweep()
|
||||
|
||||
// Uncompiled: ops evaluated individually, intermediates materialized.
|
||||
ClearCache()
|
||||
ResetPeakMemory()
|
||||
z := body(a, b)
|
||||
Eval(z)
|
||||
uncompiledPeak := PeakMemory()
|
||||
Sweep()
|
||||
|
||||
if compiledPeak == 0 && uncompiledPeak == 0 {
|
||||
t.Skip("peak memory tracking not available")
|
||||
}
|
||||
|
||||
t.Logf("peak memory: compiled=%d uncompiled=%d", compiledPeak, uncompiledPeak)
|
||||
|
||||
if compiledPeak >= uncompiledPeak {
|
||||
t.Fatalf("compilation did not reduce peak memory: compiled=%d uncompiled=%d", compiledPeak, uncompiledPeak)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompileNested(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
// A compiled function that calls another compiled function should
|
||||
// produce correct results. The inner function inlines via isTracing()
|
||||
// during the outer's trace.
|
||||
inner := Compile1("silu", func(a *Array) *Array {
|
||||
return a.Multiply(a.Sigmoid())
|
||||
}, Shapeless())
|
||||
|
||||
outer := Compile2("swiglu", func(gate, up *Array) *Array {
|
||||
return inner(gate).Multiply(up)
|
||||
}, Shapeless())
|
||||
|
||||
gate := FromValues([]float32{0, 1, 2}, 3)
|
||||
up := FromValues([]float32{1, 1, 1}, 3)
|
||||
Pin(gate, up)
|
||||
defer Unpin(gate, up)
|
||||
|
||||
y := outer(gate, up)
|
||||
Eval(y)
|
||||
|
||||
// silu(x) = x * sigmoid(x); for x=0 → 0, x=1 → ~0.7311, x=2 → ~1.7616
|
||||
got := y.Floats()
|
||||
want := []float32{0, 0.7310586, 1.7615942}
|
||||
for i, v := range got {
|
||||
if v-want[i] > 1e-4 || want[i]-v > 1e-4 {
|
||||
t.Fatalf("got[%d]=%v want %v", i, v, want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompileCallbackPanicRecovers(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
boom := Compile1("boom", func(a *Array) *Array {
|
||||
panic("intentional test panic")
|
||||
})
|
||||
|
||||
x := FromValues([]float32{1}, 1)
|
||||
Pin(x)
|
||||
defer Unpin(x)
|
||||
|
||||
defer func() {
|
||||
r := recover()
|
||||
if r == nil {
|
||||
t.Fatal("expected panic from Call, got none")
|
||||
}
|
||||
if _, ok := r.(string); !ok {
|
||||
t.Fatalf("expected string panic, got %T: %v", r, r)
|
||||
}
|
||||
}()
|
||||
boom(x)
|
||||
}
|
||||
|
||||
func TestCompileNoTrackingGrowth(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
// Repeated invocations of a compiled kernel should not grow the
|
||||
// tracked-arrays list — the callback's traceScratch collects
|
||||
// intermediates during tracing and frees them when the callback returns.
|
||||
fn := Compile2("mul_add", func(a, b *Array) *Array {
|
||||
return a.Multiply(b).Add(b)
|
||||
})
|
||||
|
||||
a := FromValues([]float32{1, 2}, 2)
|
||||
b := FromValues([]float32{3, 4}, 2)
|
||||
Pin(a, b)
|
||||
defer Unpin(a, b)
|
||||
|
||||
Sweep()
|
||||
before := len(arrays)
|
||||
|
||||
for range 100 {
|
||||
_ = fn(a, b)
|
||||
Sweep()
|
||||
}
|
||||
|
||||
after := len(arrays)
|
||||
if after > before+2 {
|
||||
t.Fatalf("tracked arrays grew from %d to %d across 100 calls (includes initial trace)", before, after)
|
||||
}
|
||||
}
|
||||
94
x/mlxrunner/mlx/dtype.go
Normal file
94
x/mlxrunner/mlx/dtype.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package mlx
|
||||
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
|
||||
type DType int
|
||||
|
||||
func (t DType) String() string {
|
||||
switch t {
|
||||
case DTypeBool:
|
||||
return "BOOL"
|
||||
case DTypeUint8:
|
||||
return "U8"
|
||||
case DTypeUint16:
|
||||
return "U16"
|
||||
case DTypeUint32:
|
||||
return "U32"
|
||||
case DTypeUint64:
|
||||
return "U64"
|
||||
case DTypeInt8:
|
||||
return "I8"
|
||||
case DTypeInt16:
|
||||
return "I16"
|
||||
case DTypeInt32:
|
||||
return "I32"
|
||||
case DTypeInt64:
|
||||
return "I64"
|
||||
case DTypeFloat16:
|
||||
return "F16"
|
||||
case DTypeFloat32:
|
||||
return "F32"
|
||||
case DTypeFloat64:
|
||||
return "F64"
|
||||
case DTypeBFloat16:
|
||||
return "BF16"
|
||||
case DTypeComplex64:
|
||||
return "C64"
|
||||
default:
|
||||
return "Unknown"
|
||||
}
|
||||
}
|
||||
|
||||
func (t *DType) UnmarshalJSON(b []byte) error {
|
||||
switch string(b) {
|
||||
case `"BOOL"`:
|
||||
*t = DTypeBool
|
||||
case `"U8"`:
|
||||
*t = DTypeUint8
|
||||
case `"U16"`:
|
||||
*t = DTypeUint16
|
||||
case `"U32"`:
|
||||
*t = DTypeUint32
|
||||
case `"U64"`:
|
||||
*t = DTypeUint64
|
||||
case `"I8"`:
|
||||
*t = DTypeInt8
|
||||
case `"I16"`:
|
||||
*t = DTypeInt16
|
||||
case `"I32"`:
|
||||
*t = DTypeInt32
|
||||
case `"I64"`:
|
||||
*t = DTypeInt64
|
||||
case `"F16"`:
|
||||
*t = DTypeFloat16
|
||||
case `"F64"`:
|
||||
*t = DTypeFloat64
|
||||
case `"F32"`:
|
||||
*t = DTypeFloat32
|
||||
case `"BF16"`:
|
||||
*t = DTypeBFloat16
|
||||
case `"C64"`:
|
||||
*t = DTypeComplex64
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
const (
|
||||
DTypeBool DType = C.MLX_BOOL
|
||||
DTypeUint8 DType = C.MLX_UINT8
|
||||
DTypeUint16 DType = C.MLX_UINT16
|
||||
DTypeUint32 DType = C.MLX_UINT32
|
||||
DTypeUint64 DType = C.MLX_UINT64
|
||||
DTypeInt8 DType = C.MLX_INT8
|
||||
DTypeInt16 DType = C.MLX_INT16
|
||||
DTypeInt32 DType = C.MLX_INT32
|
||||
DTypeInt64 DType = C.MLX_INT64
|
||||
DTypeFloat16 DType = C.MLX_FLOAT16
|
||||
DTypeFloat32 DType = C.MLX_FLOAT32
|
||||
DTypeFloat64 DType = C.MLX_FLOAT64
|
||||
DTypeBFloat16 DType = C.MLX_BFLOAT16
|
||||
DTypeComplex64 DType = C.MLX_COMPLEX64
|
||||
)
|
||||
36
x/mlxrunner/mlx/dynamic.c
Normal file
36
x/mlxrunner/mlx/dynamic.c
Normal file
@@ -0,0 +1,36 @@
|
||||
#include "dynamic.h"
|
||||
|
||||
#include <stdio.h>
|
||||
|
||||
#ifdef _WIN32
|
||||
#include <windows.h>
|
||||
#define DLOPEN(path) LoadLibraryA(path)
|
||||
#define DLCLOSE(handle) FreeLibrary((HMODULE)(handle))
|
||||
#else
|
||||
#ifdef __APPLE__
|
||||
#include <mach-o/dyld.h>
|
||||
#include <libgen.h>
|
||||
#endif
|
||||
#include <dlfcn.h>
|
||||
#define DLOPEN(path) dlopen(path, RTLD_LAZY | RTLD_GLOBAL)
|
||||
#define DLCLOSE(handle) dlclose(handle)
|
||||
#endif
|
||||
|
||||
static int mlx_dynamic_open(mlx_dynamic_handle* handle, const char* path) {
|
||||
handle->ctx = (void*) DLOPEN(path);
|
||||
if (handle->ctx == NULL) {
|
||||
return 1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
int mlx_dynamic_load(mlx_dynamic_handle* handle, const char *path) {
|
||||
return mlx_dynamic_open(handle, path);
|
||||
}
|
||||
|
||||
void mlx_dynamic_unload(mlx_dynamic_handle* handle) {
|
||||
if (handle->ctx) {
|
||||
DLCLOSE(handle->ctx);
|
||||
handle->ctx = NULL;
|
||||
}
|
||||
}
|
||||
253
x/mlxrunner/mlx/dynamic.go
Normal file
253
x/mlxrunner/mlx/dynamic.go
Normal file
@@ -0,0 +1,253 @@
|
||||
package mlx
|
||||
|
||||
// #include "dynamic.h"
|
||||
// #include "generated.h"
|
||||
// #include <stdlib.h>
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
var initError error
|
||||
var initLoadError string
|
||||
var initLoadedPath string
|
||||
|
||||
// CheckInit returns any error that occurred during MLX dynamic library initialization.
|
||||
func CheckInit() error {
|
||||
if initLoadedPath != "" {
|
||||
slog.Debug("MLX dynamic library loaded", "path", initLoadedPath)
|
||||
}
|
||||
if initError != nil && initLoadError != "" {
|
||||
slog.Error(initLoadError)
|
||||
}
|
||||
return initError
|
||||
}
|
||||
|
||||
// tryLoadFromDir searches a directory for the mlxc shared library and loads it.
|
||||
func tryLoadFromDir(dir string) bool {
|
||||
// On Windows, MSVC produces mlxc.dll (no lib prefix)
|
||||
// On Unix, it's libmlxc.so or libmlxc.dylib
|
||||
pattern := "libmlxc.*"
|
||||
if runtime.GOOS == "windows" {
|
||||
pattern = "mlxc.*"
|
||||
}
|
||||
matches, err := fs.Glob(os.DirFS(dir), pattern)
|
||||
if err != nil || len(matches) == 0 {
|
||||
return false
|
||||
}
|
||||
for _, match := range matches {
|
||||
path := filepath.Join(dir, match)
|
||||
cPath := C.CString(path)
|
||||
defer C.free(unsafe.Pointer(cPath))
|
||||
|
||||
var handle C.mlx_dynamic_handle
|
||||
if C.mlx_dynamic_load(&handle, cPath) != 0 {
|
||||
initLoadError = fmt.Sprintf("failed to load MLX dynamic library: path=%s", path)
|
||||
continue
|
||||
}
|
||||
if C.mlx_dynamic_load_symbols(handle) != 0 {
|
||||
initLoadError = fmt.Sprintf("failed to load MLX dynamic library symbols: path=%s", path)
|
||||
C.mlx_dynamic_unload(&handle)
|
||||
continue
|
||||
}
|
||||
|
||||
initLoadedPath = path
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// libOllamaRoots returns candidate directories for MLX dynamic libraries.
|
||||
// Production: exe_dir/lib/ollama (dist tarball) and exe_dir (app bundle).
|
||||
// Development: build/lib/ollama and build/*/lib/ollama.
|
||||
func libOllamaRoots() []string {
|
||||
var roots []string
|
||||
|
||||
// Production paths relative to executable
|
||||
if exe, err := os.Executable(); err == nil {
|
||||
if eval, err := filepath.EvalSymlinks(exe); err == nil {
|
||||
exe = eval
|
||||
}
|
||||
exeDir := filepath.Dir(exe)
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
roots = append(roots, filepath.Join(exeDir, "lib", "ollama"))
|
||||
roots = append(roots, exeDir) // app bundle: Contents/Resources/
|
||||
case "linux":
|
||||
roots = append(roots, filepath.Join(exeDir, "..", "lib", "ollama"))
|
||||
case "windows":
|
||||
roots = append(roots, filepath.Join(exeDir, "lib", "ollama"))
|
||||
}
|
||||
}
|
||||
|
||||
// Development paths: build/lib/ollama and build/*/lib/ollama.
|
||||
// Reverse-sort and filter the glob results so higher-versioned Metal
|
||||
// builds (e.g., metal-v4) are tried before lower ones (metal-v3),
|
||||
// and incompatible variants are skipped. Without this, alphabetical
|
||||
// order would always pick v3 over v4 in dev builds.
|
||||
for _, base := range repoBuildDirs() {
|
||||
roots = append(roots, filepath.Join(base, "lib", "ollama"))
|
||||
if matches, err := filepath.Glob(filepath.Join(base, "*", "lib", "ollama")); err == nil {
|
||||
sort.Sort(sort.Reverse(sort.StringSlice(matches)))
|
||||
for _, m := range matches {
|
||||
// Extract the build dir name (e.g., "metal-v4" from "build/metal-v4/lib/ollama")
|
||||
rel, _ := filepath.Rel(base, m)
|
||||
variant := strings.SplitN(rel, string(filepath.Separator), 2)[0]
|
||||
if isCompatibleMLXVariant(variant) {
|
||||
roots = append(roots, m)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return roots
|
||||
}
|
||||
|
||||
// repoBuildDirs returns candidate build/ directories relative to cwd and repo root.
|
||||
func repoBuildDirs() []string {
|
||||
var dirs []string
|
||||
if cwd, err := os.Getwd(); err == nil {
|
||||
dirs = append(dirs, filepath.Join(cwd, "build"))
|
||||
for dir := cwd; ; {
|
||||
if _, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil {
|
||||
if dir != cwd {
|
||||
dirs = append(dirs, filepath.Join(dir, "build"))
|
||||
}
|
||||
break
|
||||
}
|
||||
parent := filepath.Dir(dir)
|
||||
if parent == dir {
|
||||
break
|
||||
}
|
||||
dir = parent
|
||||
}
|
||||
}
|
||||
return dirs
|
||||
}
|
||||
|
||||
// prependLibraryPath prepends dir to the platform's dynamic library search
|
||||
// path so the linker finds colocated libmlx before any stale copies.
|
||||
// Called once after successful library load.
|
||||
func prependLibraryPath(dir string) {
|
||||
var envVar string
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
envVar = "DYLD_LIBRARY_PATH"
|
||||
case "linux":
|
||||
envVar = "LD_LIBRARY_PATH"
|
||||
default:
|
||||
return
|
||||
}
|
||||
if existing := os.Getenv(envVar); existing != "" {
|
||||
os.Setenv(envVar, dir+string(filepath.ListSeparator)+existing)
|
||||
} else {
|
||||
os.Setenv(envVar, dir)
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
switch runtime.GOOS {
|
||||
case "darwin", "linux", "windows":
|
||||
default:
|
||||
return
|
||||
}
|
||||
|
||||
// OLLAMA_LLM_LIBRARY overrides variant selection (e.g., "mlx_metal_v3").
|
||||
// When set to an mlx_* value, only that specific subdir is tried.
|
||||
// The GGML runner ignores mlx_* values (see discover/runner.go).
|
||||
forcedVariant, _ := os.LookupEnv("OLLAMA_LLM_LIBRARY")
|
||||
if forcedVariant != "" && !strings.HasPrefix(forcedVariant, "mlx_") {
|
||||
forcedVariant = "" // not an MLX variant, ignore
|
||||
}
|
||||
|
||||
found := findMLXLibrary(forcedVariant)
|
||||
if !found {
|
||||
initError = fmt.Errorf("failed to load MLX dynamic library (searched: %v)", libOllamaRoots())
|
||||
return
|
||||
}
|
||||
|
||||
prependLibraryPath(filepath.Dir(initLoadedPath))
|
||||
}
|
||||
|
||||
func findMLXLibrary(forcedVariant string) bool {
|
||||
for _, root := range libOllamaRoots() {
|
||||
if forcedVariant != "" {
|
||||
if tryLoadFromDir(filepath.Join(root, forcedVariant)) {
|
||||
return true
|
||||
}
|
||||
} else {
|
||||
if tryLoadFromMLXSubdirs(root) {
|
||||
return true
|
||||
}
|
||||
if tryLoadFromDir(root) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// tryLoadFromMLXSubdirs globs for mlx_* subdirs within dir, filters out
|
||||
// incompatible variants, tries the remainder in reverse sorted order (so
|
||||
// higher-versioned variants are preferred), and returns true on first
|
||||
// successful load.
|
||||
func tryLoadFromMLXSubdirs(dir string) bool {
|
||||
mlxDirs, err := filepath.Glob(filepath.Join(dir, "mlx_*"))
|
||||
if err != nil || len(mlxDirs) == 0 {
|
||||
return false
|
||||
}
|
||||
// Reverse sort: mlx_metal_v4 before mlx_metal_v3, mlx_cuda_v13 before v12
|
||||
sort.Sort(sort.Reverse(sort.StringSlice(mlxDirs)))
|
||||
for _, mlxDir := range mlxDirs {
|
||||
if !isCompatibleMLXVariant(filepath.Base(mlxDir)) {
|
||||
slog.Debug("skipping incompatible MLX variant", "dir", mlxDir)
|
||||
continue
|
||||
}
|
||||
if tryLoadFromDir(mlxDir) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isCompatibleMLXVariant checks whether an MLX variant directory is
|
||||
// compatible with the current OS. On macOS, dlopen does NOT enforce
|
||||
// the deployment target for dynamically loaded libraries, so we must
|
||||
// check compatibility ourselves to avoid loading Metal 4.x shaders
|
||||
// on a Metal 3.x driver.
|
||||
func isCompatibleMLXVariant(name string) bool {
|
||||
if runtime.GOOS != "darwin" {
|
||||
return true // non-macOS variants use dlopen failure for filtering
|
||||
}
|
||||
// Metal variant naming:
|
||||
// Production: mlx_metal_v3, mlx_metal_v4
|
||||
// Dev build: metal-v3, metal-v4
|
||||
var verStr string
|
||||
switch {
|
||||
case strings.HasPrefix(name, "mlx_metal_v"):
|
||||
verStr = strings.TrimPrefix(name, "mlx_metal_v")
|
||||
case strings.HasPrefix(name, "metal-v"):
|
||||
verStr = strings.TrimPrefix(name, "metal-v")
|
||||
}
|
||||
if verStr != "" {
|
||||
metalVer, err := strconv.Atoi(verStr)
|
||||
if err != nil {
|
||||
return true // unknown format, try it
|
||||
}
|
||||
// Metal 4.x requires macOS 26+
|
||||
if metalVer >= 4 && macOSMajorVersion() < 26 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
47
x/mlxrunner/mlx/dynamic.h
Normal file
47
x/mlxrunner/mlx/dynamic.h
Normal file
@@ -0,0 +1,47 @@
|
||||
#ifndef MLX_DYNAMIC_H
|
||||
#define MLX_DYNAMIC_H
|
||||
|
||||
#ifdef _WIN32
|
||||
#include <windows.h>
|
||||
#define DLSYM(handle, symbol) (void*)GetProcAddress((HMODULE)(handle.ctx), symbol)
|
||||
#else
|
||||
#include <dlfcn.h>
|
||||
#define DLSYM(handle, symbol) dlsym(handle.ctx, symbol)
|
||||
#endif
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
// Provide fallback typedefs for float16_t and bfloat16_t on non-ARM64
|
||||
// platforms where arm_fp16.h and arm_bf16.h are not available. These are
|
||||
// only used as function pointer signature placeholders since MLX requires
|
||||
// Apple Silicon at runtime.
|
||||
#if !defined(__aarch64__) && !defined(__ARM_FEATURE_FP16_SCALAR_ARITHMETIC)
|
||||
typedef uint16_t float16_t;
|
||||
#endif
|
||||
|
||||
#if !defined(__aarch64__) && !defined(__ARM_FEATURE_BF16)
|
||||
typedef uint16_t bfloat16_t;
|
||||
#endif
|
||||
|
||||
// Undef ERROR to avoid conflict with wingdi.h on Windows
|
||||
#ifdef ERROR
|
||||
#undef ERROR
|
||||
#endif
|
||||
#define MLX_ERROR(fmt, ...) fprintf(stderr, "%s %s - ERROR - %s:%d - " fmt "\n", __DATE__, __TIME__, __FILE__, __LINE__, ##__VA_ARGS__); return 1
|
||||
#define CHECK(x) if (!(x)) { MLX_ERROR("CHECK failed: " #x); }
|
||||
#define CHECK_LOAD(handle, x) *(void**)(&x##_) = DLSYM(handle, #x); CHECK(x##_)
|
||||
// OPTIONAL_LOAD: load symbol if available, leave function pointer NULL otherwise
|
||||
#define OPTIONAL_LOAD(handle, x) *(void**)(&x##_) = DLSYM(handle, #x)
|
||||
|
||||
typedef struct {
|
||||
void* ctx;
|
||||
} mlx_dynamic_handle;
|
||||
|
||||
int mlx_dynamic_load(
|
||||
mlx_dynamic_handle* handle,
|
||||
const char *path);
|
||||
|
||||
void mlx_dynamic_unload(
|
||||
mlx_dynamic_handle* handle);
|
||||
|
||||
#endif // MLX_DYNAMIC_H
|
||||
17
x/mlxrunner/mlx/dynamic_darwin.go
Normal file
17
x/mlxrunner/mlx/dynamic_darwin.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package mlx
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func macOSMajorVersion() int {
|
||||
ver, err := syscall.Sysctl("kern.osproductversion")
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
parts := strings.SplitN(ver, ".", 2)
|
||||
major, _ := strconv.Atoi(parts[0])
|
||||
return major
|
||||
}
|
||||
5
x/mlxrunner/mlx/dynamic_other.go
Normal file
5
x/mlxrunner/mlx/dynamic_other.go
Normal file
@@ -0,0 +1,5 @@
|
||||
//go:build !darwin
|
||||
|
||||
package mlx
|
||||
|
||||
func macOSMajorVersion() int { return 0 }
|
||||
47
x/mlxrunner/mlx/fast.go
Normal file
47
x/mlxrunner/mlx/fast.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package mlx
|
||||
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
func FastScaledDotProductAttention(q, k, v *Array, scale float32, mode string, mask *Array) *Array {
|
||||
sinks := New("")
|
||||
cMode := C.CString(mode)
|
||||
defer C.free(unsafe.Pointer(cMode))
|
||||
|
||||
var maskCtx C.mlx_array
|
||||
if mask != nil {
|
||||
maskCtx = mask.ctx
|
||||
} else {
|
||||
empty := New("")
|
||||
maskCtx = empty.ctx
|
||||
}
|
||||
|
||||
out := New("FAST_SDPA")
|
||||
C.mlx_fast_scaled_dot_product_attention(&out.ctx, q.ctx, k.ctx, v.ctx, C.float(scale), cMode, maskCtx, sinks.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
type LayerNorm struct {
|
||||
Weight *Array `weight:"weight"`
|
||||
Bias *Array `weight:"bias"`
|
||||
}
|
||||
|
||||
func (r *LayerNorm) Forward(x *Array, eps float32) *Array {
|
||||
out := New("FAST_LAYERNORM")
|
||||
C.mlx_fast_layer_norm(&out.ctx, x.ctx, r.Weight.ctx, r.Bias.ctx, C.float(eps), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
type RMSNorm struct {
|
||||
Weight *Array `weight:"weight"`
|
||||
}
|
||||
|
||||
func (r *RMSNorm) Forward(x *Array, eps float32) *Array {
|
||||
out := New("FAST_RMSNORM")
|
||||
C.mlx_fast_rms_norm(&out.ctx, x.ctx, r.Weight.ctx, C.float(eps), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
663
x/mlxrunner/mlx/gated_delta.go
Normal file
663
x/mlxrunner/mlx/gated_delta.go
Normal file
@@ -0,0 +1,663 @@
|
||||
package mlx
|
||||
|
||||
// #include <stdlib.h>
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
var (
|
||||
gatedDeltaMetalKernelOnce sync.Once
|
||||
gatedDeltaMetalKernel C.mlx_fast_metal_kernel
|
||||
gatedDeltaMetalDisabled bool
|
||||
|
||||
gatedDeltaCUDAKernelOnce sync.Once
|
||||
gatedDeltaCUDAKernel C.mlx_fast_cuda_kernel
|
||||
gatedDeltaCUDADisabled bool
|
||||
)
|
||||
|
||||
const gatedDeltaMetalKernelSource = `
|
||||
auto n = thread_position_in_grid.z;
|
||||
auto b_idx = n / Hv;
|
||||
auto hv_idx = n % Hv;
|
||||
auto hk_idx = hv_idx / (Hv / Hk);
|
||||
constexpr int n_per_t = Dk / 32;
|
||||
|
||||
// q, k: [B, T, Hk, Dk]
|
||||
auto q_ = q + b_idx * T * Hk * Dk + hk_idx * Dk;
|
||||
auto k_ = k + b_idx * T * Hk * Dk + hk_idx * Dk;
|
||||
|
||||
// v, y: [B, T, Hv, Dv]
|
||||
auto v_ = v + b_idx * T * Hv * Dv + hv_idx * Dv;
|
||||
y += b_idx * T * Hv * Dv + hv_idx * Dv;
|
||||
|
||||
auto dk_idx = thread_position_in_threadgroup.x;
|
||||
auto dv_idx = thread_position_in_grid.y;
|
||||
|
||||
// state_in, state_out: [B, Hv, Dv, Dk]
|
||||
auto i_state = state_in + (n * Dv + dv_idx) * Dk;
|
||||
auto o_state = state_out + (n * Dv + dv_idx) * Dk;
|
||||
|
||||
float state[n_per_t];
|
||||
for (int i = 0; i < n_per_t; ++i) {
|
||||
auto s_idx = n_per_t * dk_idx + i;
|
||||
state[i] = static_cast<float>(i_state[s_idx]);
|
||||
}
|
||||
|
||||
// g: [B, T, Hv]
|
||||
auto g_ = g + b_idx * T * Hv;
|
||||
auto beta_ = beta + b_idx * T * Hv;
|
||||
|
||||
for (int t = 0; t < T; ++t) {
|
||||
float kv_mem = 0.0f;
|
||||
for (int i = 0; i < n_per_t; ++i) {
|
||||
auto s_idx = n_per_t * dk_idx + i;
|
||||
state[i] = state[i] * g_[hv_idx];
|
||||
kv_mem += state[i] * k_[s_idx];
|
||||
}
|
||||
kv_mem = simd_sum(kv_mem);
|
||||
|
||||
auto delta = (v_[dv_idx] - kv_mem) * beta_[hv_idx];
|
||||
|
||||
float out = 0.0f;
|
||||
for (int i = 0; i < n_per_t; ++i) {
|
||||
auto s_idx = n_per_t * dk_idx + i;
|
||||
state[i] = state[i] + k_[s_idx] * delta;
|
||||
out += state[i] * q_[s_idx];
|
||||
}
|
||||
out = simd_sum(out);
|
||||
if (thread_index_in_simdgroup == 0) {
|
||||
y[dv_idx] = static_cast<InT>(out);
|
||||
}
|
||||
|
||||
q_ += Hk * Dk;
|
||||
k_ += Hk * Dk;
|
||||
v_ += Hv * Dv;
|
||||
y += Hv * Dv;
|
||||
g_ += Hv;
|
||||
beta_ += Hv;
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_per_t; ++i) {
|
||||
auto s_idx = n_per_t * dk_idx + i;
|
||||
o_state[s_idx] = static_cast<StT>(state[i]);
|
||||
}
|
||||
`
|
||||
|
||||
const gatedDeltaCUDAKernelSource = `
|
||||
auto tid_x = threadIdx.x;
|
||||
auto tid_y = threadIdx.y;
|
||||
auto grid_y = blockIdx.y * blockDim.y + tid_y;
|
||||
auto grid_z = blockIdx.z;
|
||||
|
||||
int T_val = static_cast<int>(*T);
|
||||
|
||||
auto n = grid_z;
|
||||
auto b_idx = n / Hv;
|
||||
auto hv_idx = n % Hv;
|
||||
auto hk_idx = hv_idx / (Hv / Hk);
|
||||
constexpr int n_per_t = Dk / 32;
|
||||
|
||||
// q, k: [B, T, Hk, Dk]
|
||||
auto q_ = q + b_idx * T_val * Hk * Dk + hk_idx * Dk;
|
||||
auto k_ = k + b_idx * T_val * Hk * Dk + hk_idx * Dk;
|
||||
|
||||
// v, y: [B, T, Hv, Dv]
|
||||
auto dv_idx = grid_y;
|
||||
auto v_ = v + b_idx * T_val * Hv * Dv + hv_idx * Dv;
|
||||
y += b_idx * T_val * Hv * Dv + hv_idx * Dv;
|
||||
|
||||
auto dk_idx = tid_x;
|
||||
|
||||
// state_in, state_out: [B, Hv, Dv, Dk]
|
||||
auto i_state = state_in + (n * Dv + dv_idx) * Dk;
|
||||
auto o_state = state_out + (n * Dv + dv_idx) * Dk;
|
||||
|
||||
float state[n_per_t];
|
||||
for (int i = 0; i < n_per_t; ++i) {
|
||||
auto s_idx = n_per_t * dk_idx + i;
|
||||
state[i] = static_cast<float>(i_state[s_idx]);
|
||||
}
|
||||
|
||||
// g: [B, T, Hv]
|
||||
auto g_ = g + b_idx * T_val * Hv;
|
||||
auto beta_ = beta + b_idx * T_val * Hv;
|
||||
|
||||
for (int t = 0; t < T_val; ++t) {
|
||||
float kv_mem = 0.0f;
|
||||
for (int i = 0; i < n_per_t; ++i) {
|
||||
auto s_idx = n_per_t * dk_idx + i;
|
||||
state[i] = state[i] * static_cast<float>(g_[hv_idx]);
|
||||
kv_mem += state[i] * static_cast<float>(k_[s_idx]);
|
||||
}
|
||||
// Warp reduction (full warp, 32 threads in x)
|
||||
for (int offset = 16; offset > 0; offset >>= 1)
|
||||
kv_mem += __shfl_down_sync(0xffffffff, kv_mem, offset);
|
||||
kv_mem = __shfl_sync(0xffffffff, kv_mem, 0);
|
||||
|
||||
auto delta = (static_cast<float>(v_[dv_idx]) - kv_mem) * static_cast<float>(beta_[hv_idx]);
|
||||
|
||||
float out = 0.0f;
|
||||
for (int i = 0; i < n_per_t; ++i) {
|
||||
auto s_idx = n_per_t * dk_idx + i;
|
||||
state[i] = state[i] + static_cast<float>(k_[s_idx]) * delta;
|
||||
out += state[i] * static_cast<float>(q_[s_idx]);
|
||||
}
|
||||
// Warp reduction
|
||||
for (int offset = 16; offset > 0; offset >>= 1)
|
||||
out += __shfl_down_sync(0xffffffff, out, offset);
|
||||
if (tid_x == 0) {
|
||||
y[dv_idx] = static_cast<InT>(out);
|
||||
}
|
||||
|
||||
q_ += Hk * Dk;
|
||||
k_ += Hk * Dk;
|
||||
v_ += Hv * Dv;
|
||||
y += Hv * Dv;
|
||||
g_ += Hv;
|
||||
beta_ += Hv;
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_per_t; ++i) {
|
||||
auto s_idx = n_per_t * dk_idx + i;
|
||||
o_state[s_idx] = static_cast<StT>(state[i]);
|
||||
}
|
||||
`
|
||||
|
||||
func cStringVector(values []string) (C.mlx_vector_string, func(), bool) {
|
||||
vec := C.mlx_vector_string_new()
|
||||
ok := true
|
||||
for _, s := range values {
|
||||
cs := C.CString(s)
|
||||
if C.mlx_vector_string_append_value(vec, cs) != 0 {
|
||||
ok = false
|
||||
}
|
||||
C.free(unsafe.Pointer(cs))
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
}
|
||||
cleanup := func() {
|
||||
C.mlx_vector_string_free(vec)
|
||||
}
|
||||
return vec, cleanup, ok
|
||||
}
|
||||
|
||||
func initGatedDeltaMetalKernel() {
|
||||
inputs, freeInputs, ok := cStringVector([]string{"q", "k", "v", "g", "beta", "state_in", "T"})
|
||||
if !ok {
|
||||
gatedDeltaMetalDisabled = true
|
||||
freeInputs()
|
||||
return
|
||||
}
|
||||
defer freeInputs()
|
||||
|
||||
outputs, freeOutputs, ok := cStringVector([]string{"y", "state_out"})
|
||||
if !ok {
|
||||
gatedDeltaMetalDisabled = true
|
||||
freeOutputs()
|
||||
return
|
||||
}
|
||||
defer freeOutputs()
|
||||
|
||||
cName := C.CString("gated_delta_step")
|
||||
defer C.free(unsafe.Pointer(cName))
|
||||
cSource := C.CString(gatedDeltaMetalKernelSource)
|
||||
defer C.free(unsafe.Pointer(cSource))
|
||||
cHeader := C.CString("")
|
||||
defer C.free(unsafe.Pointer(cHeader))
|
||||
|
||||
gatedDeltaMetalKernel = C.mlx_fast_metal_kernel_new(
|
||||
cName,
|
||||
inputs,
|
||||
outputs,
|
||||
cSource,
|
||||
cHeader,
|
||||
C.bool(true),
|
||||
C.bool(false),
|
||||
)
|
||||
}
|
||||
|
||||
// gatedDeltaKernel runs a fused Metal kernel for the qwen3.5 recurrent update.
|
||||
// It returns ok=false on unsupported shapes/devices or kernel setup/apply failure.
|
||||
func gatedDeltaKernel(q, k, v, g, beta, state *Array) (y, nextState *Array, ok bool) {
|
||||
if gatedDeltaMetalDisabled {
|
||||
return nil, nil, false
|
||||
}
|
||||
if q == nil || k == nil || v == nil || g == nil || beta == nil || state == nil {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
qd := q.Dims()
|
||||
kd := k.Dims()
|
||||
vd := v.Dims()
|
||||
gd := g.Dims()
|
||||
bd := beta.Dims()
|
||||
sd := state.Dims()
|
||||
if len(qd) != 4 || len(kd) != 4 || len(vd) != 4 || len(gd) != 3 || len(bd) != 3 || len(sd) != 4 {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
B, T, Hk, Dk := qd[0], qd[1], qd[2], qd[3]
|
||||
if T <= 0 || Hk <= 0 || Dk <= 0 || Dk%32 != 0 {
|
||||
return nil, nil, false
|
||||
}
|
||||
if kd[0] != B || kd[1] != T || kd[2] != Hk || kd[3] != Dk {
|
||||
return nil, nil, false
|
||||
}
|
||||
Hv, Dv := vd[2], vd[3]
|
||||
if vd[0] != B || vd[1] != T || Hv <= 0 || Dv <= 0 || Hv%Hk != 0 {
|
||||
return nil, nil, false
|
||||
}
|
||||
if gd[0] != B || gd[1] != T || gd[2] != Hv {
|
||||
return nil, nil, false
|
||||
}
|
||||
if bd[0] != B || bd[1] != T || bd[2] != Hv {
|
||||
return nil, nil, false
|
||||
}
|
||||
if sd[0] != B || sd[1] != Hv || sd[2] != Dv || sd[3] != Dk {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
inputDType := q.DType()
|
||||
stateDType := state.DType()
|
||||
if k.DType() != inputDType || v.DType() != inputDType || g.DType() != inputDType || beta.DType() != inputDType {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
gatedDeltaMetalKernelOnce.Do(initGatedDeltaMetalKernel)
|
||||
if gatedDeltaMetalDisabled {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
cfg := C.mlx_fast_metal_kernel_config_new()
|
||||
defer C.mlx_fast_metal_kernel_config_free(cfg)
|
||||
|
||||
cInT := C.CString("InT")
|
||||
defer C.free(unsafe.Pointer(cInT))
|
||||
if C.mlx_fast_metal_kernel_config_add_template_arg_dtype(cfg, cInT, C.mlx_dtype(inputDType)) != 0 {
|
||||
gatedDeltaMetalDisabled = true
|
||||
return nil, nil, false
|
||||
}
|
||||
cStT := C.CString("StT")
|
||||
defer C.free(unsafe.Pointer(cStT))
|
||||
if C.mlx_fast_metal_kernel_config_add_template_arg_dtype(cfg, cStT, C.mlx_dtype(stateDType)) != 0 {
|
||||
gatedDeltaMetalDisabled = true
|
||||
return nil, nil, false
|
||||
}
|
||||
for _, tpl := range []struct {
|
||||
name string
|
||||
value int
|
||||
}{
|
||||
{name: "Dk", value: Dk},
|
||||
{name: "Dv", value: Dv},
|
||||
{name: "Hk", value: Hk},
|
||||
{name: "Hv", value: Hv},
|
||||
} {
|
||||
cn := C.CString(tpl.name)
|
||||
rc := C.mlx_fast_metal_kernel_config_add_template_arg_int(cfg, cn, C.int(tpl.value))
|
||||
C.free(unsafe.Pointer(cn))
|
||||
if rc != 0 {
|
||||
gatedDeltaMetalDisabled = true
|
||||
return nil, nil, false
|
||||
}
|
||||
}
|
||||
|
||||
yShape := []C.int{C.int(B), C.int(T), C.int(Hv), C.int(Dv)}
|
||||
stateShape := []C.int{C.int(B), C.int(Hv), C.int(Dv), C.int(Dk)}
|
||||
if C.mlx_fast_metal_kernel_config_add_output_arg(cfg, unsafe.SliceData(yShape), C.size_t(len(yShape)), C.mlx_dtype(inputDType)) != 0 {
|
||||
gatedDeltaMetalDisabled = true
|
||||
return nil, nil, false
|
||||
}
|
||||
if C.mlx_fast_metal_kernel_config_add_output_arg(cfg, unsafe.SliceData(stateShape), C.size_t(len(stateShape)), C.mlx_dtype(stateDType)) != 0 {
|
||||
gatedDeltaMetalDisabled = true
|
||||
return nil, nil, false
|
||||
}
|
||||
if C.mlx_fast_metal_kernel_config_set_grid(cfg, 32, C.int(Dv), C.int(B*Hv)) != 0 {
|
||||
gatedDeltaMetalDisabled = true
|
||||
return nil, nil, false
|
||||
}
|
||||
threadY := Dv
|
||||
if threadY > 4 {
|
||||
threadY = 4
|
||||
}
|
||||
if C.mlx_fast_metal_kernel_config_set_thread_group(cfg, 32, C.int(threadY), 1) != 0 {
|
||||
gatedDeltaMetalDisabled = true
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
tScalar := FromValue(T)
|
||||
inputs := []C.mlx_array{
|
||||
q.ctx,
|
||||
k.ctx,
|
||||
v.ctx,
|
||||
g.ctx,
|
||||
beta.ctx,
|
||||
state.ctx,
|
||||
tScalar.ctx,
|
||||
}
|
||||
inVec := C.mlx_vector_array_new_data(unsafe.SliceData(inputs), C.size_t(len(inputs)))
|
||||
defer C.mlx_vector_array_free(inVec)
|
||||
|
||||
outVec := C.mlx_vector_array_new()
|
||||
defer C.mlx_vector_array_free(outVec)
|
||||
if C.mlx_fast_metal_kernel_apply(&outVec, gatedDeltaMetalKernel, inVec, cfg, DefaultStream().ctx) != 0 {
|
||||
gatedDeltaMetalDisabled = true
|
||||
return nil, nil, false
|
||||
}
|
||||
if int(C.mlx_vector_array_size(outVec)) < 2 {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
y = New("GATED_DELTA_METAL_Y")
|
||||
nextState = New("GATED_DELTA_METAL_STATE")
|
||||
C.mlx_vector_array_get(&y.ctx, outVec, 0)
|
||||
C.mlx_vector_array_get(&nextState.ctx, outVec, 1)
|
||||
return y, nextState, true
|
||||
}
|
||||
|
||||
func repeatHeadsForGatedDelta(x *Array, repeatFactor int) *Array {
|
||||
if repeatFactor <= 1 {
|
||||
return x
|
||||
}
|
||||
shape := x.Dims()
|
||||
x = ExpandDims(x, 3)
|
||||
x = Tile(x, []int32{1, 1, 1, int32(repeatFactor), 1})
|
||||
return Reshape(x, int32(shape[0]), int32(shape[1]), int32(shape[2]*repeatFactor), int32(shape[3]))
|
||||
}
|
||||
|
||||
func gatedDeltaFallback(q, k, v, g, beta, state *Array) (y, nextState *Array) {
|
||||
if q == nil || k == nil || v == nil || g == nil || beta == nil || state == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
qd := q.Dims()
|
||||
kd := k.Dims()
|
||||
vd := v.Dims()
|
||||
gd := g.Dims()
|
||||
bd := beta.Dims()
|
||||
sd := state.Dims()
|
||||
if len(qd) != 4 || len(kd) != 4 || len(vd) != 4 || len(gd) != 3 || len(bd) != 3 || len(sd) != 4 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
B, T, Hk, Dk := int32(qd[0]), int32(qd[1]), int32(qd[2]), int32(qd[3])
|
||||
Hv, Dv := int32(vd[2]), int32(vd[3])
|
||||
if T <= 0 || Hk <= 0 || Dk <= 0 || Hv <= 0 || Dv <= 0 || Hv%Hk != 0 {
|
||||
return nil, nil
|
||||
}
|
||||
if kd[0] != int(B) || kd[1] != int(T) || kd[2] != int(Hk) || kd[3] != int(Dk) {
|
||||
return nil, nil
|
||||
}
|
||||
if vd[0] != int(B) || vd[1] != int(T) {
|
||||
return nil, nil
|
||||
}
|
||||
if gd[0] != int(B) || gd[1] != int(T) || gd[2] != int(Hv) {
|
||||
return nil, nil
|
||||
}
|
||||
if bd[0] != int(B) || bd[1] != int(T) || bd[2] != int(Hv) {
|
||||
return nil, nil
|
||||
}
|
||||
if sd[0] != int(B) || sd[1] != int(Hv) || sd[2] != int(Dv) || sd[3] != int(Dk) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
repeatFactor := int(Hv / Hk)
|
||||
q = repeatHeadsForGatedDelta(q, repeatFactor)
|
||||
k = repeatHeadsForGatedDelta(k, repeatFactor)
|
||||
|
||||
nextState = state
|
||||
if T == 1 {
|
||||
qt := Squeeze(q, 1)
|
||||
kt := Squeeze(k, 1)
|
||||
vt := Squeeze(v, 1)
|
||||
gt := Squeeze(g, 1)
|
||||
bt := Squeeze(beta, 1)
|
||||
|
||||
nextState = Mul(nextState, ExpandDims(ExpandDims(gt, -1), -1))
|
||||
kvMem := Sum(Mul(nextState, ExpandDims(kt, 2)), -1, false)
|
||||
delta := Mul(Sub(vt, kvMem), ExpandDims(bt, -1))
|
||||
nextState = Add(nextState, Mul(ExpandDims(kt, 2), ExpandDims(delta, -1)))
|
||||
yt := Sum(Mul(nextState, ExpandDims(qt, 2)), -1, false)
|
||||
return ExpandDims(yt, 1), nextState
|
||||
}
|
||||
|
||||
outs := make([]*Array, 0, T)
|
||||
for t := int32(0); t < T; t++ {
|
||||
qt := Squeeze(SliceStartStop(q, []int32{0, t, 0, 0}, []int32{B, t + 1, Hv, Dk}), 1)
|
||||
kt := Squeeze(SliceStartStop(k, []int32{0, t, 0, 0}, []int32{B, t + 1, Hv, Dk}), 1)
|
||||
vt := Squeeze(SliceStartStop(v, []int32{0, t, 0, 0}, []int32{B, t + 1, Hv, Dv}), 1)
|
||||
gt := Squeeze(SliceStartStop(g, []int32{0, t, 0}, []int32{B, t + 1, Hv}), 1)
|
||||
bt := Squeeze(SliceStartStop(beta, []int32{0, t, 0}, []int32{B, t + 1, Hv}), 1)
|
||||
|
||||
nextState = Mul(nextState, ExpandDims(ExpandDims(gt, -1), -1))
|
||||
kvMem := Sum(Mul(nextState, ExpandDims(kt, 2)), -1, false)
|
||||
delta := Mul(Sub(vt, kvMem), ExpandDims(bt, -1))
|
||||
nextState = Add(nextState, Mul(ExpandDims(kt, 2), ExpandDims(delta, -1)))
|
||||
yt := Sum(Mul(nextState, ExpandDims(qt, 2)), -1, false)
|
||||
outs = append(outs, ExpandDims(yt, 1))
|
||||
}
|
||||
return Concatenate(outs, 1), nextState
|
||||
}
|
||||
|
||||
func initGatedDeltaCUDAKernel() {
|
||||
var cudaAvail C.bool
|
||||
if C.mlx_cuda_is_available(&cudaAvail) != 0 || !bool(cudaAvail) {
|
||||
gatedDeltaCUDADisabled = true
|
||||
return
|
||||
}
|
||||
|
||||
inputs, freeInputs, ok := cStringVector([]string{"q", "k", "v", "g", "beta", "state_in", "T"})
|
||||
if !ok {
|
||||
gatedDeltaCUDADisabled = true
|
||||
freeInputs()
|
||||
return
|
||||
}
|
||||
defer freeInputs()
|
||||
|
||||
outputs, freeOutputs, ok := cStringVector([]string{"y", "state_out"})
|
||||
if !ok {
|
||||
gatedDeltaCUDADisabled = true
|
||||
freeOutputs()
|
||||
return
|
||||
}
|
||||
defer freeOutputs()
|
||||
|
||||
cName := C.CString("gated_delta_step")
|
||||
defer C.free(unsafe.Pointer(cName))
|
||||
cSource := C.CString(gatedDeltaCUDAKernelSource)
|
||||
defer C.free(unsafe.Pointer(cSource))
|
||||
cHeader := C.CString("")
|
||||
defer C.free(unsafe.Pointer(cHeader))
|
||||
|
||||
gatedDeltaCUDAKernel = C.mlx_fast_cuda_kernel_new(
|
||||
cName,
|
||||
inputs,
|
||||
outputs,
|
||||
cSource,
|
||||
cHeader,
|
||||
C.bool(true),
|
||||
C.int(0),
|
||||
)
|
||||
}
|
||||
|
||||
func gatedDeltaCUDAKernelApply(q, k, v, g, beta, state *Array) (y, nextState *Array, ok bool) {
|
||||
if gatedDeltaCUDADisabled {
|
||||
return nil, nil, false
|
||||
}
|
||||
if q == nil || k == nil || v == nil || g == nil || beta == nil || state == nil {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
qd := q.Dims()
|
||||
kd := k.Dims()
|
||||
vd := v.Dims()
|
||||
gd := g.Dims()
|
||||
bd := beta.Dims()
|
||||
sd := state.Dims()
|
||||
if len(qd) != 4 || len(kd) != 4 || len(vd) != 4 || len(gd) != 3 || len(bd) != 3 || len(sd) != 4 {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
B, T, Hk, Dk := qd[0], qd[1], qd[2], qd[3]
|
||||
if T <= 0 || Hk <= 0 || Dk <= 0 || Dk%32 != 0 {
|
||||
return nil, nil, false
|
||||
}
|
||||
if kd[0] != B || kd[1] != T || kd[2] != Hk || kd[3] != Dk {
|
||||
return nil, nil, false
|
||||
}
|
||||
Hv, Dv := vd[2], vd[3]
|
||||
if vd[0] != B || vd[1] != T || Hv <= 0 || Dv <= 0 || Hv%Hk != 0 {
|
||||
return nil, nil, false
|
||||
}
|
||||
if gd[0] != B || gd[1] != T || gd[2] != Hv {
|
||||
return nil, nil, false
|
||||
}
|
||||
if bd[0] != B || bd[1] != T || bd[2] != Hv {
|
||||
return nil, nil, false
|
||||
}
|
||||
if sd[0] != B || sd[1] != Hv || sd[2] != Dv || sd[3] != Dk {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
inputDType := q.DType()
|
||||
stateDType := state.DType()
|
||||
if k.DType() != inputDType || v.DType() != inputDType || g.DType() != inputDType || beta.DType() != inputDType {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
gatedDeltaCUDAKernelOnce.Do(initGatedDeltaCUDAKernel)
|
||||
if gatedDeltaCUDADisabled {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
cfg := C.mlx_fast_cuda_kernel_config_new()
|
||||
defer C.mlx_fast_cuda_kernel_config_free(cfg)
|
||||
|
||||
cInT := C.CString("InT")
|
||||
defer C.free(unsafe.Pointer(cInT))
|
||||
if C.mlx_fast_cuda_kernel_config_add_template_arg_dtype(cfg, cInT, C.mlx_dtype(inputDType)) != 0 {
|
||||
gatedDeltaCUDADisabled = true
|
||||
return nil, nil, false
|
||||
}
|
||||
cStT := C.CString("StT")
|
||||
defer C.free(unsafe.Pointer(cStT))
|
||||
if C.mlx_fast_cuda_kernel_config_add_template_arg_dtype(cfg, cStT, C.mlx_dtype(stateDType)) != 0 {
|
||||
gatedDeltaCUDADisabled = true
|
||||
return nil, nil, false
|
||||
}
|
||||
for _, tpl := range []struct {
|
||||
name string
|
||||
value int
|
||||
}{
|
||||
{name: "Dk", value: Dk},
|
||||
{name: "Dv", value: Dv},
|
||||
{name: "Hk", value: Hk},
|
||||
{name: "Hv", value: Hv},
|
||||
} {
|
||||
cn := C.CString(tpl.name)
|
||||
rc := C.mlx_fast_cuda_kernel_config_add_template_arg_int(cfg, cn, C.int(tpl.value))
|
||||
C.free(unsafe.Pointer(cn))
|
||||
if rc != 0 {
|
||||
gatedDeltaCUDADisabled = true
|
||||
return nil, nil, false
|
||||
}
|
||||
}
|
||||
|
||||
yShape := []C.int{C.int(B), C.int(T), C.int(Hv), C.int(Dv)}
|
||||
stateShape := []C.int{C.int(B), C.int(Hv), C.int(Dv), C.int(Dk)}
|
||||
if C.mlx_fast_cuda_kernel_config_add_output_arg(cfg, unsafe.SliceData(yShape), C.size_t(len(yShape)), C.mlx_dtype(inputDType)) != 0 {
|
||||
gatedDeltaCUDADisabled = true
|
||||
return nil, nil, false
|
||||
}
|
||||
if C.mlx_fast_cuda_kernel_config_add_output_arg(cfg, unsafe.SliceData(stateShape), C.size_t(len(stateShape)), C.mlx_dtype(stateDType)) != 0 {
|
||||
gatedDeltaCUDADisabled = true
|
||||
return nil, nil, false
|
||||
}
|
||||
if C.mlx_fast_cuda_kernel_config_set_grid(cfg, 32, C.int(Dv), C.int(B*Hv)) != 0 {
|
||||
gatedDeltaCUDADisabled = true
|
||||
return nil, nil, false
|
||||
}
|
||||
threadY := Dv
|
||||
if threadY > 4 {
|
||||
threadY = 4
|
||||
}
|
||||
if C.mlx_fast_cuda_kernel_config_set_thread_group(cfg, 32, C.int(threadY), 1) != 0 {
|
||||
gatedDeltaCUDADisabled = true
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
tScalar := FromValue(T)
|
||||
inputs := []C.mlx_array{
|
||||
q.ctx,
|
||||
k.ctx,
|
||||
v.ctx,
|
||||
g.ctx,
|
||||
beta.ctx,
|
||||
state.ctx,
|
||||
tScalar.ctx,
|
||||
}
|
||||
inVec := C.mlx_vector_array_new_data(unsafe.SliceData(inputs), C.size_t(len(inputs)))
|
||||
defer C.mlx_vector_array_free(inVec)
|
||||
|
||||
outVec := C.mlx_vector_array_new()
|
||||
defer C.mlx_vector_array_free(outVec)
|
||||
if C.mlx_fast_cuda_kernel_apply(&outVec, gatedDeltaCUDAKernel, inVec, cfg, DefaultStream().ctx) != 0 {
|
||||
gatedDeltaCUDADisabled = true
|
||||
return nil, nil, false
|
||||
}
|
||||
if int(C.mlx_vector_array_size(outVec)) < 2 {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
y = New("GATED_DELTA_CUDA_Y")
|
||||
nextState = New("GATED_DELTA_CUDA_STATE")
|
||||
C.mlx_vector_array_get(&y.ctx, outVec, 0)
|
||||
C.mlx_vector_array_get(&nextState.ctx, outVec, 1)
|
||||
return y, nextState, true
|
||||
}
|
||||
|
||||
// FastGatedDelta runs the recurrent update operation.
|
||||
//
|
||||
// When mask is non-nil, it must be a [B, T] bool tensor identifying real
|
||||
// (true) vs. padded (false) positions in q/k/v/g/beta. Padded positions
|
||||
// are substituted with neutral values (q=k=v=beta=0, g=1) so each padded
|
||||
// kernel iteration is a no-op — state passes through unchanged and the
|
||||
// final state equals the state after the last real token of each row.
|
||||
//
|
||||
// It tries the fused CUDA kernel first, then Metal, then falls back to a
|
||||
// backend-agnostic MLX implementation with identical inputs/outputs.
|
||||
func FastGatedDelta(q, k, v, g, beta, state, mask *Array) (y, nextState *Array) {
|
||||
// TODO: handle this more efficiently with a masked kernel (MLX-LM has one).
|
||||
if mask != nil {
|
||||
B := int32(mask.Dim(0))
|
||||
T := int32(mask.Dim(1))
|
||||
m4 := Reshape(mask, B, T, 1, 1)
|
||||
m3 := Reshape(mask, B, T, 1)
|
||||
zeroQ := FromValue(float32(0)).AsType(q.DType())
|
||||
zeroK := FromValue(float32(0)).AsType(k.DType())
|
||||
zeroV := FromValue(float32(0)).AsType(v.DType())
|
||||
zeroBeta := FromValue(float32(0)).AsType(beta.DType())
|
||||
oneG := FromValue(float32(1)).AsType(g.DType())
|
||||
q = Where(m4, q, zeroQ)
|
||||
k = Where(m4, k, zeroK)
|
||||
v = Where(m4, v, zeroV)
|
||||
beta = Where(m3, beta, zeroBeta)
|
||||
g = Where(m3, g, oneG)
|
||||
}
|
||||
|
||||
if y, nextState, ok := gatedDeltaCUDAKernelApply(q, k, v, g, beta, state); ok {
|
||||
return y, nextState
|
||||
}
|
||||
if y, nextState, ok := gatedDeltaKernel(q, k, v, g, beta, state); ok {
|
||||
return y, nextState
|
||||
}
|
||||
y, nextState = gatedDeltaFallback(q, k, v, g, beta, state)
|
||||
if y == nil || nextState == nil {
|
||||
panic("mlx.FastGatedDelta: fallback failed (invalid inputs or unsupported shapes)")
|
||||
}
|
||||
return y, nextState
|
||||
}
|
||||
3012
x/mlxrunner/mlx/generated.c
Normal file
3012
x/mlxrunner/mlx/generated.c
Normal file
File diff suppressed because it is too large
Load Diff
7256
x/mlxrunner/mlx/generated.h
Normal file
7256
x/mlxrunner/mlx/generated.h
Normal file
File diff suppressed because it is too large
Load Diff
17
x/mlxrunner/mlx/generator/generated.c.gotmpl
Normal file
17
x/mlxrunner/mlx/generator/generated.c.gotmpl
Normal file
@@ -0,0 +1,17 @@
|
||||
// This code is auto-generated; DO NOT EDIT.
|
||||
|
||||
#include "generated.h"
|
||||
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
{{ range .Functions }}
|
||||
{{ .Type }} (*{{ .Name }}_){{ .Parameters }} = NULL;
|
||||
{{- end }}
|
||||
|
||||
int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
{{- range .Functions }}
|
||||
{{ if .Optional }}OPTIONAL_LOAD{{ else }}CHECK_LOAD{{ end }}(handle, {{ .Name }});
|
||||
{{- end }}
|
||||
return 0;
|
||||
}
|
||||
26
x/mlxrunner/mlx/generator/generated.h.gotmpl
Normal file
26
x/mlxrunner/mlx/generator/generated.h.gotmpl
Normal file
@@ -0,0 +1,26 @@
|
||||
// This code is auto-generated; DO NOT EDIT.
|
||||
|
||||
#ifndef MLX_GENERATED_H
|
||||
#define MLX_GENERATED_H
|
||||
|
||||
#include "dynamic.h"
|
||||
{{ range .Functions }}
|
||||
#define {{ .Name }} {{ .Name }}_mlx_gen_orig_
|
||||
{{- end }}
|
||||
|
||||
#include "mlx/c/mlx.h"
|
||||
{{ range .Functions }}
|
||||
#undef {{ .Name }}
|
||||
{{- end }}
|
||||
{{ range .Functions }}
|
||||
extern {{ .Type }} (*{{ .Name }}_){{ .Parameters }};
|
||||
{{- end }}
|
||||
|
||||
int mlx_dynamic_load_symbols(mlx_dynamic_handle handle);
|
||||
{{ range .Functions }}
|
||||
static inline {{ .Type }} {{ .Name }}{{ .Parameters }} {{ "{" }}
|
||||
return {{ .Name }}_({{ .Args }});
|
||||
{{ "}" }}
|
||||
{{- end }}
|
||||
|
||||
#endif // MLX_GENERATED_H
|
||||
157
x/mlxrunner/mlx/generator/main.go
Normal file
157
x/mlxrunner/mlx/generator/main.go
Normal file
@@ -0,0 +1,157 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"embed"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
tree_sitter "github.com/tree-sitter/go-tree-sitter"
|
||||
tree_sitter_cpp "github.com/tree-sitter/tree-sitter-cpp/bindings/go"
|
||||
)
|
||||
|
||||
//go:embed *.gotmpl
|
||||
var fsys embed.FS
|
||||
|
||||
// optionalSymbols lists symbols that may not be present in all builds
|
||||
// (e.g., float16/bfloat16 are unavailable in CUDA builds of MLX).
|
||||
var optionalSymbols = map[string]bool{
|
||||
"mlx_array_item_float16": true,
|
||||
"mlx_array_item_bfloat16": true,
|
||||
"mlx_array_data_float16": true,
|
||||
"mlx_array_data_bfloat16": true,
|
||||
}
|
||||
|
||||
type Function struct {
|
||||
Type,
|
||||
Name,
|
||||
Parameters,
|
||||
Args string
|
||||
Optional bool
|
||||
}
|
||||
|
||||
func ParseFunction(node *tree_sitter.Node, tc *tree_sitter.TreeCursor, source []byte) Function {
|
||||
var fn Function
|
||||
fn.Name = node.ChildByFieldName("declarator").Utf8Text(source)
|
||||
if params := node.ChildByFieldName("parameters"); params != nil {
|
||||
fn.Parameters = params.Utf8Text(source)
|
||||
fn.Args = ParseParameters(params, tc, source)
|
||||
}
|
||||
|
||||
var types []string
|
||||
for node.Parent() != nil && node.Parent().Kind() != "declaration" {
|
||||
if node.Parent().Kind() == "pointer_declarator" {
|
||||
types = append(types, "*")
|
||||
}
|
||||
node = node.Parent()
|
||||
}
|
||||
|
||||
for sibling := node.PrevSibling(); sibling != nil; sibling = sibling.PrevSibling() {
|
||||
types = append(types, sibling.Utf8Text(source))
|
||||
}
|
||||
|
||||
slices.Reverse(types)
|
||||
fn.Type = strings.Join(types, " ")
|
||||
return fn
|
||||
}
|
||||
|
||||
func ParseParameters(node *tree_sitter.Node, tc *tree_sitter.TreeCursor, source []byte) string {
|
||||
var s []string
|
||||
for _, child := range node.Children(tc) {
|
||||
if child.IsNamed() {
|
||||
child := child.ChildByFieldName("declarator")
|
||||
for child != nil && child.Kind() != "identifier" {
|
||||
if child.Kind() == "parenthesized_declarator" {
|
||||
child = child.Child(1)
|
||||
} else {
|
||||
child = child.ChildByFieldName("declarator")
|
||||
}
|
||||
}
|
||||
|
||||
if child != nil {
|
||||
s = append(s, child.Utf8Text(source))
|
||||
}
|
||||
}
|
||||
}
|
||||
return strings.Join(s, ", ")
|
||||
}
|
||||
|
||||
func main() {
|
||||
var output string
|
||||
flag.StringVar(&output, "output", ".", "Output directory for generated files")
|
||||
flag.Parse()
|
||||
|
||||
parser := tree_sitter.NewParser()
|
||||
defer parser.Close()
|
||||
|
||||
language := tree_sitter.NewLanguage(tree_sitter_cpp.Language())
|
||||
parser.SetLanguage(language)
|
||||
|
||||
query, _ := tree_sitter.NewQuery(language, `(function_declarator declarator: (identifier)) @func`)
|
||||
defer query.Close()
|
||||
|
||||
qc := tree_sitter.NewQueryCursor()
|
||||
defer qc.Close()
|
||||
|
||||
var files []string
|
||||
for _, arg := range flag.Args() {
|
||||
matches, err := filepath.Glob(arg)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error expanding glob %s: %v\n", arg, err)
|
||||
continue
|
||||
}
|
||||
files = append(files, matches...)
|
||||
}
|
||||
|
||||
var funs []Function
|
||||
for _, arg := range files {
|
||||
bts, err := os.ReadFile(arg)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error reading file %s: %v\n", arg, err)
|
||||
continue
|
||||
}
|
||||
|
||||
tree := parser.Parse(bts, nil)
|
||||
defer tree.Close()
|
||||
|
||||
tc := tree.Walk()
|
||||
defer tc.Close()
|
||||
|
||||
matches := qc.Matches(query, tree.RootNode(), bts)
|
||||
for match := matches.Next(); match != nil; match = matches.Next() {
|
||||
for _, capture := range match.Captures {
|
||||
fn := ParseFunction(&capture.Node, tc, bts)
|
||||
fn.Optional = optionalSymbols[fn.Name]
|
||||
funs = append(funs, fn)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tmpl, err := template.New("").ParseFS(fsys, "*.gotmpl")
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error parsing template: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, tmpl := range tmpl.Templates() {
|
||||
name := filepath.Join(output, strings.TrimSuffix(tmpl.Name(), ".gotmpl"))
|
||||
|
||||
fmt.Println("Generating", name)
|
||||
f, err := os.Create(name)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error creating file %s: %v\n", name, err)
|
||||
continue
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if err := tmpl.Execute(f, map[string]any{
|
||||
"Functions": funs,
|
||||
}); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error executing template %s: %v\n", tmpl.Name(), err)
|
||||
}
|
||||
}
|
||||
}
|
||||
12
x/mlxrunner/mlx/include/mlx/c/README.md
Normal file
12
x/mlxrunner/mlx/include/mlx/c/README.md
Normal file
@@ -0,0 +1,12 @@
|
||||
# Vendored MLX-C Headers
|
||||
|
||||
These header files are vendored from [mlx-c](https://github.com/ml-explore/mlx-c).
|
||||
The pinned version is in `MLX_C_VERSION` at the repo root.
|
||||
|
||||
Headers are automatically refreshed when you run a CMake build:
|
||||
|
||||
```shell
|
||||
cmake --preset 'MLX CUDA 13'
|
||||
```
|
||||
|
||||
See the [MLX Engine](../../../../../../../docs/development.md#mlx-engine-optional) section of the development docs for full build instructions.
|
||||
420
x/mlxrunner/mlx/include/mlx/c/array.h
Normal file
420
x/mlxrunner/mlx/include/mlx/c/array.h
Normal file
@@ -0,0 +1,420 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
|
||||
#ifndef MLX_ARRAY_H
|
||||
#define MLX_ARRAY_H
|
||||
|
||||
#include "mlx/c/string.h"
|
||||
|
||||
#include <float.h>
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
// Complex number support
|
||||
#ifdef _MSC_VER
|
||||
#define _CRT_USE_C_COMPLEX_H
|
||||
#include <complex.h>
|
||||
typedef _Fcomplex mlx_complex64_t;
|
||||
#else
|
||||
#include <complex.h>
|
||||
typedef float _Complex mlx_complex64_t;
|
||||
#endif
|
||||
|
||||
#include "half.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup mlx_array Array
|
||||
* MLX N-dimensional array object.
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
/**
|
||||
* A N-dimensional array object.
|
||||
*/
|
||||
typedef struct mlx_array_ {
|
||||
void* ctx;
|
||||
} mlx_array;
|
||||
|
||||
static mlx_array mlx_array_empty;
|
||||
|
||||
/**
|
||||
* Array element type.
|
||||
*/
|
||||
typedef enum mlx_dtype_ {
|
||||
MLX_BOOL,
|
||||
MLX_UINT8,
|
||||
MLX_UINT16,
|
||||
MLX_UINT32,
|
||||
MLX_UINT64,
|
||||
MLX_INT8,
|
||||
MLX_INT16,
|
||||
MLX_INT32,
|
||||
MLX_INT64,
|
||||
MLX_FLOAT16,
|
||||
MLX_FLOAT32,
|
||||
MLX_FLOAT64,
|
||||
MLX_BFLOAT16,
|
||||
MLX_COMPLEX64,
|
||||
} mlx_dtype;
|
||||
|
||||
/**
|
||||
* Size of given mlx_dtype datatype in bytes.
|
||||
*/
|
||||
size_t mlx_dtype_size(mlx_dtype dtype);
|
||||
|
||||
/**
|
||||
* Get array description.
|
||||
*/
|
||||
int mlx_array_tostring(mlx_string* str, const mlx_array arr);
|
||||
|
||||
/**
|
||||
* New empty array.
|
||||
*/
|
||||
mlx_array mlx_array_new(void);
|
||||
|
||||
/**
|
||||
* Free an array.
|
||||
*/
|
||||
int mlx_array_free(mlx_array arr);
|
||||
|
||||
/**
|
||||
* New array from a bool scalar.
|
||||
*/
|
||||
mlx_array mlx_array_new_bool(bool val);
|
||||
/**
|
||||
* New array from a int scalar.
|
||||
*/
|
||||
mlx_array mlx_array_new_int(int val);
|
||||
/**
|
||||
* New array from a float32 scalar.
|
||||
*/
|
||||
mlx_array mlx_array_new_float32(float val);
|
||||
/**
|
||||
* New array from a float scalar.
|
||||
* Same as float32.
|
||||
*/
|
||||
mlx_array mlx_array_new_float(float val);
|
||||
/**
|
||||
* New array from a float64 scalar.
|
||||
*/
|
||||
mlx_array mlx_array_new_float64(double val);
|
||||
/**
|
||||
* New array from a double scalar.
|
||||
* Same as float64.
|
||||
*/
|
||||
mlx_array mlx_array_new_double(double val);
|
||||
/**
|
||||
* New array from a complex scalar.
|
||||
*/
|
||||
mlx_array mlx_array_new_complex(float real_val, float imag_val);
|
||||
/**
|
||||
* New array from existing buffer.
|
||||
* @param data A buffer which will be copied.
|
||||
* @param shape Shape of the array.
|
||||
* @param dim Number of dimensions (size of `shape`).
|
||||
* @param dtype Type of array elements.
|
||||
*/
|
||||
mlx_array mlx_array_new_data(
|
||||
const void* data,
|
||||
const int* shape,
|
||||
int dim,
|
||||
mlx_dtype dtype);
|
||||
/**
|
||||
* New array from existing buffer.
|
||||
* @param data A buffer which will be copied.
|
||||
* @param shape Shape of the array.
|
||||
* @param dim Number of dimensions (size of `shape`).
|
||||
* @param dtype Type of array elements.
|
||||
* @param dtor Callback for when the buffer is no longer needed.
|
||||
*/
|
||||
mlx_array mlx_array_new_data_managed(
|
||||
void* data,
|
||||
const int* shape,
|
||||
int dim,
|
||||
mlx_dtype dtype,
|
||||
void (*dtor)(void*));
|
||||
/**
|
||||
* New array from existing buffer.
|
||||
* @param data A buffer which will be copied.
|
||||
* @param shape Shape of the array.
|
||||
* @param dim Number of dimensions (size of `shape`).
|
||||
* @param dtype Type of array elements.
|
||||
* @param payload Payload pointer passed to the `dtor` callback instead of
|
||||
* `data`.
|
||||
* @param dtor Callback for when the buffer is no longer needed.
|
||||
*/
|
||||
mlx_array mlx_array_new_data_managed_payload(
|
||||
void* data,
|
||||
const int* shape,
|
||||
int dim,
|
||||
mlx_dtype dtype,
|
||||
void* payload,
|
||||
void (*dtor)(void*));
|
||||
/**
|
||||
* Set array to provided src array.
|
||||
*/
|
||||
int mlx_array_set(mlx_array* arr, const mlx_array src);
|
||||
/**
|
||||
* Set array to a bool scalar.
|
||||
*/
|
||||
int mlx_array_set_bool(mlx_array* arr, bool val);
|
||||
/**
|
||||
* Set array to a int scalar.
|
||||
*/
|
||||
int mlx_array_set_int(mlx_array* arr, int val);
|
||||
/**
|
||||
* Set array to a float32 scalar.
|
||||
*/
|
||||
int mlx_array_set_float32(mlx_array* arr, float val);
|
||||
/**
|
||||
* Set array to a float scalar.
|
||||
*/
|
||||
int mlx_array_set_float(mlx_array* arr, float val);
|
||||
/**
|
||||
* Set array to a float64 scalar.
|
||||
*/
|
||||
int mlx_array_set_float64(mlx_array* arr, double val);
|
||||
/**
|
||||
* Set array to a double scalar.
|
||||
*/
|
||||
int mlx_array_set_double(mlx_array* arr, double val);
|
||||
/**
|
||||
* Set array to a complex scalar.
|
||||
*/
|
||||
int mlx_array_set_complex(mlx_array* arr, float real_val, float imag_val);
|
||||
/**
|
||||
* Set array to specified data and shape.
|
||||
* @param arr Destination array.
|
||||
* @param data A buffer which will be copied.
|
||||
* @param shape Shape of the array.
|
||||
* @param dim Number of dimensions (size of `shape`).
|
||||
* @param dtype Type of array elements.
|
||||
*/
|
||||
int mlx_array_set_data(
|
||||
mlx_array* arr,
|
||||
const void* data,
|
||||
const int* shape,
|
||||
int dim,
|
||||
mlx_dtype dtype);
|
||||
|
||||
/**
|
||||
* The size of the array's datatype in bytes.
|
||||
*/
|
||||
size_t mlx_array_itemsize(const mlx_array arr);
|
||||
/**
|
||||
* Number of elements in the array.
|
||||
*/
|
||||
size_t mlx_array_size(const mlx_array arr);
|
||||
/**
|
||||
* The number of bytes in the array.
|
||||
*/
|
||||
size_t mlx_array_nbytes(const mlx_array arr);
|
||||
/**
|
||||
* The array's dimension.
|
||||
*/
|
||||
size_t mlx_array_ndim(const mlx_array arr);
|
||||
/**
|
||||
* The shape of the array.
|
||||
* Returns: a pointer to the sizes of each dimension.
|
||||
*/
|
||||
const int* mlx_array_shape(const mlx_array arr);
|
||||
/**
|
||||
* The strides of the array.
|
||||
* Returns: a pointer to the sizes of each dimension.
|
||||
*/
|
||||
const size_t* mlx_array_strides(const mlx_array arr);
|
||||
/**
|
||||
* The shape of the array in a particular dimension.
|
||||
*/
|
||||
int mlx_array_dim(const mlx_array arr, int dim);
|
||||
/**
|
||||
* The array element type.
|
||||
*/
|
||||
mlx_dtype mlx_array_dtype(const mlx_array arr);
|
||||
|
||||
/**
|
||||
* Evaluate the array.
|
||||
*/
|
||||
int mlx_array_eval(mlx_array arr);
|
||||
|
||||
/**
|
||||
* Access the value of a scalar array.
|
||||
*/
|
||||
int mlx_array_item_bool(bool* res, const mlx_array arr);
|
||||
/**
|
||||
* Access the value of a scalar array.
|
||||
*/
|
||||
int mlx_array_item_uint8(uint8_t* res, const mlx_array arr);
|
||||
/**
|
||||
* Access the value of a scalar array.
|
||||
*/
|
||||
int mlx_array_item_uint16(uint16_t* res, const mlx_array arr);
|
||||
/**
|
||||
* Access the value of a scalar array.
|
||||
*/
|
||||
int mlx_array_item_uint32(uint32_t* res, const mlx_array arr);
|
||||
/**
|
||||
* Access the value of a scalar array.
|
||||
*/
|
||||
int mlx_array_item_uint64(uint64_t* res, const mlx_array arr);
|
||||
/**
|
||||
* Access the value of a scalar array.
|
||||
*/
|
||||
int mlx_array_item_int8(int8_t* res, const mlx_array arr);
|
||||
/**
|
||||
* Access the value of a scalar array.
|
||||
*/
|
||||
int mlx_array_item_int16(int16_t* res, const mlx_array arr);
|
||||
/**
|
||||
* Access the value of a scalar array.
|
||||
*/
|
||||
int mlx_array_item_int32(int32_t* res, const mlx_array arr);
|
||||
/**
|
||||
* Access the value of a scalar array.
|
||||
*/
|
||||
int mlx_array_item_int64(int64_t* res, const mlx_array arr);
|
||||
/**
|
||||
* Access the value of a scalar array.
|
||||
*/
|
||||
int mlx_array_item_float32(float* res, const mlx_array arr);
|
||||
/**
|
||||
* Access the value of a scalar array.
|
||||
*/
|
||||
int mlx_array_item_float64(double* res, const mlx_array arr);
|
||||
/**
|
||||
* Access the value of a scalar array.
|
||||
*/
|
||||
int mlx_array_item_complex64(mlx_complex64_t* res, const mlx_array arr);
|
||||
|
||||
#ifdef HAS_FLOAT16
|
||||
/**
|
||||
* Access the value of a scalar array.
|
||||
*/
|
||||
int mlx_array_item_float16(float16_t* res, const mlx_array arr);
|
||||
#endif
|
||||
|
||||
#ifdef HAS_BFLOAT16
|
||||
/**
|
||||
* Access the value of a scalar array.
|
||||
*/
|
||||
int mlx_array_item_bfloat16(bfloat16_t* res, const mlx_array arr);
|
||||
#endif
|
||||
|
||||
/**
|
||||
* Returns a pointer to the array data, cast to `bool*`.
|
||||
* Array must be evaluated, otherwise returns NULL.
|
||||
*/
|
||||
const bool* mlx_array_data_bool(const mlx_array arr);
|
||||
/**
|
||||
* Returns a pointer to the array data, cast to `uint8_t*`.
|
||||
* Array must be evaluated, otherwise returns NULL.
|
||||
*/
|
||||
const uint8_t* mlx_array_data_uint8(const mlx_array arr);
|
||||
/**
|
||||
* Returns a pointer to the array data, cast to `uint16_t*`.
|
||||
* Array must be evaluated, otherwise returns NULL.
|
||||
*/
|
||||
const uint16_t* mlx_array_data_uint16(const mlx_array arr);
|
||||
/**
|
||||
* Returns a pointer to the array data, cast to `uint32_t*`.
|
||||
* Array must be evaluated, otherwise returns NULL.
|
||||
*/
|
||||
const uint32_t* mlx_array_data_uint32(const mlx_array arr);
|
||||
/**
|
||||
* Returns a pointer to the array data, cast to `uint64_t*`.
|
||||
* Array must be evaluated, otherwise returns NULL.
|
||||
*/
|
||||
const uint64_t* mlx_array_data_uint64(const mlx_array arr);
|
||||
/**
|
||||
* Returns a pointer to the array data, cast to `int8_t*`.
|
||||
* Array must be evaluated, otherwise returns NULL.
|
||||
*/
|
||||
const int8_t* mlx_array_data_int8(const mlx_array arr);
|
||||
/**
|
||||
* Returns a pointer to the array data, cast to `int16_t*`.
|
||||
* Array must be evaluated, otherwise returns NULL.
|
||||
*/
|
||||
const int16_t* mlx_array_data_int16(const mlx_array arr);
|
||||
/**
|
||||
* Returns a pointer to the array data, cast to `int32_t*`.
|
||||
* Array must be evaluated, otherwise returns NULL.
|
||||
*/
|
||||
const int32_t* mlx_array_data_int32(const mlx_array arr);
|
||||
/**
|
||||
* Returns a pointer to the array data, cast to `int64_t*`.
|
||||
* Array must be evaluated, otherwise returns NULL.
|
||||
*/
|
||||
const int64_t* mlx_array_data_int64(const mlx_array arr);
|
||||
/**
|
||||
* Returns a pointer to the array data, cast to `float32*`.
|
||||
* Array must be evaluated, otherwise returns NULL.
|
||||
*/
|
||||
const float* mlx_array_data_float32(const mlx_array arr);
|
||||
/**
|
||||
* Returns a pointer to the array data, cast to `float64*`.
|
||||
* Array must be evaluated, otherwise returns NULL.
|
||||
*/
|
||||
const double* mlx_array_data_float64(const mlx_array arr);
|
||||
/**
|
||||
* Returns a pointer to the array data, cast to `_Complex*`.
|
||||
* Array must be evaluated, otherwise returns NULL.
|
||||
*/
|
||||
const mlx_complex64_t* mlx_array_data_complex64(const mlx_array arr);
|
||||
|
||||
#ifdef HAS_FLOAT16
|
||||
/**
|
||||
* Returns a pointer to the array data, cast to `float16_t*`.
|
||||
* Array must be evaluated, otherwise returns NULL.
|
||||
*/
|
||||
const float16_t* mlx_array_data_float16(const mlx_array arr);
|
||||
#endif
|
||||
|
||||
#ifdef HAS_BFLOAT16
|
||||
/**
|
||||
* Returns a pointer to the array data, cast to `bfloat16_t*`.
|
||||
* Array must be evaluated, otherwise returns NULL.
|
||||
*/
|
||||
const bfloat16_t* mlx_array_data_bfloat16(const mlx_array arr);
|
||||
#endif
|
||||
|
||||
/**
|
||||
* Check if the array is available.
|
||||
* Internal function: use at your own risk.
|
||||
*/
|
||||
int _mlx_array_is_available(bool* res, const mlx_array arr);
|
||||
|
||||
/**
|
||||
* Wait on the array to be available. After this `_mlx_array_is_available`
|
||||
* returns `true`. Internal function: use at your own risk.
|
||||
*/
|
||||
int _mlx_array_wait(const mlx_array arr);
|
||||
|
||||
/**
|
||||
* Whether the array is contiguous in memory.
|
||||
* Internal function: use at your own risk.
|
||||
*/
|
||||
int _mlx_array_is_contiguous(bool* res, const mlx_array arr);
|
||||
|
||||
/**
|
||||
* Whether the array's rows are contiguous in memory.
|
||||
* Internal function: use at your own risk.
|
||||
*/
|
||||
int _mlx_array_is_row_contiguous(bool* res, const mlx_array arr);
|
||||
|
||||
/**
|
||||
* Whether the array's columns are contiguous in memory.
|
||||
* Internal function: use at your own risk.
|
||||
*/
|
||||
int _mlx_array_is_col_contiguous(bool* res, const mlx_array arr);
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
197
x/mlxrunner/mlx/include/mlx/c/closure.h
Normal file
197
x/mlxrunner/mlx/include/mlx/c/closure.h
Normal file
@@ -0,0 +1,197 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
/* */
|
||||
/* This file is auto-generated. Do not edit manually. */
|
||||
/* */
|
||||
|
||||
#ifndef MLX_CLOSURE_H
|
||||
#define MLX_CLOSURE_H
|
||||
|
||||
#include "mlx/c/array.h"
|
||||
#include "mlx/c/map.h"
|
||||
#include "mlx/c/optional.h"
|
||||
#include "mlx/c/stream.h"
|
||||
#include "mlx/c/vector.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup mlx_closure Closures
|
||||
* MLX closure objects.
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
typedef struct mlx_closure_ {
|
||||
void* ctx;
|
||||
} mlx_closure;
|
||||
mlx_closure mlx_closure_new(void);
|
||||
int mlx_closure_free(mlx_closure cls);
|
||||
mlx_closure mlx_closure_new_func(
|
||||
int (*fun)(mlx_vector_array*, const mlx_vector_array));
|
||||
mlx_closure mlx_closure_new_func_payload(
|
||||
int (*fun)(mlx_vector_array*, const mlx_vector_array, void*),
|
||||
void* payload,
|
||||
void (*dtor)(void*));
|
||||
int mlx_closure_set(mlx_closure* cls, const mlx_closure src);
|
||||
int mlx_closure_apply(
|
||||
mlx_vector_array* res,
|
||||
mlx_closure cls,
|
||||
const mlx_vector_array input);
|
||||
|
||||
mlx_closure mlx_closure_new_unary(int (*fun)(mlx_array*, const mlx_array));
|
||||
|
||||
typedef struct mlx_closure_kwargs_ {
|
||||
void* ctx;
|
||||
} mlx_closure_kwargs;
|
||||
mlx_closure_kwargs mlx_closure_kwargs_new(void);
|
||||
int mlx_closure_kwargs_free(mlx_closure_kwargs cls);
|
||||
mlx_closure_kwargs mlx_closure_kwargs_new_func(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
const mlx_vector_array,
|
||||
const mlx_map_string_to_array));
|
||||
mlx_closure_kwargs mlx_closure_kwargs_new_func_payload(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
const mlx_vector_array,
|
||||
const mlx_map_string_to_array,
|
||||
void*),
|
||||
void* payload,
|
||||
void (*dtor)(void*));
|
||||
int mlx_closure_kwargs_set(
|
||||
mlx_closure_kwargs* cls,
|
||||
const mlx_closure_kwargs src);
|
||||
int mlx_closure_kwargs_apply(
|
||||
mlx_vector_array* res,
|
||||
mlx_closure_kwargs cls,
|
||||
const mlx_vector_array input_0,
|
||||
const mlx_map_string_to_array input_1);
|
||||
|
||||
typedef struct mlx_closure_value_and_grad_ {
|
||||
void* ctx;
|
||||
} mlx_closure_value_and_grad;
|
||||
mlx_closure_value_and_grad mlx_closure_value_and_grad_new(void);
|
||||
int mlx_closure_value_and_grad_free(mlx_closure_value_and_grad cls);
|
||||
mlx_closure_value_and_grad mlx_closure_value_and_grad_new_func(
|
||||
int (*fun)(mlx_vector_array*, mlx_vector_array*, const mlx_vector_array));
|
||||
mlx_closure_value_and_grad mlx_closure_value_and_grad_new_func_payload(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
mlx_vector_array*,
|
||||
const mlx_vector_array,
|
||||
void*),
|
||||
void* payload,
|
||||
void (*dtor)(void*));
|
||||
int mlx_closure_value_and_grad_set(
|
||||
mlx_closure_value_and_grad* cls,
|
||||
const mlx_closure_value_and_grad src);
|
||||
int mlx_closure_value_and_grad_apply(
|
||||
mlx_vector_array* res_0,
|
||||
mlx_vector_array* res_1,
|
||||
mlx_closure_value_and_grad cls,
|
||||
const mlx_vector_array input);
|
||||
|
||||
typedef struct mlx_closure_custom_ {
|
||||
void* ctx;
|
||||
} mlx_closure_custom;
|
||||
mlx_closure_custom mlx_closure_custom_new(void);
|
||||
int mlx_closure_custom_free(mlx_closure_custom cls);
|
||||
mlx_closure_custom mlx_closure_custom_new_func(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
const mlx_vector_array,
|
||||
const mlx_vector_array,
|
||||
const mlx_vector_array));
|
||||
mlx_closure_custom mlx_closure_custom_new_func_payload(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
const mlx_vector_array,
|
||||
const mlx_vector_array,
|
||||
const mlx_vector_array,
|
||||
void*),
|
||||
void* payload,
|
||||
void (*dtor)(void*));
|
||||
int mlx_closure_custom_set(
|
||||
mlx_closure_custom* cls,
|
||||
const mlx_closure_custom src);
|
||||
int mlx_closure_custom_apply(
|
||||
mlx_vector_array* res,
|
||||
mlx_closure_custom cls,
|
||||
const mlx_vector_array input_0,
|
||||
const mlx_vector_array input_1,
|
||||
const mlx_vector_array input_2);
|
||||
|
||||
typedef struct mlx_closure_custom_jvp_ {
|
||||
void* ctx;
|
||||
} mlx_closure_custom_jvp;
|
||||
mlx_closure_custom_jvp mlx_closure_custom_jvp_new(void);
|
||||
int mlx_closure_custom_jvp_free(mlx_closure_custom_jvp cls);
|
||||
mlx_closure_custom_jvp mlx_closure_custom_jvp_new_func(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
const mlx_vector_array,
|
||||
const mlx_vector_array,
|
||||
const int*,
|
||||
size_t _num));
|
||||
mlx_closure_custom_jvp mlx_closure_custom_jvp_new_func_payload(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
const mlx_vector_array,
|
||||
const mlx_vector_array,
|
||||
const int*,
|
||||
size_t _num,
|
||||
void*),
|
||||
void* payload,
|
||||
void (*dtor)(void*));
|
||||
int mlx_closure_custom_jvp_set(
|
||||
mlx_closure_custom_jvp* cls,
|
||||
const mlx_closure_custom_jvp src);
|
||||
int mlx_closure_custom_jvp_apply(
|
||||
mlx_vector_array* res,
|
||||
mlx_closure_custom_jvp cls,
|
||||
const mlx_vector_array input_0,
|
||||
const mlx_vector_array input_1,
|
||||
const int* input_2,
|
||||
size_t input_2_num);
|
||||
|
||||
typedef struct mlx_closure_custom_vmap_ {
|
||||
void* ctx;
|
||||
} mlx_closure_custom_vmap;
|
||||
mlx_closure_custom_vmap mlx_closure_custom_vmap_new(void);
|
||||
int mlx_closure_custom_vmap_free(mlx_closure_custom_vmap cls);
|
||||
mlx_closure_custom_vmap mlx_closure_custom_vmap_new_func(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
mlx_vector_int*,
|
||||
const mlx_vector_array,
|
||||
const int*,
|
||||
size_t _num));
|
||||
mlx_closure_custom_vmap mlx_closure_custom_vmap_new_func_payload(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
mlx_vector_int*,
|
||||
const mlx_vector_array,
|
||||
const int*,
|
||||
size_t _num,
|
||||
void*),
|
||||
void* payload,
|
||||
void (*dtor)(void*));
|
||||
int mlx_closure_custom_vmap_set(
|
||||
mlx_closure_custom_vmap* cls,
|
||||
const mlx_closure_custom_vmap src);
|
||||
int mlx_closure_custom_vmap_apply(
|
||||
mlx_vector_array* res_0,
|
||||
mlx_vector_int* res_1,
|
||||
mlx_closure_custom_vmap cls,
|
||||
const mlx_vector_array input_0,
|
||||
const int* input_1,
|
||||
size_t input_1_num);
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
58
x/mlxrunner/mlx/include/mlx/c/compile.h
Normal file
58
x/mlxrunner/mlx/include/mlx/c/compile.h
Normal file
@@ -0,0 +1,58 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
/* */
|
||||
/* This file is auto-generated. Do not edit manually. */
|
||||
/* */
|
||||
|
||||
#ifndef MLX_COMPILE_H
|
||||
#define MLX_COMPILE_H
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "mlx/c/array.h"
|
||||
#include "mlx/c/closure.h"
|
||||
#include "mlx/c/distributed_group.h"
|
||||
#include "mlx/c/io_types.h"
|
||||
#include "mlx/c/map.h"
|
||||
#include "mlx/c/stream.h"
|
||||
#include "mlx/c/string.h"
|
||||
#include "mlx/c/vector.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup compile Compilation operations
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
typedef enum mlx_compile_mode_ {
|
||||
MLX_COMPILE_MODE_DISABLED,
|
||||
MLX_COMPILE_MODE_NO_SIMPLIFY,
|
||||
MLX_COMPILE_MODE_NO_FUSE,
|
||||
MLX_COMPILE_MODE_ENABLED
|
||||
} mlx_compile_mode;
|
||||
|
||||
int mlx_compile(mlx_closure* res, const mlx_closure fun, bool shapeless);
|
||||
int mlx_detail_compile(
|
||||
mlx_closure* res,
|
||||
const mlx_closure fun,
|
||||
uintptr_t fun_id,
|
||||
bool shapeless,
|
||||
const uint64_t* constants,
|
||||
size_t constants_num);
|
||||
int mlx_detail_compile_clear_cache(void);
|
||||
int mlx_detail_compile_erase(uintptr_t fun_id);
|
||||
int mlx_disable_compile(void);
|
||||
int mlx_enable_compile(void);
|
||||
int mlx_set_compile_mode(mlx_compile_mode mode);
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
39
x/mlxrunner/mlx/include/mlx/c/cuda.h
Normal file
39
x/mlxrunner/mlx/include/mlx/c/cuda.h
Normal file
@@ -0,0 +1,39 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
/* */
|
||||
/* This file is auto-generated. Do not edit manually. */
|
||||
/* */
|
||||
|
||||
#ifndef MLX_CUDA_H
|
||||
#define MLX_CUDA_H
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "mlx/c/array.h"
|
||||
#include "mlx/c/closure.h"
|
||||
#include "mlx/c/distributed_group.h"
|
||||
#include "mlx/c/io_types.h"
|
||||
#include "mlx/c/map.h"
|
||||
#include "mlx/c/stream.h"
|
||||
#include "mlx/c/string.h"
|
||||
#include "mlx/c/vector.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup cuda Cuda specific operations
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
int mlx_cuda_is_available(bool* res);
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
154
x/mlxrunner/mlx/include/mlx/c/device.h
Normal file
154
x/mlxrunner/mlx/include/mlx/c/device.h
Normal file
@@ -0,0 +1,154 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
|
||||
#ifndef MLX_DEVICE_H
|
||||
#define MLX_DEVICE_H
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stddef.h>
|
||||
|
||||
#include "mlx/c/string.h"
|
||||
#include "mlx/c/vector.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup mlx_device Device
|
||||
* MLX device object.
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
/**
|
||||
* A MLX device object.
|
||||
*/
|
||||
typedef struct mlx_device_ {
|
||||
void* ctx;
|
||||
} mlx_device;
|
||||
|
||||
/**
|
||||
* Device type.
|
||||
*/
|
||||
typedef enum mlx_device_type_ { MLX_CPU, MLX_GPU } mlx_device_type;
|
||||
|
||||
/**
|
||||
* Returns a new empty device.
|
||||
*/
|
||||
mlx_device mlx_device_new(void);
|
||||
|
||||
/**
|
||||
* Returns a new device of specified `type`, with specified `index`.
|
||||
*/
|
||||
mlx_device mlx_device_new_type(mlx_device_type type, int index);
|
||||
/**
|
||||
* Free a device.
|
||||
*/
|
||||
int mlx_device_free(mlx_device dev);
|
||||
/**
|
||||
* Set device to provided src device.
|
||||
*/
|
||||
int mlx_device_set(mlx_device* dev, const mlx_device src);
|
||||
/**
|
||||
* Get device description.
|
||||
*/
|
||||
int mlx_device_tostring(mlx_string* str, mlx_device dev);
|
||||
/**
|
||||
* Check if devices are the same.
|
||||
*/
|
||||
bool mlx_device_equal(mlx_device lhs, mlx_device rhs);
|
||||
/**
|
||||
* Returns the index of the device.
|
||||
*/
|
||||
int mlx_device_get_index(int* index, mlx_device dev);
|
||||
/**
|
||||
* Returns the type of the device.
|
||||
*/
|
||||
int mlx_device_get_type(mlx_device_type* type, mlx_device dev);
|
||||
/**
|
||||
* Returns the default MLX device.
|
||||
*/
|
||||
int mlx_get_default_device(mlx_device* dev);
|
||||
/**
|
||||
* Set the default MLX device.
|
||||
*/
|
||||
int mlx_set_default_device(mlx_device dev);
|
||||
/**
|
||||
* Check if device is available.
|
||||
*/
|
||||
int mlx_device_is_available(bool* avail, mlx_device dev);
|
||||
/**
|
||||
* Get the number of available devices for a device type.
|
||||
*/
|
||||
int mlx_device_count(int* count, mlx_device_type type);
|
||||
|
||||
/**
|
||||
* A MLX device info object.
|
||||
* Contains key-value pairs with device properties.
|
||||
* Keys vary by backend but common keys include:
|
||||
* - device_name (string): Device name
|
||||
* - architecture (string): Architecture identifier
|
||||
* Additional keys may be present depending on the backend.
|
||||
*/
|
||||
typedef struct mlx_device_info_ {
|
||||
void* ctx;
|
||||
} mlx_device_info;
|
||||
|
||||
/**
|
||||
* Returns a new empty device info object.
|
||||
*/
|
||||
mlx_device_info mlx_device_info_new(void);
|
||||
/**
|
||||
* Get device information for a device.
|
||||
*/
|
||||
int mlx_device_info_get(mlx_device_info* info, mlx_device dev);
|
||||
/**
|
||||
* Free a device info object.
|
||||
*/
|
||||
int mlx_device_info_free(mlx_device_info info);
|
||||
/**
|
||||
* Check if a key exists in the device info.
|
||||
* Returns 0 on success, 1 on error.
|
||||
* Sets *exists to true if the key exists, false otherwise.
|
||||
*/
|
||||
int mlx_device_info_has_key(
|
||||
bool* exists,
|
||||
mlx_device_info info,
|
||||
const char* key);
|
||||
/**
|
||||
* Check if a value is a string type.
|
||||
* Returns 0 on success, 1 on error.
|
||||
* Sets *is_string to true if the value is a string, false if it's a size_t.
|
||||
*/
|
||||
int mlx_device_info_is_string(
|
||||
bool* is_string,
|
||||
mlx_device_info info,
|
||||
const char* key);
|
||||
/**
|
||||
* Get a string value from device info.
|
||||
* Returns 0 on success, 1 on error, 2 if key not found or wrong type.
|
||||
*/
|
||||
int mlx_device_info_get_string(
|
||||
const char** value,
|
||||
mlx_device_info info,
|
||||
const char* key);
|
||||
/**
|
||||
* Get a size_t value from device info.
|
||||
* Returns 0 on success, 1 on error, 2 if key not found or wrong type.
|
||||
*/
|
||||
int mlx_device_info_get_size(
|
||||
size_t* value,
|
||||
mlx_device_info info,
|
||||
const char* key);
|
||||
/**
|
||||
* Get all keys from device info.
|
||||
* Returns 0 on success, 1 on error.
|
||||
*/
|
||||
int mlx_device_info_get_keys(mlx_vector_string* keys, mlx_device_info info);
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
83
x/mlxrunner/mlx/include/mlx/c/distributed.h
Normal file
83
x/mlxrunner/mlx/include/mlx/c/distributed.h
Normal file
@@ -0,0 +1,83 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
/* */
|
||||
/* This file is auto-generated. Do not edit manually. */
|
||||
/* */
|
||||
|
||||
#ifndef MLX_DISTRIBUTED_H
|
||||
#define MLX_DISTRIBUTED_H
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "mlx/c/array.h"
|
||||
#include "mlx/c/closure.h"
|
||||
#include "mlx/c/distributed_group.h"
|
||||
#include "mlx/c/io_types.h"
|
||||
#include "mlx/c/map.h"
|
||||
#include "mlx/c/stream.h"
|
||||
#include "mlx/c/string.h"
|
||||
#include "mlx/c/vector.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup distributed Distributed collectives
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
int mlx_distributed_all_gather(
|
||||
mlx_array* res,
|
||||
const mlx_array x,
|
||||
const mlx_distributed_group group /* may be null */,
|
||||
const mlx_stream S);
|
||||
int mlx_distributed_all_max(
|
||||
mlx_array* res,
|
||||
const mlx_array x,
|
||||
const mlx_distributed_group group /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_distributed_all_min(
|
||||
mlx_array* res,
|
||||
const mlx_array x,
|
||||
const mlx_distributed_group group /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_distributed_all_sum(
|
||||
mlx_array* res,
|
||||
const mlx_array x,
|
||||
const mlx_distributed_group group /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_distributed_recv(
|
||||
mlx_array* res,
|
||||
const int* shape,
|
||||
size_t shape_num,
|
||||
mlx_dtype dtype,
|
||||
int src,
|
||||
const mlx_distributed_group group /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_distributed_recv_like(
|
||||
mlx_array* res,
|
||||
const mlx_array x,
|
||||
int src,
|
||||
const mlx_distributed_group group /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_distributed_send(
|
||||
mlx_array* res,
|
||||
const mlx_array x,
|
||||
int dst,
|
||||
const mlx_distributed_group group /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_distributed_sum_scatter(
|
||||
mlx_array* res,
|
||||
const mlx_array x,
|
||||
const mlx_distributed_group group /* may be null */,
|
||||
const mlx_stream s);
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
74
x/mlxrunner/mlx/include/mlx/c/distributed_group.h
Normal file
74
x/mlxrunner/mlx/include/mlx/c/distributed_group.h
Normal file
@@ -0,0 +1,74 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
|
||||
#ifndef MLX_DISTRIBUTED_GROUP_H
|
||||
#define MLX_DISTRIBUTED_GROUP_H
|
||||
|
||||
#include <stdbool.h>
|
||||
|
||||
#include "mlx/c/stream.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup mlx_distributed_group MLX distributed
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
/**
|
||||
* A MLX distributed group object.
|
||||
*/
|
||||
typedef struct mlx_distributed_group_ {
|
||||
void* ctx;
|
||||
} mlx_distributed_group;
|
||||
|
||||
/**
|
||||
* Create an empty group.
|
||||
*/
|
||||
mlx_distributed_group mlx_distributed_group_new(void);
|
||||
|
||||
/**
|
||||
* Free the group.
|
||||
*/
|
||||
int mlx_distributed_group_free(mlx_distributed_group group);
|
||||
|
||||
/**
|
||||
* Initialize distributed.
|
||||
*/
|
||||
int mlx_distributed_init(
|
||||
mlx_distributed_group* res,
|
||||
bool strict,
|
||||
const char* bk /* may be null */);
|
||||
|
||||
/**
|
||||
* Get the rank.
|
||||
*/
|
||||
int mlx_distributed_group_rank(mlx_distributed_group group);
|
||||
|
||||
/**
|
||||
* Get the group size.
|
||||
*/
|
||||
int mlx_distributed_group_size(mlx_distributed_group group);
|
||||
|
||||
/**
|
||||
* Split the group.
|
||||
*/
|
||||
int mlx_distributed_group_split(
|
||||
mlx_distributed_group* res,
|
||||
mlx_distributed_group group,
|
||||
int color,
|
||||
int key);
|
||||
|
||||
/**
|
||||
* Check if distributed is available.
|
||||
*/
|
||||
bool mlx_distributed_is_available(const char* bk /* may be null */);
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
41
x/mlxrunner/mlx/include/mlx/c/error.h
Normal file
41
x/mlxrunner/mlx/include/mlx/c/error.h
Normal file
@@ -0,0 +1,41 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
|
||||
#ifndef MLX_ERROR_H
|
||||
#define MLX_ERROR_H
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup mlx_error Error management
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
typedef void (*mlx_error_handler_func)(const char* msg, void* data);
|
||||
|
||||
/**
|
||||
* Set the error handler.
|
||||
*/
|
||||
void mlx_set_error_handler(
|
||||
mlx_error_handler_func handler,
|
||||
void* data,
|
||||
void (*dtor)(void*));
|
||||
|
||||
/**
|
||||
* Throw an error.
|
||||
*/
|
||||
void _mlx_error(const char* file, const int line, const char* fmt, ...);
|
||||
|
||||
/**
|
||||
* Throw an error. Macro which passes file name and line number to _mlx_error().
|
||||
*/
|
||||
#define mlx_error(...) _mlx_error(__FILE__, __LINE__, __VA_ARGS__)
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
75
x/mlxrunner/mlx/include/mlx/c/export.h
Normal file
75
x/mlxrunner/mlx/include/mlx/c/export.h
Normal file
@@ -0,0 +1,75 @@
|
||||
/* Copyright © 2023-2025 Apple Inc. */
|
||||
|
||||
#ifndef MLX_EXPORT_H
|
||||
#define MLX_EXPORT_H
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "mlx/c/array.h"
|
||||
#include "mlx/c/closure.h"
|
||||
#include "mlx/c/distributed_group.h"
|
||||
#include "mlx/c/io_types.h"
|
||||
#include "mlx/c/map.h"
|
||||
#include "mlx/c/stream.h"
|
||||
#include "mlx/c/string.h"
|
||||
#include "mlx/c/vector.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup export Function serialization
|
||||
*/
|
||||
/**@{*/
|
||||
int mlx_export_function(
|
||||
const char* file,
|
||||
const mlx_closure fun,
|
||||
const mlx_vector_array args,
|
||||
bool shapeless);
|
||||
int mlx_export_function_kwargs(
|
||||
const char* file,
|
||||
const mlx_closure_kwargs fun,
|
||||
const mlx_vector_array args,
|
||||
const mlx_map_string_to_array kwargs,
|
||||
bool shapeless);
|
||||
|
||||
typedef struct mlx_function_exporter_ {
|
||||
void* ctx;
|
||||
} mlx_function_exporter;
|
||||
mlx_function_exporter mlx_function_exporter_new(
|
||||
const char* file,
|
||||
const mlx_closure fun,
|
||||
bool shapeless);
|
||||
int mlx_function_exporter_free(mlx_function_exporter xfunc);
|
||||
int mlx_function_exporter_apply(
|
||||
const mlx_function_exporter xfunc,
|
||||
const mlx_vector_array args);
|
||||
int mlx_function_exporter_apply_kwargs(
|
||||
const mlx_function_exporter xfunc,
|
||||
const mlx_vector_array args,
|
||||
const mlx_map_string_to_array kwargs);
|
||||
|
||||
typedef struct mlx_imported_function_ {
|
||||
void* ctx;
|
||||
} mlx_imported_function;
|
||||
mlx_imported_function mlx_imported_function_new(const char* file);
|
||||
int mlx_imported_function_free(mlx_imported_function xfunc);
|
||||
int mlx_imported_function_apply(
|
||||
mlx_vector_array* res,
|
||||
const mlx_imported_function xfunc,
|
||||
const mlx_vector_array args);
|
||||
int mlx_imported_function_apply_kwargs(
|
||||
mlx_vector_array* res,
|
||||
const mlx_imported_function xfunc,
|
||||
const mlx_vector_array args,
|
||||
const mlx_map_string_to_array kwargs);
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
206
x/mlxrunner/mlx/include/mlx/c/fast.h
Normal file
206
x/mlxrunner/mlx/include/mlx/c/fast.h
Normal file
@@ -0,0 +1,206 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
/* */
|
||||
/* This file is auto-generated. Do not edit manually. */
|
||||
/* */
|
||||
|
||||
#ifndef MLX_FAST_H
|
||||
#define MLX_FAST_H
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "mlx/c/array.h"
|
||||
#include "mlx/c/closure.h"
|
||||
#include "mlx/c/distributed_group.h"
|
||||
#include "mlx/c/io_types.h"
|
||||
#include "mlx/c/map.h"
|
||||
#include "mlx/c/stream.h"
|
||||
#include "mlx/c/string.h"
|
||||
#include "mlx/c/vector.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup fast Fast custom operations
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
typedef struct mlx_fast_cuda_kernel_config_ {
|
||||
void* ctx;
|
||||
} mlx_fast_cuda_kernel_config;
|
||||
mlx_fast_cuda_kernel_config mlx_fast_cuda_kernel_config_new(void);
|
||||
void mlx_fast_cuda_kernel_config_free(mlx_fast_cuda_kernel_config cls);
|
||||
|
||||
int mlx_fast_cuda_kernel_config_add_output_arg(
|
||||
mlx_fast_cuda_kernel_config cls,
|
||||
const int* shape,
|
||||
size_t size,
|
||||
mlx_dtype dtype);
|
||||
int mlx_fast_cuda_kernel_config_set_grid(
|
||||
mlx_fast_cuda_kernel_config cls,
|
||||
int grid1,
|
||||
int grid2,
|
||||
int grid3);
|
||||
int mlx_fast_cuda_kernel_config_set_thread_group(
|
||||
mlx_fast_cuda_kernel_config cls,
|
||||
int thread1,
|
||||
int thread2,
|
||||
int thread3);
|
||||
int mlx_fast_cuda_kernel_config_set_init_value(
|
||||
mlx_fast_cuda_kernel_config cls,
|
||||
float value);
|
||||
int mlx_fast_cuda_kernel_config_set_verbose(
|
||||
mlx_fast_cuda_kernel_config cls,
|
||||
bool verbose);
|
||||
int mlx_fast_cuda_kernel_config_add_template_arg_dtype(
|
||||
mlx_fast_cuda_kernel_config cls,
|
||||
const char* name,
|
||||
mlx_dtype dtype);
|
||||
int mlx_fast_cuda_kernel_config_add_template_arg_int(
|
||||
mlx_fast_cuda_kernel_config cls,
|
||||
const char* name,
|
||||
int value);
|
||||
int mlx_fast_cuda_kernel_config_add_template_arg_bool(
|
||||
mlx_fast_cuda_kernel_config cls,
|
||||
const char* name,
|
||||
bool value);
|
||||
|
||||
typedef struct mlx_fast_cuda_kernel_ {
|
||||
void* ctx;
|
||||
} mlx_fast_cuda_kernel;
|
||||
|
||||
mlx_fast_cuda_kernel mlx_fast_cuda_kernel_new(
|
||||
const char* name,
|
||||
const mlx_vector_string input_names,
|
||||
const mlx_vector_string output_names,
|
||||
const char* source,
|
||||
const char* header,
|
||||
bool ensure_row_contiguous,
|
||||
int shared_memory);
|
||||
|
||||
void mlx_fast_cuda_kernel_free(mlx_fast_cuda_kernel cls);
|
||||
|
||||
int mlx_fast_cuda_kernel_apply(
|
||||
mlx_vector_array* outputs,
|
||||
mlx_fast_cuda_kernel cls,
|
||||
const mlx_vector_array inputs,
|
||||
const mlx_fast_cuda_kernel_config config,
|
||||
const mlx_stream stream);
|
||||
|
||||
int mlx_fast_layer_norm(
|
||||
mlx_array* res,
|
||||
const mlx_array x,
|
||||
const mlx_array weight /* may be null */,
|
||||
const mlx_array bias /* may be null */,
|
||||
float eps,
|
||||
const mlx_stream s);
|
||||
|
||||
typedef struct mlx_fast_metal_kernel_config_ {
|
||||
void* ctx;
|
||||
} mlx_fast_metal_kernel_config;
|
||||
mlx_fast_metal_kernel_config mlx_fast_metal_kernel_config_new(void);
|
||||
void mlx_fast_metal_kernel_config_free(mlx_fast_metal_kernel_config cls);
|
||||
|
||||
int mlx_fast_metal_kernel_config_add_output_arg(
|
||||
mlx_fast_metal_kernel_config cls,
|
||||
const int* shape,
|
||||
size_t size,
|
||||
mlx_dtype dtype);
|
||||
int mlx_fast_metal_kernel_config_set_grid(
|
||||
mlx_fast_metal_kernel_config cls,
|
||||
int grid1,
|
||||
int grid2,
|
||||
int grid3);
|
||||
int mlx_fast_metal_kernel_config_set_thread_group(
|
||||
mlx_fast_metal_kernel_config cls,
|
||||
int thread1,
|
||||
int thread2,
|
||||
int thread3);
|
||||
int mlx_fast_metal_kernel_config_set_init_value(
|
||||
mlx_fast_metal_kernel_config cls,
|
||||
float value);
|
||||
int mlx_fast_metal_kernel_config_set_verbose(
|
||||
mlx_fast_metal_kernel_config cls,
|
||||
bool verbose);
|
||||
int mlx_fast_metal_kernel_config_add_template_arg_dtype(
|
||||
mlx_fast_metal_kernel_config cls,
|
||||
const char* name,
|
||||
mlx_dtype dtype);
|
||||
int mlx_fast_metal_kernel_config_add_template_arg_int(
|
||||
mlx_fast_metal_kernel_config cls,
|
||||
const char* name,
|
||||
int value);
|
||||
int mlx_fast_metal_kernel_config_add_template_arg_bool(
|
||||
mlx_fast_metal_kernel_config cls,
|
||||
const char* name,
|
||||
bool value);
|
||||
|
||||
typedef struct mlx_fast_metal_kernel_ {
|
||||
void* ctx;
|
||||
} mlx_fast_metal_kernel;
|
||||
|
||||
mlx_fast_metal_kernel mlx_fast_metal_kernel_new(
|
||||
const char* name,
|
||||
const mlx_vector_string input_names,
|
||||
const mlx_vector_string output_names,
|
||||
const char* source,
|
||||
const char* header,
|
||||
bool ensure_row_contiguous,
|
||||
bool atomic_outputs);
|
||||
|
||||
void mlx_fast_metal_kernel_free(mlx_fast_metal_kernel cls);
|
||||
|
||||
int mlx_fast_metal_kernel_apply(
|
||||
mlx_vector_array* outputs,
|
||||
mlx_fast_metal_kernel cls,
|
||||
const mlx_vector_array inputs,
|
||||
const mlx_fast_metal_kernel_config config,
|
||||
const mlx_stream stream);
|
||||
|
||||
int mlx_fast_rms_norm(
|
||||
mlx_array* res,
|
||||
const mlx_array x,
|
||||
const mlx_array weight /* may be null */,
|
||||
float eps,
|
||||
const mlx_stream s);
|
||||
int mlx_fast_rope(
|
||||
mlx_array* res,
|
||||
const mlx_array x,
|
||||
int dims,
|
||||
bool traditional,
|
||||
mlx_optional_float base,
|
||||
float scale,
|
||||
int offset,
|
||||
const mlx_array freqs /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_fast_rope_dynamic(
|
||||
mlx_array* res,
|
||||
const mlx_array x,
|
||||
int dims,
|
||||
bool traditional,
|
||||
mlx_optional_float base,
|
||||
float scale,
|
||||
const mlx_array offset,
|
||||
const mlx_array freqs /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_fast_scaled_dot_product_attention(
|
||||
mlx_array* res,
|
||||
const mlx_array queries,
|
||||
const mlx_array keys,
|
||||
const mlx_array values,
|
||||
float scale,
|
||||
const char* mask_mode,
|
||||
const mlx_array mask_arr /* may be null */,
|
||||
const mlx_array sinks /* may be null */,
|
||||
const mlx_stream s);
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
158
x/mlxrunner/mlx/include/mlx/c/fft.h
Normal file
158
x/mlxrunner/mlx/include/mlx/c/fft.h
Normal file
@@ -0,0 +1,158 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
/* */
|
||||
/* This file is auto-generated. Do not edit manually. */
|
||||
/* */
|
||||
|
||||
#ifndef MLX_FFT_H
|
||||
#define MLX_FFT_H
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "mlx/c/array.h"
|
||||
#include "mlx/c/closure.h"
|
||||
#include "mlx/c/distributed_group.h"
|
||||
#include "mlx/c/io_types.h"
|
||||
#include "mlx/c/map.h"
|
||||
#include "mlx/c/stream.h"
|
||||
#include "mlx/c/string.h"
|
||||
#include "mlx/c/vector.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup fft FFT operations
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
typedef enum mlx_fft_norm_ {
|
||||
MLX_FFT_NORM_BACKWARD,
|
||||
MLX_FFT_NORM_ORTHO,
|
||||
MLX_FFT_NORM_FORWARD
|
||||
} mlx_fft_norm;
|
||||
|
||||
int mlx_fft_fft(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
int n,
|
||||
int axis,
|
||||
mlx_fft_norm norm,
|
||||
const mlx_stream s);
|
||||
int mlx_fft_fft2(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const int* n,
|
||||
size_t n_num,
|
||||
const int* axes,
|
||||
size_t axes_num,
|
||||
mlx_fft_norm norm,
|
||||
const mlx_stream s);
|
||||
int mlx_fft_fftfreq(mlx_array* res, int n, double d, const mlx_stream s);
|
||||
int mlx_fft_fftn(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const int* n,
|
||||
size_t n_num,
|
||||
const int* axes,
|
||||
size_t axes_num,
|
||||
mlx_fft_norm norm,
|
||||
const mlx_stream s);
|
||||
int mlx_fft_fftshift(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const int* axes,
|
||||
size_t axes_num,
|
||||
const mlx_stream s);
|
||||
int mlx_fft_ifft(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
int n,
|
||||
int axis,
|
||||
mlx_fft_norm norm,
|
||||
const mlx_stream s);
|
||||
int mlx_fft_ifft2(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const int* n,
|
||||
size_t n_num,
|
||||
const int* axes,
|
||||
size_t axes_num,
|
||||
mlx_fft_norm norm,
|
||||
const mlx_stream s);
|
||||
int mlx_fft_ifftn(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const int* n,
|
||||
size_t n_num,
|
||||
const int* axes,
|
||||
size_t axes_num,
|
||||
mlx_fft_norm norm,
|
||||
const mlx_stream s);
|
||||
int mlx_fft_ifftshift(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const int* axes,
|
||||
size_t axes_num,
|
||||
const mlx_stream s);
|
||||
int mlx_fft_irfft(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
int n,
|
||||
int axis,
|
||||
mlx_fft_norm norm,
|
||||
const mlx_stream s);
|
||||
int mlx_fft_irfft2(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const int* n,
|
||||
size_t n_num,
|
||||
const int* axes,
|
||||
size_t axes_num,
|
||||
mlx_fft_norm norm,
|
||||
const mlx_stream s);
|
||||
int mlx_fft_irfftn(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const int* n,
|
||||
size_t n_num,
|
||||
const int* axes,
|
||||
size_t axes_num,
|
||||
mlx_fft_norm norm,
|
||||
const mlx_stream s);
|
||||
int mlx_fft_rfft(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
int n,
|
||||
int axis,
|
||||
mlx_fft_norm norm,
|
||||
const mlx_stream s);
|
||||
int mlx_fft_rfft2(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const int* n,
|
||||
size_t n_num,
|
||||
const int* axes,
|
||||
size_t axes_num,
|
||||
mlx_fft_norm norm,
|
||||
const mlx_stream s);
|
||||
int mlx_fft_rfftfreq(mlx_array* res, int n, double d, const mlx_stream s);
|
||||
int mlx_fft_rfftn(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const int* n,
|
||||
size_t n_num,
|
||||
const int* axes,
|
||||
size_t axes_num,
|
||||
mlx_fft_norm norm,
|
||||
const mlx_stream s);
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
61
x/mlxrunner/mlx/include/mlx/c/graph_utils.h
Normal file
61
x/mlxrunner/mlx/include/mlx/c/graph_utils.h
Normal file
@@ -0,0 +1,61 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
/* */
|
||||
/* This file is auto-generated. Do not edit manually. */
|
||||
/* */
|
||||
|
||||
#ifndef MLX_GRAPH_UTILS_H
|
||||
#define MLX_GRAPH_UTILS_H
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "mlx/c/array.h"
|
||||
#include "mlx/c/closure.h"
|
||||
#include "mlx/c/distributed_group.h"
|
||||
#include "mlx/c/io_types.h"
|
||||
#include "mlx/c/map.h"
|
||||
#include "mlx/c/stream.h"
|
||||
#include "mlx/c/string.h"
|
||||
#include "mlx/c/vector.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup graph_utils Graph Utils
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
typedef struct mlx_node_namer_ {
|
||||
void* ctx;
|
||||
} mlx_node_namer;
|
||||
|
||||
mlx_node_namer mlx_node_namer_new();
|
||||
int mlx_node_namer_free(mlx_node_namer namer);
|
||||
int mlx_node_namer_set_name(
|
||||
mlx_node_namer namer,
|
||||
const mlx_array arr,
|
||||
const char* name);
|
||||
int mlx_node_namer_get_name(
|
||||
const char** name,
|
||||
mlx_node_namer namer,
|
||||
const mlx_array arr);
|
||||
|
||||
int mlx_export_to_dot(
|
||||
FILE* os,
|
||||
const mlx_node_namer namer,
|
||||
const mlx_vector_array outputs);
|
||||
int mlx_print_graph(
|
||||
FILE* os,
|
||||
const mlx_node_namer namer,
|
||||
const mlx_vector_array outputs);
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
26
x/mlxrunner/mlx/include/mlx/c/half.h
Normal file
26
x/mlxrunner/mlx/include/mlx/c/half.h
Normal file
@@ -0,0 +1,26 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
|
||||
#ifndef MLX_HALF_H
|
||||
#define MLX_HALF_H
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#if defined(__ARM_FEATURE_FP16_SCALAR_ARITHMETIC) || defined(__aarch64__)
|
||||
#define HAS_FLOAT16
|
||||
#include <arm_fp16.h>
|
||||
typedef __fp16 float16_t;
|
||||
#endif
|
||||
|
||||
#if defined(__ARM_FEATURE_BF16) || defined(__aarch64__)
|
||||
#define HAS_BFLOAT16
|
||||
#include <arm_bf16.h>
|
||||
typedef __bf16 bfloat16_t;
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
68
x/mlxrunner/mlx/include/mlx/c/io.h
Normal file
68
x/mlxrunner/mlx/include/mlx/c/io.h
Normal file
@@ -0,0 +1,68 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
/* */
|
||||
/* This file is auto-generated. Do not edit manually. */
|
||||
/* */
|
||||
|
||||
#ifndef MLX_IO_H
|
||||
#define MLX_IO_H
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "mlx/c/array.h"
|
||||
#include "mlx/c/closure.h"
|
||||
#include "mlx/c/distributed_group.h"
|
||||
#include "mlx/c/io_types.h"
|
||||
#include "mlx/c/map.h"
|
||||
#include "mlx/c/stream.h"
|
||||
#include "mlx/c/string.h"
|
||||
#include "mlx/c/vector.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup io IO operations
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
int mlx_load_reader(
|
||||
mlx_array* res,
|
||||
mlx_io_reader in_stream,
|
||||
const mlx_stream s);
|
||||
int mlx_load(mlx_array* res, const char* file, const mlx_stream s);
|
||||
|
||||
int mlx_load_gguf(mlx_io_gguf* gguf, const char* file, const mlx_stream s);
|
||||
|
||||
int mlx_load_safetensors_reader(
|
||||
mlx_map_string_to_array* res_0,
|
||||
mlx_map_string_to_string* res_1,
|
||||
mlx_io_reader in_stream,
|
||||
const mlx_stream s);
|
||||
int mlx_load_safetensors(
|
||||
mlx_map_string_to_array* res_0,
|
||||
mlx_map_string_to_string* res_1,
|
||||
const char* file,
|
||||
const mlx_stream s);
|
||||
int mlx_save_writer(mlx_io_writer out_stream, const mlx_array a);
|
||||
int mlx_save(const char* file, const mlx_array a);
|
||||
int mlx_save_gguf(const char* file, mlx_io_gguf gguf);
|
||||
|
||||
int mlx_save_safetensors_writer(
|
||||
mlx_io_writer in_stream,
|
||||
const mlx_map_string_to_array param,
|
||||
const mlx_map_string_to_string metadata);
|
||||
int mlx_save_safetensors(
|
||||
const char* file,
|
||||
const mlx_map_string_to_array param,
|
||||
const mlx_map_string_to_string metadata);
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
150
x/mlxrunner/mlx/include/mlx/c/io_types.h
Normal file
150
x/mlxrunner/mlx/include/mlx/c/io_types.h
Normal file
@@ -0,0 +1,150 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
|
||||
#ifndef MLX_IO_TYPES_H
|
||||
#define MLX_IO_TYPES_H
|
||||
|
||||
#include <stdbool.h>
|
||||
|
||||
#include "mlx/c/string.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup mlx_io_types IO Types
|
||||
* MLX IO type objects.
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
/**
|
||||
* A MLX IO reader object.
|
||||
*/
|
||||
typedef struct mlx_io_reader_ {
|
||||
void* ctx;
|
||||
} mlx_io_reader;
|
||||
/**
|
||||
* A MLX IO writer object.
|
||||
*/
|
||||
typedef struct mlx_io_writer_ {
|
||||
void* ctx;
|
||||
} mlx_io_writer;
|
||||
|
||||
/**
|
||||
* Virtual table for custom IO reader and writer objects.
|
||||
*/
|
||||
typedef struct mlx_io_vtable_ {
|
||||
bool (*is_open)(void*);
|
||||
bool (*good)(void*);
|
||||
size_t (*tell)(void*);
|
||||
void (*seek)(void*, int64_t off, int whence);
|
||||
void (*read)(void*, char* data, size_t n);
|
||||
void (*read_at_offset)(void*, char* data, size_t n, size_t off);
|
||||
void (*write)(void*, const char* data, size_t n);
|
||||
const char* (*label)(void*);
|
||||
void (*free)(void*);
|
||||
} mlx_io_vtable;
|
||||
|
||||
/**
|
||||
* Returns a new custom IO reader.
|
||||
* `vtable` operates on user descriptor `desc`.
|
||||
*/
|
||||
mlx_io_reader mlx_io_reader_new(void* desc, mlx_io_vtable vtable);
|
||||
|
||||
/**
|
||||
* Get IO reader user descriptor.
|
||||
*/
|
||||
int mlx_io_reader_descriptor(void** desc_, mlx_io_reader io);
|
||||
|
||||
/**
|
||||
* Get IO reader description.
|
||||
*/
|
||||
int mlx_io_reader_tostring(mlx_string* str_, mlx_io_reader io);
|
||||
|
||||
/**
|
||||
* Free IO reader.
|
||||
*
|
||||
* Note that MLX arrays are lazily evaluated, so the underlying object may
|
||||
* be not freed right away. The ``free()`` callback from ``mlx_io_vtable``
|
||||
* will be called when the underlying object is actually freed.
|
||||
*/
|
||||
int mlx_io_reader_free(mlx_io_reader io);
|
||||
|
||||
/**
|
||||
* Returns a new custom IO writer.
|
||||
* `vtable` operates on user descriptor `desc`.
|
||||
*/
|
||||
mlx_io_writer mlx_io_writer_new(void* desc, mlx_io_vtable vtable);
|
||||
|
||||
/**
|
||||
* Get IO writer user descriptor.
|
||||
*/
|
||||
int mlx_io_writer_descriptor(void** desc_, mlx_io_writer io);
|
||||
|
||||
/**
|
||||
* Get IO writer description.
|
||||
*/
|
||||
int mlx_io_writer_tostring(mlx_string* str_, mlx_io_writer io);
|
||||
|
||||
/**
|
||||
* Free IO writer.
|
||||
*
|
||||
* Note that MLX arrays are lazily evaluated, so the underlying object may
|
||||
* be not freed right away. The ``free()`` callback from ``mlx_io_vtable``
|
||||
* will be called when the underlying object is actually freed.
|
||||
*/
|
||||
int mlx_io_writer_free(mlx_io_writer io);
|
||||
|
||||
/**
|
||||
* A MLX GGUF object.
|
||||
*/
|
||||
typedef struct mlx_io_gguf_ {
|
||||
void* ctx;
|
||||
} mlx_io_gguf;
|
||||
|
||||
mlx_io_gguf mlx_io_gguf_new(void);
|
||||
int mlx_io_gguf_free(mlx_io_gguf io);
|
||||
int mlx_io_gguf_get_keys(mlx_vector_string* keys, mlx_io_gguf io);
|
||||
int mlx_io_gguf_get_array(mlx_array* arr, mlx_io_gguf io, const char* key);
|
||||
int mlx_io_gguf_get_metadata_array(
|
||||
mlx_array* arr,
|
||||
mlx_io_gguf io,
|
||||
const char* key);
|
||||
int mlx_io_gguf_get_metadata_string(
|
||||
mlx_string* str,
|
||||
mlx_io_gguf io,
|
||||
const char* key);
|
||||
int mlx_io_gguf_get_metadata_vector_string(
|
||||
mlx_vector_string* vstr,
|
||||
mlx_io_gguf io,
|
||||
const char* key);
|
||||
int mlx_io_gguf_has_metadata_array(bool* flag, mlx_io_gguf io, const char* key);
|
||||
int mlx_io_gguf_has_metadata_string(
|
||||
bool* flag,
|
||||
mlx_io_gguf io,
|
||||
const char* key);
|
||||
int mlx_io_gguf_has_metadata_vector_string(
|
||||
bool* flag,
|
||||
mlx_io_gguf io,
|
||||
const char* key);
|
||||
int mlx_io_gguf_set_array(mlx_io_gguf io, const char* key, const mlx_array arr);
|
||||
int mlx_io_gguf_set_metadata_array(
|
||||
mlx_io_gguf io,
|
||||
const char* key,
|
||||
const mlx_array marr);
|
||||
int mlx_io_gguf_set_metadata_string(
|
||||
mlx_io_gguf io,
|
||||
const char* key,
|
||||
const char* mstr);
|
||||
int mlx_io_gguf_set_metadata_vector_string(
|
||||
mlx_io_gguf io,
|
||||
const char* key,
|
||||
const mlx_vector_string mvstr);
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
128
x/mlxrunner/mlx/include/mlx/c/linalg.h
Normal file
128
x/mlxrunner/mlx/include/mlx/c/linalg.h
Normal file
@@ -0,0 +1,128 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
/* */
|
||||
/* This file is auto-generated. Do not edit manually. */
|
||||
/* */
|
||||
|
||||
#ifndef MLX_LINALG_H
|
||||
#define MLX_LINALG_H
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "mlx/c/array.h"
|
||||
#include "mlx/c/closure.h"
|
||||
#include "mlx/c/distributed_group.h"
|
||||
#include "mlx/c/io_types.h"
|
||||
#include "mlx/c/map.h"
|
||||
#include "mlx/c/stream.h"
|
||||
#include "mlx/c/string.h"
|
||||
#include "mlx/c/vector.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup linalg Linear algebra operations
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
int mlx_linalg_cholesky(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
bool upper,
|
||||
const mlx_stream s);
|
||||
int mlx_linalg_cholesky_inv(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
bool upper,
|
||||
const mlx_stream s);
|
||||
int mlx_linalg_cross(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const mlx_array b,
|
||||
int axis,
|
||||
const mlx_stream s);
|
||||
int mlx_linalg_eig(
|
||||
mlx_array* res_0,
|
||||
mlx_array* res_1,
|
||||
const mlx_array a,
|
||||
const mlx_stream s);
|
||||
int mlx_linalg_eigh(
|
||||
mlx_array* res_0,
|
||||
mlx_array* res_1,
|
||||
const mlx_array a,
|
||||
const char* UPLO,
|
||||
const mlx_stream s);
|
||||
int mlx_linalg_eigvals(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||
int mlx_linalg_eigvalsh(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const char* UPLO,
|
||||
const mlx_stream s);
|
||||
int mlx_linalg_inv(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||
int mlx_linalg_lu(mlx_vector_array* res, const mlx_array a, const mlx_stream s);
|
||||
int mlx_linalg_lu_factor(
|
||||
mlx_array* res_0,
|
||||
mlx_array* res_1,
|
||||
const mlx_array a,
|
||||
const mlx_stream s);
|
||||
int mlx_linalg_norm(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
double ord,
|
||||
const int* axis /* may be null */,
|
||||
size_t axis_num,
|
||||
bool keepdims,
|
||||
const mlx_stream s);
|
||||
int mlx_linalg_norm_matrix(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const char* ord,
|
||||
const int* axis /* may be null */,
|
||||
size_t axis_num,
|
||||
bool keepdims,
|
||||
const mlx_stream s);
|
||||
int mlx_linalg_norm_l2(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const int* axis /* may be null */,
|
||||
size_t axis_num,
|
||||
bool keepdims,
|
||||
const mlx_stream s);
|
||||
int mlx_linalg_pinv(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||
int mlx_linalg_qr(
|
||||
mlx_array* res_0,
|
||||
mlx_array* res_1,
|
||||
const mlx_array a,
|
||||
const mlx_stream s);
|
||||
int mlx_linalg_solve(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const mlx_array b,
|
||||
const mlx_stream s);
|
||||
int mlx_linalg_solve_triangular(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const mlx_array b,
|
||||
bool upper,
|
||||
const mlx_stream s);
|
||||
int mlx_linalg_svd(
|
||||
mlx_vector_array* res,
|
||||
const mlx_array a,
|
||||
bool compute_uv,
|
||||
const mlx_stream s);
|
||||
int mlx_linalg_tri_inv(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
bool upper,
|
||||
const mlx_stream s);
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
149
x/mlxrunner/mlx/include/mlx/c/map.h
Normal file
149
x/mlxrunner/mlx/include/mlx/c/map.h
Normal file
@@ -0,0 +1,149 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
/* */
|
||||
/* This file is auto-generated. Do not edit manually. */
|
||||
/* */
|
||||
|
||||
#ifndef MLX_MAP_H
|
||||
#define MLX_MAP_H
|
||||
|
||||
#include "mlx/c/array.h"
|
||||
#include "mlx/c/string.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup mlx_map Maps
|
||||
* MLX map objects.
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
/**
|
||||
* A string-to-array map
|
||||
*/
|
||||
typedef struct mlx_map_string_to_array_ {
|
||||
void* ctx;
|
||||
} mlx_map_string_to_array;
|
||||
|
||||
/**
|
||||
* Returns a new empty string-to-array map.
|
||||
*/
|
||||
mlx_map_string_to_array mlx_map_string_to_array_new(void);
|
||||
/**
|
||||
* Set map to provided src map.
|
||||
*/
|
||||
int mlx_map_string_to_array_set(
|
||||
mlx_map_string_to_array* map,
|
||||
const mlx_map_string_to_array src);
|
||||
/**
|
||||
* Free a string-to-array map.
|
||||
*/
|
||||
int mlx_map_string_to_array_free(mlx_map_string_to_array map);
|
||||
/**
|
||||
* Insert a new `value` at the specified `key` in the map.
|
||||
*/
|
||||
int mlx_map_string_to_array_insert(
|
||||
mlx_map_string_to_array map,
|
||||
const char* key,
|
||||
const mlx_array value);
|
||||
/**
|
||||
* Returns the value indexed at the specified `key` in the map.
|
||||
*/
|
||||
int mlx_map_string_to_array_get(
|
||||
mlx_array* value,
|
||||
const mlx_map_string_to_array map,
|
||||
const char* key);
|
||||
|
||||
/**
|
||||
* An iterator over a string-to-array map.
|
||||
*/
|
||||
typedef struct mlx_map_string_to_array_iterator_ {
|
||||
void* ctx;
|
||||
void* map_ctx;
|
||||
} mlx_map_string_to_array_iterator;
|
||||
/**
|
||||
* Returns a new iterator over the given map.
|
||||
*/
|
||||
mlx_map_string_to_array_iterator mlx_map_string_to_array_iterator_new(
|
||||
mlx_map_string_to_array map);
|
||||
/**
|
||||
* Free iterator.
|
||||
*/
|
||||
int mlx_map_string_to_array_iterator_free(mlx_map_string_to_array_iterator it);
|
||||
/**
|
||||
* Increment iterator.
|
||||
*/
|
||||
int mlx_map_string_to_array_iterator_next(
|
||||
const char** key,
|
||||
mlx_array* value,
|
||||
mlx_map_string_to_array_iterator it);
|
||||
|
||||
/**
|
||||
* A string-to-string map
|
||||
*/
|
||||
typedef struct mlx_map_string_to_string_ {
|
||||
void* ctx;
|
||||
} mlx_map_string_to_string;
|
||||
|
||||
/**
|
||||
* Returns a new empty string-to-string map.
|
||||
*/
|
||||
mlx_map_string_to_string mlx_map_string_to_string_new(void);
|
||||
/**
|
||||
* Set map to provided src map.
|
||||
*/
|
||||
int mlx_map_string_to_string_set(
|
||||
mlx_map_string_to_string* map,
|
||||
const mlx_map_string_to_string src);
|
||||
/**
|
||||
* Free a string-to-string map.
|
||||
*/
|
||||
int mlx_map_string_to_string_free(mlx_map_string_to_string map);
|
||||
/**
|
||||
* Insert a new `value` at the specified `key` in the map.
|
||||
*/
|
||||
int mlx_map_string_to_string_insert(
|
||||
mlx_map_string_to_string map,
|
||||
const char* key,
|
||||
const char* value);
|
||||
/**
|
||||
* Returns the value indexed at the specified `key` in the map.
|
||||
*/
|
||||
int mlx_map_string_to_string_get(
|
||||
const char** value,
|
||||
const mlx_map_string_to_string map,
|
||||
const char* key);
|
||||
|
||||
/**
|
||||
* An iterator over a string-to-string map.
|
||||
*/
|
||||
typedef struct mlx_map_string_to_string_iterator_ {
|
||||
void* ctx;
|
||||
void* map_ctx;
|
||||
} mlx_map_string_to_string_iterator;
|
||||
/**
|
||||
* Returns a new iterator over the given map.
|
||||
*/
|
||||
mlx_map_string_to_string_iterator mlx_map_string_to_string_iterator_new(
|
||||
mlx_map_string_to_string map);
|
||||
/**
|
||||
* Free iterator.
|
||||
*/
|
||||
int mlx_map_string_to_string_iterator_free(
|
||||
mlx_map_string_to_string_iterator it);
|
||||
/**
|
||||
* Increment iterator.
|
||||
*/
|
||||
int mlx_map_string_to_string_iterator_next(
|
||||
const char** key,
|
||||
const char** value,
|
||||
mlx_map_string_to_string_iterator it);
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
47
x/mlxrunner/mlx/include/mlx/c/memory.h
Normal file
47
x/mlxrunner/mlx/include/mlx/c/memory.h
Normal file
@@ -0,0 +1,47 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
/* */
|
||||
/* This file is auto-generated. Do not edit manually. */
|
||||
/* */
|
||||
|
||||
#ifndef MLX_MEMORY_H
|
||||
#define MLX_MEMORY_H
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "mlx/c/array.h"
|
||||
#include "mlx/c/closure.h"
|
||||
#include "mlx/c/distributed_group.h"
|
||||
#include "mlx/c/io_types.h"
|
||||
#include "mlx/c/map.h"
|
||||
#include "mlx/c/stream.h"
|
||||
#include "mlx/c/string.h"
|
||||
#include "mlx/c/vector.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup memory Memory operations
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
int mlx_clear_cache(void);
|
||||
int mlx_get_active_memory(size_t* res);
|
||||
int mlx_get_cache_memory(size_t* res);
|
||||
int mlx_get_memory_limit(size_t* res);
|
||||
int mlx_get_peak_memory(size_t* res);
|
||||
int mlx_reset_peak_memory(void);
|
||||
int mlx_set_cache_limit(size_t* res, size_t limit);
|
||||
int mlx_set_memory_limit(size_t* res, size_t limit);
|
||||
int mlx_set_wired_limit(size_t* res, size_t limit);
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
41
x/mlxrunner/mlx/include/mlx/c/metal.h
Normal file
41
x/mlxrunner/mlx/include/mlx/c/metal.h
Normal file
@@ -0,0 +1,41 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
/* */
|
||||
/* This file is auto-generated. Do not edit manually. */
|
||||
/* */
|
||||
|
||||
#ifndef MLX_METAL_H
|
||||
#define MLX_METAL_H
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "mlx/c/array.h"
|
||||
#include "mlx/c/closure.h"
|
||||
#include "mlx/c/distributed_group.h"
|
||||
#include "mlx/c/io_types.h"
|
||||
#include "mlx/c/map.h"
|
||||
#include "mlx/c/stream.h"
|
||||
#include "mlx/c/string.h"
|
||||
#include "mlx/c/vector.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup metal Metal specific operations
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
int mlx_metal_is_available(bool* res);
|
||||
int mlx_metal_start_capture(const char* path);
|
||||
int mlx_metal_stop_capture(void);
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
35
x/mlxrunner/mlx/include/mlx/c/mlx.h
Normal file
35
x/mlxrunner/mlx/include/mlx/c/mlx.h
Normal file
@@ -0,0 +1,35 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
|
||||
#ifndef MLX_ALL_H
|
||||
#define MLX_ALL_H
|
||||
|
||||
#include "mlx/c/array.h"
|
||||
#include "mlx/c/closure.h"
|
||||
#include "mlx/c/compile.h"
|
||||
#include "mlx/c/cuda.h"
|
||||
#include "mlx/c/device.h"
|
||||
#include "mlx/c/distributed.h"
|
||||
#include "mlx/c/distributed_group.h"
|
||||
#include "mlx/c/error.h"
|
||||
#include "mlx/c/export.h"
|
||||
#include "mlx/c/fast.h"
|
||||
#include "mlx/c/fft.h"
|
||||
#include "mlx/c/graph_utils.h"
|
||||
#include "mlx/c/half.h"
|
||||
#include "mlx/c/io.h"
|
||||
#include "mlx/c/io_types.h"
|
||||
#include "mlx/c/linalg.h"
|
||||
#include "mlx/c/map.h"
|
||||
#include "mlx/c/memory.h"
|
||||
#include "mlx/c/metal.h"
|
||||
#include "mlx/c/ops.h"
|
||||
#include "mlx/c/optional.h"
|
||||
#include "mlx/c/random.h"
|
||||
#include "mlx/c/stream.h"
|
||||
#include "mlx/c/string.h"
|
||||
#include "mlx/c/transforms.h"
|
||||
#include "mlx/c/transforms_impl.h"
|
||||
#include "mlx/c/vector.h"
|
||||
#include "mlx/c/version.h"
|
||||
|
||||
#endif
|
||||
1287
x/mlxrunner/mlx/include/mlx/c/ops.h
Normal file
1287
x/mlxrunner/mlx/include/mlx/c/ops.h
Normal file
File diff suppressed because it is too large
Load Diff
51
x/mlxrunner/mlx/include/mlx/c/optional.h
Normal file
51
x/mlxrunner/mlx/include/mlx/c/optional.h
Normal file
@@ -0,0 +1,51 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
|
||||
#ifndef MLX_OPTIONAL_H
|
||||
#define MLX_OPTIONAL_H
|
||||
|
||||
#include <stdbool.h>
|
||||
|
||||
#include "mlx/c/array.h"
|
||||
#include "mlx/c/string.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup mlx_optional Optionals
|
||||
* MLX optional scalars.
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
/**
|
||||
* A int optional.
|
||||
*/
|
||||
typedef struct mlx_optional_int_ {
|
||||
int value;
|
||||
bool has_value;
|
||||
} mlx_optional_int;
|
||||
|
||||
/**
|
||||
* A float optional.
|
||||
*/
|
||||
typedef struct mlx_optional_float_ {
|
||||
float value;
|
||||
bool has_value;
|
||||
} mlx_optional_float;
|
||||
|
||||
/**
|
||||
* A dtype optional.
|
||||
*/
|
||||
typedef struct mlx_optional_dtype_ {
|
||||
mlx_dtype value;
|
||||
bool has_value;
|
||||
} mlx_optional_dtype;
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
166
x/mlxrunner/mlx/include/mlx/c/random.h
Normal file
166
x/mlxrunner/mlx/include/mlx/c/random.h
Normal file
@@ -0,0 +1,166 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
/* */
|
||||
/* This file is auto-generated. Do not edit manually. */
|
||||
/* */
|
||||
|
||||
#ifndef MLX_RANDOM_H
|
||||
#define MLX_RANDOM_H
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "mlx/c/array.h"
|
||||
#include "mlx/c/closure.h"
|
||||
#include "mlx/c/distributed_group.h"
|
||||
#include "mlx/c/io_types.h"
|
||||
#include "mlx/c/map.h"
|
||||
#include "mlx/c/stream.h"
|
||||
#include "mlx/c/string.h"
|
||||
#include "mlx/c/vector.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup random Random number operations
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
int mlx_random_bernoulli(
|
||||
mlx_array* res,
|
||||
const mlx_array p,
|
||||
const int* shape,
|
||||
size_t shape_num,
|
||||
const mlx_array key /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_random_bits(
|
||||
mlx_array* res,
|
||||
const int* shape,
|
||||
size_t shape_num,
|
||||
int width,
|
||||
const mlx_array key /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_random_categorical_shape(
|
||||
mlx_array* res,
|
||||
const mlx_array logits,
|
||||
int axis,
|
||||
const int* shape,
|
||||
size_t shape_num,
|
||||
const mlx_array key /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_random_categorical_num_samples(
|
||||
mlx_array* res,
|
||||
const mlx_array logits_,
|
||||
int axis,
|
||||
int num_samples,
|
||||
const mlx_array key /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_random_categorical(
|
||||
mlx_array* res,
|
||||
const mlx_array logits,
|
||||
int axis,
|
||||
const mlx_array key /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_random_gumbel(
|
||||
mlx_array* res,
|
||||
const int* shape,
|
||||
size_t shape_num,
|
||||
mlx_dtype dtype,
|
||||
const mlx_array key /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_random_key(mlx_array* res, uint64_t seed);
|
||||
int mlx_random_laplace(
|
||||
mlx_array* res,
|
||||
const int* shape,
|
||||
size_t shape_num,
|
||||
mlx_dtype dtype,
|
||||
float loc,
|
||||
float scale,
|
||||
const mlx_array key /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_random_multivariate_normal(
|
||||
mlx_array* res,
|
||||
const mlx_array mean,
|
||||
const mlx_array cov,
|
||||
const int* shape,
|
||||
size_t shape_num,
|
||||
mlx_dtype dtype,
|
||||
const mlx_array key /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_random_normal_broadcast(
|
||||
mlx_array* res,
|
||||
const int* shape,
|
||||
size_t shape_num,
|
||||
mlx_dtype dtype,
|
||||
const mlx_array loc /* may be null */,
|
||||
const mlx_array scale /* may be null */,
|
||||
const mlx_array key /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_random_normal(
|
||||
mlx_array* res,
|
||||
const int* shape,
|
||||
size_t shape_num,
|
||||
mlx_dtype dtype,
|
||||
float loc,
|
||||
float scale,
|
||||
const mlx_array key /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_random_permutation(
|
||||
mlx_array* res,
|
||||
const mlx_array x,
|
||||
int axis,
|
||||
const mlx_array key /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_random_permutation_arange(
|
||||
mlx_array* res,
|
||||
int x,
|
||||
const mlx_array key /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_random_randint(
|
||||
mlx_array* res,
|
||||
const mlx_array low,
|
||||
const mlx_array high,
|
||||
const int* shape,
|
||||
size_t shape_num,
|
||||
mlx_dtype dtype,
|
||||
const mlx_array key /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_random_seed(uint64_t seed);
|
||||
int mlx_random_split_num(
|
||||
mlx_array* res,
|
||||
const mlx_array key,
|
||||
int num,
|
||||
const mlx_stream s);
|
||||
int mlx_random_split(
|
||||
mlx_array* res_0,
|
||||
mlx_array* res_1,
|
||||
const mlx_array key,
|
||||
const mlx_stream s);
|
||||
int mlx_random_truncated_normal(
|
||||
mlx_array* res,
|
||||
const mlx_array lower,
|
||||
const mlx_array upper,
|
||||
const int* shape,
|
||||
size_t shape_num,
|
||||
mlx_dtype dtype,
|
||||
const mlx_array key /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_random_uniform(
|
||||
mlx_array* res,
|
||||
const mlx_array low,
|
||||
const mlx_array high,
|
||||
const int* shape,
|
||||
size_t shape_num,
|
||||
mlx_dtype dtype,
|
||||
const mlx_array key /* may be null */,
|
||||
const mlx_stream s);
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
88
x/mlxrunner/mlx/include/mlx/c/stream.h
Normal file
88
x/mlxrunner/mlx/include/mlx/c/stream.h
Normal file
@@ -0,0 +1,88 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
|
||||
#ifndef MLX_STREAM_H
|
||||
#define MLX_STREAM_H
|
||||
|
||||
#include <stdbool.h>
|
||||
|
||||
#include "mlx/c/device.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup mlx_stream Stream
|
||||
* MLX stream object.
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
/**
|
||||
* A MLX stream object.
|
||||
*/
|
||||
typedef struct mlx_stream_ {
|
||||
void* ctx;
|
||||
} mlx_stream;
|
||||
|
||||
/**
|
||||
* Returns a new empty stream.
|
||||
*/
|
||||
mlx_stream mlx_stream_new(void);
|
||||
|
||||
/**
|
||||
* Returns a new stream on a device.
|
||||
*/
|
||||
mlx_stream mlx_stream_new_device(mlx_device dev);
|
||||
/**
|
||||
* Set stream to provided src stream.
|
||||
*/
|
||||
int mlx_stream_set(mlx_stream* stream, const mlx_stream src);
|
||||
/**
|
||||
* Free a stream.
|
||||
*/
|
||||
int mlx_stream_free(mlx_stream stream);
|
||||
/**
|
||||
* Get stream description.
|
||||
*/
|
||||
int mlx_stream_tostring(mlx_string* str, mlx_stream stream);
|
||||
/**
|
||||
* Check if streams are the same.
|
||||
*/
|
||||
bool mlx_stream_equal(mlx_stream lhs, mlx_stream rhs);
|
||||
/**
|
||||
* Return the device of the stream.
|
||||
*/
|
||||
int mlx_stream_get_device(mlx_device* dev, mlx_stream stream);
|
||||
/**
|
||||
* Return the index of the stream.
|
||||
*/
|
||||
int mlx_stream_get_index(int* index, mlx_stream stream);
|
||||
/**
|
||||
* Synchronize with the provided stream.
|
||||
*/
|
||||
int mlx_synchronize(mlx_stream stream);
|
||||
/**
|
||||
* Returns the default stream on the given device.
|
||||
*/
|
||||
int mlx_get_default_stream(mlx_stream* stream, mlx_device dev);
|
||||
/**
|
||||
* Set default stream.
|
||||
*/
|
||||
int mlx_set_default_stream(mlx_stream stream);
|
||||
/**
|
||||
* Returns the current default CPU stream.
|
||||
*/
|
||||
mlx_stream mlx_default_cpu_stream_new(void);
|
||||
|
||||
/**
|
||||
* Returns the current default GPU stream.
|
||||
*/
|
||||
mlx_stream mlx_default_gpu_stream_new(void);
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
55
x/mlxrunner/mlx/include/mlx/c/string.h
Normal file
55
x/mlxrunner/mlx/include/mlx/c/string.h
Normal file
@@ -0,0 +1,55 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
|
||||
#ifndef MLX_STRING_H
|
||||
#define MLX_STRING_H
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup mlx_string String
|
||||
* MLX string object.
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
/**
|
||||
* A MLX string object.
|
||||
*/
|
||||
typedef struct mlx_string_ {
|
||||
void* ctx;
|
||||
} mlx_string;
|
||||
|
||||
/**
|
||||
* Returns a new empty string.
|
||||
*/
|
||||
mlx_string mlx_string_new(void);
|
||||
|
||||
/**
|
||||
* Returns a new string, copying contents from `str`, which must end with `\0`.
|
||||
*/
|
||||
mlx_string mlx_string_new_data(const char* str);
|
||||
|
||||
/**
|
||||
* Set string to src string.
|
||||
*/
|
||||
int mlx_string_set(mlx_string* str, const mlx_string src);
|
||||
|
||||
/**
|
||||
* Returns a pointer to the string contents.
|
||||
* The pointer is valid for the life duration of the string.
|
||||
*/
|
||||
const char* mlx_string_data(mlx_string str);
|
||||
|
||||
/**
|
||||
* Free string.
|
||||
*/
|
||||
int mlx_string_free(mlx_string str);
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
68
x/mlxrunner/mlx/include/mlx/c/transforms.h
Normal file
68
x/mlxrunner/mlx/include/mlx/c/transforms.h
Normal file
@@ -0,0 +1,68 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
/* */
|
||||
/* This file is auto-generated. Do not edit manually. */
|
||||
/* */
|
||||
|
||||
#ifndef MLX_TRANSFORMS_H
|
||||
#define MLX_TRANSFORMS_H
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "mlx/c/array.h"
|
||||
#include "mlx/c/closure.h"
|
||||
#include "mlx/c/distributed_group.h"
|
||||
#include "mlx/c/io_types.h"
|
||||
#include "mlx/c/map.h"
|
||||
#include "mlx/c/stream.h"
|
||||
#include "mlx/c/string.h"
|
||||
#include "mlx/c/vector.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup transforms Transform operations
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
int mlx_async_eval(const mlx_vector_array outputs);
|
||||
int mlx_checkpoint(mlx_closure* res, const mlx_closure fun);
|
||||
int mlx_custom_function(
|
||||
mlx_closure* res,
|
||||
const mlx_closure fun,
|
||||
const mlx_closure_custom fun_vjp /* may be null */,
|
||||
const mlx_closure_custom_jvp fun_jvp /* may be null */,
|
||||
const mlx_closure_custom_vmap fun_vmap /* may be null */);
|
||||
int mlx_custom_vjp(
|
||||
mlx_closure* res,
|
||||
const mlx_closure fun,
|
||||
const mlx_closure_custom fun_vjp);
|
||||
int mlx_eval(const mlx_vector_array outputs);
|
||||
int mlx_jvp(
|
||||
mlx_vector_array* res_0,
|
||||
mlx_vector_array* res_1,
|
||||
const mlx_closure fun,
|
||||
const mlx_vector_array primals,
|
||||
const mlx_vector_array tangents);
|
||||
int mlx_value_and_grad(
|
||||
mlx_closure_value_and_grad* res,
|
||||
const mlx_closure fun,
|
||||
const int* argnums,
|
||||
size_t argnums_num);
|
||||
int mlx_vjp(
|
||||
mlx_vector_array* res_0,
|
||||
mlx_vector_array* res_1,
|
||||
const mlx_closure fun,
|
||||
const mlx_vector_array primals,
|
||||
const mlx_vector_array cotangents);
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
54
x/mlxrunner/mlx/include/mlx/c/transforms_impl.h
Normal file
54
x/mlxrunner/mlx/include/mlx/c/transforms_impl.h
Normal file
@@ -0,0 +1,54 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
/* */
|
||||
/* This file is auto-generated. Do not edit manually. */
|
||||
/* */
|
||||
|
||||
#ifndef MLX_TRANSFORMS_IMPL_H
|
||||
#define MLX_TRANSFORMS_IMPL_H
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "mlx/c/array.h"
|
||||
#include "mlx/c/closure.h"
|
||||
#include "mlx/c/distributed_group.h"
|
||||
#include "mlx/c/io_types.h"
|
||||
#include "mlx/c/map.h"
|
||||
#include "mlx/c/stream.h"
|
||||
#include "mlx/c/string.h"
|
||||
#include "mlx/c/vector.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup transforms_impl Implementation detail operations
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
int mlx_detail_vmap_replace(
|
||||
mlx_vector_array* res,
|
||||
const mlx_vector_array inputs,
|
||||
const mlx_vector_array s_inputs,
|
||||
const mlx_vector_array s_outputs,
|
||||
const int* in_axes,
|
||||
size_t in_axes_num,
|
||||
const int* out_axes,
|
||||
size_t out_axes_num);
|
||||
int mlx_detail_vmap_trace(
|
||||
mlx_vector_array* res_0,
|
||||
mlx_vector_array* res_1,
|
||||
const mlx_closure fun,
|
||||
const mlx_vector_array inputs,
|
||||
const int* in_axes,
|
||||
size_t in_axes_num);
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
133
x/mlxrunner/mlx/include/mlx/c/vector.h
Normal file
133
x/mlxrunner/mlx/include/mlx/c/vector.h
Normal file
@@ -0,0 +1,133 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
/* */
|
||||
/* This file is auto-generated. Do not edit manually. */
|
||||
/* */
|
||||
|
||||
#ifndef MLX_VECTOR_H
|
||||
#define MLX_VECTOR_H
|
||||
|
||||
#include "mlx/c/array.h"
|
||||
#include "mlx/c/string.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup mlx_vector Vectors
|
||||
* MLX vector objects.
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
/**
|
||||
* A vector of array.
|
||||
*/
|
||||
typedef struct mlx_vector_array_ {
|
||||
void* ctx;
|
||||
} mlx_vector_array;
|
||||
mlx_vector_array mlx_vector_array_new(void);
|
||||
int mlx_vector_array_set(mlx_vector_array* vec, const mlx_vector_array src);
|
||||
int mlx_vector_array_free(mlx_vector_array vec);
|
||||
mlx_vector_array mlx_vector_array_new_data(const mlx_array* data, size_t size);
|
||||
mlx_vector_array mlx_vector_array_new_value(const mlx_array val);
|
||||
int mlx_vector_array_set_data(
|
||||
mlx_vector_array* vec,
|
||||
const mlx_array* data,
|
||||
size_t size);
|
||||
int mlx_vector_array_set_value(mlx_vector_array* vec, const mlx_array val);
|
||||
int mlx_vector_array_append_data(
|
||||
mlx_vector_array vec,
|
||||
const mlx_array* data,
|
||||
size_t size);
|
||||
int mlx_vector_array_append_value(mlx_vector_array vec, const mlx_array val);
|
||||
size_t mlx_vector_array_size(mlx_vector_array vec);
|
||||
int mlx_vector_array_get(
|
||||
mlx_array* res,
|
||||
const mlx_vector_array vec,
|
||||
size_t idx);
|
||||
|
||||
/**
|
||||
* A vector of vector_array.
|
||||
*/
|
||||
typedef struct mlx_vector_vector_array_ {
|
||||
void* ctx;
|
||||
} mlx_vector_vector_array;
|
||||
mlx_vector_vector_array mlx_vector_vector_array_new(void);
|
||||
int mlx_vector_vector_array_set(
|
||||
mlx_vector_vector_array* vec,
|
||||
const mlx_vector_vector_array src);
|
||||
int mlx_vector_vector_array_free(mlx_vector_vector_array vec);
|
||||
mlx_vector_vector_array mlx_vector_vector_array_new_data(
|
||||
const mlx_vector_array* data,
|
||||
size_t size);
|
||||
mlx_vector_vector_array mlx_vector_vector_array_new_value(
|
||||
const mlx_vector_array val);
|
||||
int mlx_vector_vector_array_set_data(
|
||||
mlx_vector_vector_array* vec,
|
||||
const mlx_vector_array* data,
|
||||
size_t size);
|
||||
int mlx_vector_vector_array_set_value(
|
||||
mlx_vector_vector_array* vec,
|
||||
const mlx_vector_array val);
|
||||
int mlx_vector_vector_array_append_data(
|
||||
mlx_vector_vector_array vec,
|
||||
const mlx_vector_array* data,
|
||||
size_t size);
|
||||
int mlx_vector_vector_array_append_value(
|
||||
mlx_vector_vector_array vec,
|
||||
const mlx_vector_array val);
|
||||
size_t mlx_vector_vector_array_size(mlx_vector_vector_array vec);
|
||||
int mlx_vector_vector_array_get(
|
||||
mlx_vector_array* res,
|
||||
const mlx_vector_vector_array vec,
|
||||
size_t idx);
|
||||
|
||||
/**
|
||||
* A vector of int.
|
||||
*/
|
||||
typedef struct mlx_vector_int_ {
|
||||
void* ctx;
|
||||
} mlx_vector_int;
|
||||
mlx_vector_int mlx_vector_int_new(void);
|
||||
int mlx_vector_int_set(mlx_vector_int* vec, const mlx_vector_int src);
|
||||
int mlx_vector_int_free(mlx_vector_int vec);
|
||||
mlx_vector_int mlx_vector_int_new_data(int* data, size_t size);
|
||||
mlx_vector_int mlx_vector_int_new_value(int val);
|
||||
int mlx_vector_int_set_data(mlx_vector_int* vec, int* data, size_t size);
|
||||
int mlx_vector_int_set_value(mlx_vector_int* vec, int val);
|
||||
int mlx_vector_int_append_data(mlx_vector_int vec, int* data, size_t size);
|
||||
int mlx_vector_int_append_value(mlx_vector_int vec, int val);
|
||||
size_t mlx_vector_int_size(mlx_vector_int vec);
|
||||
int mlx_vector_int_get(int* res, const mlx_vector_int vec, size_t idx);
|
||||
|
||||
/**
|
||||
* A vector of string.
|
||||
*/
|
||||
typedef struct mlx_vector_string_ {
|
||||
void* ctx;
|
||||
} mlx_vector_string;
|
||||
mlx_vector_string mlx_vector_string_new(void);
|
||||
int mlx_vector_string_set(mlx_vector_string* vec, const mlx_vector_string src);
|
||||
int mlx_vector_string_free(mlx_vector_string vec);
|
||||
mlx_vector_string mlx_vector_string_new_data(const char** data, size_t size);
|
||||
mlx_vector_string mlx_vector_string_new_value(const char* val);
|
||||
int mlx_vector_string_set_data(
|
||||
mlx_vector_string* vec,
|
||||
const char** data,
|
||||
size_t size);
|
||||
int mlx_vector_string_set_value(mlx_vector_string* vec, const char* val);
|
||||
int mlx_vector_string_append_data(
|
||||
mlx_vector_string vec,
|
||||
const char** data,
|
||||
size_t size);
|
||||
int mlx_vector_string_append_value(mlx_vector_string vec, const char* val);
|
||||
size_t mlx_vector_string_size(mlx_vector_string vec);
|
||||
int mlx_vector_string_get(char** res, const mlx_vector_string vec, size_t idx);
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
18
x/mlxrunner/mlx/include/mlx/c/version.h
Normal file
18
x/mlxrunner/mlx/include/mlx/c/version.h
Normal file
@@ -0,0 +1,18 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
|
||||
#ifndef MLX_VERSION_H
|
||||
#define MLX_VERSION_H
|
||||
|
||||
#include "mlx/c/string.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
int mlx_version(mlx_string* str_);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
164
x/mlxrunner/mlx/io.go
Normal file
164
x/mlxrunner/mlx/io.go
Normal file
@@ -0,0 +1,164 @@
|
||||
package mlx
|
||||
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"iter"
|
||||
"runtime"
|
||||
"sort"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// SafetensorsFile represents a loaded safetensors file.
|
||||
type SafetensorsFile struct {
|
||||
arrays C.mlx_map_string_to_array
|
||||
metadata C.mlx_map_string_to_string
|
||||
}
|
||||
|
||||
func loadSafetensorsStream() C.mlx_stream {
|
||||
if runtime.GOOS == "darwin" {
|
||||
return C.mlx_default_cpu_stream_new()
|
||||
}
|
||||
return C.mlx_default_gpu_stream_new()
|
||||
}
|
||||
|
||||
// LoadSafetensorsNative loads a safetensors file using MLX's native loader.
|
||||
func LoadSafetensorsNative(path string) (*SafetensorsFile, error) {
|
||||
var arrays C.mlx_map_string_to_array
|
||||
var metadata C.mlx_map_string_to_string
|
||||
|
||||
cPath := C.CString(path)
|
||||
defer C.free(unsafe.Pointer(cPath))
|
||||
|
||||
stream := loadSafetensorsStream()
|
||||
defer C.mlx_stream_free(stream)
|
||||
|
||||
if C.mlx_load_safetensors(&arrays, &metadata, cPath, stream) != 0 {
|
||||
return nil, fmt.Errorf("failed to load safetensors: %s", path)
|
||||
}
|
||||
|
||||
return &SafetensorsFile{arrays: arrays, metadata: metadata}, nil
|
||||
}
|
||||
|
||||
// Get retrieves a tensor by name.
|
||||
func (s *SafetensorsFile) Get(name string) *Array {
|
||||
cName := C.CString(name)
|
||||
defer C.free(unsafe.Pointer(cName))
|
||||
|
||||
value := C.mlx_array_new()
|
||||
if C.mlx_map_string_to_array_get(&value, s.arrays, cName) != 0 {
|
||||
return nil
|
||||
}
|
||||
if value.ctx == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
arr := New(name)
|
||||
arr.ctx = value
|
||||
return arr
|
||||
}
|
||||
|
||||
// GetMetadata retrieves a metadata value by key.
|
||||
func (s *SafetensorsFile) GetMetadata(key string) string {
|
||||
cKey := C.CString(key)
|
||||
defer C.free(unsafe.Pointer(cKey))
|
||||
|
||||
var cValue *C.char
|
||||
if C.mlx_map_string_to_string_get(&cValue, s.metadata, cKey) != 0 {
|
||||
return ""
|
||||
}
|
||||
return C.GoString(cValue)
|
||||
}
|
||||
|
||||
// Free releases the loaded safetensors maps.
|
||||
func (s *SafetensorsFile) Free() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
C.mlx_map_string_to_array_free(s.arrays)
|
||||
C.mlx_map_string_to_string_free(s.metadata)
|
||||
}
|
||||
|
||||
func Load(path string) iter.Seq2[string, *Array] {
|
||||
return func(yield func(string, *Array) bool) {
|
||||
sf, err := LoadSafetensorsNative(path)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer sf.Free()
|
||||
|
||||
it := C.mlx_map_string_to_array_iterator_new(sf.arrays)
|
||||
defer C.mlx_map_string_to_array_iterator_free(it)
|
||||
|
||||
for {
|
||||
var key *C.char
|
||||
value := C.mlx_array_new()
|
||||
if C.mlx_map_string_to_array_iterator_next(&key, &value, it) != 0 {
|
||||
break
|
||||
}
|
||||
|
||||
name := C.GoString(key)
|
||||
arr := New(name)
|
||||
arr.ctx = value
|
||||
if !yield(name, arr) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SaveSafetensors saves arrays to a safetensors file without metadata.
|
||||
func SaveSafetensors(path string, arrays map[string]*Array) error {
|
||||
return SaveSafetensorsWithMetadata(path, arrays, nil)
|
||||
}
|
||||
|
||||
// SaveSafetensorsWithMetadata saves arrays to a safetensors file with metadata.
|
||||
func SaveSafetensorsWithMetadata(path string, arrays map[string]*Array, metadata map[string]string) error {
|
||||
cPath := C.CString(path)
|
||||
defer C.free(unsafe.Pointer(cPath))
|
||||
|
||||
cArrays := C.mlx_map_string_to_array_new()
|
||||
defer C.mlx_map_string_to_array_free(cArrays)
|
||||
|
||||
arrayNames := make([]string, 0, len(arrays))
|
||||
for name, arr := range arrays {
|
||||
if arr == nil {
|
||||
continue
|
||||
}
|
||||
arrayNames = append(arrayNames, name)
|
||||
}
|
||||
sort.Strings(arrayNames)
|
||||
|
||||
for _, name := range arrayNames {
|
||||
arr := arrays[name]
|
||||
cName := C.CString(name)
|
||||
C.mlx_map_string_to_array_insert(cArrays, cName, arr.ctx)
|
||||
C.free(unsafe.Pointer(cName))
|
||||
}
|
||||
|
||||
cMetadata := C.mlx_map_string_to_string_new()
|
||||
defer C.mlx_map_string_to_string_free(cMetadata)
|
||||
|
||||
metadataKeys := make([]string, 0, len(metadata))
|
||||
for key := range metadata {
|
||||
metadataKeys = append(metadataKeys, key)
|
||||
}
|
||||
sort.Strings(metadataKeys)
|
||||
|
||||
for _, key := range metadataKeys {
|
||||
value := metadata[key]
|
||||
cKey := C.CString(key)
|
||||
cValue := C.CString(value)
|
||||
C.mlx_map_string_to_string_insert(cMetadata, cKey, cValue)
|
||||
C.free(unsafe.Pointer(cKey))
|
||||
C.free(unsafe.Pointer(cValue))
|
||||
}
|
||||
|
||||
if C.mlx_save_safetensors(cPath, cArrays, cMetadata) != 0 {
|
||||
return fmt.Errorf("failed to save safetensors: %s", path)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
89
x/mlxrunner/mlx/memory.go
Normal file
89
x/mlxrunner/mlx/memory.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package mlx
|
||||
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
func (b Byte) String() string {
|
||||
return strconv.FormatInt(int64(b), 10) + " B"
|
||||
}
|
||||
|
||||
func (b KibiByte) String() string {
|
||||
return strconv.FormatFloat(float64(b)/(1<<10), 'f', 2, 64) + " KiB"
|
||||
}
|
||||
|
||||
func (b MebiByte) String() string {
|
||||
return strconv.FormatFloat(float64(b)/(1<<(2*10)), 'f', 2, 64) + " MiB"
|
||||
}
|
||||
|
||||
func (b GibiByte) String() string {
|
||||
return strconv.FormatFloat(float64(b)/(1<<(3*10)), 'f', 2, 64) + " GiB"
|
||||
}
|
||||
|
||||
func (b TebiByte) String() string {
|
||||
return strconv.FormatFloat(float64(b)/(1<<(4*10)), 'f', 2, 64) + " TiB"
|
||||
}
|
||||
|
||||
func PrettyBytes(n int) fmt.Stringer {
|
||||
switch {
|
||||
case n < 1<<10:
|
||||
return Byte(n)
|
||||
case n < 1<<(2*10):
|
||||
return KibiByte(n)
|
||||
case n < 1<<(3*10):
|
||||
return MebiByte(n)
|
||||
case n < 1<<(4*10):
|
||||
return GibiByte(n)
|
||||
default:
|
||||
return TebiByte(n)
|
||||
}
|
||||
}
|
||||
|
||||
func ActiveMemory() int {
|
||||
var active C.size_t
|
||||
C.mlx_get_active_memory(&active)
|
||||
return int(active)
|
||||
}
|
||||
|
||||
func CacheMemory() int {
|
||||
var cache C.size_t
|
||||
C.mlx_get_cache_memory(&cache)
|
||||
return int(cache)
|
||||
}
|
||||
|
||||
func PeakMemory() int {
|
||||
var peak C.size_t
|
||||
C.mlx_get_peak_memory(&peak)
|
||||
return int(peak)
|
||||
}
|
||||
|
||||
func ResetPeakMemory() {
|
||||
C.mlx_reset_peak_memory()
|
||||
}
|
||||
|
||||
type Memory struct{}
|
||||
|
||||
func (Memory) LogValue() slog.Value {
|
||||
return slog.GroupValue(
|
||||
slog.Any("active", PrettyBytes(ActiveMemory())),
|
||||
slog.Any("cache", PrettyBytes(CacheMemory())),
|
||||
slog.Any("peak", PrettyBytes(PeakMemory())),
|
||||
)
|
||||
}
|
||||
|
||||
type (
|
||||
Byte int
|
||||
KibiByte int
|
||||
MebiByte int
|
||||
GibiByte int
|
||||
TebiByte int
|
||||
)
|
||||
|
||||
func ClearCache() {
|
||||
C.mlx_clear_cache()
|
||||
}
|
||||
107
x/mlxrunner/mlx/mlx.go
Normal file
107
x/mlxrunner/mlx/mlx.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package mlx
|
||||
|
||||
//go:generate go run generator/main.go -output=. ./include/mlx/c/*.h
|
||||
|
||||
// #cgo CXXFLAGS: -std=c++17
|
||||
// #cgo CPPFLAGS: -I${SRCDIR}/include
|
||||
// #cgo LDFLAGS: -lstdc++
|
||||
// #cgo darwin LDFLAGS: -framework Foundation -framework Metal -framework Accelerate
|
||||
// #include "generated.h"
|
||||
// #include <string.h>
|
||||
//
|
||||
// static __thread char _mlx_last_error_msg[1024] = {0};
|
||||
// static __thread int _mlx_last_error_flag = 0;
|
||||
//
|
||||
// static void _mlx_capture_error_handler(const char* msg, void* data) {
|
||||
// (void)data;
|
||||
// strncpy(_mlx_last_error_msg, msg, sizeof(_mlx_last_error_msg) - 1);
|
||||
// _mlx_last_error_msg[sizeof(_mlx_last_error_msg) - 1] = '\0';
|
||||
// _mlx_last_error_flag = 1;
|
||||
// }
|
||||
//
|
||||
// static void mlx_install_capture_handler(void) {
|
||||
// if (mlx_set_error_handler_) {
|
||||
// mlx_set_error_handler_(_mlx_capture_error_handler, NULL, NULL);
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// static void mlx_clear_last_error(void) {
|
||||
// _mlx_last_error_flag = 0;
|
||||
// _mlx_last_error_msg[0] = '\0';
|
||||
// }
|
||||
//
|
||||
// static const char* mlx_get_last_error(void) {
|
||||
// return _mlx_last_error_flag ? _mlx_last_error_msg : "";
|
||||
// }
|
||||
import "C"
|
||||
|
||||
import "runtime"
|
||||
|
||||
func init() {
|
||||
// Replace the default exit(-1) error handler with one that captures
|
||||
// the error message so we can surface it in Go.
|
||||
C.mlx_install_capture_handler()
|
||||
}
|
||||
|
||||
// Version returns the MLX core library version string.
|
||||
func Version() string {
|
||||
str := C.mlx_string_new()
|
||||
defer C.mlx_string_free(str)
|
||||
C.mlx_version(&str)
|
||||
return C.GoString(C.mlx_string_data(str))
|
||||
}
|
||||
|
||||
// mlxCheck locks the goroutine to its OS thread, clears the captured error
|
||||
// state, calls fn, and panics with the captured message if fn returns non-zero.
|
||||
// The thread lock ensures the thread-local error state is read from the same
|
||||
// thread that executed the call.
|
||||
func mlxCheck(fallback string, fn func() C.int) {
|
||||
runtime.LockOSThread()
|
||||
defer runtime.UnlockOSThread()
|
||||
|
||||
C.mlx_clear_last_error()
|
||||
if fn() != 0 {
|
||||
msg := C.GoString(C.mlx_get_last_error())
|
||||
if msg == "" {
|
||||
msg = fallback
|
||||
}
|
||||
panic("mlx: " + msg)
|
||||
}
|
||||
}
|
||||
|
||||
func doEval(outputs []*Array, async bool) {
|
||||
if len(outputs) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
vector := C.mlx_vector_array_new()
|
||||
defer C.mlx_vector_array_free(vector)
|
||||
|
||||
for _, output := range outputs {
|
||||
if output != nil && output.Valid() {
|
||||
C.mlx_vector_array_append_value(vector, output.ctx)
|
||||
}
|
||||
}
|
||||
|
||||
mlxCheck("eval failed", func() C.int {
|
||||
if async {
|
||||
return C.mlx_async_eval(vector)
|
||||
}
|
||||
return C.mlx_eval(vector)
|
||||
})
|
||||
}
|
||||
|
||||
func AsyncEval(outputs ...*Array) {
|
||||
doEval(outputs, true)
|
||||
}
|
||||
|
||||
func Eval(outputs ...*Array) {
|
||||
doEval(outputs, false)
|
||||
}
|
||||
|
||||
// MetalIsAvailable returns true if a Metal GPU is available.
|
||||
func MetalIsAvailable() bool {
|
||||
var available C._Bool
|
||||
C.mlx_metal_is_available(&available)
|
||||
return bool(available)
|
||||
}
|
||||
36
x/mlxrunner/mlx/nn.go
Normal file
36
x/mlxrunner/mlx/nn.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package mlx
|
||||
|
||||
type Linear struct {
|
||||
Weight *Array `weight:"weight"`
|
||||
Bias *Array `weight:"bias"`
|
||||
}
|
||||
|
||||
// Forward computes the linear transformation: x @ Weight.T + Bias
|
||||
func (m *Linear) Forward(x *Array) *Array {
|
||||
w := m.Weight.Transpose(1, 0)
|
||||
if m.Bias.Valid() {
|
||||
return m.Bias.Addmm(x, w, 1.0, 1.0)
|
||||
}
|
||||
|
||||
return x.Matmul(w)
|
||||
}
|
||||
|
||||
func (m *Linear) Gather(x, lhs, rhs *Array, sorted bool) *Array {
|
||||
w := m.Weight.Transpose(0, 2, 1)
|
||||
// TODO: bias
|
||||
return x.GatherMM(w, lhs, rhs, sorted)
|
||||
}
|
||||
|
||||
type Embedding struct {
|
||||
Weight *Array `weight:"weight"`
|
||||
}
|
||||
|
||||
func (e *Embedding) Forward(indices *Array) *Array {
|
||||
return e.Weight.TakeAxis(indices, 0)
|
||||
}
|
||||
|
||||
func (e *Embedding) AsLinear() Linear {
|
||||
return Linear{
|
||||
Weight: e.Weight,
|
||||
}
|
||||
}
|
||||
300
x/mlxrunner/mlx/ops.go
Normal file
300
x/mlxrunner/mlx/ops.go
Normal file
@@ -0,0 +1,300 @@
|
||||
package mlx
|
||||
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
func (t *Array) Abs() *Array {
|
||||
out := New("ABS")
|
||||
C.mlx_abs(&out.ctx, t.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Add(other *Array) *Array {
|
||||
out := New("ADD")
|
||||
C.mlx_add(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Addmm(a, b *Array, alpha, beta float32) *Array {
|
||||
out := New("ADDMM")
|
||||
C.mlx_addmm(&out.ctx, t.ctx, a.ctx, b.ctx, C.float(alpha), C.float(beta), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Argmax(axis int, keepDims bool) *Array {
|
||||
out := New("ARGMAX")
|
||||
C.mlx_argmax_axis(&out.ctx, t.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) ArgpartitionAxis(kth int, axis int) *Array {
|
||||
out := New("ARGPARTITION")
|
||||
C.mlx_argpartition_axis(&out.ctx, t.ctx, C.int(kth), C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) ArgsortAxis(axis int) *Array {
|
||||
out := New("ARGSORT_AXIS")
|
||||
C.mlx_argsort_axis(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) AsType(dtype DType) *Array {
|
||||
out := New("AS_TYPE")
|
||||
C.mlx_astype(&out.ctx, t.ctx, C.mlx_dtype(dtype), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) AsStrided(shape []int, strides []int, offset int) *Array {
|
||||
cShape := make([]C.int, len(shape))
|
||||
for i, s := range shape {
|
||||
cShape[i] = C.int(s)
|
||||
}
|
||||
|
||||
cStrides := make([]C.int64_t, len(strides))
|
||||
for i, s := range strides {
|
||||
cStrides[i] = C.int64_t(s)
|
||||
}
|
||||
|
||||
out := New("AS_STRIDED")
|
||||
C.mlx_as_strided(
|
||||
&out.ctx, t.ctx,
|
||||
unsafe.SliceData(cShape), C.size_t(len(shape)),
|
||||
unsafe.SliceData(cStrides), C.size_t(len(strides)),
|
||||
C.size_t(offset),
|
||||
DefaultStream().ctx,
|
||||
)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Concatenate(axis int, others ...*Array) *Array {
|
||||
if len(others) == 0 {
|
||||
return t.Clone()
|
||||
}
|
||||
|
||||
vector := C.mlx_vector_array_new()
|
||||
defer C.mlx_vector_array_free(vector)
|
||||
|
||||
s := append([]*Array{t}, others...)
|
||||
for _, other := range s {
|
||||
C.mlx_vector_array_append_value(vector, other.ctx)
|
||||
}
|
||||
|
||||
out := New("CONCATENATE")
|
||||
C.mlx_concatenate_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Cumsum(axis int, reverse, inclusive bool) *Array {
|
||||
out := New("CUMSUM")
|
||||
C.mlx_cumsum(&out.ctx, t.ctx, C.int(axis), C.bool(reverse), C.bool(inclusive), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Divide(other *Array) *Array {
|
||||
out := New("DIVIDE")
|
||||
C.mlx_divide(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) ExpandDims(axis int) *Array {
|
||||
out := New("EXPAND_DIMS")
|
||||
C.mlx_expand_dims(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Flatten(startAxis, endAxis int) *Array {
|
||||
out := New("FLATTEN")
|
||||
C.mlx_flatten(&out.ctx, t.ctx, C.int(startAxis), C.int(endAxis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) FloorDivide(other *Array) *Array {
|
||||
out := New("FLOOR_DIVIDE")
|
||||
C.mlx_floor_divide(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) GatherMM(other, lhs, rhs *Array, sorted bool) *Array {
|
||||
if lhs == nil {
|
||||
lhs = New("")
|
||||
}
|
||||
if rhs == nil {
|
||||
rhs = New("")
|
||||
}
|
||||
out := New("GATHER_MM")
|
||||
C.mlx_gather_mm(&out.ctx, t.ctx, other.ctx, lhs.ctx, rhs.ctx, C.bool(sorted), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) LogsumexpAxis(axis int, keepDims bool) *Array {
|
||||
out := New("LOGSUMEXP_AXIS")
|
||||
C.mlx_logsumexp_axis(&out.ctx, t.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Equal(other *Array) *Array {
|
||||
out := New("EQUAL")
|
||||
C.mlx_equal(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Greater(other *Array) *Array {
|
||||
out := New("GREATER")
|
||||
C.mlx_greater(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Less(other *Array) *Array {
|
||||
out := New("LESS")
|
||||
C.mlx_less(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) LessEqual(other *Array) *Array {
|
||||
out := New("LESS_EQUAL")
|
||||
C.mlx_less_equal(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) MaxAxis(axis int, keepDims bool) *Array {
|
||||
out := New("MAX_AXIS")
|
||||
C.mlx_max_axis(&out.ctx, t.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Matmul(other *Array) *Array {
|
||||
out := New("MATMUL")
|
||||
C.mlx_matmul(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Multiply(other *Array) *Array {
|
||||
out := New("MULTIPLY")
|
||||
C.mlx_multiply(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Negative() *Array {
|
||||
out := New("NEGATIVE")
|
||||
C.mlx_negative(&out.ctx, t.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Power(exponent *Array) *Array {
|
||||
out := New("POWER")
|
||||
C.mlx_power(&out.ctx, t.ctx, exponent.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) PutAlongAxis(indices, values *Array, axis int) *Array {
|
||||
out := New("PUT_ALONG_AXIS")
|
||||
C.mlx_put_along_axis(&out.ctx, t.ctx, indices.ctx, values.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) ScatterAddAxis(indices, values *Array, axis int) *Array {
|
||||
out := New("SCATTER_ADD_AXIS")
|
||||
C.mlx_scatter_add_axis(&out.ctx, t.ctx, indices.ctx, values.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Reshape(axes ...int) *Array {
|
||||
cAxes := make([]C.int, len(axes))
|
||||
for i := range axes {
|
||||
cAxes[i] = C.int(axes[i])
|
||||
}
|
||||
|
||||
out := New("RESHAPE")
|
||||
C.mlx_reshape(&out.ctx, t.ctx, unsafe.SliceData(cAxes), C.size_t(len(cAxes)), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Sigmoid() *Array {
|
||||
out := New("SIGMOID")
|
||||
C.mlx_sigmoid(&out.ctx, t.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Sqrt() *Array {
|
||||
out := New("SQRT")
|
||||
C.mlx_sqrt(&out.ctx, t.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Squeeze(axis int) *Array {
|
||||
out := New("SQUEEZE")
|
||||
C.mlx_squeeze_axis(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) StackAxis(axis int, others ...*Array) *Array {
|
||||
vectorData := make([]C.mlx_array, len(others)+1)
|
||||
vectorData[0] = t.ctx
|
||||
for i := range others {
|
||||
vectorData[i+1] = others[i].ctx
|
||||
}
|
||||
|
||||
vector := C.mlx_vector_array_new_data(unsafe.SliceData(vectorData), C.size_t(len(vectorData)))
|
||||
defer C.mlx_vector_array_free(vector)
|
||||
|
||||
out := New("STACK_AXIS")
|
||||
C.mlx_stack_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Subtract(other *Array) *Array {
|
||||
out := New("SUBTRACT")
|
||||
C.mlx_subtract(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) SumAxis(axis int, keepDims bool) *Array {
|
||||
out := New("SUM_AXIS")
|
||||
C.mlx_sum_axis(&out.ctx, t.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) TakeAxis(indices *Array, axis int) *Array {
|
||||
out := New("TAKE_AXIS")
|
||||
C.mlx_take_axis(&out.ctx, t.ctx, indices.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) TakeAlongAxis(indices *Array, axis int) *Array {
|
||||
out := New("TAKE_ALONG_AXIS")
|
||||
C.mlx_take_along_axis(&out.ctx, t.ctx, indices.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Tanh() *Array {
|
||||
out := New("TANH")
|
||||
C.mlx_tanh(&out.ctx, t.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Transpose(axes ...int) *Array {
|
||||
cAxes := make([]C.int, len(axes))
|
||||
for i, axis := range axes {
|
||||
cAxes[i] = C.int(axis)
|
||||
}
|
||||
|
||||
out := New("TRANSPOSE")
|
||||
C.mlx_transpose_axes(&out.ctx, t.ctx, unsafe.SliceData(cAxes), C.size_t(len(cAxes)), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func Zeros(dtype DType, shape ...int) *Array {
|
||||
cAxes := make([]C.int, len(shape))
|
||||
for i := range shape {
|
||||
cAxes[i] = C.int(shape[i])
|
||||
}
|
||||
|
||||
t := New("ZEROS")
|
||||
C.mlx_zeros(&t.ctx, unsafe.SliceData(cAxes), C.size_t(len(cAxes)), C.mlx_dtype(dtype), DefaultStream().ctx)
|
||||
return t
|
||||
}
|
||||
666
x/mlxrunner/mlx/ops_extra.go
Normal file
666
x/mlxrunner/mlx/ops_extra.go
Normal file
@@ -0,0 +1,666 @@
|
||||
package mlx
|
||||
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// Quantization operations
|
||||
|
||||
func Quantize(w *Array, groupSize, bits int, mode string) (weights, scales, biases *Array) {
|
||||
cMode := C.CString(mode)
|
||||
defer C.free(unsafe.Pointer(cMode))
|
||||
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
|
||||
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
|
||||
res := C.mlx_vector_array_new()
|
||||
defer C.mlx_vector_array_free(res)
|
||||
var globalScale C.mlx_array
|
||||
C.mlx_quantize(&res, w.ctx, optGroupSize, optBits, cMode, globalScale, DefaultStream().ctx)
|
||||
|
||||
vecSize := int(C.mlx_vector_array_size(res))
|
||||
w0 := New("QUANTIZE_W")
|
||||
C.mlx_vector_array_get(&w0.ctx, res, 0)
|
||||
w1 := New("QUANTIZE_S")
|
||||
C.mlx_vector_array_get(&w1.ctx, res, 1)
|
||||
if vecSize >= 3 {
|
||||
w2 := New("QUANTIZE_B")
|
||||
C.mlx_vector_array_get(&w2.ctx, res, 2)
|
||||
return w0, w1, w2
|
||||
}
|
||||
return w0, w1, nil
|
||||
}
|
||||
|
||||
func FromFP8(x *Array, dtype DType) *Array {
|
||||
out := New("FROM_FP8")
|
||||
C.mlx_from_fp8(&out.ctx, x.ctx, C.mlx_dtype(dtype), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func ToFP8(x *Array) *Array {
|
||||
out := New("TO_FP8")
|
||||
C.mlx_to_fp8(&out.ctx, x.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func Dequantize(w, scales, biases *Array, groupSize, bits int, mode string) *Array {
|
||||
cMode := C.CString(mode)
|
||||
defer C.free(unsafe.Pointer(cMode))
|
||||
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
|
||||
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
|
||||
optDtype := C.mlx_optional_dtype{has_value: false}
|
||||
|
||||
var b C.mlx_array
|
||||
if biases != nil {
|
||||
b = biases.ctx
|
||||
}
|
||||
|
||||
out := New("DEQUANTIZE")
|
||||
var globalScale C.mlx_array
|
||||
C.mlx_dequantize(&out.ctx, w.ctx, scales.ctx, b, optGroupSize, optBits, cMode, globalScale, optDtype, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func QuantizedMatmul(x, w, scales, biases *Array, transpose bool, groupSize, bits int, mode string) *Array {
|
||||
cMode := C.CString(mode)
|
||||
defer C.free(unsafe.Pointer(cMode))
|
||||
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
|
||||
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
|
||||
|
||||
var b C.mlx_array
|
||||
if biases != nil {
|
||||
b = biases.ctx
|
||||
}
|
||||
|
||||
out := New("QUANTIZED_MATMUL")
|
||||
C.mlx_quantized_matmul(&out.ctx, x.ctx, w.ctx, scales.ctx, b, C.bool(transpose), optGroupSize, optBits, cMode, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func GatherQMM(x, w, scales *Array, biases, lhsIndices, rhsIndices *Array, transpose bool, groupSize, bits int, mode string, sortedIndices bool) *Array {
|
||||
cMode := C.CString(mode)
|
||||
defer C.free(unsafe.Pointer(cMode))
|
||||
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
|
||||
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
|
||||
|
||||
var b, lhs, rhs C.mlx_array
|
||||
if biases != nil {
|
||||
b = biases.ctx
|
||||
}
|
||||
if lhsIndices != nil {
|
||||
lhs = lhsIndices.ctx
|
||||
}
|
||||
if rhsIndices != nil {
|
||||
rhs = rhsIndices.ctx
|
||||
}
|
||||
|
||||
out := New("GATHER_QMM")
|
||||
C.mlx_gather_qmm(&out.ctx, x.ctx, w.ctx, scales.ctx, b, lhs, rhs, C.bool(transpose), optGroupSize, optBits, cMode, C.bool(sortedIndices), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Missing tensor ops
|
||||
|
||||
func Tile(a *Array, reps []int32) *Array {
|
||||
cReps := make([]C.int, len(reps))
|
||||
for i, r := range reps {
|
||||
cReps[i] = C.int(r)
|
||||
}
|
||||
out := New("TILE")
|
||||
C.mlx_tile(&out.ctx, a.ctx, unsafe.SliceData(cReps), C.size_t(len(reps)), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func Tri(n, m int32, k int) *Array {
|
||||
out := New("TRI")
|
||||
C.mlx_tri(&out.ctx, C.int(n), C.int(m), C.int(k), C.mlx_dtype(DTypeFloat32), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func Where(condition, a, b *Array) *Array {
|
||||
out := New("WHERE")
|
||||
C.mlx_where(&out.ctx, condition.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func Conv1d(x, weight *Array, bias *Array, stride, padding, dilation, groups int32) *Array {
|
||||
out := New("CONV1D")
|
||||
C.mlx_conv1d(
|
||||
&out.ctx,
|
||||
x.ctx,
|
||||
weight.ctx,
|
||||
C.int(stride),
|
||||
C.int(padding),
|
||||
C.int(dilation),
|
||||
C.int(groups),
|
||||
DefaultStream().ctx,
|
||||
)
|
||||
if bias != nil && bias.Valid() {
|
||||
out = Add(out, bias)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func Contiguous(a *Array, allowColMajor bool) *Array {
|
||||
out := New("CONTIGUOUS")
|
||||
C.mlx_contiguous(&out.ctx, a.ctx, C.bool(allowColMajor), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Conv2d performs 2D convolution: x [N,H,W,C_in], weight [C_out,kH,kW,C_in].
|
||||
// MLX uses NHWC layout.
|
||||
func Conv2d(x, weight *Array, strideH, strideW, padH, padW, dilationH, dilationW, groups int32) *Array {
|
||||
out := New("CONV2D")
|
||||
C.mlx_conv2d(
|
||||
&out.ctx,
|
||||
x.ctx,
|
||||
weight.ctx,
|
||||
C.int(strideH), C.int(strideW),
|
||||
C.int(padH), C.int(padW),
|
||||
C.int(dilationH), C.int(dilationW),
|
||||
C.int(groups),
|
||||
DefaultStream().ctx,
|
||||
)
|
||||
return out
|
||||
}
|
||||
|
||||
// Pad pads array a along the given axes with specified low/high pad sizes.
|
||||
// mode should be "constant", "edge", or "reflect".
|
||||
func Pad(a *Array, axes []int, lowPad, highPad []int, padValue *Array, mode string) *Array {
|
||||
cAxes := make([]C.int, len(axes))
|
||||
cLow := make([]C.int, len(lowPad))
|
||||
cHigh := make([]C.int, len(highPad))
|
||||
for i := range axes {
|
||||
cAxes[i] = C.int(axes[i])
|
||||
cLow[i] = C.int(lowPad[i])
|
||||
cHigh[i] = C.int(highPad[i])
|
||||
}
|
||||
cMode := C.CString(mode)
|
||||
defer C.free(unsafe.Pointer(cMode))
|
||||
out := New("PAD")
|
||||
C.mlx_pad(
|
||||
&out.ctx,
|
||||
a.ctx,
|
||||
unsafe.SliceData(cAxes), C.size_t(len(cAxes)),
|
||||
unsafe.SliceData(cLow), C.size_t(len(cLow)),
|
||||
unsafe.SliceData(cHigh), C.size_t(len(cHigh)),
|
||||
padValue.ctx,
|
||||
cMode,
|
||||
DefaultStream().ctx,
|
||||
)
|
||||
return out
|
||||
}
|
||||
|
||||
// PadConstant pads with zeros along the given axes.
|
||||
func PadConstant(a *Array, axes []int, lowPad, highPad []int) *Array {
|
||||
zero := NewScalarArray(float32(0))
|
||||
return Pad(a, axes, lowPad, highPad, zero, "constant")
|
||||
}
|
||||
|
||||
func DepthwiseConv1d(x, weight *Array, bias *Array) *Array {
|
||||
groups := int32(x.Dim(x.NumDims() - 1))
|
||||
return Conv1d(x, weight, bias, 1, 0, 1, groups)
|
||||
}
|
||||
|
||||
// Maximum returns element-wise maximum of two arrays.
|
||||
func Maximum(a, b *Array) *Array {
|
||||
out := New("MAXIMUM")
|
||||
C.mlx_maximum(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Minimum returns element-wise minimum of two arrays.
|
||||
func Minimum(a, b *Array) *Array {
|
||||
out := New("MINIMUM")
|
||||
C.mlx_minimum(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Softplus computes log(1 + exp(x)) using logaddexp for numerical stability.
|
||||
func Softplus(a *Array) *Array {
|
||||
return Logaddexp(a, Zeros(a.DType(), a.Dims()...))
|
||||
}
|
||||
|
||||
// ReLU computes max(0, x).
|
||||
func ReLU(a *Array) *Array {
|
||||
return Maximum(a, NewScalarArray(float32(0)))
|
||||
}
|
||||
|
||||
// GLU applies Gated Linear Unit: splits x along last dim into two halves,
|
||||
// returns first * sigmoid(second).
|
||||
func GLU(a *Array) *Array {
|
||||
lastDim := a.NumDims() - 1
|
||||
halfSize := a.Dim(lastDim) / 2
|
||||
first := SliceStartStop(a,
|
||||
make([]int32, lastDim+1), // all zeros for start
|
||||
appendDims(a, lastDim, int32(halfSize)),
|
||||
)
|
||||
second := SliceStartStop(a,
|
||||
appendDimsStart(a, lastDim, int32(halfSize)),
|
||||
appendDims(a, lastDim, int32(a.Dim(lastDim))),
|
||||
)
|
||||
return first.Multiply(second.Sigmoid())
|
||||
}
|
||||
|
||||
// helper: builds stop array for SliceStartStop where the target axis = val
|
||||
func appendDims(a *Array, targetAxis int, val int32) []int32 {
|
||||
n := a.NumDims()
|
||||
out := make([]int32, n)
|
||||
for i := range n {
|
||||
if i == targetAxis {
|
||||
out[i] = val
|
||||
} else {
|
||||
out[i] = int32(a.Dim(i))
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// helper: builds start array for SliceStartStop where the target axis = val
|
||||
func appendDimsStart(a *Array, targetAxis int, val int32) []int32 {
|
||||
n := a.NumDims()
|
||||
out := make([]int32, n)
|
||||
for i := range n {
|
||||
if i == targetAxis {
|
||||
out[i] = val
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// Clamp clamps array values to [min, max].
|
||||
func Clamp(a *Array, minVal, maxVal float32) *Array {
|
||||
return Minimum(Maximum(a, NewScalarArray(minVal)), NewScalarArray(maxVal))
|
||||
}
|
||||
|
||||
// Convenience wrappers (function-style for the model code)
|
||||
|
||||
func Stack(arrays []*Array, axis int) *Array {
|
||||
vectorData := make([]C.mlx_array, len(arrays))
|
||||
for i := range arrays {
|
||||
vectorData[i] = arrays[i].ctx
|
||||
}
|
||||
vector := C.mlx_vector_array_new_data(unsafe.SliceData(vectorData), C.size_t(len(vectorData)))
|
||||
defer C.mlx_vector_array_free(vector)
|
||||
|
||||
out := New("STACK")
|
||||
C.mlx_stack_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func Neg(a *Array) *Array {
|
||||
return a.Negative()
|
||||
}
|
||||
|
||||
func Sum(a *Array, axis int, keepDims bool) *Array {
|
||||
return a.SumAxis(axis, keepDims)
|
||||
}
|
||||
|
||||
func Argsort(a *Array, axis int) *Array {
|
||||
return a.ArgsortAxis(axis)
|
||||
}
|
||||
|
||||
func Take(a *Array, indices *Array, axis int) *Array {
|
||||
return a.TakeAxis(indices, axis)
|
||||
}
|
||||
|
||||
func RSqrt(a *Array) *Array {
|
||||
out := New("RSQRT")
|
||||
C.mlx_rsqrt(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func Mean(a *Array, axis int, keepDims bool) *Array {
|
||||
out := New("MEAN_AXIS")
|
||||
C.mlx_mean_axis(&out.ctx, a.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func Argpartition(a *Array, kth int, axis int) *Array {
|
||||
return a.ArgpartitionAxis(kth, axis)
|
||||
}
|
||||
|
||||
func TakeAlongAxis(a, indices *Array, axis int) *Array {
|
||||
return a.TakeAlongAxis(indices, axis)
|
||||
}
|
||||
|
||||
// Function-style wrappers matching imagegen API
|
||||
|
||||
func Add(a, b *Array) *Array {
|
||||
return a.Add(b)
|
||||
}
|
||||
|
||||
func Sub(a, b *Array) *Array {
|
||||
return a.Subtract(b)
|
||||
}
|
||||
|
||||
func Mul(a, b *Array) *Array {
|
||||
return a.Multiply(b)
|
||||
}
|
||||
|
||||
func Div(a, b *Array) *Array {
|
||||
return a.Divide(b)
|
||||
}
|
||||
|
||||
func Matmul(a, b *Array) *Array {
|
||||
return a.Matmul(b)
|
||||
}
|
||||
|
||||
func Reshape(a *Array, shape ...int32) *Array {
|
||||
axes := make([]int, len(shape))
|
||||
for i, s := range shape {
|
||||
axes[i] = int(s)
|
||||
}
|
||||
return a.Reshape(axes...)
|
||||
}
|
||||
|
||||
func Transpose(a *Array, axes ...int) *Array {
|
||||
return a.Transpose(axes...)
|
||||
}
|
||||
|
||||
func ExpandDims(a *Array, axis int) *Array {
|
||||
return a.ExpandDims(axis)
|
||||
}
|
||||
|
||||
func Squeeze(a *Array, axis int) *Array {
|
||||
return a.Squeeze(axis)
|
||||
}
|
||||
|
||||
func Flatten(a *Array) *Array {
|
||||
return a.Flatten(0, -1)
|
||||
}
|
||||
|
||||
func Concatenate(arrays []*Array, axis int) *Array {
|
||||
if len(arrays) == 0 {
|
||||
return nil
|
||||
}
|
||||
if len(arrays) == 1 {
|
||||
return arrays[0].Clone()
|
||||
}
|
||||
return arrays[0].Concatenate(axis, arrays[1:]...)
|
||||
}
|
||||
|
||||
func SliceStartStop(a *Array, start, stop []int32) *Array {
|
||||
n := len(start)
|
||||
cStart := make([]C.int, n)
|
||||
cStop := make([]C.int, n)
|
||||
cStrides := make([]C.int, n)
|
||||
for i := 0; i < n; i++ {
|
||||
cStart[i] = C.int(start[i])
|
||||
cStop[i] = C.int(stop[i])
|
||||
cStrides[i] = 1
|
||||
}
|
||||
out := New("SLICE")
|
||||
C.mlx_slice(&out.ctx, a.ctx, unsafe.SliceData(cStart), C.size_t(n), unsafe.SliceData(cStop), C.size_t(n), unsafe.SliceData(cStrides), C.size_t(n), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func GatherMM(a, b *Array, lhsIndices, rhsIndices *Array, sortedIndices bool) *Array {
|
||||
if lhsIndices == nil {
|
||||
lhsIndices = New("")
|
||||
}
|
||||
if rhsIndices == nil {
|
||||
rhsIndices = New("")
|
||||
}
|
||||
return a.GatherMM(b, lhsIndices, rhsIndices, sortedIndices)
|
||||
}
|
||||
|
||||
// RoPEWithBase applies rotary position embeddings to x. offsets is an
|
||||
// int32 array of shape [B] giving each batch row's starting position;
|
||||
// the kernel applies positions offsets[b] + 0..T-1 per row.
|
||||
func RoPEWithBase(x *Array, dims int, traditional bool, base, scale float32, offsets *Array) *Array {
|
||||
return RoPEWithFreqs(x, dims, traditional, base, scale, offsets, nil)
|
||||
}
|
||||
|
||||
// RoPEWithFreqs applies RoPE with optional custom frequencies.
|
||||
// When freqs is non-nil, it is used instead of computing from base.
|
||||
// Note: MLX takes reciprocal(freqs) internally to get inv_freq, so pass
|
||||
// the actual frequencies (base^(2i/dim)), not the inverse frequencies.
|
||||
func RoPEWithFreqs(x *Array, dims int, traditional bool, base, scale float32, offsets *Array, freqs *Array) *Array {
|
||||
var freqsCtx C.mlx_array
|
||||
var optBase C.mlx_optional_float
|
||||
if freqs != nil {
|
||||
freqsCtx = freqs.ctx
|
||||
optBase = C.mlx_optional_float{has_value: C.bool(false)}
|
||||
} else {
|
||||
empty := New("")
|
||||
freqsCtx = empty.ctx
|
||||
optBase = C.mlx_optional_float{
|
||||
value: C.float(base),
|
||||
has_value: C.bool(func() bool { return base != 0 }()),
|
||||
}
|
||||
}
|
||||
out := New("FAST_ROPE")
|
||||
C.mlx_fast_rope_dynamic(
|
||||
&out.ctx,
|
||||
x.ctx,
|
||||
C.int(dims),
|
||||
C.bool(traditional),
|
||||
optBase,
|
||||
C.float(scale),
|
||||
offsets.ctx,
|
||||
freqsCtx,
|
||||
DefaultStream().ctx,
|
||||
)
|
||||
return out
|
||||
}
|
||||
|
||||
func Sigmoid(a *Array) *Array {
|
||||
return a.Sigmoid()
|
||||
}
|
||||
|
||||
func Exp(a *Array) *Array {
|
||||
out := New("EXP")
|
||||
C.mlx_exp(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func Log(a *Array) *Array {
|
||||
out := New("LOG")
|
||||
C.mlx_log(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func Sin(a *Array) *Array {
|
||||
out := New("SIN")
|
||||
C.mlx_sin(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func Cos(a *Array) *Array {
|
||||
out := New("COS")
|
||||
C.mlx_cos(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func Clip(a, aMin, aMax *Array) *Array {
|
||||
out := New("CLIP")
|
||||
C.mlx_clip(&out.ctx, a.ctx, aMin.ctx, aMax.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func Logaddexp(a, b *Array) *Array {
|
||||
out := New("LOGADDEXP")
|
||||
C.mlx_logaddexp(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func SoftmaxAxis(a *Array, axis int, precise bool) *Array {
|
||||
out := New("SOFTMAX_AXIS")
|
||||
C.mlx_softmax_axis(&out.ctx, a.ctx, C.int(axis), C.bool(precise), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func LayerNormFn(x, weight, bias *Array, eps float32) *Array {
|
||||
out := New("FAST_LAYERNORM")
|
||||
var w, b C.mlx_array
|
||||
if weight != nil {
|
||||
w = weight.ctx
|
||||
}
|
||||
if bias != nil {
|
||||
b = bias.ctx
|
||||
}
|
||||
C.mlx_fast_layer_norm(&out.ctx, x.ctx, w, b, C.float(eps), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func RMSNormFn(x, weight *Array, eps float32) *Array {
|
||||
out := New("FAST_RMSNORM")
|
||||
var w C.mlx_array
|
||||
if weight != nil {
|
||||
w = weight.ctx
|
||||
}
|
||||
C.mlx_fast_rms_norm(&out.ctx, x.ctx, w, C.float(eps), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func AddMM(c, a, b *Array, alpha, beta float32) *Array {
|
||||
return c.Addmm(a, b, alpha, beta)
|
||||
}
|
||||
|
||||
// Scalar helpers
|
||||
|
||||
// scalarWithDtype creates a scalar array matching the dtype of a.
|
||||
// Matching dtype is important for graph fusion and avoiding implicit casts.
|
||||
func scalarWithDtype(s float32, a *Array) C.mlx_array {
|
||||
f32 := C.mlx_array_new_float(C.float(s))
|
||||
dtype := a.DType()
|
||||
if dtype == DTypeFloat32 {
|
||||
return f32
|
||||
}
|
||||
casted := C.mlx_array_new()
|
||||
C.mlx_astype(&casted, f32, C.mlx_dtype(dtype), DefaultStream().ctx)
|
||||
C.mlx_array_free(f32)
|
||||
return casted
|
||||
}
|
||||
|
||||
func AddScalar(a *Array, s float32) *Array {
|
||||
scalar := scalarWithDtype(s, a)
|
||||
out := New("ADD_SCALAR")
|
||||
C.mlx_add(&out.ctx, a.ctx, scalar, DefaultStream().ctx)
|
||||
C.mlx_array_free(scalar)
|
||||
return out
|
||||
}
|
||||
|
||||
func MulScalar(a *Array, s float32) *Array {
|
||||
scalar := scalarWithDtype(s, a)
|
||||
out := New("MUL_SCALAR")
|
||||
C.mlx_multiply(&out.ctx, a.ctx, scalar, DefaultStream().ctx)
|
||||
C.mlx_array_free(scalar)
|
||||
return out
|
||||
}
|
||||
|
||||
func DivScalar(a *Array, s float32) *Array {
|
||||
scalar := scalarWithDtype(s, a)
|
||||
out := New("DIV_SCALAR")
|
||||
C.mlx_divide(&out.ctx, a.ctx, scalar, DefaultStream().ctx)
|
||||
C.mlx_array_free(scalar)
|
||||
return out
|
||||
}
|
||||
|
||||
func FloorDivideScalar(a *Array, s int32) *Array {
|
||||
scalar := FromValue(int(s))
|
||||
return a.FloorDivide(scalar)
|
||||
}
|
||||
|
||||
// Array constructors
|
||||
|
||||
func NewArrayInt32(data []int32, shape []int32) *Array {
|
||||
cShape := make([]C.int, len(shape))
|
||||
for i, s := range shape {
|
||||
cShape[i] = C.int(s)
|
||||
}
|
||||
out := New("NEW_ARRAY_INT32")
|
||||
out.ctx = C.mlx_array_new_data(unsafe.Pointer(&data[0]), unsafe.SliceData(cShape), C.int(len(shape)), C.mlx_dtype(DTypeInt32))
|
||||
return out
|
||||
}
|
||||
|
||||
func NewScalarArray(value float32) *Array {
|
||||
out := New("SCALAR")
|
||||
out.ctx = C.mlx_array_new_float32(C.float(value))
|
||||
return out
|
||||
}
|
||||
|
||||
func ZerosF32(shape []int32) *Array {
|
||||
return Zeros(DTypeFloat32, func() []int {
|
||||
ints := make([]int, len(shape))
|
||||
for i, s := range shape {
|
||||
ints[i] = int(s)
|
||||
}
|
||||
return ints
|
||||
}()...)
|
||||
}
|
||||
|
||||
// Utility
|
||||
|
||||
func Collect(v any) []*Array {
|
||||
var arrays []*Array
|
||||
seen := make(map[uintptr]bool)
|
||||
collect(reflect.ValueOf(v), &arrays, seen)
|
||||
return arrays
|
||||
}
|
||||
|
||||
func collect(v reflect.Value, arrays *[]*Array, seen map[uintptr]bool) {
|
||||
if !v.IsValid() {
|
||||
return
|
||||
}
|
||||
|
||||
if v.Kind() == reflect.Ptr {
|
||||
if v.IsNil() {
|
||||
return
|
||||
}
|
||||
ptr := v.Pointer()
|
||||
if seen[ptr] {
|
||||
return
|
||||
}
|
||||
seen[ptr] = true
|
||||
|
||||
if arr, ok := v.Interface().(*Array); ok {
|
||||
if arr != nil && arr.Valid() {
|
||||
*arrays = append(*arrays, arr)
|
||||
}
|
||||
return
|
||||
}
|
||||
collect(v.Elem(), arrays, seen)
|
||||
return
|
||||
}
|
||||
|
||||
switch v.Kind() {
|
||||
case reflect.Struct:
|
||||
// Check if this struct IS an Array (not a pointer to one)
|
||||
if arr, ok := v.Addr().Interface().(*Array); ok {
|
||||
if arr != nil && arr.Valid() {
|
||||
*arrays = append(*arrays, arr)
|
||||
}
|
||||
return
|
||||
}
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
field := v.Field(i)
|
||||
if field.CanInterface() {
|
||||
collect(field, arrays, seen)
|
||||
}
|
||||
}
|
||||
case reflect.Slice:
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
collect(v.Index(i), arrays, seen)
|
||||
}
|
||||
case reflect.Map:
|
||||
for _, key := range v.MapKeys() {
|
||||
collect(v.MapIndex(key), arrays, seen)
|
||||
}
|
||||
case reflect.Interface:
|
||||
if !v.IsNil() {
|
||||
collect(v.Elem(), arrays, seen)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func EnableCompile() {
|
||||
C.mlx_enable_compile()
|
||||
}
|
||||
|
||||
func DisableCompile() {
|
||||
C.mlx_disable_compile()
|
||||
}
|
||||
44
x/mlxrunner/mlx/random.go
Normal file
44
x/mlxrunner/mlx/random.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package mlx
|
||||
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
|
||||
import "unsafe"
|
||||
|
||||
func RandomKey(seed uint64) *Array {
|
||||
out := New("RANDOM_KEY")
|
||||
C.mlx_random_key(&out.ctx, C.uint64_t(seed))
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Categorical(axis int) *Array {
|
||||
return t.CategoricalWithKey(axis, nil)
|
||||
}
|
||||
|
||||
func (t *Array) CategoricalWithKey(axis int, key *Array) *Array {
|
||||
if key == nil {
|
||||
key = New("")
|
||||
}
|
||||
out := New("")
|
||||
C.mlx_random_categorical(&out.ctx, t.ctx, C.int(axis), key.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func Bernoulli(p *Array) *Array {
|
||||
return BernoulliWithKey(p, nil)
|
||||
}
|
||||
|
||||
func BernoulliWithKey(p *Array, key *Array) *Array {
|
||||
dims := p.Dims()
|
||||
shape := make([]C.int, len(dims))
|
||||
for i, d := range dims {
|
||||
shape[i] = C.int(d)
|
||||
}
|
||||
|
||||
if key == nil {
|
||||
key = New("")
|
||||
}
|
||||
out := New("BERNOULLI")
|
||||
C.mlx_random_bernoulli(&out.ctx, p.ctx, unsafe.SliceData(shape), C.size_t(len(shape)), key.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
100
x/mlxrunner/mlx/slice.go
Normal file
100
x/mlxrunner/mlx/slice.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package mlx
|
||||
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"math"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// End is a sentinel value meaning "to the end of the dimension",
|
||||
// equivalent to an omitted stop in Python (e.g. a[i:]).
|
||||
const End = math.MaxInt32
|
||||
|
||||
type slice struct {
|
||||
args []int
|
||||
}
|
||||
|
||||
func Slice(args ...int) slice {
|
||||
return slice{args: args}
|
||||
}
|
||||
|
||||
func resolve(val, dim int) C.int {
|
||||
if val == End {
|
||||
return C.int(dim)
|
||||
}
|
||||
if val < 0 {
|
||||
return C.int(dim + val)
|
||||
}
|
||||
return C.int(val)
|
||||
}
|
||||
|
||||
func makeSlices(dims []int, slices ...slice) (starts, stops, strides []C.int) {
|
||||
if len(slices) != len(dims) {
|
||||
panic("number of slice arguments must match number of tensor dimensions")
|
||||
}
|
||||
|
||||
args := [3][]C.int{
|
||||
make([]C.int, len(slices)),
|
||||
make([]C.int, len(slices)),
|
||||
make([]C.int, len(slices)),
|
||||
}
|
||||
|
||||
for i, s := range slices {
|
||||
dim := dims[i]
|
||||
switch len(s.args) {
|
||||
case 0:
|
||||
// slice[:]
|
||||
args[0][i] = C.int(0)
|
||||
args[1][i] = C.int(dim)
|
||||
args[2][i] = C.int(1)
|
||||
case 1:
|
||||
// slice[i]
|
||||
start := resolve(s.args[0], dim)
|
||||
args[0][i] = start
|
||||
args[1][i] = start + 1
|
||||
args[2][i] = C.int(1)
|
||||
case 2:
|
||||
// slice[i:j]
|
||||
args[0][i] = resolve(s.args[0], dim)
|
||||
args[1][i] = resolve(s.args[1], dim)
|
||||
args[2][i] = C.int(1)
|
||||
case 3:
|
||||
// slice[i:j:k]
|
||||
args[0][i] = resolve(s.args[0], dim)
|
||||
args[1][i] = resolve(s.args[1], dim)
|
||||
args[2][i] = C.int(s.args[2])
|
||||
default:
|
||||
panic("invalid slice arguments")
|
||||
}
|
||||
}
|
||||
|
||||
return args[0], args[1], args[2]
|
||||
}
|
||||
|
||||
func (t *Array) Slice(slices ...slice) *Array {
|
||||
starts, stops, strides := makeSlices(t.Dims(), slices...)
|
||||
out := New("SLICE")
|
||||
C.mlx_slice(
|
||||
&out.ctx, t.ctx,
|
||||
unsafe.SliceData(starts), C.size_t(len(starts)),
|
||||
unsafe.SliceData(stops), C.size_t(len(stops)),
|
||||
unsafe.SliceData(strides), C.size_t(len(strides)),
|
||||
DefaultStream().ctx,
|
||||
)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) SliceUpdate(other *Array, slices ...slice) *Array {
|
||||
starts, stops, strides := makeSlices(t.Dims(), slices...)
|
||||
out := New("SLICE_UPDATE")
|
||||
C.mlx_slice_update(
|
||||
&out.ctx, t.ctx, other.ctx,
|
||||
unsafe.SliceData(starts), C.size_t(len(starts)),
|
||||
unsafe.SliceData(stops), C.size_t(len(stops)),
|
||||
unsafe.SliceData(strides), C.size_t(len(strides)),
|
||||
DefaultStream().ctx,
|
||||
)
|
||||
return out
|
||||
}
|
||||
79
x/mlxrunner/mlx/stream.go
Normal file
79
x/mlxrunner/mlx/stream.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package mlx
|
||||
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
|
||||
import "log/slog"
|
||||
|
||||
type Device struct {
|
||||
ctx C.mlx_device
|
||||
}
|
||||
|
||||
func (d Device) LogValue() slog.Value {
|
||||
str := C.mlx_string_new()
|
||||
defer C.mlx_string_free(str)
|
||||
C.mlx_device_tostring(&str, d.ctx)
|
||||
return slog.StringValue(C.GoString(C.mlx_string_data(str)))
|
||||
}
|
||||
|
||||
var (
|
||||
defaultDevice Device
|
||||
defaultDeviceSet bool
|
||||
defaultStream Stream
|
||||
defaultStreamSet bool
|
||||
)
|
||||
|
||||
func resetDefaultStreamCache() {
|
||||
defaultDeviceSet = false
|
||||
defaultStreamSet = false
|
||||
}
|
||||
|
||||
func DefaultDevice() Device {
|
||||
if !defaultDeviceSet {
|
||||
d := C.mlx_device_new()
|
||||
C.mlx_get_default_device(&d)
|
||||
defaultDevice = Device{d}
|
||||
defaultDeviceSet = true
|
||||
}
|
||||
|
||||
return defaultDevice
|
||||
}
|
||||
|
||||
// GPUIsAvailable returns true if a GPU device is available.
|
||||
func GPUIsAvailable() bool {
|
||||
dev := C.mlx_device_new_type(C.MLX_GPU, 0)
|
||||
defer C.mlx_device_free(dev)
|
||||
var avail C.bool
|
||||
C.mlx_device_is_available(&avail, dev)
|
||||
return bool(avail)
|
||||
}
|
||||
|
||||
// SetDefaultDeviceGPU sets the default MLX device to GPU.
|
||||
func SetDefaultDeviceGPU() {
|
||||
dev := C.mlx_device_new_type(C.MLX_GPU, 0)
|
||||
C.mlx_set_default_device(dev)
|
||||
C.mlx_device_free(dev)
|
||||
resetDefaultStreamCache()
|
||||
}
|
||||
|
||||
type Stream struct {
|
||||
ctx C.mlx_stream
|
||||
}
|
||||
|
||||
func (s Stream) LogValue() slog.Value {
|
||||
str := C.mlx_string_new()
|
||||
defer C.mlx_string_free(str)
|
||||
C.mlx_stream_tostring(&str, s.ctx)
|
||||
return slog.StringValue(C.GoString(C.mlx_string_data(str)))
|
||||
}
|
||||
|
||||
func DefaultStream() Stream {
|
||||
if !defaultStreamSet {
|
||||
s := C.mlx_stream_new()
|
||||
C.mlx_get_default_stream(&s, DefaultDevice().ctx)
|
||||
defaultStream = Stream{s}
|
||||
defaultStreamSet = true
|
||||
}
|
||||
|
||||
return defaultStream
|
||||
}
|
||||
104
x/mlxrunner/mlx/thread_test.go
Normal file
104
x/mlxrunner/mlx/thread_test.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package mlx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"runtime"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/internal/mlxthread"
|
||||
)
|
||||
|
||||
func skipIfNoMLX(t *testing.T) {
|
||||
t.Helper()
|
||||
if err := CheckInit(); err != nil {
|
||||
t.Skipf("MLX not available: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func startMLXThread(t *testing.T) *mlxthread.Thread {
|
||||
t.Helper()
|
||||
|
||||
thread, err := mlxthread.Start("mlx-test", func() error {
|
||||
if err := CheckInit(); err != nil {
|
||||
return err
|
||||
}
|
||||
if GPUIsAvailable() {
|
||||
SetDefaultDeviceGPU()
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Skipf("MLX not available: %v", err)
|
||||
}
|
||||
|
||||
return thread
|
||||
}
|
||||
|
||||
func stopMLXThread(t *testing.T, thread *mlxthread.Thread) {
|
||||
t.Helper()
|
||||
|
||||
if err := thread.Stop(context.Background(), func() {
|
||||
Sweep()
|
||||
ClearCache()
|
||||
resetDefaultStreamCache()
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func withMLXThread(t *testing.T, fn func()) {
|
||||
t.Helper()
|
||||
|
||||
thread := startMLXThread(t)
|
||||
defer stopMLXThread(t, thread)
|
||||
|
||||
if err := thread.Do(context.Background(), func() error {
|
||||
fn()
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestThreadedMLXOperations(t *testing.T) {
|
||||
thread := startMLXThread(t)
|
||||
defer stopMLXThread(t, thread)
|
||||
|
||||
oldProcs := runtime.GOMAXPROCS(8)
|
||||
defer runtime.GOMAXPROCS(oldProcs)
|
||||
|
||||
const goroutines = 8
|
||||
const iterations = 8
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errCh := make(chan error, goroutines)
|
||||
for range goroutines {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
for range iterations {
|
||||
if err := thread.Do(context.Background(), func() error {
|
||||
a := FromValues([]float32{1, 2, 3, 4}, 2, 2)
|
||||
b := Matmul(a, a)
|
||||
AsyncEval(b)
|
||||
Eval(b)
|
||||
Sweep()
|
||||
ClearCache()
|
||||
return nil
|
||||
}); err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errCh)
|
||||
|
||||
for err := range errCh {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user