ollama source for Momentry Core verification
This commit is contained in:
44
x/mlxrunner/mlx/random.go
Normal file
44
x/mlxrunner/mlx/random.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package mlx
|
||||
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
|
||||
import "unsafe"
|
||||
|
||||
func RandomKey(seed uint64) *Array {
|
||||
out := New("RANDOM_KEY")
|
||||
C.mlx_random_key(&out.ctx, C.uint64_t(seed))
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Categorical(axis int) *Array {
|
||||
return t.CategoricalWithKey(axis, nil)
|
||||
}
|
||||
|
||||
func (t *Array) CategoricalWithKey(axis int, key *Array) *Array {
|
||||
if key == nil {
|
||||
key = New("")
|
||||
}
|
||||
out := New("")
|
||||
C.mlx_random_categorical(&out.ctx, t.ctx, C.int(axis), key.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func Bernoulli(p *Array) *Array {
|
||||
return BernoulliWithKey(p, nil)
|
||||
}
|
||||
|
||||
func BernoulliWithKey(p *Array, key *Array) *Array {
|
||||
dims := p.Dims()
|
||||
shape := make([]C.int, len(dims))
|
||||
for i, d := range dims {
|
||||
shape[i] = C.int(d)
|
||||
}
|
||||
|
||||
if key == nil {
|
||||
key = New("")
|
||||
}
|
||||
out := New("BERNOULLI")
|
||||
C.mlx_random_bernoulli(&out.ctx, p.ctx, unsafe.SliceData(shape), C.size_t(len(shape)), key.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
Reference in New Issue
Block a user