ollama source for Momentry Core verification
This commit is contained in:
79
x/mlxrunner/mlx/stream.go
Normal file
79
x/mlxrunner/mlx/stream.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package mlx
|
||||
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
|
||||
import "log/slog"
|
||||
|
||||
type Device struct {
|
||||
ctx C.mlx_device
|
||||
}
|
||||
|
||||
func (d Device) LogValue() slog.Value {
|
||||
str := C.mlx_string_new()
|
||||
defer C.mlx_string_free(str)
|
||||
C.mlx_device_tostring(&str, d.ctx)
|
||||
return slog.StringValue(C.GoString(C.mlx_string_data(str)))
|
||||
}
|
||||
|
||||
var (
|
||||
defaultDevice Device
|
||||
defaultDeviceSet bool
|
||||
defaultStream Stream
|
||||
defaultStreamSet bool
|
||||
)
|
||||
|
||||
func resetDefaultStreamCache() {
|
||||
defaultDeviceSet = false
|
||||
defaultStreamSet = false
|
||||
}
|
||||
|
||||
func DefaultDevice() Device {
|
||||
if !defaultDeviceSet {
|
||||
d := C.mlx_device_new()
|
||||
C.mlx_get_default_device(&d)
|
||||
defaultDevice = Device{d}
|
||||
defaultDeviceSet = true
|
||||
}
|
||||
|
||||
return defaultDevice
|
||||
}
|
||||
|
||||
// GPUIsAvailable returns true if a GPU device is available.
|
||||
func GPUIsAvailable() bool {
|
||||
dev := C.mlx_device_new_type(C.MLX_GPU, 0)
|
||||
defer C.mlx_device_free(dev)
|
||||
var avail C.bool
|
||||
C.mlx_device_is_available(&avail, dev)
|
||||
return bool(avail)
|
||||
}
|
||||
|
||||
// SetDefaultDeviceGPU sets the default MLX device to GPU.
|
||||
func SetDefaultDeviceGPU() {
|
||||
dev := C.mlx_device_new_type(C.MLX_GPU, 0)
|
||||
C.mlx_set_default_device(dev)
|
||||
C.mlx_device_free(dev)
|
||||
resetDefaultStreamCache()
|
||||
}
|
||||
|
||||
type Stream struct {
|
||||
ctx C.mlx_stream
|
||||
}
|
||||
|
||||
func (s Stream) LogValue() slog.Value {
|
||||
str := C.mlx_string_new()
|
||||
defer C.mlx_string_free(str)
|
||||
C.mlx_stream_tostring(&str, s.ctx)
|
||||
return slog.StringValue(C.GoString(C.mlx_string_data(str)))
|
||||
}
|
||||
|
||||
func DefaultStream() Stream {
|
||||
if !defaultStreamSet {
|
||||
s := C.mlx_stream_new()
|
||||
C.mlx_get_default_stream(&s, DefaultDevice().ctx)
|
||||
defaultStream = Stream{s}
|
||||
defaultStreamSet = true
|
||||
}
|
||||
|
||||
return defaultStream
|
||||
}
|
||||
Reference in New Issue
Block a user