ollama source for Momentry Core verification
This commit is contained in:
17
x/mlxrunner/mlx/generator/generated.c.gotmpl
Normal file
17
x/mlxrunner/mlx/generator/generated.c.gotmpl
Normal file
@@ -0,0 +1,17 @@
|
||||
// This code is auto-generated; DO NOT EDIT.
|
||||
|
||||
#include "generated.h"
|
||||
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
{{ range .Functions }}
|
||||
{{ .Type }} (*{{ .Name }}_){{ .Parameters }} = NULL;
|
||||
{{- end }}
|
||||
|
||||
int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
{{- range .Functions }}
|
||||
{{ if .Optional }}OPTIONAL_LOAD{{ else }}CHECK_LOAD{{ end }}(handle, {{ .Name }});
|
||||
{{- end }}
|
||||
return 0;
|
||||
}
|
||||
26
x/mlxrunner/mlx/generator/generated.h.gotmpl
Normal file
26
x/mlxrunner/mlx/generator/generated.h.gotmpl
Normal file
@@ -0,0 +1,26 @@
|
||||
// This code is auto-generated; DO NOT EDIT.
|
||||
|
||||
#ifndef MLX_GENERATED_H
|
||||
#define MLX_GENERATED_H
|
||||
|
||||
#include "dynamic.h"
|
||||
{{ range .Functions }}
|
||||
#define {{ .Name }} {{ .Name }}_mlx_gen_orig_
|
||||
{{- end }}
|
||||
|
||||
#include "mlx/c/mlx.h"
|
||||
{{ range .Functions }}
|
||||
#undef {{ .Name }}
|
||||
{{- end }}
|
||||
{{ range .Functions }}
|
||||
extern {{ .Type }} (*{{ .Name }}_){{ .Parameters }};
|
||||
{{- end }}
|
||||
|
||||
int mlx_dynamic_load_symbols(mlx_dynamic_handle handle);
|
||||
{{ range .Functions }}
|
||||
static inline {{ .Type }} {{ .Name }}{{ .Parameters }} {{ "{" }}
|
||||
return {{ .Name }}_({{ .Args }});
|
||||
{{ "}" }}
|
||||
{{- end }}
|
||||
|
||||
#endif // MLX_GENERATED_H
|
||||
157
x/mlxrunner/mlx/generator/main.go
Normal file
157
x/mlxrunner/mlx/generator/main.go
Normal file
@@ -0,0 +1,157 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"embed"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
tree_sitter "github.com/tree-sitter/go-tree-sitter"
|
||||
tree_sitter_cpp "github.com/tree-sitter/tree-sitter-cpp/bindings/go"
|
||||
)
|
||||
|
||||
//go:embed *.gotmpl
|
||||
var fsys embed.FS
|
||||
|
||||
// optionalSymbols lists symbols that may not be present in all builds
|
||||
// (e.g., float16/bfloat16 are unavailable in CUDA builds of MLX).
|
||||
var optionalSymbols = map[string]bool{
|
||||
"mlx_array_item_float16": true,
|
||||
"mlx_array_item_bfloat16": true,
|
||||
"mlx_array_data_float16": true,
|
||||
"mlx_array_data_bfloat16": true,
|
||||
}
|
||||
|
||||
type Function struct {
|
||||
Type,
|
||||
Name,
|
||||
Parameters,
|
||||
Args string
|
||||
Optional bool
|
||||
}
|
||||
|
||||
func ParseFunction(node *tree_sitter.Node, tc *tree_sitter.TreeCursor, source []byte) Function {
|
||||
var fn Function
|
||||
fn.Name = node.ChildByFieldName("declarator").Utf8Text(source)
|
||||
if params := node.ChildByFieldName("parameters"); params != nil {
|
||||
fn.Parameters = params.Utf8Text(source)
|
||||
fn.Args = ParseParameters(params, tc, source)
|
||||
}
|
||||
|
||||
var types []string
|
||||
for node.Parent() != nil && node.Parent().Kind() != "declaration" {
|
||||
if node.Parent().Kind() == "pointer_declarator" {
|
||||
types = append(types, "*")
|
||||
}
|
||||
node = node.Parent()
|
||||
}
|
||||
|
||||
for sibling := node.PrevSibling(); sibling != nil; sibling = sibling.PrevSibling() {
|
||||
types = append(types, sibling.Utf8Text(source))
|
||||
}
|
||||
|
||||
slices.Reverse(types)
|
||||
fn.Type = strings.Join(types, " ")
|
||||
return fn
|
||||
}
|
||||
|
||||
func ParseParameters(node *tree_sitter.Node, tc *tree_sitter.TreeCursor, source []byte) string {
|
||||
var s []string
|
||||
for _, child := range node.Children(tc) {
|
||||
if child.IsNamed() {
|
||||
child := child.ChildByFieldName("declarator")
|
||||
for child != nil && child.Kind() != "identifier" {
|
||||
if child.Kind() == "parenthesized_declarator" {
|
||||
child = child.Child(1)
|
||||
} else {
|
||||
child = child.ChildByFieldName("declarator")
|
||||
}
|
||||
}
|
||||
|
||||
if child != nil {
|
||||
s = append(s, child.Utf8Text(source))
|
||||
}
|
||||
}
|
||||
}
|
||||
return strings.Join(s, ", ")
|
||||
}
|
||||
|
||||
func main() {
|
||||
var output string
|
||||
flag.StringVar(&output, "output", ".", "Output directory for generated files")
|
||||
flag.Parse()
|
||||
|
||||
parser := tree_sitter.NewParser()
|
||||
defer parser.Close()
|
||||
|
||||
language := tree_sitter.NewLanguage(tree_sitter_cpp.Language())
|
||||
parser.SetLanguage(language)
|
||||
|
||||
query, _ := tree_sitter.NewQuery(language, `(function_declarator declarator: (identifier)) @func`)
|
||||
defer query.Close()
|
||||
|
||||
qc := tree_sitter.NewQueryCursor()
|
||||
defer qc.Close()
|
||||
|
||||
var files []string
|
||||
for _, arg := range flag.Args() {
|
||||
matches, err := filepath.Glob(arg)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error expanding glob %s: %v\n", arg, err)
|
||||
continue
|
||||
}
|
||||
files = append(files, matches...)
|
||||
}
|
||||
|
||||
var funs []Function
|
||||
for _, arg := range files {
|
||||
bts, err := os.ReadFile(arg)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error reading file %s: %v\n", arg, err)
|
||||
continue
|
||||
}
|
||||
|
||||
tree := parser.Parse(bts, nil)
|
||||
defer tree.Close()
|
||||
|
||||
tc := tree.Walk()
|
||||
defer tc.Close()
|
||||
|
||||
matches := qc.Matches(query, tree.RootNode(), bts)
|
||||
for match := matches.Next(); match != nil; match = matches.Next() {
|
||||
for _, capture := range match.Captures {
|
||||
fn := ParseFunction(&capture.Node, tc, bts)
|
||||
fn.Optional = optionalSymbols[fn.Name]
|
||||
funs = append(funs, fn)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tmpl, err := template.New("").ParseFS(fsys, "*.gotmpl")
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error parsing template: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, tmpl := range tmpl.Templates() {
|
||||
name := filepath.Join(output, strings.TrimSuffix(tmpl.Name(), ".gotmpl"))
|
||||
|
||||
fmt.Println("Generating", name)
|
||||
f, err := os.Create(name)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error creating file %s: %v\n", name, err)
|
||||
continue
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if err := tmpl.Execute(f, map[string]any{
|
||||
"Functions": funs,
|
||||
}); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error executing template %s: %v\n", tmpl.Name(), err)
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user