ollama source for Momentry Core verification
This commit is contained in:
76
x/mlxrunner/mlx/array_test.go
Normal file
76
x/mlxrunner/mlx/array_test.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package mlx
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestFromValue(t *testing.T) {
|
||||
withMLXThread(t, func() {
|
||||
for got, want := range map[*Array]DType{
|
||||
FromValue(true): DTypeBool,
|
||||
FromValue(false): DTypeBool,
|
||||
FromValue(int(7)): DTypeInt32,
|
||||
FromValue(float32(3.14)): DTypeFloat32,
|
||||
FromValue(float64(2.71)): DTypeFloat64,
|
||||
FromValue(complex64(1 + 2i)): DTypeComplex64,
|
||||
} {
|
||||
if got.DType() != want {
|
||||
t.Errorf("%s: want %v, got %v", want, want, got)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestFromValues(t *testing.T) {
|
||||
withMLXThread(t, func() {
|
||||
for got, want := range map[*Array]DType{
|
||||
FromValues([]bool{true, false, true}, 3): DTypeBool,
|
||||
FromValues([]uint8{1, 2, 3}, 3): DTypeUint8,
|
||||
FromValues([]uint16{1, 2, 3}, 3): DTypeUint16,
|
||||
FromValues([]uint32{1, 2, 3}, 3): DTypeUint32,
|
||||
FromValues([]uint64{1, 2, 3}, 3): DTypeUint64,
|
||||
FromValues([]int8{-1, -2, -3}, 3): DTypeInt8,
|
||||
FromValues([]int16{-1, -2, -3}, 3): DTypeInt16,
|
||||
FromValues([]int32{-1, -2, -3}, 3): DTypeInt32,
|
||||
FromValues([]int64{-1, -2, -3}, 3): DTypeInt64,
|
||||
FromValues([]float32{3.14, 2.71, 1.61}, 3): DTypeFloat32,
|
||||
FromValues([]float64{3.14, 2.71, 1.61}, 3): DTypeFloat64,
|
||||
FromValues([]complex64{1 + 2i, 3 + 4i, 5 + 6i}, 3): DTypeComplex64,
|
||||
} {
|
||||
if got.DType() != want {
|
||||
t.Errorf("%s: want %v, got %v", want, want, got)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestComparisonOpsAndBernoulli(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
a := FromValues([]float32{1, 2, 3}, 3)
|
||||
b := FromValues([]float32{1, 1, 4}, 3)
|
||||
eq := a.Equal(b).AsType(DTypeInt32)
|
||||
gt := a.Greater(b).AsType(DTypeInt32)
|
||||
le := a.LessEqual(b).AsType(DTypeInt32)
|
||||
bern := Bernoulli(FromValues([]float32{1, 0}, 2)).AsType(DTypeInt32)
|
||||
Eval(eq, gt, le, bern)
|
||||
|
||||
for name, tc := range map[string]struct {
|
||||
got []int
|
||||
want []int
|
||||
}{
|
||||
"equal": {eq.Ints(), []int{1, 0, 0}},
|
||||
"greater": {gt.Ints(), []int{0, 1, 0}},
|
||||
"lessEqual": {le.Ints(), []int{1, 0, 1}},
|
||||
"bernoulli": {bern.Ints(), []int{1, 0}},
|
||||
} {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
if len(tc.got) != len(tc.want) {
|
||||
t.Fatalf("got %v, want %v", tc.got, tc.want)
|
||||
}
|
||||
for i := range tc.want {
|
||||
if tc.got[i] != tc.want[i] {
|
||||
t.Fatalf("got %v, want %v", tc.got, tc.want)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user