ollama source for Momentry Core verification

This commit is contained in:
Accusys
2026-05-22 17:19:10 +08:00
commit 0b31ff9135
2020 changed files with 1413145 additions and 0 deletions

View File

@@ -0,0 +1,42 @@
package batch
import "github.com/ollama/ollama/x/mlxrunner/mlx"
// Batch is the per-forward-pass input handed to a model.
type Batch struct {
// InputIDs is the input token IDs for this forward pass, shape (B, L).
InputIDs *mlx.Array
// SeqOffsets gives each row's current position within its sequence —
// where the chunk in InputIDs starts. Length equals the batch dimension
// of InputIDs.
SeqOffsets []int32
// SeqQueryLens is each row's real query length in this forward. Values
// less than L mean the row's tail is padding that must be masked out.
// Length equals the batch dimension of InputIDs.
SeqQueryLens []int32
// Memo is per-forward memoization used to cache results, such as masks,
// which are often the same across layers.
Memo Memo
}
type Memo struct {
entries map[any]any
}
// Get returns the memoized value for key and true if present, or nil
// and false otherwise.
func (m *Memo) Get(key any) (any, bool) {
v, ok := m.entries[key]
return v, ok
}
// Put stores value under key, allocating on first use.
func (m *Memo) Put(key, value any) {
if m.entries == nil {
m.entries = map[any]any{}
}
m.entries[key] = value
}

632
x/mlxrunner/cache.go Normal file
View File

@@ -0,0 +1,632 @@
// cache.go manages a shared KV cache across conversations using a compressed
// prefix trie. Each trie node stores a token sequence (edge) and optional
// per-layer snapshots that can be paged in/out of the live MLX cache arrays.
//
// Key properties:
// - Only one path through the trie is "active" (backed by live MLX arrays)
// at a time. Switching paths pages out the frontier node and pages in the
// new path.
// - Snapshots are only captured at the frontier (end) of the active path.
// Intermediate node snapshots come from split prefill.
// - All cache layers must stay at the same token offset.
// - Sibling edges must not share a common token prefix (compressed trie
// invariant).
// - begin() always re-evaluates at least one token so the pipeline can seed
// generation, even on a full prefix match.
package mlxrunner
import (
"cmp"
"fmt"
"log/slog"
"slices"
"time"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model/base"
)
const maxPagedOutBytes int64 = 8 << 30 // 8 GiB eviction threshold for paged-out snapshot memory
type kvCache struct {
root *trieNode // root of the prefix trie
activePath []*trieNode // current root→leaf path with live MLX arrays
caches []cache.Cache
pagedOutBytes int64 // total bytes in paged-out snapshots across the trie
}
type cacheFactoryFunc func() []cache.Cache
// pendingSnapshot is a snapshot scheduled to be taken during prefill.
type pendingSnapshot struct {
offset int
user bool
}
// cacheSession manages caches for a single pipeline run.
// Callers should append generated tokens to outputs and
// defer close to save the cache state.
type cacheSession struct {
cache *kvCache
inputs []int32
outputs []int32
caches []cache.Cache
remaining []int32
// pendingSnapshots lists offsets where snapshots should be captured
// during prefill, sorted by offset. Entries are consumed as the
// cache advances past them.
pendingSnapshots []pendingSnapshot
}
func (c *kvCache) ensureCachesWithFactory(newCaches cacheFactoryFunc) {
if len(c.caches) != 0 {
return
}
c.caches = newCaches()
}
func newModelCaches(m base.Model) []cache.Cache {
if cacheFactory, ok := m.(interface{ NewCaches() []cache.Cache }); ok {
return cacheFactory.NewCaches()
}
caches := make([]cache.Cache, m.NumLayers())
for i := range caches {
caches[i] = cache.NewKVCache()
}
return caches
}
func (c *kvCache) ensureRoot() {
if c.root == nil {
c.root = &trieNode{
lastUsed: time.Now(),
}
c.activePath = []*trieNode{c.root}
}
}
// begin prepares caches for a new request. It finds the nearest
// matching cache or creates new caches if none match.
func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession {
return c.beginWithFactory(inputs, func() []cache.Cache { return newModelCaches(m) }, "")
}
func (c *kvCache) beginWithFactory(inputs []int32, newCaches cacheFactoryFunc, logPrefix string) *cacheSession {
return c.beginWithFactoryLimit(inputs, newCaches, logPrefix, -1, true)
}
func (c *kvCache) beginWithFactoryLimit(inputs []int32, newCaches cacheFactoryFunc, logPrefix string, maxCachedPrefix int, keepSeedToken bool) *cacheSession {
c.ensureCachesWithFactory(newCaches)
c.ensureRoot()
matchPath, matched := findBestMatch(c.root, inputs)
originalMatched := matched
if maxCachedPrefix >= 0 {
maxCachedPrefix = min(maxCachedPrefix, len(inputs))
if matched > maxCachedPrefix {
matchPath, matched = findBestMatch(c.root, inputs[:maxCachedPrefix])
}
}
// Always keep at least one token to re-evaluate so the
// pipeline can seed token generation from it.
if keepSeedToken && matched == len(inputs) && matched > 0 {
matchPath, matched = findBestMatch(c.root, inputs[:len(inputs)-1])
}
// Switch to the matched path, paging in/out as needed.
c.switchToPath(matchPath, matched)
// switchToPath aligns caches to a common offset
prefix := c.minCacheOffset()
remaining := inputs[prefix:]
session := &cacheSession{
cache: c,
inputs: inputs,
caches: c.caches,
remaining: remaining,
}
// Schedule a snapshot at the branch point during prefill so future
// requests diverging here can restore instead of re-evaluating.
if prefix < matched {
session.pendingSnapshots = append(session.pendingSnapshots, pendingSnapshot{offset: matched, user: false})
}
msg := "cache hit"
if prefix == 0 {
msg = "cache miss"
}
if logPrefix != "" {
msg = logPrefix + " " + msg
}
slog.Info(msg, "total", len(inputs), "matched", originalMatched, "cached", prefix, "left", len(remaining))
return session
}
// switchToPath transitions from the current active path to a new path,
// paging out diverging segments and paging in the new path.
func (c *kvCache) switchToPath(newPath []*trieNode, matched int) {
defer c.enforceEvictionPolicy()
// Find common ancestor index.
commonLen := 0
for commonLen < len(c.activePath) && commonLen < len(newPath) {
if c.activePath[commonLen] != newPath[commonLen] {
break
}
commonLen++
}
ancestorOffset := 0
if commonLen > 0 {
ancestorOffset = c.activePath[commonLen-1].endOffset
}
var pageOutCount, pageInCount int
// Page out the leaf of the old path. Only the leaf's live cache
// state is correct — intermediate nodes already have snapshots
// captured during their creation (splitNode + prefill). Snapshotting
// non-leaf nodes here would produce wrong results for non-rewindable
// caches (e.g. RecurrentCache) whose state reflects the leaf, not
// the intermediate boundary.
leaf := len(c.activePath) - 1
leafDiverges := leaf >= commonLen
leafNeedsRewind := matched < c.activePath[leaf].endOffset
if leafDiverges || leafNeedsRewind {
node := c.activePath[leaf]
if !node.hasAllSnapshots() {
fromOffset := node.startOffset()
snaps := make([]cache.Snapshot, len(c.caches))
for j, kv := range c.caches {
if kv == nil {
continue
}
snaps[j] = kv.Snapshot(fromOffset)
}
node.setSnapshots(snaps, &c.pagedOutBytes)
pageOutCount++
logutil.Trace(fmt.Sprintf("page out: [%d, %d)", fromOffset, node.endOffset))
}
}
// Rewind each cache to the target offset or free it. When matched
// falls within the ancestor's range (same-path case), we rewind
// directly to the match point. Otherwise we rewind to the ancestor
// and let page-in bring us forward to matched.
rewindTarget := min(ancestorOffset, matched)
for _, kv := range c.caches {
if kv == nil {
continue
}
if !kv.Restore(nil, rewindTarget) {
kv.Free()
}
}
// Page in — walk the full new path, restoring from snapshots.
// Freed caches naturally pick up the first available snapshot.
// Caches already past a node skip it via offset check.
pageIn:
for _, node := range newPath {
if !node.hasSnapshots() {
continue
}
nodeTarget := min(node.endOffset, matched)
for j, kv := range c.caches {
if kv == nil {
continue
}
if j >= len(node.snapshots) || node.snapshots[j] == nil {
continue
}
if kv.Offset() >= nodeTarget {
continue
}
if !kv.Restore(node.snapshots[j], nodeTarget) {
// Restore failed — stop page-in and let alignment
// bring all caches to a consistent offset.
break pageIn
}
}
if node.endOffset > ancestorOffset {
pageInCount++
logutil.Trace(fmt.Sprintf("page in: [%d, %d)", node.startOffset(), nodeTarget))
}
}
// Align all caches to the minimum offset.
c.activePath = newPath
minOff := c.minCacheOffset()
for _, kv := range c.caches {
if kv != nil && kv.Offset() != minOff {
if !kv.Restore(nil, minOff) {
slog.Warn("failed to restore cache, freeing all caches", "offset", minOff)
c.freeAll()
break
}
}
}
for i := len(c.activePath) - 1; i >= 0; i-- {
if c.activePath[i].endOffset <= minOff {
c.activePath = c.activePath[:i+1]
break
}
}
// Update last-used time on only the final used node. For recurrent
// caches we don't need the intermediate snapshots and for KV caches
// we can reslice the data out of merged edges.
if len(c.activePath) > 0 {
c.activePath[len(c.activePath)-1].lastUsed = time.Now()
}
if pageOutCount > 0 || pageInCount > 0 {
slog.Debug("switching cache path", "page_out", pageOutCount, "page_in", pageInCount)
}
}
// requestSnapshot schedules a user snapshot at the given absolute token
// offset. The snapshot will be captured during prefill when the cache
// reaches this offset.
func (s *cacheSession) requestSnapshot(offset int) {
baseOffset := len(s.inputs) - len(s.remaining)
if offset <= baseOffset || offset > len(s.inputs) {
return
}
// Deduplicate: if this offset already exists, upgrade to user.
for i := range s.pendingSnapshots {
if s.pendingSnapshots[i].offset == offset {
s.pendingSnapshots[i].user = true
return
}
}
s.pendingSnapshots = append(s.pendingSnapshots, pendingSnapshot{offset: offset, user: true})
slices.SortFunc(s.pendingSnapshots, func(a, b pendingSnapshot) int {
return a.offset - b.offset
})
}
// nextPendingSnapshot returns the offset of the next pending snapshot,
// or 0 if there are none.
func (s *cacheSession) nextPendingSnapshot() int {
if len(s.pendingSnapshots) == 0 {
return 0
}
return s.pendingSnapshots[0].offset
}
// snapshot creates a snapshot at the current cache position. It determines
// whether this is a user snapshot by consuming pending entries whose offset
// has been reached.
func (s *cacheSession) snapshot() {
c := s.cache
cacheOffset := c.minCacheOffset()
if cacheOffset <= 0 {
return
}
// Consume pending snapshots up to the current offset and derive
// the user flag from them.
user := false
for len(s.pendingSnapshots) > 0 && cacheOffset >= s.pendingSnapshots[0].offset {
if s.pendingSnapshots[0].user {
user = true
}
s.pendingSnapshots = s.pendingSnapshots[1:]
}
// The last node in activePath is the frontier where caches are advancing.
// cacheOffset is always >= its endOffset: begin() restores caches to this
// boundary and prefill advances monotonically forward.
frontier := c.activePath[len(c.activePath)-1]
// If the frontier already ends at cacheOffset, just ensure it has snapshots.
if frontier.endOffset == cacheOffset {
if user {
frontier.user = true
}
if !frontier.hasAllSnapshots() {
s.attachSnapshots(frontier, cacheOffset)
}
return
}
if frontier.endOffset > cacheOffset {
slog.Warn("snapshot skipped: cacheOffset is behind frontier", "cacheOffset", cacheOffset, "frontierEndOffset", frontier.endOffset)
return
}
// Advance the trie to cacheOffset — find or create a node there.
edgeTokens := append(s.inputs, s.outputs...)[frontier.endOffset:cacheOffset]
frontier = c.advancePath(frontier, edgeTokens, cacheOffset)
// Attach fresh snapshots from the live caches. Always use fresh
// snapshots even if the node already has some (e.g. from splitNode's
// Cache.Split which may be incomplete for non-splittable caches
// like RecurrentCache).
if user {
frontier.user = true
}
s.attachSnapshots(frontier, cacheOffset)
}
// advancePath advances the active path from the current frontier by matching
// tokens against existing trie children, splitting partial matches, and
// appending any remaining tokens as new nodes. Returns the new frontier.
func (c *kvCache) advancePath(frontier *trieNode, tokens []int32, endOffset int) *trieNode {
// Check if existing children already cover some or all of tokens.
// tokens may span multiple trie nodes when extending a previous run's
// leaf and this snapshot now overlaps that same range.
matchPath, matched := findBestMatch(frontier, tokens)
// matchPath[0] is frontier itself; the rest are newly traversed nodes.
remaining := tokens[matched:]
// Check for a partial match within the last node's edge — if so, split it.
if len(matchPath) > 1 {
lastNode := matchPath[len(matchPath)-1]
matchedInEdge := frontier.endOffset + matched - lastNode.startOffset()
if matchedInEdge > 0 && matchedInEdge < len(lastNode.tokens) {
matchPath[len(matchPath)-1] = splitNode(lastNode, matchedInEdge, c.caches, &c.pagedOutBytes)
}
}
// Append traversed nodes (excluding frontier) to the active path.
c.activePath = append(c.activePath, matchPath[1:]...)
dest := matchPath[len(matchPath)-1]
if len(remaining) > 0 {
// Drop non-user snapshots so appendTokens can extend in-place
// rather than creating a new child node.
if len(dest.children) == 0 && !dest.user {
dest.setSnapshots(nil, &c.pagedOutBytes)
}
newDest := dest.appendTokens(c.root, remaining, endOffset)
if newDest != dest {
c.activePath = append(c.activePath, newDest)
}
dest = newDest
}
return dest
}
// attachSnapshots attaches cache snapshots to a trie node at the given offset.
// The node must be on the active path (and thus protected from eviction;
// lastUsed is updated in close()). All non-nil caches must be at the same
// offset (cacheOffset); a mismatch indicates a bug in the caller.
func (s *cacheSession) attachSnapshots(node *trieNode, cacheOffset int) {
c := s.cache
if c.activePath[len(c.activePath)-1] != node {
slog.Warn("attachSnapshots skipped: node is not the active frontier", "nodeEndOffset", node.endOffset)
return
}
snaps := make([]cache.Snapshot, len(c.caches))
for i, kv := range c.caches {
if kv != nil {
if kv.Offset() != cacheOffset {
panic(fmt.Sprintf("attachSnapshots: cache offset mismatch layer %d: expected %d, got %d", i, cacheOffset, kv.Offset()))
}
snaps[i] = kv.Snapshot(node.startOffset())
}
}
node.setSnapshots(snaps, &c.pagedOutBytes)
node.lastUsed = time.Now()
slog.Debug("created snapshot", "offset", cacheOffset)
c.enforceEvictionPolicy()
}
// freeAll releases all cache layers.
func (c *kvCache) freeAll() {
for _, kv := range c.caches {
if kv != nil {
kv.Free()
}
}
}
func (c *kvCache) minCacheOffset() int {
offset := 0
found := false
for _, kv := range c.caches {
if kv == nil {
continue
}
if off := kv.Offset(); !found || off < offset {
offset = off
found = true
}
}
return offset
}
// close saves the token state if the forward pass ran.
func (s *cacheSession) close() {
offset := s.cache.minCacheOffset()
if offset <= 0 {
return
}
arrays := make([]*mlx.Array, 0, 2*len(s.caches))
for _, kv := range s.caches {
if kv == nil {
continue
}
arrays = append(arrays, kv.State()...)
}
// Ensure that if we have run the forward pass and set the metadata
// that we also actually have the data.
mlx.AsyncEval(arrays...)
// Advance the trie frontier with any newly generated tokens.
c := s.cache
if len(c.activePath) > 0 {
frontier := c.activePath[len(c.activePath)-1]
stored := append(s.inputs, s.outputs...)
if offset > frontier.endOffset {
newTokens := stored[frontier.endOffset:offset]
c.advancePath(frontier, newTokens, offset)
}
c.activePath[len(c.activePath)-1].lastUsed = time.Now()
}
}
// enforceEvictionPolicy evicts eligible nodes until paged-out memory is within limits.
func (c *kvCache) enforceEvictionPolicy() {
if c.pagedOutBytes <= maxPagedOutBytes {
return
}
activeSet := make(map[*trieNode]bool, len(c.activePath))
for _, n := range c.activePath {
activeSet[n] = true
}
for c.pagedOutBytes > maxPagedOutBytes {
var best *trieNode
walkNodes(c.root, func(n *trieNode) bool {
if n == c.root || activeSet[n] || len(n.children) > 1 {
return true
}
// Evict: oldest, then deepest, then largest.
if best == nil || cmp.Or(
n.lastUsed.Compare(best.lastUsed),
cmp.Compare(best.endOffset, n.endOffset),
cmp.Compare(best.snapshotBytes(), n.snapshotBytes()),
) < 0 {
best = n
}
return true
})
if best == nil {
break
}
c.evictNode(best)
}
}
// evictNode evicts a single node from the trie, freeing its snapshot memory.
func (c *kvCache) evictNode(node *trieNode) {
if len(node.children) == 0 {
// Leaf: remove entirely.
slog.Debug("evicting leaf", "offset", node.startOffset(), "tokens", len(node.tokens), "freed", mlx.PrettyBytes(int(node.snapshotBytes())))
removeNode(node, &c.pagedOutBytes)
} else if len(node.children) == 1 {
// Interior node with one child: merge with child.
before := c.pagedOutBytes
tokens := len(node.tokens)
mergeWithChild(node, c.caches, &c.pagedOutBytes)
slog.Debug("evicting interior node", "offset", node.startOffset(), "tokens", tokens, "freed", mlx.PrettyBytes(int(before-c.pagedOutBytes)))
} else {
panic("evictNode called on multi-child branch point")
}
}
func (c *kvCache) dumpTree() {
// Summary stats
var cacheBytes int
for _, kv := range c.caches {
if kv == nil {
continue
}
for _, a := range kv.State() {
if a != nil {
cacheBytes += a.NumBytes()
}
}
}
// Build active path set for marking.
active := make(map[*trieNode]bool, len(c.activePath))
for _, n := range c.activePath {
active[n] = true
}
var nodeCount, snapshotCount int
var pagedBytes int64
var lines []string
var dump func(n *trieNode, prefix string, isLast bool)
dump = func(n *trieNode, prefix string, isLast bool) {
if n == nil {
return
}
nodeCount++
// Build connector
var connector string
if n.parent == nil {
connector = ""
} else if isLast {
connector = prefix + "`-- "
} else {
connector = prefix + "|-- "
}
// Node label
nodeBytes := n.snapshotBytes()
pagedBytes += nodeBytes
label := fmt.Sprintf("[%d,%d) %dt", n.startOffset(), n.endOffset, len(n.tokens))
if nodeBytes > 0 {
label += " " + mlx.PrettyBytes(int(nodeBytes)).String()
}
if !n.lastUsed.IsZero() {
label += fmt.Sprintf(" %s ago", time.Since(n.lastUsed).Truncate(time.Millisecond))
}
var flags []string
if n.user {
flags = append(flags, "user")
}
if n.hasAllSnapshots() {
snapshotCount++
flags = append(flags, "snap")
}
if active[n] {
flags = append(flags, "active")
}
if len(flags) > 0 {
label += " (" + flags[0]
for _, f := range flags[1:] {
label += ", " + f
}
label += ")"
}
lines = append(lines, connector+label)
// Recurse children
childPrefix := prefix
if n.parent != nil {
if isLast {
childPrefix += " "
} else {
childPrefix += "| "
}
}
for i, child := range n.children {
dump(child, childPrefix, i == len(n.children)-1)
}
}
dump(c.root, "", true)
offset := c.minCacheOffset()
logutil.Trace(fmt.Sprintf("kv cache active_tokens: %d, active_size: %s, paged_out: %s, trie: nodes=%d, snapshots=%d",
offset, mlx.PrettyBytes(cacheBytes), mlx.PrettyBytes(int(pagedBytes)), nodeCount, snapshotCount))
for i, l := range lines {
if i == 0 {
logutil.Trace("cache trie: " + l)
} else {
logutil.Trace(" " + l)
}
}
}

846
x/mlxrunner/cache/cache.go vendored Normal file
View File

@@ -0,0 +1,846 @@
package cache
import (
"fmt"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/x/mlxrunner/batch"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/models/nn"
)
// Cache is common state management shared by every cache kind. Writers
// live on the specific caches
type Cache interface {
// State returns the cache-owned state roots that should be kept/evaluated.
State() []*mlx.Array
Free()
Offset() int
// Snapshot copies cache state from fromOffset to current offset into
// pinned VRAM arrays. The active cache is unchanged.
Snapshot(fromOffset int) Snapshot
// Restore brings the cache to target. If snapshot is nil, rewinds
// using the cache's own live state. Returns false if the target is
// unreachable (e.g. target > current offset, or negative).
Restore(snapshot Snapshot, target int) bool
// Merge combines two sequential snapshots [a,b) and [b,c) into [a,c).
// Takes ownership of both inputs.
Merge(parent, child Snapshot) Snapshot
// Split divides a snapshot [a,c) at offset b into [a,b) and [b,c).
// Takes ownership of the input. Cache types that cannot split
// (e.g. recurrent) return (nil, snapshot).
Split(snapshot Snapshot, at int) (parent, child Snapshot)
}
// Snapshot is paged-out cache state that can be restored later.
type Snapshot interface {
// Size returns the byte size of the paged-out data (in VRAM).
Size() int
// Close unpins the snapshot's arrays so they can be freed by Sweep.
Close()
}
// Attention is the contract for caches that back attention layers
// (KVCache, RotatingKVCache).
type Attention interface {
Cache
// Update appends (k, v) and returns an opaque nn.KVHistory for
// this layer's SDPA.
Update(b *batch.Batch, keys, values *mlx.Array) *nn.KVHistory
}
// Viewer exposes a read-only attention history for a cache.
type Viewer interface {
View(b *batch.Batch) *nn.KVHistory
}
type speculativeCommitter interface {
Cache
commit(n int)
}
// Speculation is an isolated cache transaction for speculative target
// validation. Updates record generated K/V without mutating the live caches;
// Commit appends only the accepted prefix to the live caches.
type Speculation struct {
layers []speculativeCommitter
}
// BeginSpeculation returns cache wrappers suitable for a speculative target
// forward. The returned caches must only be used for that forward.
func BeginSpeculation(caches []Cache) ([]Cache, *Speculation, bool) {
specCaches := make([]Cache, len(caches))
layers := make([]speculativeCommitter, len(caches))
for i, c := range caches {
switch c := c.(type) {
case nil:
case *RotatingKVCache:
sc := newSpeculativeRotatingKVCache(c)
specCaches[i] = sc
layers[i] = sc
case *KVCache:
sc := newSpeculativeKVCache(c)
specCaches[i] = sc
layers[i] = sc
case *RecurrentCache:
sc := newSpeculativeRecurrentCache(c)
specCaches[i] = sc
layers[i] = sc
default:
return nil, nil, false
}
}
return specCaches, &Speculation{layers: layers}, true
}
// BeginIsolatedSpeculation returns cache wrappers that never mutate live cache
// state. It is intended for correctness instrumentation, not the hot path.
func BeginIsolatedSpeculation(caches []Cache) ([]Cache, bool) {
specCaches := make([]Cache, len(caches))
for i, c := range caches {
switch c := c.(type) {
case nil:
case *RotatingKVCache:
specCaches[i] = newSpeculativeRotatingKVCache(c)
case *KVCache:
specCaches[i] = newIsolatedKVCache(c)
case *RecurrentCache:
specCaches[i] = newSpeculativeRecurrentCache(c)
default:
return nil, false
}
}
return specCaches, true
}
// Commit appends the accepted prefix from the speculative forward to the live
// caches. The target bonus token is intentionally not committed.
func (s *Speculation) Commit(n int) {
if s == nil {
return
}
for _, layer := range s.layers {
if layer != nil {
layer.commit(n)
}
}
}
type KVCache struct {
keys, values *mlx.Array
offset int
step int
}
func NewKVCache() *KVCache {
return &KVCache{step: 256}
}
// Assumes B = 1; heterogeneous batches are not supported.
func (c *KVCache) Update(_ *batch.Batch, keys, values *mlx.Array) *nn.KVHistory {
newK, newV := c.appendKV(keys, values)
return nn.NewKVHistory(newK, newV, nil)
}
// View returns the current cache contents as attention history without writing.
func (c *KVCache) View(_ *batch.Batch) *nn.KVHistory {
state := c.State()
if len(state) < 2 {
return nil
}
return nn.NewKVHistory(state[0], state[1], nil)
}
// appendKV is the raw write path shared by Update and Restore.
func (c *KVCache) appendKV(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
B, H, L, Dk, Dv := keys.Dim(0), keys.Dim(1), keys.Dim(2), keys.Dim(3), values.Dim(3)
prev := c.offset
// Grow buffer if needed
if c.keys == nil || (prev+L) > c.keys.Dim(2) {
steps := (c.step + L - 1) / c.step
newKeys := mlx.Zeros(keys.DType(), B, H, steps*c.step, Dk)
newValues := mlx.Zeros(values.DType(), B, H, steps*c.step, Dv)
if c.keys != nil {
if prev%c.step != 0 {
c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, prev), mlx.Slice()))
c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, prev), mlx.Slice()))
}
c.keys.Set(c.keys.Concatenate(2, newKeys))
c.values.Set(c.values.Concatenate(2, newValues))
} else {
c.keys, c.values = newKeys, newValues
mlx.Pin(c.keys, c.values)
}
}
c.offset += L
c.keys.Set(c.keys.SliceUpdate(keys, mlx.Slice(), mlx.Slice(), mlx.Slice(prev, c.offset), mlx.Slice()))
c.values.Set(c.values.SliceUpdate(values, mlx.Slice(), mlx.Slice(), mlx.Slice(prev, c.offset), mlx.Slice()))
return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())
}
func (c *KVCache) State() []*mlx.Array {
if c.keys == nil || c.values == nil {
return nil
}
return []*mlx.Array{
c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
}
}
// kvSnapshot holds paged-out KV data for a range [fromOffset, toOffset).
type kvSnapshot struct {
keys, values *mlx.Array
fromOffset, toOffset int
}
func (s *kvSnapshot) Size() int { return s.keys.NumBytes() + s.values.NumBytes() }
func (s *kvSnapshot) Close() { mlx.Unpin(s.keys, s.values) }
func (c *KVCache) Snapshot(fromOffset int) Snapshot {
if c.keys == nil || c.offset <= fromOffset {
return nil
}
from := max(0, fromOffset)
to := c.offset
kSlice := c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(from, to), mlx.Slice())
vSlice := c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(from, to), mlx.Slice())
kCopy := mlx.Contiguous(kSlice, false)
vCopy := mlx.Contiguous(vSlice, false)
mlx.Pin(kCopy, vCopy)
mlx.AsyncEval(kCopy, vCopy)
return &kvSnapshot{
keys: kCopy,
values: vCopy,
fromOffset: from,
toOffset: to,
}
}
func (c *KVCache) Restore(snapshot Snapshot, target int) bool {
if target < 0 {
return false
}
if snapshot == nil {
if target > c.offset {
return false
}
c.offset = target
return true
}
snap := snapshot.(*kvSnapshot)
if target > snap.toOffset || c.offset < snap.fromOffset {
return false
}
// Rewind to snapshot start, then feed snapshot.
c.offset = snap.fromOffset
c.appendKV(snap.keys, snap.values)
// Clamp to target if needed (target may be less than full snapshot).
if target < c.offset {
c.offset = target
}
return true
}
func (c *KVCache) Merge(parent, child Snapshot) Snapshot {
if parent == nil || child == nil {
if parent != nil {
parent.Close()
}
if child != nil {
child.Close()
}
return nil
}
p := parent.(*kvSnapshot)
ch := child.(*kvSnapshot)
mk := p.keys.Concatenate(2, ch.keys)
mv := p.values.Concatenate(2, ch.values)
mlx.Pin(mk, mv)
mlx.AsyncEval(mk, mv)
p.Close()
ch.Close()
return &kvSnapshot{
keys: mk,
values: mv,
fromOffset: p.fromOffset,
toOffset: ch.toOffset,
}
}
func (c *KVCache) Split(snapshot Snapshot, at int) (Snapshot, Snapshot) {
if snapshot == nil {
return nil, nil
}
snap := snapshot.(*kvSnapshot)
splitIdx := at - snap.fromOffset
seqLen := snap.toOffset - snap.fromOffset
if splitIdx <= 0 {
return nil, snapshot
}
if splitIdx >= seqLen {
return snapshot, nil
}
pk := mlx.Contiguous(snap.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, splitIdx), mlx.Slice()), false)
pv := mlx.Contiguous(snap.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, splitIdx), mlx.Slice()), false)
ck := mlx.Contiguous(snap.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(splitIdx, seqLen), mlx.Slice()), false)
cv := mlx.Contiguous(snap.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(splitIdx, seqLen), mlx.Slice()), false)
mlx.Pin(pk, pv, ck, cv)
mlx.AsyncEval(pk, pv, ck, cv)
snap.Close()
p := &kvSnapshot{
keys: pk,
values: pv,
fromOffset: snap.fromOffset,
toOffset: at,
}
ch := &kvSnapshot{
keys: ck,
values: cv,
fromOffset: at,
toOffset: snap.toOffset,
}
return p, ch
}
func (c *KVCache) Free() {
mlx.Unpin(c.keys, c.values)
c.keys, c.values = nil, nil
c.offset = 0
}
func (c *KVCache) Offset() int { return c.offset }
type speculativeBase struct {
offset int
}
func (s *speculativeBase) Free() {}
func (s *speculativeBase) Offset() int { return s.offset }
func (s *speculativeBase) Snapshot(int) Snapshot { return nil }
func (s *speculativeBase) Restore(Snapshot, int) bool { return false }
func (s *speculativeBase) Merge(parent, child Snapshot) Snapshot { return nil }
func (s *speculativeBase) Split(snapshot Snapshot, at int) (Snapshot, Snapshot) {
return nil, snapshot
}
type speculativeKVCache struct {
speculativeBase
target *KVCache
start int
end int
}
func newSpeculativeKVCache(target *KVCache) *speculativeKVCache {
return &speculativeKVCache{
speculativeBase: speculativeBase{offset: target.Offset()},
target: target,
start: target.Offset(),
end: target.Offset(),
}
}
func (c *speculativeKVCache) Update(b *batch.Batch, keys, values *mlx.Array) *nn.KVHistory {
history := c.target.Update(b, keys, values)
c.offset = c.target.Offset()
c.end = c.target.Offset()
return history
}
func (c *speculativeKVCache) State() []*mlx.Array {
return c.target.State()
}
func (c *speculativeKVCache) commit(n int) {
target := max(c.start, c.start+n)
if target > c.end {
target = c.end
}
c.target.offset = target
c.offset = target
}
type isolatedKVCache struct {
speculativeBase
target *KVCache
keys, values *mlx.Array
}
func newIsolatedKVCache(target *KVCache) *isolatedKVCache {
return &isolatedKVCache{
speculativeBase: speculativeBase{offset: target.Offset()},
target: target,
}
}
func (c *isolatedKVCache) Update(_ *batch.Batch, keys, values *mlx.Array) *nn.KVHistory {
c.keys = concatKV(c.keys, keys)
c.values = concatKV(c.values, values)
c.offset += keys.Dim(2)
state := c.target.State()
if len(state) < 2 {
return nn.NewKVHistory(c.keys, c.values, nil)
}
return nn.NewKVHistory(state[0].Concatenate(2, c.keys), state[1].Concatenate(2, c.values), nil)
}
func (c *isolatedKVCache) State() []*mlx.Array {
if c.keys == nil || c.values == nil {
return c.target.State()
}
state := c.target.State()
if len(state) < 2 {
return []*mlx.Array{c.keys, c.values}
}
return []*mlx.Array{
state[0].Concatenate(2, c.keys),
state[1].Concatenate(2, c.values),
}
}
// RotatingKVCache implements sliding window attention with bounded memory
type RotatingKVCache struct {
maxSize int
idx int
*KVCache
}
func NewRotatingKVCache(maxSize int) *RotatingKVCache {
return &RotatingKVCache{maxSize: maxSize, KVCache: NewKVCache()}
}
// Assumes B = 1; heterogeneous batches are not supported.
func (c *RotatingKVCache) Update(b *batch.Batch, keys, values *mlx.Array) *nn.KVHistory {
newK, newV := c.appendKV(keys, values)
return nn.NewKVHistory(newK, newV, rotatingApplier{
b: b,
K: newK.Dim(2),
L: keys.Dim(2),
window: c.maxSize,
ringIdx: c.idx,
dtype: keys.DType(),
})
}
// View returns the current rotating cache contents in logical order for
// assistant KV sharing.
func (c *RotatingKVCache) View(_ *batch.Batch) *nn.KVHistory {
k, v := c.logicalTail(c.maxSize - 1)
if k == nil || v == nil {
return nil
}
return nn.NewKVHistory(k, v, nil)
}
func (c *RotatingKVCache) logicalTail(keep int) (*mlx.Array, *mlx.Array) {
state := c.State()
if len(state) < 2 || keep <= 0 {
return nil, nil
}
keys, values := state[0], state[1]
K := keys.Dim(2)
if K == 0 {
return nil, nil
}
keep = min(keep, K)
if K > c.maxSize || c.offset < c.maxSize {
start := K - keep
return keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(start, K), mlx.Slice()),
values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(start, K), mlx.Slice())
}
oldest := c.idx % K
var logicalK, logicalV *mlx.Array
if oldest == 0 {
logicalK, logicalV = keys, values
} else {
tailK := keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(oldest, K), mlx.Slice())
tailV := values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(oldest, K), mlx.Slice())
headK := keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, oldest), mlx.Slice())
headV := values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, oldest), mlx.Slice())
logicalK = tailK.Concatenate(2, headK)
logicalV = tailV.Concatenate(2, headV)
}
start := K - keep
return logicalK.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(start, K), mlx.Slice()),
logicalV.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(start, K), mlx.Slice())
}
type speculativeRotatingKVCache struct {
speculativeBase
target *RotatingKVCache
keys, values *mlx.Array
}
func newSpeculativeRotatingKVCache(target *RotatingKVCache) *speculativeRotatingKVCache {
return &speculativeRotatingKVCache{
speculativeBase: speculativeBase{offset: target.Offset()},
target: target,
}
}
func (c *speculativeRotatingKVCache) Update(b *batch.Batch, keys, values *mlx.Array) *nn.KVHistory {
c.keys = concatKV(c.keys, keys)
c.values = concatKV(c.values, values)
c.offset += keys.Dim(2)
oldK, oldV := c.target.logicalTail(c.target.maxSize - 1)
histK, histV := c.keys, c.values
if oldK != nil && oldV != nil {
histK = oldK.Concatenate(2, c.keys)
histV = oldV.Concatenate(2, c.values)
}
return nn.NewKVHistory(histK, histV, logicalSlidingApplier{
b: b,
K: histK.Dim(2),
window: c.target.maxSize,
dtype: keys.DType(),
})
}
func (c *speculativeRotatingKVCache) State() []*mlx.Array {
if c.keys == nil || c.values == nil {
return c.target.State()
}
oldK, oldV := c.target.logicalTail(c.target.maxSize - 1)
if oldK == nil || oldV == nil {
return []*mlx.Array{c.keys, c.values}
}
return []*mlx.Array{oldK.Concatenate(2, c.keys), oldV.Concatenate(2, c.values)}
}
func (c *speculativeRotatingKVCache) commit(n int) {
if c.keys == nil || c.values == nil || n <= 0 {
return
}
n = min(n, c.keys.Dim(2))
c.target.appendKV(prefixKV(c.keys, n), prefixKV(c.values, n))
}
type logicalSlidingApplier struct {
b *batch.Batch
K int
window int
dtype mlx.DType
}
func (a logicalSlidingApplier) ApplyMask(logical nn.AttentionMask) nn.AttentionMask {
return logical.Intersect(nn.SlidingWindowMask(a.b, a.K, a.window, a.dtype))
}
func concatKV(prev, next *mlx.Array) *mlx.Array {
if prev == nil {
return next
}
return prev.Concatenate(2, next)
}
func prefixKV(a *mlx.Array, n int) *mlx.Array {
return a.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, n), mlx.Slice())
}
// appendKV is the raw write path shared by Update and Restore —
// routes to concat for prefill (L > 1) and update for decode.
func (c *RotatingKVCache) appendKV(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
if keys.Dim(2) > 1 {
return c.concat(keys, values)
}
return c.update(keys, values)
}
func (c *RotatingKVCache) concat(keys, values *mlx.Array) (newK *mlx.Array, newV *mlx.Array) {
logutil.Trace("(*RotatingKVCache).concat", "keys_dim", keys.Dims(), "values_dim", values.Dims(), "offset", c.offset, "idx", c.idx, "max_size", c.maxSize)
if c.keys == nil {
c.keys, c.values = keys.Clone(), values.Clone()
mlx.Pin(c.keys, c.values)
} else {
if c.idx < c.keys.Dim(2) {
if c.offset <= c.maxSize {
// Not yet wrapped: slots [c.idx, Dim) are grow padding
// or stale post-rewind data, not live window content.
c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice()))
c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice()))
} else {
// Wrapped: logical order is slots[idx..Dim) then slots[0..idx).
// Linearize so the trim + concat below operate on contiguous
// positions and preserve the last (maxSize - 1) old tokens.
tailK := c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(c.idx, c.keys.Dim(2)), mlx.Slice())
tailV := c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(c.idx, c.values.Dim(2)), mlx.Slice())
headK := c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice())
headV := c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice())
c.keys.Set(tailK.Concatenate(2, headK))
c.values.Set(tailV.Concatenate(2, headV))
c.idx = c.keys.Dim(2)
}
}
// Trim to max_size to maintain sliding window
if trim := c.idx - c.maxSize + 1; trim > 0 {
c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(trim, c.keys.Dim(2)), mlx.Slice()))
c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(trim, c.values.Dim(2)), mlx.Slice()))
}
c.keys.Set(c.keys.Concatenate(2, keys))
c.values.Set(c.values.Concatenate(2, values))
c.idx = c.keys.Dim(2)
}
c.offset += keys.Dim(2)
c.idx = c.keys.Dim(2)
return c.keys, c.values
}
func (c *RotatingKVCache) update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
logutil.Trace("(*RotatingKVCache).update", "keys_dim", keys.Dims(), "values_dim", values.Dims(), "offset", c.offset, "idx", c.idx, "max_size", c.maxSize)
B, H, L, Dk, Dv := keys.Dim(0), keys.Dim(1), keys.Dim(2), keys.Dim(3), values.Dim(3)
prev := c.offset
// Grow buffer if not yet at max
if c.keys == nil || (prev >= c.keys.Dim(2) && c.keys.Dim(2) < c.maxSize) {
newSize := min(c.step, c.maxSize-prev)
newKeys := mlx.Zeros(keys.DType(), B, H, newSize, Dk)
newValues := mlx.Zeros(values.DType(), B, H, newSize, Dv)
if c.keys != nil {
c.keys.Set(c.keys.Concatenate(2, newKeys))
c.values.Set(c.values.Concatenate(2, newValues))
} else {
c.keys, c.values = newKeys, newValues
mlx.Pin(c.keys, c.values)
}
c.idx = prev
}
// Trim to max_size to maintain sliding window
if trim := c.keys.Dim(2) - c.maxSize; trim > 0 {
c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(trim, c.keys.Dim(2)), mlx.Slice()))
c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(trim, c.values.Dim(2)), mlx.Slice()))
c.idx = c.maxSize
}
// Rotate when hitting max
if c.idx >= c.maxSize {
c.idx = 0
}
c.keys.Set(c.keys.SliceUpdate(keys, mlx.Slice(), mlx.Slice(), mlx.Slice(c.idx, c.idx+L), mlx.Slice()))
c.values.Set(c.values.SliceUpdate(values, mlx.Slice(), mlx.Slice(), mlx.Slice(c.idx, c.idx+L), mlx.Slice()))
c.offset += L
c.idx += L
validLen := min(c.offset, c.maxSize)
return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, validLen), mlx.Slice()),
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, validLen), mlx.Slice())
}
func (c *RotatingKVCache) State() []*mlx.Array {
if c.keys == nil || c.values == nil {
return nil
}
liveLen := min(c.offset, c.keys.Dim(2))
return []*mlx.Array{
c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, liveLen), mlx.Slice()),
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, liveLen), mlx.Slice()),
}
}
// rotatingSnapshot holds paged-out data for a RotatingKVCache.
type rotatingSnapshot struct {
kvSnapshot // embedded KV data
idx int // buffer write position at snapshot time
}
func (s *rotatingSnapshot) Size() int { return s.kvSnapshot.Size() }
func (s *rotatingSnapshot) Close() { s.kvSnapshot.Close() }
func (c *RotatingKVCache) Snapshot(fromOffset int) Snapshot {
if c.keys == nil || c.offset <= fromOffset {
return nil
}
state := c.State()
k := state[0].Clone()
v := state[1].Clone()
mlx.Pin(k, v)
return &rotatingSnapshot{
kvSnapshot: kvSnapshot{
keys: k,
values: v,
fromOffset: fromOffset,
toOffset: c.offset,
},
idx: c.idx,
}
}
func (c *RotatingKVCache) Restore(snapshot Snapshot, target int) bool {
if target < 0 {
return false
}
if snapshot == nil {
if target >= c.offset {
return target == c.offset
}
// Live rewind is only safe when the buffer hasn't filled yet
// (offset <= maxSize). Once the window has shifted, rewinding
// leaves fewer than maxSize trailing tokens to attend to —
// a snapshot is required to restore the full window.
if c.offset > c.maxSize {
return false
}
c.offset = target
c.idx = target
return true
}
snap := snapshot.(*rotatingSnapshot)
if target > snap.toOffset {
return false
}
// Reject if clamping would leave an incomplete window.
if target < snap.toOffset && snap.toOffset > c.maxSize {
return false
}
// Restore from snapshot: rebuild buffer state.
// Free existing state first.
if c.keys != nil {
mlx.Unpin(c.keys, c.values)
}
c.keys = snap.keys.Clone()
c.values = snap.values.Clone()
mlx.Pin(c.keys, c.values)
c.offset = snap.toOffset
c.idx = snap.idx
// Clamp to target if needed.
if target < c.offset {
c.offset = target
c.idx = target
}
return true
}
func (c *RotatingKVCache) Merge(parent, child Snapshot) Snapshot {
// For rotating caches, the child snapshot supersedes the parent
// since it contains the full window state.
if parent != nil {
parent.Close()
}
return child
}
func (c *RotatingKVCache) Split(snapshot Snapshot, at int) (Snapshot, Snapshot) {
// Rotating cache snapshots contain the full window state.
// Cannot cleanly split a ring buffer at an arbitrary point.
return nil, snapshot
}
func (c *RotatingKVCache) Free() {
c.KVCache.Free()
c.idx = 0
}
// rotatingApplier composes the sliding-window storage restriction
// onto the caller's logical mask.
//
// ringIdx is the cache's write cursor at Update time. At L=1 decode
// the ring buffer is not position-ordered — logical col j lives at
// storage slot (ringIdx+j) mod K — so tensor masks built in
// logical space must be gathered into this layout before the kernel
// sees them. At L>1 prefill the concat path has already linearised
// storage, so the gather is identity and ringIdx is unused.
type rotatingApplier struct {
b *batch.Batch
K int
L int
window int
ringIdx int
dtype mlx.DType
}
func (r rotatingApplier) ApplyMask(logical nn.AttentionMask) nn.AttentionMask {
if r.L == 1 {
// Single-query decode: storage already enforces the window
// (Update keeps the last maxSize tokens, all within
// [absQ-window+1, absQ]), and every stored key's absolute
// position <= absQ. For a zero or plain-causal logical mask
// both constraints reduce to "no mask", so return the zero
// mask and let SDPA dispatch to mode="".
if logical.IsZero() || logical.IsCausal() {
return nn.AttentionMask{}
}
// Tensor-backed mask (user ArrayMask, causal+Relax, causal
// with accumulated array): materialize in logical-position
// order then gather K cols into ring-slot order so they
// align with the cache output the kernel will index.
arr := logical.AsArray(r.b, r.K, r.dtype)
arr = gatherRingCols(arr, r.ringIdx, r.K)
return nn.ArrayMask(arr)
}
return logical.Intersect(nn.SlidingWindowMask(r.b, r.K, r.window, r.dtype))
}
// gatherRingCols reorders a [B, 1, L, K] mask's K axis from
// logical-position order (col 0 = oldest stored position) into the
// cache's ring-slot order (col 0 = buffer slot 0). Logical col j
// lives at slot (ringIdx+j) mod K, so storage slot s reads from
// logical col (s-ringIdx+K) mod K. Returns arr unchanged when the
// permutation is a no-op: ringIdx % K == 0 (layouts coincide), or
// the K axis broadcasts (dim 3 == 1, i.e. Q-padding-shaped masks
// where every key shares the same value).
func gatherRingCols(arr *mlx.Array, ringIdx, K int) *mlx.Array {
if w := arr.Dim(3); w != 1 && w != K {
panic(fmt.Sprintf("gatherRingCols: K-axis width %d must be 1 or %d", w, K))
}
ringIdx %= K
if ringIdx == 0 || arr.Dim(3) == 1 {
return arr
}
shift := K - ringIdx
tail := arr.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(), mlx.Slice(shift, K))
head := arr.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(), mlx.Slice(0, shift))
return tail.Concatenate(3, head)
}

283
x/mlxrunner/cache/cache_test.go vendored Normal file
View File

@@ -0,0 +1,283 @@
package cache
import (
"testing"
"github.com/ollama/ollama/x/mlxrunner/batch"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
func skipIfNoMLX(t *testing.T) {
t.Helper()
if err := mlx.CheckInit(); err != nil {
t.Skipf("MLX not available: %v", err)
}
}
// newKVBatch builds a B=1 batch at SeqOffsets=off with all-real
// queries (SeqQueryLens=L) — the standard single-sequence cache
// test shape.
func newKVBatch(off, L int) *batch.Batch {
return &batch.Batch{
InputIDs: mlx.Zeros(mlx.DTypeInt32, 1, L),
SeqOffsets: []int32{int32(off)},
SeqQueryLens: []int32{int32(L)},
}
}
func TestKVCacheSnapshotRestoreNeedBase(t *testing.T) {
skipIfNoMLX(t)
c := NewKVCache()
for range 10 {
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
c.Update(newKVBatch(c.Offset(), k.Dim(2)), k, v)
}
// Snapshot [5, 10).
snap := c.Snapshot(5)
// Free the cache completely — offset is now 0.
c.Free()
// Restore should fail because cache doesn't have data up to fromOffset=5.
if c.Restore(snap, 10) {
t.Fatal("expected Restore to fail with no base data")
}
}
// TestKVCacheDataSurvivesSnapshotRestore verifies that actual array data
// is preserved through a snapshot→free→restore cycle.
func TestKVCacheDataSurvivesSnapshotRestore(t *testing.T) {
skipIfNoMLX(t)
c := NewKVCache()
for range 10 {
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
c.Update(newKVBatch(c.Offset(), k.Dim(2)), k, v)
}
snap := c.Snapshot(0)
if snap == nil {
t.Fatal("Snapshot returned nil")
}
// Free and restore to a fresh cache.
c2 := NewKVCache()
if !c2.Restore(snap, 10) {
t.Fatal("Restore failed")
}
if c2.Offset() != 10 {
t.Fatalf("offset = %d, want 10", c2.Offset())
}
// Verify State() returns arrays with correct sequence dimension.
state := c2.State()
if len(state) != 2 {
t.Fatalf("State() returned %d arrays, want 2", len(state))
}
// keys shape: [B, H, seqLen, Dk]
if state[0].Dim(2) != 10 {
t.Fatalf("keys seq dim = %d, want 10", state[0].Dim(2))
}
if state[1].Dim(2) != 10 {
t.Fatalf("values seq dim = %d, want 10", state[1].Dim(2))
}
}
// TestKVCacheSplitPreservesData verifies that split produces two snapshots
// that can be sequentially restored to rebuild the original cache state.
func TestKVCacheSplitPreservesData(t *testing.T) {
skipIfNoMLX(t)
c := NewKVCache()
for range 10 {
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
c.Update(newKVBatch(c.Offset(), k.Dim(2)), k, v)
}
snap := c.Snapshot(0)
parent, child := c.Split(snap, 5)
if parent == nil || child == nil {
t.Fatal("Split returned nil")
}
// Restore parent → offset=5, seq dim=5.
c2 := NewKVCache()
if !c2.Restore(parent, 5) {
t.Fatal("Restore(parent) failed")
}
if c2.Offset() != 5 {
t.Fatalf("offset after parent = %d, want 5", c2.Offset())
}
state := c2.State()
if state[0].Dim(2) != 5 {
t.Fatalf("keys seq dim after parent = %d, want 5", state[0].Dim(2))
}
// Restore child on top → offset=10, seq dim=10.
if !c2.Restore(child, 10) {
t.Fatal("Restore(child) failed")
}
if c2.Offset() != 10 {
t.Fatalf("offset after child = %d, want 10", c2.Offset())
}
state = c2.State()
if state[0].Dim(2) != 10 {
t.Fatalf("keys seq dim after child = %d, want 10", state[0].Dim(2))
}
}
// TestKVCacheSplitMergeRoundTripData verifies that splitting and merging back
// produces a snapshot equivalent to the original.
func TestKVCacheSplitMergeRoundTripData(t *testing.T) {
skipIfNoMLX(t)
c := NewKVCache()
for range 10 {
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
c.Update(newKVBatch(c.Offset(), k.Dim(2)), k, v)
}
snap := c.Snapshot(0)
parent, child := c.Split(snap, 6)
merged := c.Merge(parent, child)
if merged == nil {
t.Fatal("Merge returned nil")
}
c2 := NewKVCache()
if !c2.Restore(merged, 10) {
t.Fatal("Restore(merged) failed")
}
if c2.Offset() != 10 {
t.Fatalf("offset = %d, want 10", c2.Offset())
}
state := c2.State()
if state[0].Dim(2) != 10 {
t.Fatalf("keys seq dim = %d, want 10", state[0].Dim(2))
}
if state[1].Dim(2) != 10 {
t.Fatalf("values seq dim = %d, want 10", state[1].Dim(2))
}
}
func TestRotatingKVCacheRestoreOutsideWindow(t *testing.T) {
skipIfNoMLX(t)
c := NewRotatingKVCache(4)
// Feed 10 tokens (window size 4, so positions 0-5 are evicted).
for range 10 {
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
c.Update(newKVBatch(c.Offset(), k.Dim(2)), k, v)
}
// Offset 3 is outside the window.
if c.Restore(nil, 3) {
t.Fatal("Restore(nil, 3) should fail when outside window")
}
}
// TestRotatingKVCacheSnapshotPreservesWindow verifies that after restoring
// from a snapshot, the rotating cache has the correct window of data.
func TestRotatingKVCacheSnapshotPreservesWindow(t *testing.T) {
skipIfNoMLX(t)
c := NewRotatingKVCache(4)
// Feed 10 tokens one at a time. Window size 4, so only last 4 are kept.
for range 10 {
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
c.Update(newKVBatch(c.Offset(), k.Dim(2)), k, v)
}
snap := c.Snapshot(0)
if snap == nil {
t.Fatal("Snapshot returned nil")
}
// Feed 5 more tokens.
for range 5 {
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
c.Update(newKVBatch(c.Offset(), k.Dim(2)), k, v)
}
// Restore to offset 10.
if !c.Restore(snap, 10) {
t.Fatal("Restore failed")
}
if c.Offset() != 10 {
t.Fatalf("offset = %d, want 10", c.Offset())
}
state := c.State()
if len(state) != 2 {
t.Fatalf("State() returned %d arrays, want 2", len(state))
}
// Seq dim should be min(offset, maxSize) = min(10, 4) = 4.
seqDim := state[0].Dim(2)
if seqDim != 4 {
t.Fatalf("keys seq dim = %d, want 4 (window size)", seqDim)
}
}
// TestRotatingKVCacheRestoreFromSnapshot verifies that restoring from a
// snapshot correctly preserves the write position (idx), so subsequent
// single-token updates land in the right buffer slot.
func TestRotatingKVCacheRestoreFromSnapshot(t *testing.T) {
skipIfNoMLX(t)
c := NewRotatingKVCache(4)
// Fill the window: 6 tokens into a size-4 window.
// After this, idx has wrapped and the buffer has rotated.
for range 6 {
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
c.Update(newKVBatch(c.Offset(), k.Dim(2)), k, v)
}
if c.Offset() != 6 {
t.Fatalf("offset = %d, want 6", c.Offset())
}
snap := c.Snapshot(0)
// Mutate the cache further so live state diverges from snapshot.
for range 3 {
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
c.Update(newKVBatch(c.Offset(), k.Dim(2)), k, v)
}
// Restore to snapshot state.
if !c.Restore(snap, 6) {
t.Fatal("Restore failed")
}
if c.Offset() != 6 {
t.Fatalf("offset after restore = %d, want 6", c.Offset())
}
// Feed one more token. If idx was restored correctly, this should
// produce a valid window of size 4 at offset 7.
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
c.Update(newKVBatch(c.Offset(), k.Dim(2)), k, v)
if c.Offset() != 7 {
t.Fatalf("offset after post-restore update = %d, want 7", c.Offset())
}
state := c.State()
if len(state) != 2 {
t.Fatalf("State() returned %d arrays, want 2", len(state))
}
seqDim := state[0].Dim(2)
if seqDim != 4 {
t.Fatalf("keys seq dim = %d, want 4 (window size)", seqDim)
}
}

285
x/mlxrunner/cache/recurrent.go vendored Normal file
View File

@@ -0,0 +1,285 @@
package cache
import (
"github.com/ollama/ollama/x/mlxrunner/batch"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/models/nn"
)
// Recurrent is the contract for caches that back recurrent linear-attention layers.
type Recurrent interface {
Cache
Get(b *batch.Batch, dtype mlx.DType) *nn.RecurrentHistory
Put(b *batch.Batch, newConv, newDelta *mlx.Array)
}
// RecurrentRecorder records the per-token scan inputs needed to commit an
// accepted prefix after a speculative recurrent forward.
type RecurrentRecorder interface {
Record(qkv, q, k, v, gDecay, beta *mlx.Array)
}
// RecurrentCache stores state for linear-recurrent layers.
//
// Conv state shape: [B, convTail, convDim]
// Delta state shape: [B, numVHeads, headVDim, headKDim]
type RecurrentCache struct {
convState *mlx.Array
deltaState *mlx.Array
offset int
convTail int
convDim int
numVHeads int
headVDim int
headKDim int
}
func (c *RecurrentCache) setState(old, v *mlx.Array, contiguous bool) *mlx.Array {
if v == nil || !v.Valid() {
return old
}
if contiguous {
v = mlx.Contiguous(v, false)
}
v = v.Clone()
mlx.Pin(v)
mlx.Unpin(old)
return v
}
func NewRecurrentCache(convTail, convDim, numVHeads, headVDim, headKDim int32) *RecurrentCache {
return &RecurrentCache{
convTail: int(convTail),
convDim: int(convDim),
numVHeads: int(numVHeads),
headVDim: int(headVDim),
headKDim: int(headKDim),
}
}
func (c *RecurrentCache) ensure(batch int, dtype mlx.DType) {
if batch <= 0 {
batch = 1
}
// Keep the gated-delta recurrent state in float32 even when activations are
// bf16/fp16. The convolution tail stays in the activation dtype.
deltaDType := mlx.DTypeFloat32
needConv := c.convState == nil || !c.convState.Valid() || c.convState.DType() != dtype ||
c.convState.Dim(0) != batch || c.convState.Dim(1) != c.convTail || c.convState.Dim(2) != c.convDim
needDelta := c.deltaState == nil || !c.deltaState.Valid() || c.deltaState.DType() != deltaDType ||
c.deltaState.Dim(0) != batch || c.deltaState.Dim(1) != c.numVHeads || c.deltaState.Dim(2) != c.headVDim || c.deltaState.Dim(3) != c.headKDim
if !needConv && !needDelta {
return
}
if needConv {
c.convState = c.setState(c.convState, mlx.Zeros(dtype, batch, c.convTail, c.convDim), false)
}
if needDelta {
c.deltaState = c.setState(c.deltaState, mlx.Zeros(deltaDType, batch, c.numVHeads, c.headVDim, c.headKDim), false)
}
}
// Get returns the current conv/delta state for the SSM layer's read
// phase. Lazy-initializes zero-filled state tensors using b.InputIDs
// for the batch size; reallocates if the existing state's batch size
// or dtype no longer matches.
func (c *RecurrentCache) Get(b *batch.Batch, dtype mlx.DType) *nn.RecurrentHistory {
c.ensure(b.InputIDs.Dim(0), dtype)
return nn.NewRecurrentHistory(c.convState, c.deltaState)
}
// Put stores the post-computation conv/delta states for the SSM
// layer's write phase and advances the cache offset by the current
// forward's real token count.
//
// Assumes B = 1; heterogeneous batches are not supported.
func (c *RecurrentCache) Put(b *batch.Batch, newConv, newDelta *mlx.Array) {
c.convState = c.setState(c.convState, newConv, true)
c.deltaState = c.setState(c.deltaState, newDelta, false)
c.offset += int(b.SeqQueryLens[0])
}
func (c *RecurrentCache) State() []*mlx.Array {
return []*mlx.Array{c.convState, c.deltaState}
}
// recurrentSnapshot holds paged-out recurrent state. Self-contained —
// does not depend on any parent state.
type recurrentSnapshot struct {
convState, deltaState *mlx.Array
offset int
}
func (s *recurrentSnapshot) Size() int { return s.convState.NumBytes() + s.deltaState.NumBytes() }
func (s *recurrentSnapshot) Close() { mlx.Unpin(s.convState, s.deltaState) }
func (c *RecurrentCache) Snapshot(fromOffset int) Snapshot {
// Recurrent state is not position-sliceable — always snapshot the full state.
if c.convState == nil && c.deltaState == nil {
return nil
}
snap := &recurrentSnapshot{offset: c.offset}
snap.convState = c.convState.Clone()
snap.deltaState = c.deltaState.Clone()
mlx.Pin(snap.convState, snap.deltaState)
return snap
}
func (c *RecurrentCache) Restore(snapshot Snapshot, target int) bool {
if snapshot == nil {
// Recurrent state is cumulative and can't rewind. Only succeed
// if we're already at the target (no-op).
return target == c.offset
}
snap := snapshot.(*recurrentSnapshot)
// Recurrent snapshots encode cumulative state up to exactly
// snap.offset. Target must match — rewinding would leave stale
// state, and advancing isn't possible without feeding tokens.
if target != snap.offset {
return false
}
c.convState = c.setState(c.convState, snap.convState, false)
c.deltaState = c.setState(c.deltaState, snap.deltaState, false)
c.offset = snap.offset
return true
}
func (c *RecurrentCache) Merge(parent, child Snapshot) Snapshot {
// Recurrent snapshots are self-contained — child supersedes parent.
if parent != nil {
parent.Close()
}
return child
}
func (c *RecurrentCache) Split(snapshot Snapshot, at int) (Snapshot, Snapshot) {
// Recurrent state is cumulative and not position-sliceable.
// Cannot recover intermediate state at the split point.
return nil, snapshot
}
func (c *RecurrentCache) Free() {
mlx.Unpin(c.convState, c.deltaState)
c.convState, c.deltaState = nil, nil
c.offset = 0
}
func (c *RecurrentCache) Offset() int { return c.offset }
type speculativeRecurrentCache struct {
speculativeBase
target *RecurrentCache
start int
initialConv *mlx.Array
initialDelta *mlx.Array
qkv, q, k, v, gDecay, beta *mlx.Array
fullConv, fullDelta *mlx.Array
length int
}
func newSpeculativeRecurrentCache(target *RecurrentCache) *speculativeRecurrentCache {
return &speculativeRecurrentCache{
speculativeBase: speculativeBase{offset: target.Offset()},
target: target,
start: target.Offset(),
}
}
func (c *speculativeRecurrentCache) Get(b *batch.Batch, dtype mlx.DType) *nn.RecurrentHistory {
if c.fullConv != nil && c.fullDelta != nil {
return nn.NewRecurrentHistory(c.fullConv, c.fullDelta)
}
history := c.target.Get(b, dtype)
if c.initialConv == nil {
c.initialConv = history.ConvState()
}
if c.initialDelta == nil {
c.initialDelta = history.DeltaState()
}
return history
}
func (c *speculativeRecurrentCache) Record(qkv, q, k, v, gDecay, beta *mlx.Array) {
c.qkv, c.q, c.k, c.v, c.gDecay, c.beta = qkv, q, k, v, gDecay, beta
if qkv != nil {
c.length = qkv.Dim(1)
}
}
func (c *speculativeRecurrentCache) Put(b *batch.Batch, newConv, newDelta *mlx.Array) {
c.fullConv, c.fullDelta = newConv, newDelta
c.offset += int(b.SeqQueryLens[0])
}
func (c *speculativeRecurrentCache) State() []*mlx.Array {
if c.fullConv != nil && c.fullDelta != nil {
return []*mlx.Array{c.fullConv, c.fullDelta}
}
return c.target.State()
}
func (c *speculativeRecurrentCache) commit(n int) {
if n <= 0 {
return
}
if c.length > 0 && n > c.length {
n = c.length
}
if c.length > 0 && n == c.length && c.fullConv != nil && c.fullDelta != nil {
c.target.convState = c.target.setState(c.target.convState, c.fullConv, true)
c.target.deltaState = c.target.setState(c.target.deltaState, c.fullDelta, false)
c.target.offset = c.start + n
return
}
if c.initialConv == nil || c.initialDelta == nil || c.qkv == nil || c.q == nil || c.k == nil || c.v == nil || c.gDecay == nil || c.beta == nil {
return
}
qkv := sliceSeq(c.qkv, n)
convConcat := mlx.Concatenate([]*mlx.Array{c.initialConv, qkv}, 1)
total := convConcat.Dim(1)
nextConv := convConcat.Slice(mlx.Slice(), mlx.Slice(total-c.target.convTail, total), mlx.Slice())
_, delta := mlx.FastGatedDelta(
sliceSeq(c.q, n),
sliceSeq(c.k, n),
sliceSeq(c.v, n),
sliceSeq(c.gDecay, n),
sliceSeq(c.beta, n),
c.initialDelta,
nil,
)
c.target.convState = c.target.setState(c.target.convState, nextConv, true)
c.target.deltaState = c.target.setState(c.target.deltaState, delta, false)
c.target.offset = c.start + n
}
func sliceSeq(a *mlx.Array, n int) *mlx.Array {
switch a.NumDims() {
case 3:
return a.Slice(mlx.Slice(), mlx.Slice(0, n), mlx.Slice())
case 4:
return a.Slice(mlx.Slice(), mlx.Slice(0, n), mlx.Slice(), mlx.Slice())
default:
panic("recurrent speculative sequence tensor must be rank 3 or 4")
}
}

284
x/mlxrunner/cache/recurrent_test.go vendored Normal file
View File

@@ -0,0 +1,284 @@
package cache
import (
"math"
"testing"
"github.com/ollama/ollama/x/mlxrunner/batch"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/models/nn"
)
// TestRecurrentCacheRestoreExactOffset verifies that RecurrentCache restore
// only succeeds when target exactly matches the snapshot's offset. Recurrent
// state is cumulative, so it can't be rewound or fast-forwarded.
func TestRecurrentCacheRestoreExactOffset(t *testing.T) {
skipIfNoMLX(t)
c := NewRecurrentCache(3, 12, 4, 8, 8)
b1 := &batch.Batch{InputIDs: mlx.Zeros(mlx.DTypeInt32, 1, 1)}
c.Get(b1, mlx.DTypeFloat16) // lazy-init
b10 := &batch.Batch{InputIDs: mlx.Zeros(mlx.DTypeInt32, 1, 10), SeqQueryLens: []int32{10}}
c.Put(b10, nil, nil) // advance to 10
snap := c.Snapshot(0) // snap.offset == 10
b5 := &batch.Batch{InputIDs: mlx.Zeros(mlx.DTypeInt32, 1, 5), SeqQueryLens: []int32{5}}
c.Put(b5, nil, nil) // cache now at 15
// target < snap.offset: fails (can't rewind past snapshot)
if c.Restore(snap, 5) {
t.Fatal("Restore(snap, 5) should fail — target != snap.offset")
}
// target > snap.offset: fails (can't advance without feeding tokens)
if c.Restore(snap, 15) {
t.Fatal("Restore(snap, 15) should fail — target != snap.offset")
}
// target == snap.offset: succeeds
if !c.Restore(snap, 10) {
t.Fatal("Restore(snap, 10) should succeed — target == snap.offset")
}
if c.Offset() != 10 {
t.Fatalf("offset = %d, want 10", c.Offset())
}
}
func TestRecurrentCacheGetLazyInit(t *testing.T) {
skipIfNoMLX(t)
c := NewRecurrentCache(3, 4, 2, 4, 4)
b := &batch.Batch{
InputIDs: mlx.Zeros(mlx.DTypeInt32, 1, 1),
SeqOffsets: []int32{0},
SeqQueryLens: []int32{1},
}
h := c.Get(b, mlx.DTypeBFloat16)
if c.Offset() != 0 {
t.Fatalf("Get should not advance; got offset %d", c.Offset())
}
if h.ConvState() == nil || h.DeltaState() == nil {
t.Fatal("history should expose conv/delta tensors")
}
if got := h.ConvState().DType(); got != mlx.DTypeBFloat16 {
t.Fatalf("conv state dtype = %v, want %v", got, mlx.DTypeBFloat16)
}
if got := h.DeltaState().DType(); got != mlx.DTypeFloat32 {
t.Fatalf("delta state dtype = %v, want %v", got, mlx.DTypeFloat32)
}
}
func TestSpeculativeRecurrentCacheUsesStagedState(t *testing.T) {
skipIfNoMLX(t)
target := NewRecurrentCache(2, 3, 1, 2, 3)
caches, ok := BeginIsolatedSpeculation([]Cache{target})
if !ok {
t.Fatal("BeginIsolatedSpeculation failed")
}
c := caches[0].(*speculativeRecurrentCache)
b := &batch.Batch{
InputIDs: mlx.Zeros(mlx.DTypeInt32, 1, 1),
SeqOffsets: []int32{0},
SeqQueryLens: []int32{1},
}
c.Get(b, mlx.DTypeFloat32)
convVals := []float32{1, 2, 3, 4, 5, 6}
deltaVals := []float32{7, 8, 9, 10, 11, 12}
nextConv := mlx.FromValues(convVals, 1, 2, 3)
nextDelta := mlx.FromValues(deltaVals, 1, 1, 2, 3)
c.Put(b, nextConv, nextDelta)
h := c.Get(b, mlx.DTypeFloat32)
state := c.State()
if len(state) != 2 {
t.Fatalf("State() returned %d arrays, want 2", len(state))
}
assertArray := func(name string, got, want *mlx.Array) {
t.Helper()
if got != want {
t.Fatalf("%s = %p, want %p", name, got, want)
}
}
assertArray("history conv", h.ConvState(), nextConv)
assertArray("history delta", h.DeltaState(), nextDelta)
assertArray("state conv", state[0], nextConv)
assertArray("state delta", state[1], nextDelta)
if got := c.Offset(); got != 1 {
t.Fatalf("speculative offset = %d, want 1", got)
}
if got := target.Offset(); got != 0 {
t.Fatalf("target offset = %d, want 0", got)
}
}
// TestRecurrentCachePaddedRoundTrip runs Get → CausalConv1D →
// GatedDelta → Put on a B=1 batch with qLen<L, then again on a
// fresh cache with an unpadded length-qLen batch using the same
// real prefix. After the call, Offset() must equal qLen (not L),
// and the resulting cache state must match the unpadded equivalent.
// Pins the recurrent contract: a forward with padding produces the
// same end-state as a forward with the real-prefix-only input.
func TestRecurrentCachePaddedRoundTrip(t *testing.T) {
skipIfNoMLX(t)
const convTail, convDim = 2, 6
const numVHeads, headVDim, headKDim = 1, 4, 6
const L = 4
const qLen = 2
// Use distinct values for the real prefix and the padded tail so
// we can detect any leak from padded positions into the result.
makeQKV := func(seed float32, T int) (q, k, v *mlx.Array) {
mkLast := func(off float32, T, n, d int) *mlx.Array {
vals := make([]float32, 1*T*n*d)
for i := range vals {
vals[i] = off + 0.05*float32(i)
}
return mlx.FromValues(vals, 1, T, n, d)
}
q = mkLast(seed, T, 1, headKDim)
k = mkLast(seed+0.1, T, 1, headKDim)
v = mkLast(seed+0.2, T, numVHeads, headVDim)
return
}
makeGB := func(seed float32, T int) (g, beta *mlx.Array) {
gVals := make([]float32, 1*T*numVHeads)
bVals := make([]float32, 1*T*numVHeads)
for i := range gVals {
gVals[i] = seed + 0.01*float32(i)
bVals[i] = seed - 0.02*float32(i)
}
g = mlx.FromValues(gVals, 1, T, numVHeads)
beta = mlx.FromValues(bVals, 1, T, numVHeads)
return
}
makeQKVPadded := func() (q, k, v *mlx.Array) {
qReal, kReal, vReal := makeQKV(0.3, qLen)
// Distinct, large junk values in the padded tail to surface
// any leak (real outputs are O(1)).
qPad := mlx.AddScalar(mlx.Zeros(mlx.DTypeFloat32, 1, L-qLen, 1, headKDim), 99)
kPad := mlx.AddScalar(mlx.Zeros(mlx.DTypeFloat32, 1, L-qLen, 1, headKDim), 99)
vPad := mlx.AddScalar(mlx.Zeros(mlx.DTypeFloat32, 1, L-qLen, numVHeads, headVDim), 99)
q = mlx.Concatenate([]*mlx.Array{qReal, qPad}, 1)
k = mlx.Concatenate([]*mlx.Array{kReal, kPad}, 1)
v = mlx.Concatenate([]*mlx.Array{vReal, vPad}, 1)
return
}
makeGBPadded := func() (g, beta *mlx.Array) {
gReal, betaReal := makeGB(0.1, qLen)
gPad := mlx.AddScalar(mlx.Zeros(mlx.DTypeFloat32, 1, L-qLen, numVHeads), 99)
betaPad := mlx.AddScalar(mlx.Zeros(mlx.DTypeFloat32, 1, L-qLen, numVHeads), 99)
g = mlx.Concatenate([]*mlx.Array{gReal, gPad}, 1)
beta = mlx.Concatenate([]*mlx.Array{betaReal, betaPad}, 1)
return
}
// The conv input dimension must match the cache's convDim.
mkConvInput := func(seed float32, T int) *mlx.Array {
vals := make([]float32, 1*T*convDim)
for i := range vals {
vals[i] = seed + 0.05*float32(i)
}
return mlx.FromValues(vals, 1, T, convDim)
}
mkWeight := func(seed float32) *mlx.Array {
vals := make([]float32, convDim*(convTail+1))
for i := range vals {
vals[i] = seed + 0.1*float32(i)
}
return mlx.FromValues(vals, convDim, convTail+1)
}
weight := mkWeight(0.2)
runForward := func(c *RecurrentCache, b *batch.Batch, T int) (*mlx.Array, *mlx.Array) {
var convInput *mlx.Array
if T == L {
realPart := mkConvInput(0.4, qLen)
padPart := mlx.AddScalar(mlx.Zeros(mlx.DTypeFloat32, 1, T-qLen, convDim), 99)
convInput = mlx.Concatenate([]*mlx.Array{realPart, padPart}, 1)
} else {
convInput = mkConvInput(0.4, T)
}
history := c.Get(b, mlx.DTypeFloat32)
_, nextConv := nn.CausalConv1D(b, convInput, nil, weight, convTail,
nn.WithRecurrentHistory(history))
var q, k, v, g, beta *mlx.Array
if T == L {
q, k, v = makeQKVPadded()
g, beta = makeGBPadded()
} else {
q, k, v = makeQKV(0.3, T)
g, beta = makeGB(0.1, T)
}
_, newDelta := nn.GatedDelta(b, q, k, v, g, beta,
nn.WithRecurrentHistory(history))
c.Put(b, nextConv, newDelta)
return nextConv, newDelta
}
// Padded forward.
cPad := NewRecurrentCache(convTail, convDim, numVHeads, headVDim, headKDim)
bPad := &batch.Batch{
InputIDs: mlx.Zeros(mlx.DTypeInt32, 1, L),
SeqOffsets: []int32{0},
SeqQueryLens: []int32{int32(qLen)},
}
nextConvPad, deltaPad := runForward(cPad, bPad, L)
mlx.Eval(nextConvPad, deltaPad)
if got := cPad.Offset(); got != qLen {
t.Fatalf("padded forward: Offset() = %d, want %d (must advance by SeqQueryLens, not L)", got, qLen)
}
// Unpadded reference.
cRef := NewRecurrentCache(convTail, convDim, numVHeads, headVDim, headKDim)
bRef := &batch.Batch{
InputIDs: mlx.Zeros(mlx.DTypeInt32, 1, qLen),
SeqOffsets: []int32{0},
SeqQueryLens: []int32{int32(qLen)},
}
nextConvRef, deltaRef := runForward(cRef, bRef, qLen)
mlx.Eval(nextConvRef, deltaRef)
if got := cRef.Offset(); got != qLen {
t.Fatalf("unpadded forward: Offset() = %d, want %d", got, qLen)
}
gp := nextConvPad.Floats()
gr := nextConvRef.Floats()
if len(gp) != len(gr) {
t.Fatalf("nextConv shape mismatch: padded %d vs unpadded %d", len(gp), len(gr))
}
for i := range gp {
if math.Abs(float64(gp[i]-gr[i])) > 1e-4 {
t.Fatalf("nextConv[%d]: padded=%v unpadded=%v (padding leaked into conv state)", i, gp[i], gr[i])
}
}
dp := deltaPad.Floats()
dr := deltaRef.Floats()
if len(dp) != len(dr) {
t.Fatalf("delta state shape mismatch: padded %d vs unpadded %d", len(dp), len(dr))
}
for i := range dp {
if math.Abs(float64(dp[i]-dr[i])) > 1e-3 {
t.Fatalf("delta state[%d]: padded=%v unpadded=%v (padding leaked into recurrent state)", i, dp[i], dr[i])
}
}
}
func TestRecurrentCachePutAdvances(t *testing.T) {
skipIfNoMLX(t)
c := NewRecurrentCache(3, 4, 2, 4, 4)
b := &batch.Batch{InputIDs: mlx.Zeros(mlx.DTypeInt32, 1, 2), SeqQueryLens: []int32{2}}
newConv := mlx.Zeros(mlx.DTypeFloat16, 1, 3, 4)
newDelta := mlx.Zeros(mlx.DTypeFloat16, 1, 2, 4, 4)
c.Put(b, newConv, newDelta)
if c.Offset() != 2 {
t.Fatalf("cache offset not advanced: %d", c.Offset())
}
}

View File

@@ -0,0 +1,292 @@
package cache
import (
"math"
"testing"
"github.com/ollama/ollama/x/mlxrunner/batch"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/models/nn"
)
// TestRotatingKVCacheDecodeParity drives a rotating cache past its
// wrap point with single-token writes, runs an L=1 decode through
// SDPA, and compares against a reference computed from the same per-
// position K/V in logical-position order with the same caller mask.
//
// Attention is permutation-invariant when K, V, and the mask are
// permuted together, so the cache's storage-order output (with the
// applier's gather composing the caller's logical mask back into
// storage order) must equal the logical-order reference.
func TestRotatingKVCacheDecodeParity(t *testing.T) {
skipIfNoMLX(t)
const H, D = 1, 4
const window = 4
const totalWrites = 7 // past wrap (window=4); last write is the L=1 decode
const scale = 1.0
// Per-position k, v values. Use distinct seeds so the per-position
// values are clearly distinguishable.
perPosKV := func(pos int) (k, v *mlx.Array) {
kVals := make([]float32, H*D)
vVals := make([]float32, H*D)
for i := range kVals {
kVals[i] = 0.1*float32(pos+1) + 0.01*float32(i)
vVals[i] = -0.1*float32(pos+1) + 0.01*float32(i)
}
k = mlx.FromValues(kVals, 1, H, 1, D)
v = mlx.FromValues(vVals, 1, H, 1, D)
return
}
q := mlx.FromValues([]float32{0.7, -0.4, 0.2, 0.9}, 1, H, 1, D)
mlx.Eval(q)
// Drive the cache: write positions 0..totalWrites-2 as a "history",
// then position totalWrites-1 is the actual L=1 decode under test.
c := NewRotatingKVCache(window)
for pos := range totalWrites - 1 {
k, v := perPosKV(pos)
c.Update(newKVBatch(c.Offset(), k.Dim(2)), k, v)
}
finalPos := totalWrites - 1
kFinal, vFinal := perPosKV(finalPos)
b := &batch.Batch{
InputIDs: mlx.Zeros(mlx.DTypeInt32, 1, 1),
SeqOffsets: []int32{int32(finalPos)},
SeqQueryLens: []int32{1},
}
history := c.Update(b, kFinal, vFinal)
// Reference: the in-window logical-position-ordered K and V are
// the last `window` per-position values (positions
// [finalPos-window+1, finalPos]). Build them in that order.
startPos := max(finalPos-window+1, 0)
logicalKs := make([]*mlx.Array, 0, window)
logicalVs := make([]*mlx.Array, 0, window)
for pos := startPos; pos <= finalPos; pos++ {
kp, vp := perPosKV(pos)
logicalKs = append(logicalKs, kp)
logicalVs = append(logicalVs, vp)
}
kLogical := mlx.Concatenate(logicalKs, 2)
vLogical := mlx.Concatenate(logicalVs, 2)
// A logical-order ArrayMask with distinct, non-trivial values per
// key column. Picked so each column's contribution to softmax is
// distinct — the test fails if the cache's gather permutes the
// columns wrong before the kernel sees them.
maskVals := []float32{0.1, -0.3, 0.7, -0.2}
logicalMask := mlx.FromValues(maskVals, 1, 1, 1, window)
cases := []struct {
name string
model nn.AttentionMask
// reference mask uses the same coordinates the model mask
// represents; for ArrayMask it's the same tensor (since the
// reference K/V is in logical order).
refMode string
refMask *mlx.Array
}{
{"zero", nn.AttentionMask{}, "", nil},
{"causal-at-L1", nn.CausalMask(), "", nil},
{"array", nn.ArrayMask(logicalMask), "array", logicalMask},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
got := nn.ScaledDotProductAttention(b, q, scale,
nn.WithKVHistory(history),
nn.WithMask(tc.model))
want := mlx.FastScaledDotProductAttention(q, kLogical, vLogical, scale,
tc.refMode, tc.refMask)
mlx.Eval(got, want)
gs, ws := got.Floats(), want.Floats()
for i := range ws {
if math.Abs(float64(gs[i]-ws[i])) > 1e-5 {
t.Fatalf("index %d: got %v, want %v", i, gs[i], ws[i])
}
}
})
}
}
func TestAssistantSharedHistoryL1MasksMatchNoMask(t *testing.T) {
skipIfNoMLX(t)
if !mlx.MetalIsAvailable() {
t.Skip("MLX Metal not available")
}
const H, D = 1, 4
const window = 4
const total = 7
const scale = 1.0
q := mlx.FromValues([]float32{0.7, -0.4, 0.2, 0.9}, 1, H, 1, D)
mlx.Eval(q)
full := NewKVCache()
sliding := NewRotatingKVCache(window)
for pos := range total {
kVals := make([]float32, H*D)
vVals := make([]float32, H*D)
for i := range kVals {
kVals[i] = 0.1*float32(pos+1) + 0.01*float32(i)
vVals[i] = -0.1*float32(pos+1) + 0.01*float32(i)
}
k := mlx.FromValues(kVals, 1, H, 1, D)
v := mlx.FromValues(vVals, 1, H, 1, D)
full.Update(newKVBatch(full.Offset(), 1), k, v)
sliding.Update(newKVBatch(sliding.Offset(), 1), k, v)
}
b := newKVBatch(total-1, 1)
slidingHistory := sliding.View(b)
cases := []struct {
name string
h *nn.KVHistory
mask nn.AttentionMask
}{
{name: "full", h: full.View(b), mask: nn.CausalMask()},
{name: "sliding", h: slidingHistory, mask: nn.CausalMask().Intersect(nn.SlidingWindowMask(b, slidingHistory.K().Dim(2), window, q.DType()))},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
got := nn.ScaledDotProductAttention(b, q, scale, nn.WithKVHistory(tc.h), nn.WithMask(tc.mask))
want := mlx.FastScaledDotProductAttention(q, tc.h.K(), tc.h.V(), scale, "", nil)
mlx.Eval(got, want)
gs, ws := got.Floats(), want.Floats()
for i := range ws {
if math.Abs(float64(gs[i]-ws[i])) > 1e-5 {
t.Fatalf("index %d: got %v, want %v", i, gs[i], ws[i])
}
}
})
}
}
// TestRotatingKVCachePrefillParity drives an L>1 prefill into a
// rotating cache and verifies SDPA output through WithKVHistory
// matches a reference computed from the same K/V with the model mask
// and window restriction composed manually.
func TestRotatingKVCachePrefillParity(t *testing.T) {
skipIfNoMLX(t)
const H, L, D = 1, 6, 4
const window = 4
const scale = 1.0
qVals := make([]float32, 1*H*L*D)
kVals := make([]float32, 1*H*L*D)
vVals := make([]float32, 1*H*L*D)
for i := range qVals {
qVals[i] = 0.5 + 0.05*float32(i)
kVals[i] = -0.3 + 0.07*float32(i)
vVals[i] = 0.3 + 0.03*float32(i)
}
q := mlx.FromValues(qVals, 1, H, L, D)
k := mlx.FromValues(kVals, 1, H, L, D)
v := mlx.FromValues(vVals, 1, H, L, D)
b := newKVBatch(0, L)
cases := []struct {
name string
mask nn.AttentionMask
// rect arguments matching nn.AttentionMask.Relax (qLo, qHi, kLo, kHi)
relax [][4]int
causal bool
}{
{"zero", nn.AttentionMask{}, nil, false},
{"causal", nn.CausalMask(), nil, true},
{"causal+relax", nn.CausalMask().Relax(0, 1, 4, 2, 5), [][4]int{{1, 4, 2, 5}}, true},
}
negInf := float32(math.Inf(-1))
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
c := NewRotatingKVCache(window)
history := c.Update(b, k, v)
got := nn.ScaledDotProductAttention(b, q, scale,
nn.WithKVHistory(history),
nn.WithMask(tc.mask))
// Reference mask: causal blocks k > absQ; relax rectangles
// release causal-blocked cells; window blocks k < absQ - window + 1.
refVals := make([]float32, L*L)
for qi := range L {
absQ := qi
for ki := range L {
blocked := false
if tc.causal && ki > absQ {
blocked = true
}
for _, r := range tc.relax {
qLo, qHi, kLo, kHi := r[0], r[1], r[2], r[3]
if absQ >= qLo && absQ < qHi && ki >= kLo && ki < kHi {
blocked = false
}
}
if window > 0 && ki < absQ-window+1 {
blocked = true
}
if blocked {
refVals[qi*L+ki] = negInf
}
}
}
refMask := mlx.FromValues(refVals, 1, 1, L, L)
want := mlx.FastScaledDotProductAttention(q, k, v, scale, "array", refMask)
mlx.Eval(got, want)
gs, ws := got.Floats(), want.Floats()
for i := range ws {
if math.Abs(float64(gs[i]-ws[i])) > 1e-4 {
t.Fatalf("index %d: got %v, want %v", i, gs[i], ws[i])
}
}
})
}
}
// TestRotatingKVCacheMLAParity drives a rotating cache with the MLA
// shape — K = [kvLatent, kPE] concatenated, V = zero-width — then
// uses WithMLAHistory to slice V from K and compares output against
// a manual reference. Pins the cache+MLA integration that
// glm4_moe_lite uses in production.
func TestRotatingKVCacheMLAParity(t *testing.T) {
skipIfNoMLX(t)
const H, L, D, valueDim = 1, 3, 6, 4
const scale = 1.0
const window = 8 // window >= L so no window restriction
kVals := make([]float32, 1*H*L*D)
for i := range kVals {
kVals[i] = 0.1 * float32(i+1)
}
k := mlx.FromValues(kVals, 1, H, L, D)
v := mlx.Zeros(mlx.DTypeFloat32, 1, H, L, 0)
q := mlx.Zeros(mlx.DTypeFloat32, 1, H, L, D)
b := newKVBatch(0, L)
c := NewRotatingKVCache(window)
history := c.Update(b, k, v)
got := nn.ScaledDotProductAttention(b, q, scale,
nn.WithMLAHistory(history, valueDim),
nn.WithMask(nn.CausalMask()))
vRef := k.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(), mlx.Slice(0, valueDim))
want := mlx.FastScaledDotProductAttention(q, k, vRef, scale, "causal", nil)
mlx.Eval(got, want)
gs, ws := got.Floats(), want.Floats()
for i := range ws {
if math.Abs(float64(gs[i]-ws[i])) > 1e-5 {
t.Fatalf("index %d: got %v, want %v", i, gs[i], ws[i])
}
}
}

View File

@@ -0,0 +1,338 @@
package cache
import (
"testing"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
// singleTokenKV and multiTokenKV fabricate [B=1, H=1, L, D=2] key/value
// tensors whose channel value is the token id, so stateIDs can recover
// which ids survived in the cache.
func singleTokenKV(id float32) (*mlx.Array, *mlx.Array) {
k := mlx.FromValues([]float32{id, id}, 1, 1, 1, 2)
v := mlx.FromValues([]float32{id, id}, 1, 1, 1, 2)
return k, v
}
func multiTokenKV(ids []float32) (*mlx.Array, *mlx.Array) {
data := make([]float32, 0, 2*len(ids))
for _, id := range ids {
data = append(data, id, id)
}
k := mlx.FromValues(data, 1, 1, len(ids), 2)
v := mlx.FromValues(data, 1, 1, len(ids), 2)
return k, v
}
// stateIDs returns the ids currently in the cache in slot order (logical
// after a concat, physical/rotated after a single-token update).
func stateIDs(t *testing.T, c *RotatingKVCache) []float32 {
t.Helper()
state := c.State()
if state == nil {
return nil
}
mlx.Eval(state[0])
flat := state[0].Floats()
n := state[0].Dim(2)
out := make([]float32, n)
for i := range n {
out[i] = flat[i*2]
}
return out
}
func equalSlice(a, b []float32) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}
func feedMulti(c *RotatingKVCache, startID float32, n int) float32 {
ids := make([]float32, n)
for i := range ids {
ids[i] = startID + float32(i)
}
k, v := multiTokenKV(ids)
c.Update(newKVBatch(c.Offset(), k.Dim(2)), k, v)
return startID + float32(n)
}
func feedSingle(c *RotatingKVCache, id float32) {
k, v := singleTokenKV(id)
c.Update(newKVBatch(c.Offset(), k.Dim(2)), k, v)
}
// TestRotatingKVCacheConcatMidRotationPreservesContext: after the buffer
// has wrapped, a multi-token concat must keep the (maxSize-1) most recent
// pre-existing tokens in logical order so the first Q of the new batch
// has a full sliding window.
func TestRotatingKVCacheConcatMidRotationPreservesContext(t *testing.T) {
skipIfNoMLX(t)
const window = 4
c := NewRotatingKVCache(window)
nextID := feedMulti(c, 1, 3)
for range 6 {
feedSingle(c, nextID)
nextID++
}
if c.Offset() != 9 {
t.Fatalf("setup: offset=%d want 9", c.Offset())
}
if c.idx >= c.maxSize {
t.Fatalf("setup: expected mid-rotation idx (<%d), got %d", c.maxSize, c.idx)
}
feedMulti(c, 10, 2)
got := stateIDs(t, c)
want := []float32{7, 8, 9, 10, 11}
if !equalSlice(got, want) {
t.Fatalf("post-concat window=%v want %v", got, want)
}
if c.Offset() != 11 {
t.Fatalf("offset=%d want 11", c.Offset())
}
}
// TestRotatingKVCacheConcatAlignedInvariant: with an aligned buffer
// (c.idx == Dim), an L>1 concat keeps the last (maxSize-1) pre-existing
// tokens plus the full new batch. This is the chunked-prefill contract
// x/mlxrunner/pipeline.go relies on.
func TestRotatingKVCacheConcatAlignedInvariant(t *testing.T) {
skipIfNoMLX(t)
const window = 4
c := NewRotatingKVCache(window)
// Chunk 1 fills past maxSize, leaving Dim == maxSize aligned.
feedMulti(c, 1, 6)
// Chunk 2: the buffer is intentionally oversized to (maxSize-1) + L
// so the first new Q has its full window in scope for this forward.
feedMulti(c, 7, 3)
got := stateIDs(t, c)
want := []float32{4, 5, 6, 7, 8, 9}
if !equalSlice(got, want) {
t.Fatalf("post-chunk-2 buffer=%v want %v", got, want)
}
// The next decode trims oversize back to maxSize; order may be
// physical (rotated), so check as a set.
feedSingle(c, 10)
got = stateIDs(t, c)
if len(got) != window {
t.Fatalf("post-decode Dim=%d want %d", len(got), window)
}
seen := map[float32]bool{}
for _, v := range got {
seen[v] = true
}
for _, w := range []float32{7, 8, 9, 10} {
if !seen[w] {
t.Fatalf("post-decode window missing %v (got %v)", w, got)
}
}
}
// TestRotatingKVCacheConcatAfterDecodeGrowsBuffer: update() grows the
// underlying buffer by `step` slots via mlx.Zeros before writing, so
// after one decode on a short prefill c.idx < Dim even though the cache
// has not wrapped. Those trailing slots are zero padding and must not
// be pulled back into the live window on the next concat.
func TestRotatingKVCacheConcatAfterDecodeGrowsBuffer(t *testing.T) {
skipIfNoMLX(t)
const window = 512
c := NewRotatingKVCache(window)
feedMulti(c, 1, 3)
feedSingle(c, 4)
feedMulti(c, 5, 3)
got := stateIDs(t, c)
want := []float32{1, 2, 3, 4, 5, 6, 7}
if !equalSlice(got, want) {
t.Fatalf("growing-buffer concat=%v want %v", got, want)
}
}
// TestRotatingKVCacheConcatAfterLiveRewind: x/mlxrunner/cache.go calls
// Restore(nil, target) between conversation turns to rewind the cache to
// the matched prefix. Restore moves c.offset/c.idx without trimming the
// underlying buffer, so slots [c.idx, Dim) still hold stale pre-rewind
// tokens. A subsequent concat must drop those, not treat them as wrapped
// window content.
func TestRotatingKVCacheConcatAfterLiveRewind(t *testing.T) {
skipIfNoMLX(t)
const window = 8
c := NewRotatingKVCache(window)
// Grow the buffer to exactly maxSize without wrapping.
feedMulti(c, 1, 2)
for id := float32(3); id <= 8; id++ {
feedSingle(c, id)
}
if c.Offset() != window {
t.Fatalf("setup: offset=%d want %d", c.Offset(), window)
}
if !c.Restore(nil, 2) {
t.Fatalf("live rewind to 2 failed")
}
if c.Offset() != 2 {
t.Fatalf("post-rewind offset=%d want 2", c.Offset())
}
feedMulti(c, 9, 3)
got := stateIDs(t, c)
want := []float32{1, 2, 9, 10, 11}
if !equalSlice(got, want) {
t.Fatalf("post-rewind concat=%v want %v", got, want)
}
if c.Offset() != 5 {
t.Fatalf("offset=%d want 5", c.Offset())
}
}
// TestRotatingKVCacheConcatGrowingBuffer: when oldLen < maxSize the trim
// formula drops to non-positive and all pre-existing tokens are kept.
func TestRotatingKVCacheConcatGrowingBuffer(t *testing.T) {
skipIfNoMLX(t)
const window = 4
c := NewRotatingKVCache(window)
feedMulti(c, 1, 2)
feedMulti(c, 3, 2)
got := stateIDs(t, c)
want := []float32{1, 2, 3, 4}
if !equalSlice(got, want) {
t.Fatalf("growing buffer=%v want %v", got, want)
}
}
// TestRotatingKVCacheRunnerChunkedPrefill mirrors the
// x/mlxrunner/pipeline.go prefill loop: a long prompt fed through
// repeated L>1 Update() calls on a single cache. Scaled-down proxy for
// the Gemma 4 26B case (sliding_window=1024, prefillChunkSize=2048).
func TestRotatingKVCacheRunnerChunkedPrefill(t *testing.T) {
skipIfNoMLX(t)
const window = 4
c := NewRotatingKVCache(window)
feedMulti(c, 1, 8)
if c.Offset() != 8 {
t.Fatalf("chunk 1: offset=%d want 8", c.Offset())
}
feedMulti(c, 9, 8)
got := stateIDs(t, c)
want := []float32{6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
if !equalSlice(got, want) {
t.Fatalf("chunk 2: buffer=%v want %v", got, want)
}
feedMulti(c, 17, 4)
got = stateIDs(t, c)
want = []float32{14, 15, 16, 17, 18, 19, 20}
if !equalSlice(got, want) {
t.Fatalf("chunk 3: buffer=%v want %v", got, want)
}
// Decode trims oversize back to maxSize; order may be physical.
feedSingle(c, 21)
got = stateIDs(t, c)
if len(got) != window {
t.Fatalf("post-decode Dim=%d want %d", len(got), window)
}
seen := map[float32]bool{}
for _, v := range got {
seen[v] = true
}
for _, w := range []float32{18, 19, 20, 21} {
if !seen[w] {
t.Fatalf("post-decode window missing %v (got %v)", w, got)
}
}
}
// TestRotatingKVCacheMultiTurnChatSimulation walks a prefill → decode →
// prefill sequence and checks that each new prefill retains the last
// (maxSize-1) pre-existing tokens in logical order.
func TestRotatingKVCacheMultiTurnChatSimulation(t *testing.T) {
skipIfNoMLX(t)
const window = 4
c := NewRotatingKVCache(window)
nextID := feedMulti(c, 1, 2)
for range 5 {
feedSingle(c, nextID)
nextID++
}
if c.Offset() != 7 {
t.Fatalf("turn 1: offset=%d want 7", c.Offset())
}
feedMulti(c, nextID, 3)
nextID += 3
got := stateIDs(t, c)
want := []float32{5, 6, 7, 8, 9, 10}
if !equalSlice(got, want) {
t.Fatalf("turn 2 prefill buffer=%v want %v", got, want)
}
for range 4 {
feedSingle(c, nextID)
nextID++
}
if c.Offset() != 14 {
t.Fatalf("turn 2 decode: offset=%d want 14", c.Offset())
}
feedMulti(c, nextID, 2)
got = stateIDs(t, c)
want = []float32{12, 13, 14, 15, 16}
if !equalSlice(got, want) {
t.Fatalf("turn 3 prefill buffer=%v want %v", got, want)
}
}
// TestRotatingKVCacheOffsetTracking: Offset() is the monotonic logical
// token count through any mix of Update() calls — Gemma 4 uses
// donorEntry.Offset - L for the consumer's RoPE offset.
func TestRotatingKVCacheOffsetTracking(t *testing.T) {
skipIfNoMLX(t)
c := NewRotatingKVCache(4)
nextID := feedMulti(c, 1, 3)
if c.Offset() != 3 {
t.Fatalf("after prefill 3: offset=%d want 3", c.Offset())
}
for i := range 5 {
feedSingle(c, nextID)
nextID++
if c.Offset() != 3+i+1 {
t.Fatalf("after decode %d: offset=%d want %d", i, c.Offset(), 3+i+1)
}
}
nextID = feedMulti(c, nextID, 2)
if c.Offset() != 10 {
t.Fatalf("after turn-2 prefill: offset=%d want 10", c.Offset())
}
// L > maxSize concat.
feedMulti(c, nextID, 7)
if c.Offset() != 17 {
t.Fatalf("after large prefill: offset=%d want 17", c.Offset())
}
}

982
x/mlxrunner/cache_test.go Normal file
View File

@@ -0,0 +1,982 @@
package mlxrunner
import (
"slices"
"testing"
"time"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
// snapshotTracker records every fakeSnapshot created and every Close() call
// so tests can detect leaked (created but never closed) or double-closed snapshots.
type snapshotTracker struct {
all []*fakeSnapshot
}
func (tr *snapshotTracker) track(s *fakeSnapshot) {
if s == nil {
return
}
s.tracker = tr
tr.all = append(tr.all, s)
}
// Fake caches that store actual token sequences so tests can verify the right
// data was restored, not just the right offset.
// fakeSnapshot stores a copy of the token sub-sequence it covers.
type fakeSnapshot struct {
tokens []int32
from, to int
byteSize int // configurable for eviction tests
tracker *snapshotTracker
closeCount int
}
func (s *fakeSnapshot) Size() int { return s.byteSize }
func (s *fakeSnapshot) Close() {
s.closeCount++
}
// fakeRewindableCache tracks the full token sequence and supports
// arbitrary rewind via Restore(nil, target).
type fakeRewindableCache struct {
tokens []int32
tracker *snapshotTracker
}
func (c *fakeRewindableCache) feed(tokens []int32) {
c.tokens = append(c.tokens, tokens...)
}
func (c *fakeRewindableCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
return nil, nil
}
func (c *fakeRewindableCache) State() []*mlx.Array { return nil }
func (c *fakeRewindableCache) Offset() int { return len(c.tokens) }
func (c *fakeRewindableCache) Free() {
c.tokens = nil
}
func (c *fakeRewindableCache) Snapshot(fromOffset int) cache.Snapshot {
if fromOffset >= len(c.tokens) {
return nil
}
from := fromOffset
if from < 0 {
from = 0
}
s := &fakeSnapshot{
tokens: slices.Clone(c.tokens[from:]),
from: from,
to: len(c.tokens),
}
c.tracker.track(s)
return s
}
func (c *fakeRewindableCache) Restore(snapshot cache.Snapshot, target int) bool {
if target < 0 {
return false
}
if snapshot == nil {
if target > len(c.tokens) {
return false
}
c.tokens = c.tokens[:target]
return true
}
s := snapshot.(*fakeSnapshot)
if target > s.to || len(c.tokens) < s.from {
return false
}
c.tokens = append(c.tokens[:s.from], s.tokens...)
if target < len(c.tokens) {
c.tokens = c.tokens[:target]
}
return true
}
func (c *fakeRewindableCache) Merge(parent, child cache.Snapshot) cache.Snapshot {
if parent == nil || child == nil {
if parent != nil {
parent.Close()
}
if child != nil {
child.Close()
}
return nil
}
p := parent.(*fakeSnapshot)
ch := child.(*fakeSnapshot)
merged := make([]int32, len(p.tokens)+len(ch.tokens))
copy(merged, p.tokens)
copy(merged[len(p.tokens):], ch.tokens)
s := &fakeSnapshot{
tokens: merged,
from: p.from,
to: ch.to,
byteSize: p.byteSize + ch.byteSize,
}
c.tracker.track(s)
p.Close()
ch.Close()
return s
}
func (c *fakeRewindableCache) Split(snapshot cache.Snapshot, at int) (cache.Snapshot, cache.Snapshot) {
if snapshot == nil {
return nil, nil
}
s := snapshot.(*fakeSnapshot)
relAt := at - s.from
if relAt <= 0 {
return nil, snapshot
}
if relAt >= len(s.tokens) {
return snapshot, nil
}
p := &fakeSnapshot{
tokens: slices.Clone(s.tokens[:relAt]),
from: s.from,
to: at,
byteSize: s.byteSize,
}
ch := &fakeSnapshot{
tokens: slices.Clone(s.tokens[relAt:]),
from: at,
to: s.to,
byteSize: s.byteSize,
}
c.tracker.track(p)
c.tracker.track(ch)
s.Close()
return p, ch
}
func TestKVCacheBeginWithFactoryLimitCapsPrefix(t *testing.T) {
inputs := []int32{1, 2, 3, 4, 5}
tracker := &snapshotTracker{}
var kc kvCache
factory := func() []cache.Cache {
return []cache.Cache{&fakeRewindableCache{tracker: tracker}}
}
session := kc.beginWithFactoryLimit(inputs, factory, "test", -1, false)
session.caches[0].(*fakeRewindableCache).feed(inputs)
session.close()
session = kc.beginWithFactoryLimit(inputs, factory, "test", 3, false)
if got, want := session.caches[0].Offset(), 3; got != want {
t.Fatalf("cache offset = %d, want %d", got, want)
}
if got, want := session.remaining, inputs[3:]; !slices.Equal(got, want) {
t.Fatalf("remaining = %v, want %v", got, want)
}
}
func TestKVCacheBeginWithFactoryLimitDoesNotKeepSeedToken(t *testing.T) {
inputs := []int32{1, 2, 3, 4, 5}
tracker := &snapshotTracker{}
var kc kvCache
factory := func() []cache.Cache {
return []cache.Cache{&fakeRewindableCache{tracker: tracker}}
}
session := kc.beginWithFactoryLimit(inputs, factory, "test", -1, false)
session.caches[0].(*fakeRewindableCache).feed(inputs)
session.close()
session = kc.beginWithFactoryLimit(inputs, factory, "test", len(inputs), false)
if got, want := session.caches[0].Offset(), len(inputs); got != want {
t.Fatalf("cache offset = %d, want %d", got, want)
}
if len(session.remaining) != 0 {
t.Fatalf("remaining = %v, want empty", session.remaining)
}
}
// fakeSlidingWindowCache models RotatingKVCache semantics: stores the full
// token sequence but only the trailing maxSize tokens are "live" in the window.
// Once the window fills, live rewind is impossible without a snapshot.
type fakeSlidingWindowCache struct {
tokens []int32
maxSize int
tracker *snapshotTracker
}
func (c *fakeSlidingWindowCache) feed(tokens []int32) {
c.tokens = append(c.tokens, tokens...)
}
func (c *fakeSlidingWindowCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
return nil, nil
}
func (c *fakeSlidingWindowCache) State() []*mlx.Array { return nil }
func (c *fakeSlidingWindowCache) Offset() int { return len(c.tokens) }
func (c *fakeSlidingWindowCache) Free() {
c.tokens = nil
}
func (c *fakeSlidingWindowCache) Snapshot(fromOffset int) cache.Snapshot {
if len(c.tokens) == 0 || len(c.tokens) <= fromOffset {
return nil
}
// Snapshot captures the full window state (like RotatingKVCache.Snapshot).
s := &fakeSnapshot{
tokens: slices.Clone(c.tokens),
from: 0,
to: len(c.tokens),
}
c.tracker.track(s)
return s
}
func (c *fakeSlidingWindowCache) Restore(snapshot cache.Snapshot, target int) bool {
if target < 0 {
return false
}
if snapshot == nil {
if target >= len(c.tokens) {
return target == len(c.tokens)
}
// Live rewind only works when buffer hasn't filled (offset <= maxSize).
if len(c.tokens) > c.maxSize {
return false
}
c.tokens = c.tokens[:target]
return true
}
s := snapshot.(*fakeSnapshot)
if target > s.to {
return false
}
// Reject if clamping would leave an incomplete window
// (matches RotatingKVCache behavior).
if target < s.to && s.to > c.maxSize {
return false
}
c.tokens = slices.Clone(s.tokens)
if target < len(c.tokens) {
c.tokens = c.tokens[:target]
}
return true
}
func (c *fakeSlidingWindowCache) Merge(parent, child cache.Snapshot) cache.Snapshot {
// Child supersedes parent for sliding window (full window state).
if parent != nil {
parent.Close()
}
return child
}
func (c *fakeSlidingWindowCache) Split(snapshot cache.Snapshot, at int) (cache.Snapshot, cache.Snapshot) {
// Can't split a ring buffer at an arbitrary point.
return nil, snapshot
}
// fakeRecurrentCache models RecurrentCache semantics: stores tokens
// but cannot rewind without a snapshot.
type fakeRecurrentCache struct {
tokens []int32
tracker *snapshotTracker
}
func (c *fakeRecurrentCache) feed(tokens []int32) {
c.tokens = append(c.tokens, tokens...)
}
func (c *fakeRecurrentCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
return nil, nil
}
func (c *fakeRecurrentCache) State() []*mlx.Array { return nil }
func (c *fakeRecurrentCache) Offset() int { return len(c.tokens) }
func (c *fakeRecurrentCache) Free() {
c.tokens = nil
}
func (c *fakeRecurrentCache) Snapshot(fromOffset int) cache.Snapshot {
// Recurrent state is cumulative; snapshot captures the full state.
if len(c.tokens) == 0 {
return nil
}
s := &fakeSnapshot{
tokens: slices.Clone(c.tokens),
from: 0,
to: len(c.tokens),
}
c.tracker.track(s)
return s
}
func (c *fakeRecurrentCache) Restore(snapshot cache.Snapshot, target int) bool {
if snapshot == nil {
return target == len(c.tokens) // can only no-op
}
s := snapshot.(*fakeSnapshot)
if target != s.to {
return false // cumulative state requires exact match
}
c.tokens = slices.Clone(s.tokens)
return true
}
func (c *fakeRecurrentCache) Merge(parent, child cache.Snapshot) cache.Snapshot {
// Child supersedes parent for cumulative state.
if parent != nil {
parent.Close()
}
return child
}
func (c *fakeRecurrentCache) Split(snapshot cache.Snapshot, at int) (cache.Snapshot, cache.Snapshot) {
return nil, snapshot // can't split cumulative state
}
type feedableCache interface {
cache.Cache
feed(tokens []int32)
}
// testEnv encapsulates a kvCache and its fake caches for a test scenario.
type testEnv struct {
kvc *kvCache
caches []cache.Cache // typed references for assertions
tracker *snapshotTracker
rewindable bool // true when all caches support arbitrary Restore(nil, target)
}
// newTransformerEnv creates a test environment with a single rewindable cache
// (pure transformer model).
func newTransformerEnv() *testEnv {
tracker := &snapshotTracker{}
caches := []cache.Cache{&fakeRewindableCache{tracker: tracker}}
return &testEnv{
kvc: &kvCache{caches: caches},
caches: caches,
tracker: tracker,
rewindable: true,
}
}
// newSlidingWindowEnv creates a test environment with one rewindable cache and
// one sliding window cache (Mistral-style architecture). The sliding window
// maxSize is set small enough that test sequences fill it, making
// Restore(nil, target) fail — the same behavior as production models where
// the window fills after a few turns.
func newSlidingWindowEnv() *testEnv {
tr := &snapshotTracker{}
rc := &fakeRewindableCache{tracker: tr}
sw := &fakeSlidingWindowCache{maxSize: 4, tracker: tr}
caches := []cache.Cache{rc, sw}
return &testEnv{
kvc: &kvCache{caches: caches},
caches: caches,
tracker: tr,
rewindable: false,
}
}
// newRecurrentEnv creates a test environment with one rewindable cache and one
// non-rewindable cache (Jamba-style architecture).
func newRecurrentEnv() *testEnv {
tr := &snapshotTracker{}
rc := &fakeRewindableCache{tracker: tr}
nrc := &fakeRecurrentCache{tracker: tr}
caches := []cache.Cache{rc, nrc}
return &testEnv{
kvc: &kvCache{caches: caches},
caches: caches,
tracker: tr,
rewindable: false,
}
}
// assertAllTokens checks that every cache in the environment contains exactly
// the expected token sequence.
func (e *testEnv) assertAllTokens(t *testing.T, label string, expected []int32) {
t.Helper()
for i, c := range e.caches {
assertTokens(t, label, c, expected)
// Verify all caches report the same offset.
if i > 0 && c.Offset() != e.caches[0].Offset() {
t.Errorf("%s: cache %d offset=%d != cache 0 offset=%d",
label, i, c.Offset(), e.caches[0].Offset())
}
}
}
// simulateRequest mirrors the production pipeline lifecycle:
// begin -> prefill with snapshot(false) at branch points -> generate -> close
type requestResult struct {
remaining []int32
pendingSnapshots int
}
// simulateRequest runs a request through the harness. If userSnapshotAt > 0,
// a user snapshot is requested at that offset during prefill.
func simulateRequest(t *testing.T, kvc *kvCache, inputs, generated []int32, userSnapshotAt ...int) requestResult {
t.Helper()
session := kvc.begin(nil, inputs)
for _, at := range userSnapshotAt {
if at > 0 {
session.requestSnapshot(at)
}
}
result := requestResult{
remaining: slices.Clone(session.remaining),
pendingSnapshots: len(session.pendingSnapshots),
}
assertCacheOffsetAlignment(t, kvc, "after begin")
baseOffset := kvc.minCacheOffset()
remaining := inputs[baseOffset:]
// Prefill: feed tokens, pausing at each pending snapshot.
for len(session.pendingSnapshots) > 0 {
sp := session.pendingSnapshots[0]
count := sp.offset - baseOffset
if count > len(remaining) {
break
}
if count > 0 {
feedAll(kvc.caches, remaining[:count])
remaining = remaining[count:]
baseOffset = sp.offset
}
assertCacheOffsetAlignment(t, kvc, "at snapshot point")
session.snapshot()
}
// Feed rest of input tokens.
if len(remaining) > 0 {
feedAll(kvc.caches, remaining)
}
assertCacheOffsetAlignment(t, kvc, "after prefill")
// Generate tokens.
if len(generated) > 0 {
session.outputs = generated
feedAll(kvc.caches, generated)
}
assertCacheOffsetAlignment(t, kvc, "before close")
session.close()
return result
}
func feedAll(caches []cache.Cache, tokens []int32) {
for _, c := range caches {
if fc, ok := c.(feedableCache); ok {
fc.feed(tokens)
}
}
}
// assertCacheOffsetAlignment verifies all caches report the same offset.
func assertCacheOffsetAlignment(t *testing.T, kvc *kvCache, label string) {
t.Helper()
if len(kvc.caches) < 2 {
return
}
expected := kvc.caches[0].Offset()
for i := 1; i < len(kvc.caches); i++ {
if got := kvc.caches[i].Offset(); got != expected {
t.Errorf("%s: cache %d offset=%d != cache 0 offset=%d", label, i, got, expected)
}
}
}
// assertTokens checks that a feedable cache contains the expected token sequence.
// For sliding window caches, only the trailing maxSize tokens are checked.
func assertTokens(t *testing.T, label string, c cache.Cache, expected []int32) {
t.Helper()
switch fc := c.(type) {
case *fakeRewindableCache:
if !slices.Equal(fc.tokens, expected) {
t.Errorf("%s: rewindable tokens = %v, want %v", label, fc.tokens, expected)
}
case *fakeSlidingWindowCache:
// Sliding window stores full history but only trailing maxSize are live.
// Verify the full token sequence matches (the window semantics are
// enforced by Snapshot/Restore, not by the token log).
if !slices.Equal(fc.tokens, expected) {
t.Errorf("%s: sliding window tokens = %v, want %v", label, fc.tokens, expected)
}
case *fakeRecurrentCache:
if !slices.Equal(fc.tokens, expected) {
t.Errorf("%s: non-rewindable tokens = %v, want %v", label, fc.tokens, expected)
}
default:
t.Fatalf("%s: unknown cache type %T", label, c)
}
}
// checkTrieInvariants walks the trie and checks structural invariants.
func checkTrieInvariants(t *testing.T, root *trieNode) {
t.Helper()
walkNodes(root, func(n *trieNode) bool {
if n.parent != nil {
if n.startOffset() != n.parent.endOffset {
t.Errorf("node [%d,%d): startOffset %d != parent endOffset %d",
n.startOffset(), n.endOffset, n.startOffset(), n.parent.endOffset)
}
}
if len(n.tokens) != n.endOffset-n.startOffset() {
t.Errorf("node [%d,%d): token count %d != offset span %d",
n.startOffset(), n.endOffset, len(n.tokens), n.endOffset-n.startOffset())
}
for _, c := range n.children {
if c.parent != n {
t.Errorf("child [%d,%d) parent mismatch", c.startOffset(), c.endOffset)
}
}
// No two siblings should start with the same token.
seen := make(map[int32]bool)
for _, c := range n.children {
if len(c.tokens) > 0 {
first := c.tokens[0]
if seen[first] {
t.Errorf("node [%d,%d): duplicate sibling first token %d",
n.startOffset(), n.endOffset, first)
}
seen[first] = true
}
}
return true
})
}
// checkSnapshotLeaks verifies that every tracked snapshot is either still live
// in the trie (closeCount == 0) or has been closed exactly once. It reports
// leaked snapshots (not in trie, never closed) and double-closes.
func checkSnapshotLeaks(t *testing.T, tracker *snapshotTracker, root *trieNode) {
t.Helper()
if tracker == nil {
return
}
// Collect all live snapshots still referenced by trie nodes.
live := make(map[*fakeSnapshot]bool)
walkNodes(root, func(n *trieNode) bool {
for _, s := range n.snapshots {
if s != nil {
if fs, ok := s.(*fakeSnapshot); ok {
live[fs] = true
}
}
}
return true
})
for i, s := range tracker.all {
if live[s] {
if s.closeCount != 0 {
t.Errorf("snapshot #%d [%d,%d) is still in trie but was closed %d time(s)",
i, s.from, s.to, s.closeCount)
}
} else {
if s.closeCount == 0 {
t.Errorf("snapshot #%d [%d,%d) leaked: created but never closed and not in trie",
i, s.from, s.to)
} else if s.closeCount > 1 {
t.Errorf("snapshot #%d [%d,%d) double-closed: closed %d times",
i, s.from, s.to, s.closeCount)
}
}
}
}
// forEachEnv runs fn as subtests for three realistic model configurations:
// pure transformer, transformer + sliding window (Mistral-style), and
// transformer + recurrent (Jamba-style). Leak checking runs automatically
// at the end of each subtest.
func forEachEnv(t *testing.T, fn func(t *testing.T, env *testEnv)) {
t.Helper()
run := func(t *testing.T, env *testEnv) {
t.Cleanup(func() {
checkSnapshotLeaks(t, env.tracker, env.kvc.root)
})
fn(t, env)
}
t.Run("Transformer", func(t *testing.T) { run(t, newTransformerEnv()) })
t.Run("SlidingWindow", func(t *testing.T) { run(t, newSlidingWindowEnv()) })
t.Run("Recurrent", func(t *testing.T) { run(t, newRecurrentEnv()) })
}
// TestBranchCreationAndReuse exercises the core multi-conversation lifecycle:
// two conversations share a prefix and diverge, creating a branch point.
// A third conversation extends the first. Verifies trie structure, cache
// hit lengths, and that semantic caches contain the correct token sequences.
func TestBranchCreationAndReuse(t *testing.T) {
forEachEnv(t, func(t *testing.T, env *testEnv) {
kvc := env.kvc
// Request A: [1,2,3,4,5,6,7,8] + generate [20,21] — full miss.
resA := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 6, 7, 8}, []int32{20, 21})
if len(resA.remaining) != 8 {
t.Fatalf("A: remaining = %d, want 8 (full miss)", len(resA.remaining))
}
env.assertAllTokens(t, "after A", []int32{1, 2, 3, 4, 5, 6, 7, 8, 20, 21})
// Verify trie was populated by close().
_, mA := findBestMatch(kvc.root, []int32{1, 2, 3, 4, 5, 6, 7, 8, 20, 21})
if mA != 10 {
t.Fatalf("A findable: expected 10 matched, got %d", mA)
}
// Request B: [1,2,3,4,5,10,11,12] — shares 5-token prefix with A.
// For rewindable caches, switchToPath rewinds to the match point
// so only the non-matching suffix needs evaluation. For non-rewindable
// caches (RecurrentCache), the rewind fails and freeAll fires.
resB := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 12}, []int32{30, 31})
if env.rewindable {
if resB.pendingSnapshots != 0 {
t.Fatalf("B: pendingSnapshots = %d, want 0 (rewind succeeded)", resB.pendingSnapshots)
}
if len(resB.remaining) != 3 {
t.Fatalf("B: remaining = %d, want 3 (rewind to match point)", len(resB.remaining))
}
} else {
if resB.pendingSnapshots != 1 {
t.Fatalf("B: pendingSnapshots = %d, want 1", resB.pendingSnapshots)
}
if len(resB.remaining) != 8 {
t.Fatalf("B: remaining = %d, want 8 (freeAll fallback)", len(resB.remaining))
}
}
env.assertAllTokens(t, "after B", []int32{1, 2, 3, 4, 5, 10, 11, 12, 30, 31})
// Both A and B should be findable in the trie.
_, mA2 := findBestMatch(kvc.root, []int32{1, 2, 3, 4, 5, 6, 7, 8, 20, 21})
if mA2 < 5 {
t.Fatalf("A still findable: expected >= 5 matched, got %d", mA2)
}
_, mB := findBestMatch(kvc.root, []int32{1, 2, 3, 4, 5, 10, 11, 12, 30, 31})
if mB < 5 {
t.Fatalf("B findable: expected >= 5 matched, got %d", mB)
}
// Request C: [1,2,3,4,5,6,7,8,40,41] — extends A's prefix.
// Should get a cache hit for the shared prefix.
resC := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 6, 7, 8, 40, 41}, nil)
if len(resC.remaining) >= 10 {
t.Fatalf("C: remaining = %d, want < 10 (should get cache hit)", len(resC.remaining))
}
env.assertAllTokens(t, "after C", []int32{1, 2, 3, 4, 5, 6, 7, 8, 40, 41})
checkTrieInvariants(t, kvc.root)
})
}
// TestExactMatchSeedBehavior verifies the holdback mechanism: when the exact
// same prompt is requested twice, the cache does not overclaim cached work.
// The last token must be re-evaluated to seed generation.
func TestExactMatchSeedBehavior(t *testing.T) {
forEachEnv(t, func(t *testing.T, env *testEnv) {
kvc := env.kvc
// Request A: first time.
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{10, 11})
// Request B: identical prompt. Holdback means matched=4, partial in
// the 5-token edge. For rewindable caches, switchToPath rewinds to
// offset 4, so only the held-back token needs re-evaluation. For
// non-rewindable caches, the rewind fails and freeAll fires.
resB := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{20, 21})
if env.rewindable {
if len(resB.remaining) != 1 {
t.Fatalf("B: remaining = %d, want 1 (rewind to holdback point)", len(resB.remaining))
}
if resB.pendingSnapshots != 0 {
t.Fatalf("B: pendingSnapshots = %d, want 0 (rewind succeeded)", resB.pendingSnapshots)
}
} else {
if len(resB.remaining) != 5 {
t.Fatalf("B: remaining = %d, want 5 (freeAll fallback)", len(resB.remaining))
}
if resB.pendingSnapshots != 1 {
t.Fatalf("B: pendingSnapshots = %d, want 1", resB.pendingSnapshots)
}
}
env.assertAllTokens(t, "after B", []int32{1, 2, 3, 4, 5, 20, 21})
checkTrieInvariants(t, kvc.root)
})
}
// TestConversationResumption tests the most common pattern: user sends a message,
// gets a response, then sends a follow-up. The follow-up should reuse the cached
// prefix (system prompt + first turn + assistant response).
func TestConversationResumption(t *testing.T) {
forEachEnv(t, func(t *testing.T, env *testEnv) {
kvc := env.kvc
// Turn 1: system prompt + user message, assistant generates response.
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{10, 11, 12})
env.assertAllTokens(t, "turn 1", []int32{1, 2, 3, 4, 5, 10, 11, 12})
// Turn 2: full history + new user message. Should get a cache hit on
// the prefix [1,2,3,4,5,10,11,12] and only need to evaluate [20,21].
resB := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 12, 20, 21}, []int32{30})
if len(resB.remaining) > 5 {
t.Fatalf("turn 2: remaining = %d, want <= 5 (should reuse most of history)", len(resB.remaining))
}
env.assertAllTokens(t, "turn 2", []int32{1, 2, 3, 4, 5, 10, 11, 12, 20, 21, 30})
// Turn 3: even longer history.
resC := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 12, 20, 21, 30, 40, 41}, nil)
if len(resC.remaining) > 5 {
t.Fatalf("turn 3: remaining = %d, want <= 5", len(resC.remaining))
}
env.assertAllTokens(t, "turn 3", []int32{1, 2, 3, 4, 5, 10, 11, 12, 20, 21, 30, 40, 41})
checkTrieInvariants(t, kvc.root)
})
}
// TestEvictionPreservesActiveConversations creates multiple conversations sharing
// a system prompt, triggers eviction via large snapshot sizes, and verifies the
// active path and shared prefix survive while memory stays bounded.
func TestEvictionPreservesActiveConversations(t *testing.T) {
forEachEnv(t, func(t *testing.T, env *testEnv) {
kvc := env.kvc
systemPrompt := []int32{1, 2, 3, 4, 5}
// Create 5 conversations with unique suffixes.
for i := range 5 {
suffix := []int32{int32(100 + i*10), int32(101 + i*10), int32(102 + i*10)}
inputs := append(slices.Clone(systemPrompt), suffix...)
simulateRequest(t, kvc, inputs, []int32{int32(200 + i)})
}
// Inflate snapshot sizes to trigger eviction.
walkNodes(kvc.root, func(n *trieNode) bool {
if !n.hasSnapshots() {
return true
}
snaps := make([]cache.Snapshot, len(n.snapshots))
for i, s := range n.snapshots {
if s != nil {
snaps[i] = &fakeSnapshot{byteSize: 2 * 1024 * 1024 * 1024} // 2 GiB per snapshot
}
}
n.setSnapshots(snaps, &kvc.pagedOutBytes)
return true
})
// Run eviction.
kvc.enforceEvictionPolicy()
// Memory should be within limits.
if kvc.pagedOutBytes > maxPagedOutBytes {
t.Fatalf("pagedOutBytes = %d, want <= %d", kvc.pagedOutBytes, maxPagedOutBytes)
}
// Active path should be untouched.
if len(kvc.activePath) < 2 {
t.Fatalf("activePath should have >= 2 nodes, got %d", len(kvc.activePath))
}
// System prompt prefix should still be findable (multi-child
// branch points are protected from eviction entirely).
_, matched := findBestMatch(kvc.root, systemPrompt)
if matched < len(systemPrompt) {
t.Fatalf("system prompt match = %d, want %d", matched, len(systemPrompt))
}
checkTrieInvariants(t, kvc.root)
})
}
// TestUserSnapshotPreservesRestorePoint verifies that user-created snapshots
// (snapshot(true)) resist structural changes that would destroy them:
// - A user node forces new tokens into a child instead of extending in-place
// - The snapshot remains restorable after other branches are added
func TestUserSnapshotPreservesRestorePoint(t *testing.T) {
forEachEnv(t, func(t *testing.T, env *testEnv) {
kvc := env.kvc
// Request A: user snapshot at offset 5, then generate.
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{10, 11}, 5)
assertUserNodeExists(t, kvc, "after A")
// Request B: extends A's prefix. The user node at offset 5 should
// force tokens into a child rather than extending in-place.
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 20, 21}, nil)
env.assertAllTokens(t, "after B", []int32{1, 2, 3, 4, 5, 10, 11, 20, 21})
assertUserNodeExists(t, kvc, "after B")
// Request C: diverge from the user node.
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 30, 31}, []int32{40})
// Request D: switch back to A's branch — user snapshot still restorable.
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 20, 21, 50}, nil)
env.assertAllTokens(t, "back to A", []int32{1, 2, 3, 4, 5, 10, 11, 20, 21, 50})
checkTrieInvariants(t, kvc.root)
})
}
// TestUserSnapshotResistsAutoMerge verifies that when a sibling leaf is evicted,
// a user-marked parent node is not auto-merged with its remaining single child.
func TestUserSnapshotResistsAutoMerge(t *testing.T) {
forEachEnv(t, func(t *testing.T, env *testEnv) {
kvc := env.kvc
// Request A: user snapshot at offset 3, then continue to offset 5.
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{10}, 3)
// Request B: diverges at the user node, creating a second child.
simulateRequest(t, kvc, []int32{1, 2, 3, 6, 7}, []int32{20})
userNode := findUserNode(t, kvc)
if len(userNode.children) != 2 {
t.Fatalf("user node children = %d, want 2", len(userNode.children))
}
// Inflate snapshot sizes and evict. The non-active branch should be
// evicted, leaving the user node with one child.
walkNodes(kvc.root, func(n *trieNode) bool {
if !n.hasSnapshots() {
return true
}
snaps := make([]cache.Snapshot, len(n.snapshots))
for i, s := range n.snapshots {
if s != nil {
snaps[i] = &fakeSnapshot{byteSize: 5 * 1024 * 1024 * 1024}
}
}
n.setSnapshots(snaps, &kvc.pagedOutBytes)
return true
})
kvc.enforceEvictionPolicy()
// The user node should still exist (not auto-merged) even with one child.
assertUserNodeExists(t, kvc, "after eviction")
checkTrieInvariants(t, kvc.root)
})
}
func findUserNode(t *testing.T, kvc *kvCache) *trieNode {
t.Helper()
var found *trieNode
walkNodes(kvc.root, func(n *trieNode) bool {
if n.user {
found = n
}
return true
})
if found == nil {
t.Fatal("no user-marked node found")
}
return found
}
func assertUserNodeExists(t *testing.T, kvc *kvCache, label string) {
t.Helper()
var exists bool
walkNodes(kvc.root, func(n *trieNode) bool {
if n.user {
exists = true
}
return true
})
if !exists {
t.Fatalf("%s: no user-marked node found", label)
}
}
// TestBranchSwitchRestoresCorrectState exercises switching back to an older
// branch after working on a different one, verifying that the restored cache
// state contains the correct token sequence for both rewindable and
// non-rewindable caches.
func TestBranchSwitchRestoresCorrectState(t *testing.T) {
forEachEnv(t, func(t *testing.T, env *testEnv) {
kvc := env.kvc
// Request A: [1,2,3,4,5] + generate [10,11]
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{10, 11})
env.assertAllTokens(t, "after A", []int32{1, 2, 3, 4, 5, 10, 11})
// Request B: [1,2,3,6,7] — diverges at token 4
simulateRequest(t, kvc, []int32{1, 2, 3, 6, 7}, []int32{12, 13})
env.assertAllTokens(t, "after B", []int32{1, 2, 3, 6, 7, 12, 13})
// Request C: switch back to A's branch [1,2,3,4,5,10,11,20]
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 20}, nil)
env.assertAllTokens(t, "after C (back to A)", []int32{1, 2, 3, 4, 5, 10, 11, 20})
checkTrieInvariants(t, kvc.root)
})
}
// TestLRUOnlyUpdatesUsedNodes verifies that intermediate nodes on the active
// path whose snapshots were not actually restored don't get their lastUsed
// refreshed, allowing them to age out and collapse.
func TestLRUOnlyUpdatesUsedNodes(t *testing.T) {
forEachEnv(t, func(t *testing.T, env *testEnv) {
kvc := env.kvc
// Request A: creates path [1,2,3,4,5] + generate [10,11]
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{10, 11})
// Request B: diverges at token 4, creating a branch point at offset 3
// with a split snapshot.
simulateRequest(t, kvc, []int32{1, 2, 3, 6, 7}, []int32{20, 21})
// Set all lastUsed to a known old time.
oldTime := time.Now().Add(-1 * time.Hour)
walkNodes(kvc.root, func(n *trieNode) bool {
n.lastUsed = oldTime
return true
})
// Request C: continue on B's branch. This will match B's path
// and extend it. The branch point's snapshot may be paged in
// for some cache types but not others.
beforeRequest := time.Now()
simulateRequest(t, kvc, []int32{1, 2, 3, 6, 7, 20, 21, 30}, nil)
// The path must have enough depth to exercise intermediate nodes.
if len(kvc.activePath) < 3 {
t.Fatalf("activePath too short to test intermediate nodes: got %d nodes", len(kvc.activePath))
}
// The frontier (deepest node on the active path) must be updated.
frontier := kvc.activePath[len(kvc.activePath)-1]
if frontier.lastUsed.Before(beforeRequest) {
t.Errorf("frontier lastUsed was not updated: got %v, want >= %v",
frontier.lastUsed, beforeRequest)
}
// Every non-frontier node on the active path (including root)
// should retain its old lastUsed — only the frontier gets refreshed.
for i, node := range kvc.activePath[:len(kvc.activePath)-1] {
if !node.lastUsed.Before(beforeRequest) {
t.Errorf("activePath[%d] (endOffset=%d) lastUsed was refreshed: got %v, want < %v",
i, node.endOffset, node.lastUsed, beforeRequest)
}
}
checkTrieInvariants(t, kvc.root)
})
}

296
x/mlxrunner/cache_trie.go Normal file
View File

@@ -0,0 +1,296 @@
package mlxrunner
import (
"fmt"
"slices"
"time"
"github.com/ollama/ollama/x/mlxrunner/cache"
)
// trieNode represents a node in the compressed prefix trie for KV cache branching.
// Each node stores a compressed edge (multiple tokens) and optional paged-out
// snapshot data per cache layer.
type trieNode struct {
tokens []int32 // compressed edge — multiple tokens per node
endOffset int // cumulative tokens from root to end of this node
parent *trieNode
children []*trieNode
lastUsed time.Time // for LRU eviction
snapshots []cache.Snapshot // per-layer paged-out snapshot data (nil if not paged out)
user bool // true = explicit restore point (resist auto-merge)
}
// startOffset returns the cumulative token offset at the start of this node's edge.
func (n *trieNode) startOffset() int {
return n.endOffset - len(n.tokens)
}
// snapshotBytes returns the total bytes of paged-out snapshots on this node.
func (n *trieNode) snapshotBytes() int64 {
var total int64
for _, s := range n.snapshots {
if s != nil {
total += int64(s.Size())
}
}
return total
}
// setSnapshots replaces this node's snapshots with snaps and closes the old ones.
// If counter is non-nil, the net byte delta is applied to it.
func (n *trieNode) setSnapshots(snaps []cache.Snapshot, counter *int64) {
old := n.swapSnapshots(snaps, counter)
for _, s := range old {
if s != nil {
s.Close()
}
}
}
// swapSnapshots is like setSnapshots but returns the previous snapshots
// without closing them. Use this when the old snapshots will be consumed
// (e.g. by Split/Merge).
func (n *trieNode) swapSnapshots(snaps []cache.Snapshot, counter *int64) []cache.Snapshot {
old := n.snapshots
if counter != nil {
*counter -= n.snapshotBytes()
}
n.snapshots = snaps
if counter != nil {
*counter += n.snapshotBytes()
}
return old
}
// hasSnapshots returns true if any layer has snapshot data.
func (n *trieNode) hasSnapshots() bool {
return slices.ContainsFunc(n.snapshots, func(s cache.Snapshot) bool { return s != nil })
}
// hasAllSnapshots returns true if every layer has snapshot data.
func (n *trieNode) hasAllSnapshots() bool {
return len(n.snapshots) > 0 && !slices.Contains(n.snapshots, nil)
}
// findBestMatch walks the trie matching input tokens, returning the path of
// nodes traversed and the total number of tokens matched.
func findBestMatch(root *trieNode, tokens []int32) (path []*trieNode, matched int) {
if root == nil {
return nil, 0
}
path = []*trieNode{root}
pos := 0
node := root
for pos < len(tokens) {
// When multiple children share the same first token (e.g. after
// a split), prefer the child whose full edge matches over one
// that only partially matches. This is just being defensive - it
// shouldn't actually happen.
var best *trieNode
bestMatched := 0
bestFull := false
for _, child := range node.children {
edge := child.tokens
if len(edge) == 0 {
continue
}
if edge[0] != tokens[pos] {
continue
}
// Count matching tokens in this child's edge.
j := 0
for j < len(edge) && pos+j < len(tokens) && edge[j] == tokens[pos+j] {
j++
}
full := j == len(edge)
// Prefer full edge matches; among same type, prefer longer.
if best == nil || (full && !bestFull) || (full == bestFull && j > bestMatched) {
best = child
bestMatched = j
bestFull = full
}
}
if best == nil {
break
}
pos += bestMatched
path = append(path, best)
if !bestFull {
// Partial match within this edge
break
}
node = best
}
return path, pos
}
// appendTokens either creates a new child node or extends the leaf in place,
// returning the node that now holds the tokens.
func (n *trieNode) appendTokens(root *trieNode, tokens []int32, endOffset int) *trieNode {
if n == root || len(n.children) > 0 || n.hasSnapshots() {
child := &trieNode{
tokens: make([]int32, len(tokens)),
endOffset: endOffset,
parent: n,
lastUsed: n.lastUsed,
}
copy(child.tokens, tokens)
n.children = append(n.children, child)
return child
}
n.tokens = append(n.tokens, tokens...)
n.endOffset = endOffset
return n
}
// removeNode removes a leaf node from the trie.
func removeNode(node *trieNode, counter *int64) {
if node.parent == nil {
panic("removeNode called on root")
}
if len(node.children) != 0 {
panic("removeNode called on non-leaf node")
}
p := node.parent
for i, child := range p.children {
if child == node {
p.children = append(p.children[:i], p.children[i+1:]...)
break
}
}
node.parent = nil
node.setSnapshots(nil, counter)
}
// splitNode splits a node at the given token offset within its edge,
// creating a new parent node. Returns the new parent.
// `at` is relative to the node's edge (0-based index into node.tokens).
// If caches are provided, snapshots are split between parent and child
// using Cache.Split; otherwise snapshots are invalidated.
func splitNode(node *trieNode, at int, caches []cache.Cache, counter *int64) *trieNode {
if at <= 0 || at >= len(node.tokens) {
panic(fmt.Sprintf("splitNode: invalid split offset %d for node with %d tokens", at, len(node.tokens)))
}
// Create new parent with the prefix of the edge.
newParent := &trieNode{
tokens: make([]int32, at),
endOffset: node.startOffset() + at,
parent: node.parent,
children: []*trieNode{node},
lastUsed: node.lastUsed,
}
copy(newParent.tokens, node.tokens[:at])
// Update the original node to have only the suffix.
node.tokens = node.tokens[at:]
// endOffset stays the same for the original node.
// Split snapshots between parent and child using Cache.Split.
// Split consumes the old snapshots, so we remove them first (adjusting
// the counter), then assign the split halves (adjusting it back).
if node.hasSnapshots() {
oldSnaps := node.swapSnapshots(nil, counter)
parentSnaps := make([]cache.Snapshot, len(oldSnaps))
childSnaps := make([]cache.Snapshot, len(oldSnaps))
for i, snap := range oldSnaps {
if snap != nil {
parentSnaps[i], childSnaps[i] = caches[i].Split(snap, newParent.endOffset)
}
}
newParent.setSnapshots(parentSnaps, counter)
node.setSnapshots(childSnaps, counter)
}
// Reparent: replace node with newParent in the old parent's children.
if node.parent != nil {
for i, child := range node.parent.children {
if child == node {
node.parent.children[i] = newParent
break
}
}
}
node.parent = newParent
return newParent
}
// mergeWithChild merges a node with its single child: concatenates tokens,
// merges snapshot data via Cache.Merge, and removes the child.
func mergeWithChild(node *trieNode, caches []cache.Cache, counter *int64) {
if len(node.children) != 1 {
panic(fmt.Sprintf("mergeWithChild called on node with %d children", len(node.children)))
}
child := node.children[0]
// Concatenate tokens.
node.tokens = append(node.tokens, child.tokens...)
node.endOffset = child.endOffset
// Merge snapshots per layer. Merge consumes the old snapshots, so we
// remove them first (adjusting the counter), then assign the merged
// result (adjusting it back).
if len(node.snapshots) > 0 || len(child.snapshots) > 0 {
nodeSnaps := node.swapSnapshots(nil, counter)
childSnaps := child.swapSnapshots(nil, counter)
merged := make([]cache.Snapshot, len(caches))
for i := range caches {
var ps, cs cache.Snapshot
if nodeSnaps != nil {
ps = nodeSnaps[i]
}
if childSnaps != nil {
cs = childSnaps[i]
}
merged[i] = caches[i].Merge(ps, cs)
}
node.setSnapshots(merged, counter)
}
// Adopt grandchildren.
node.children = child.children
for _, gc := range node.children {
gc.parent = node
}
// Inherit user flag from child if child was a user-created snapshot node.
node.user = child.user
// Update lastUsed to the more recent of the two.
if child.lastUsed.After(node.lastUsed) {
node.lastUsed = child.lastUsed
}
child.parent = nil
child.children = nil
}
// walkNodes calls fn for every node in the trie (depth-first).
// If fn returns false, the walk stops.
func walkNodes(root *trieNode, fn func(*trieNode) bool) {
if root == nil {
return
}
var walk func(*trieNode) bool
walk = func(n *trieNode) bool {
if !fn(n) {
return false
}
for _, child := range n.children {
if !walk(child) {
return false
}
}
return true
}
walk(root)
}

View File

@@ -0,0 +1,455 @@
package mlxrunner
import (
"slices"
"testing"
"time"
"github.com/ollama/ollama/x/mlxrunner/cache"
)
func newTestTrie(tokens []int32) *trieNode {
root := &trieNode{lastUsed: time.Now()}
if len(tokens) > 0 {
child := &trieNode{
tokens: slices.Clone(tokens),
endOffset: len(tokens),
parent: root,
lastUsed: time.Now(),
}
root.children = []*trieNode{child}
}
return root
}
func TestFindBestMatchMultipleBranches(t *testing.T) {
root := &trieNode{lastUsed: time.Now()}
branch1 := &trieNode{
tokens: []int32{1, 2, 3},
endOffset: 3,
parent: root,
lastUsed: time.Now(),
}
branch2 := &trieNode{
tokens: []int32{4, 5, 6},
endOffset: 3,
parent: root,
lastUsed: time.Now(),
}
root.children = []*trieNode{branch1, branch2}
// Match branch 1.
path, matched := findBestMatch(root, []int32{1, 2, 3, 7})
if matched != 3 {
t.Fatalf("expected 3 matched, got %d", matched)
}
if len(path) != 2 || path[1] != branch1 {
t.Fatal("expected to match branch1")
}
// Match branch 2.
path, matched = findBestMatch(root, []int32{4, 5, 6, 8})
if matched != 3 {
t.Fatalf("expected 3 matched, got %d", matched)
}
if len(path) != 2 || path[1] != branch2 {
t.Fatal("expected to match branch2")
}
// Match neither.
_, matched = findBestMatch(root, []int32{7, 8, 9})
if matched != 0 {
t.Fatalf("expected 0 matched, got %d", matched)
}
}
func TestFindBestMatchPrefersFullEdge(t *testing.T) {
root := &trieNode{lastUsed: time.Now()}
shared := &trieNode{
tokens: []int32{1, 2, 3},
endOffset: 3,
parent: root,
lastUsed: time.Now(),
}
root.children = []*trieNode{shared}
longer := &trieNode{
tokens: []int32{10, 11, 12, 13, 14},
endOffset: 8,
parent: shared,
lastUsed: time.Now(),
}
shorter := &trieNode{
tokens: []int32{10, 11, 12},
endOffset: 6,
parent: shared,
lastUsed: time.Now(),
}
// Put longer first so naive first-match would pick it.
shared.children = []*trieNode{longer, shorter}
input := []int32{1, 2, 3, 10, 11, 12, 99, 100}
path, matched := findBestMatch(root, input)
if matched != 6 {
t.Fatalf("expected 6 matched, got %d", matched)
}
if len(path) != 3 {
t.Fatalf("expected 3 nodes in path, got %d", len(path))
}
if path[2] != shorter {
t.Fatal("expected findBestMatch to pick shorter (full edge match), not longer (partial)")
}
}
func TestFindBestMatchPrefersLongerPartial(t *testing.T) {
root := &trieNode{lastUsed: time.Now()}
child1 := &trieNode{
tokens: []int32{1, 2, 3, 4, 5},
endOffset: 5,
parent: root,
lastUsed: time.Now(),
}
child2 := &trieNode{
tokens: []int32{1, 2, 9},
endOffset: 3,
parent: root,
lastUsed: time.Now(),
}
root.children = []*trieNode{child2, child1}
input := []int32{1, 2, 3, 7, 8}
path, matched := findBestMatch(root, input)
if matched != 3 {
t.Fatalf("expected 3 matched, got %d", matched)
}
if path[1] != child1 {
t.Fatal("expected findBestMatch to pick child1 (longer partial match)")
}
}
func TestSplitNodeWithSnapshots(t *testing.T) {
root := newTestTrie([]int32{1, 2, 3, 4, 5})
child := root.children[0]
rc := &fakeRewindableCache{tracker: &snapshotTracker{}, tokens: []int32{1, 2, 3, 4, 5}}
child.snapshots = []cache.Snapshot{rc.Snapshot(0)}
child.user = true
caches := []cache.Cache{rc}
newParent := splitNode(child, 3, caches, nil)
if !newParent.hasSnapshots() {
t.Fatal("newParent should have snapshots after split")
}
if newParent.user {
t.Fatal("newParent should not be a user snapshot after splitNode")
}
if !child.hasSnapshots() {
t.Fatal("child should have snapshots after split")
}
if !child.user {
t.Fatal("child should remain a user snapshot")
}
}
func TestFindSplitAppendSequence(t *testing.T) {
root := newTestTrie([]int32{1, 2, 3, 4, 5})
path, matched := findBestMatch(root, []int32{1, 2, 3, 6, 7})
if matched != 3 {
t.Fatalf("expected 3 matched, got %d", matched)
}
lastNode := path[len(path)-1]
matchedInEdge := matched - lastNode.startOffset()
split := splitNode(lastNode, matchedInEdge, nil, nil)
split.appendTokens(root, []int32{6, 7}, 5)
if len(root.children) != 1 {
t.Fatalf("root should have 1 child, got %d", len(root.children))
}
shared := root.children[0]
if !slices.Equal(shared.tokens, []int32{1, 2, 3}) {
t.Fatalf("shared tokens = %v, want [1,2,3]", shared.tokens)
}
if len(shared.children) != 2 {
t.Fatalf("shared should have 2 children, got %d", len(shared.children))
}
_, m1 := findBestMatch(root, []int32{1, 2, 3, 4, 5})
if m1 != 5 {
t.Fatalf("original branch: expected 5 matched, got %d", m1)
}
_, m2 := findBestMatch(root, []int32{1, 2, 3, 6, 7})
if m2 != 5 {
t.Fatalf("new branch: expected 5 matched, got %d", m2)
}
_, m3 := findBestMatch(root, []int32{1, 2, 3, 9, 9})
if m3 != 3 {
t.Fatalf("unrelated input: expected 3 matched, got %d", m3)
}
}
func TestRepeatedBranching(t *testing.T) {
root := &trieNode{lastUsed: time.Now()}
root.appendTokens(root, []int32{1, 2, 3, 4, 5}, 5)
_, matchedB := findBestMatch(root, []int32{1, 2, 3, 6, 7})
if matchedB != 3 {
t.Fatalf("B: expected 3 matched, got %d", matchedB)
}
nodeA := root.children[0]
split1 := splitNode(nodeA, 3, nil, nil)
split1.appendTokens(root, []int32{6, 7}, 5)
_, matchedC := findBestMatch(root, []int32{1, 2, 8, 9})
if matchedC != 2 {
t.Fatalf("C: expected 2 matched, got %d", matchedC)
}
split2 := splitNode(split1, 2, nil, nil)
split2.appendTokens(root, []int32{8, 9}, 4)
_, mA := findBestMatch(root, []int32{1, 2, 3, 4, 5})
if mA != 5 {
t.Fatalf("A: expected 5 matched, got %d", mA)
}
_, mB := findBestMatch(root, []int32{1, 2, 3, 6, 7})
if mB != 5 {
t.Fatalf("B: expected 5 matched, got %d", mB)
}
_, mC := findBestMatch(root, []int32{1, 2, 8, 9})
if mC != 4 {
t.Fatalf("C: expected 4 matched, got %d", mC)
}
checkTrieInvariants(t, root)
}
func TestMergeWithChild(t *testing.T) {
t.Run("Basic", func(t *testing.T) {
// root -> A[1,2,3] -> B[4,5] -> {C[6], D[7]}
now := time.Now()
root := &trieNode{lastUsed: now}
a := &trieNode{
tokens: []int32{1, 2, 3},
endOffset: 3,
parent: root,
lastUsed: now,
snapshots: []cache.Snapshot{&fakeSnapshot{tokens: []int32{1, 2, 3}, from: 0, to: 3}},
}
b := &trieNode{
tokens: []int32{4, 5},
endOffset: 5,
parent: a,
lastUsed: now,
snapshots: []cache.Snapshot{&fakeSnapshot{tokens: []int32{4, 5}, from: 3, to: 5}},
}
c := &trieNode{tokens: []int32{6}, endOffset: 6, parent: b, lastUsed: now}
d := &trieNode{tokens: []int32{7}, endOffset: 6, parent: b, lastUsed: now}
root.children = []*trieNode{a}
a.children = []*trieNode{b}
b.children = []*trieNode{c, d}
mc := &fakeRewindableCache{tracker: &snapshotTracker{}, tokens: []int32{1, 2, 3, 4, 5}}
mergeWithChild(a, []cache.Cache{mc}, nil)
// Tokens concatenated.
if !slices.Equal(a.tokens, []int32{1, 2, 3, 4, 5}) {
t.Fatalf("merged tokens = %v, want [1,2,3,4,5]", a.tokens)
}
if a.endOffset != 5 {
t.Fatalf("merged endOffset = %d, want 5", a.endOffset)
}
// Grandchildren reparented.
if len(a.children) != 2 {
t.Fatalf("merged children count = %d, want 2", len(a.children))
}
if c.parent != a || d.parent != a {
t.Fatal("grandchildren should be reparented to merged node")
}
// B detached.
if b.parent != nil || b.children != nil || b.snapshots != nil {
t.Fatal("child B should be fully detached after merge")
}
// Merged snapshot should cover [0,5).
if !a.hasSnapshots() {
t.Fatal("merged node should have snapshots")
}
ms := a.snapshots[0].(*fakeSnapshot)
if ms.from != 0 || ms.to != 5 {
t.Fatalf("merged snapshot = [%d,%d), want [0,5)", ms.from, ms.to)
}
checkTrieInvariants(t, root)
})
t.Run("UserFlag", func(t *testing.T) {
root := &trieNode{lastUsed: time.Now()}
parent := &trieNode{
tokens: []int32{1, 2}, endOffset: 2, parent: root,
lastUsed: time.Now(), user: false,
}
child := &trieNode{
tokens: []int32{3, 4}, endOffset: 4, parent: parent,
lastUsed: time.Now(), user: true,
}
root.children = []*trieNode{parent}
parent.children = []*trieNode{child}
mergeWithChild(parent, nil, nil)
if !parent.user {
t.Fatal("merged node should inherit user=true from child")
}
})
t.Run("LastUsed", func(t *testing.T) {
now := time.Now()
root := &trieNode{lastUsed: now}
parent := &trieNode{
tokens: []int32{1}, endOffset: 1, parent: root,
lastUsed: now.Add(-1 * time.Hour),
}
child := &trieNode{
tokens: []int32{2}, endOffset: 2, parent: parent,
lastUsed: now.Add(1 * time.Hour),
}
root.children = []*trieNode{parent}
parent.children = []*trieNode{child}
mergeWithChild(parent, nil, nil)
if !parent.lastUsed.Equal(now.Add(1 * time.Hour)) {
t.Fatal("merged node should pick the more recent lastUsed")
}
})
t.Run("PanicOnMultipleChildren", func(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Fatal("expected panic on node with 2 children")
}
}()
root := &trieNode{lastUsed: time.Now()}
node := &trieNode{
tokens: []int32{1}, endOffset: 1, parent: root, lastUsed: time.Now(),
children: []*trieNode{
{tokens: []int32{2}, endOffset: 2, lastUsed: time.Now()},
{tokens: []int32{3}, endOffset: 2, lastUsed: time.Now()},
},
}
root.children = []*trieNode{node}
mergeWithChild(node, nil, nil)
})
}
func TestSplitMergeRoundTrip(t *testing.T) {
root := &trieNode{lastUsed: time.Now()}
leaf := &trieNode{
tokens: []int32{1, 2, 3, 4, 5},
endOffset: 5,
parent: root,
lastUsed: time.Now(),
snapshots: []cache.Snapshot{&fakeSnapshot{tokens: []int32{1, 2, 3, 4, 5}, from: 0, to: 5}},
}
root.children = []*trieNode{leaf}
mc := &fakeRewindableCache{tracker: &snapshotTracker{}, tokens: []int32{1, 2, 3, 4, 5}}
caches := []cache.Cache{mc}
// Split at 3: [1,2,3] -> [4,5]
newParent := splitNode(leaf, 3, caches, nil)
if !slices.Equal(newParent.tokens, []int32{1, 2, 3}) {
t.Fatalf("after split: parent tokens = %v, want [1,2,3]", newParent.tokens)
}
if !slices.Equal(leaf.tokens, []int32{4, 5}) {
t.Fatalf("after split: child tokens = %v, want [4,5]", leaf.tokens)
}
checkTrieInvariants(t, root)
// Merge back: should restore [1,2,3,4,5]
mergeWithChild(newParent, caches, nil)
if !slices.Equal(newParent.tokens, []int32{1, 2, 3, 4, 5}) {
t.Fatalf("after merge: tokens = %v, want [1,2,3,4,5]", newParent.tokens)
}
if newParent.endOffset != 5 {
t.Fatalf("after merge: endOffset = %d, want 5", newParent.endOffset)
}
if len(newParent.children) != 0 {
t.Fatalf("after merge: children count = %d, want 0", len(newParent.children))
}
// Merged snapshot should cover [0,5).
if !newParent.hasSnapshots() {
t.Fatal("after merge: should have snapshots")
}
ms := newParent.snapshots[0].(*fakeSnapshot)
if ms.from != 0 || ms.to != 5 {
t.Fatalf("after merge: snapshot = [%d,%d), want [0,5)", ms.from, ms.to)
}
checkTrieInvariants(t, root)
}
func TestRemoveNode(t *testing.T) {
t.Run("Leaf", func(t *testing.T) {
root := &trieNode{lastUsed: time.Now()}
shared := &trieNode{
tokens: []int32{1, 2, 3}, endOffset: 3, parent: root, lastUsed: time.Now(),
}
leafA := &trieNode{
tokens: []int32{4, 5}, endOffset: 5, parent: shared, lastUsed: time.Now(),
snapshots: []cache.Snapshot{&fakeSnapshot{from: 3, to: 5}},
}
leafB := &trieNode{
tokens: []int32{6, 7}, endOffset: 5, parent: shared, lastUsed: time.Now(),
snapshots: []cache.Snapshot{&fakeSnapshot{from: 3, to: 5}},
}
root.children = []*trieNode{shared}
shared.children = []*trieNode{leafA, leafB}
removeNode(leafA, nil)
if len(shared.children) != 1 {
t.Fatalf("parent should have 1 child, got %d", len(shared.children))
}
if shared.children[0] != leafB {
t.Fatal("remaining child should be leafB")
}
if leafA.parent != nil {
t.Fatal("removed node parent should be nil")
}
if leafA.snapshots != nil {
t.Fatal("removed node snapshots should be nil")
}
checkTrieInvariants(t, root)
})
t.Run("PanicOnRoot", func(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Fatal("expected panic when removing root")
}
}()
removeNode(&trieNode{}, nil)
})
t.Run("PanicOnNonLeaf", func(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Fatal("expected panic when removing non-leaf")
}
}()
parent := &trieNode{parent: &trieNode{}}
parent.children = []*trieNode{{}}
removeNode(parent, nil)
})
}

476
x/mlxrunner/client.go Normal file
View File

@@ -0,0 +1,476 @@
package mlxrunner
import (
"bufio"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"math/rand"
"net"
"net/http"
"os"
"os/exec"
"path/filepath"
"runtime"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/manifest"
)
// Client wraps an MLX runner subprocess to implement llm.LlamaServer for LLM models.
type Client struct {
port int
modelName string
contextLength atomic.Int64
memory atomic.Uint64
done chan struct{}
doneErr error // valid after done is closed
client *http.Client
status *llm.StatusWriter
mu sync.Mutex
cmd *exec.Cmd
}
// NewClient prepares a new MLX runner client for LLM models.
// The subprocess is not started until Load() is called.
func NewClient(modelName string) (*Client, error) {
if err := imagegen.CheckPlatformSupport(); err != nil {
return nil, err
}
c := &Client{
modelName: modelName,
done: make(chan struct{}),
client: http.DefaultClient,
}
modelManifest, err := manifest.LoadManifest(modelName)
if err != nil {
return nil, err
}
c.memory.Store(uint64(modelManifest.TotalTensorSize()))
return c, nil
}
// WaitUntilRunning waits for the subprocess to be ready.
func (c *Client) WaitUntilRunning(ctx context.Context) error {
timeout := time.After(2 * time.Minute)
ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-c.done:
if msg := c.status.LastError(); msg != "" {
return fmt.Errorf("mlx runner failed: %s (exit: %v)", msg, c.doneErr)
}
return fmt.Errorf("mlx runner exited unexpectedly: %w", c.doneErr)
case <-timeout:
if msg := c.status.LastError(); msg != "" {
return fmt.Errorf("timeout waiting for mlx runner: %s", msg)
}
return errors.New("timeout waiting for mlx runner to start")
case <-ticker.C:
if err := c.Ping(ctx); err == nil {
slog.Info("mlx runner is ready", "port", c.port)
return nil
}
}
}
}
type CompletionRequest struct {
Prompt string
Options api.Options
Logprobs bool
TopLogprobs int
}
type CompletionResponse struct {
Content string
Done bool
DoneReason int
PromptEvalCount int
PromptEvalDuration time.Duration
EvalCount int
EvalDuration time.Duration
Logprobs []llm.Logprob
Error *api.StatusError
}
// Close terminates the subprocess.
func (c *Client) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
if c.cmd != nil && c.cmd.Process != nil {
slog.Info("stopping mlx runner subprocess", "pid", c.cmd.Process.Pid)
c.cmd.Process.Signal(os.Interrupt)
select {
case <-c.done:
case <-time.After(5 * time.Second):
c.cmd.Process.Kill()
}
c.cmd = nil
}
return nil
}
// Completion implements llm.LlamaServer.
func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
creq := CompletionRequest{
Prompt: req.Prompt,
Logprobs: req.Logprobs,
TopLogprobs: req.TopLogprobs,
}
if req.Options != nil {
creq.Options = *req.Options
}
body, err := json.Marshal(creq)
if err != nil {
return err
}
httpURL := fmt.Sprintf("http://127.0.0.1:%d/completion", c.port)
httpReq, err := http.NewRequestWithContext(ctx, "POST", httpURL, strings.NewReader(string(body)))
if err != nil {
return err
}
httpReq.Header.Set("Content-Type", "application/json")
resp, err := c.client.Do(httpReq)
if err != nil {
if errMsg := c.status.LastError(); errMsg != "" {
return fmt.Errorf("mlx runner failed: %s", errMsg)
}
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
respBody, _ := io.ReadAll(resp.Body)
return api.StatusError{StatusCode: resp.StatusCode, ErrorMessage: strings.TrimSpace(string(respBody))}
}
scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
var raw CompletionResponse
if err := json.Unmarshal(scanner.Bytes(), &raw); err != nil {
slog.Debug("mlx response parse error", "error", err, "line", string(scanner.Bytes()))
continue
}
if raw.Error != nil {
return *raw.Error
}
cresp := llm.CompletionResponse{
Content: raw.Content,
Done: raw.Done,
DoneReason: llm.DoneReason(raw.DoneReason),
PromptEvalCount: raw.PromptEvalCount,
PromptEvalDuration: raw.PromptEvalDuration,
EvalCount: raw.EvalCount,
EvalDuration: raw.EvalDuration,
Logprobs: raw.Logprobs,
}
fn(cresp)
if cresp.Done {
return nil
}
}
if err := scanner.Err(); err != nil {
if errMsg := c.status.LastError(); errMsg != "" {
return fmt.Errorf("mlx runner failed: %s", errMsg)
}
return err
}
return nil
}
func (c *Client) ContextLength() int {
return int(c.contextLength.Load())
}
// Detokenize implements llm.LlamaServer.
func (c *Client) Detokenize(ctx context.Context, tokens []int) (string, error) {
return "", errors.New("not supported")
}
// Embedding implements llm.LlamaServer.
func (c *Client) Embedding(ctx context.Context, input string) ([]float32, int, error) {
return nil, 0, errors.New("not supported")
}
// GetDeviceInfos implements llm.LlamaServer.
func (c *Client) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
return nil
}
// GetPort implements llm.LlamaServer.
func (c *Client) GetPort() int {
return c.port
}
// HasExited implements llm.LlamaServer.
func (c *Client) HasExited() bool {
select {
case <-c.done:
return true
default:
return false
}
}
// Load checks whether the model fits in GPU memory and starts the subprocess.
func (c *Client) Load(ctx context.Context, _ ml.SystemInfo, gpus []ml.DeviceInfo, requireFull bool) ([]ml.DeviceID, error) {
if len(gpus) > 0 {
modelSize := c.memory.Load()
// We currently only use the first GPU with MLX
available := gpus[0].FreeMemory
overhead := gpus[0].MinimumMemory() + envconfig.GpuOverhead()
if available > overhead {
available -= overhead
} else {
available = 0
}
if modelSize > available {
if requireFull {
return nil, llm.ErrLoadRequiredFull
}
return nil, fmt.Errorf("model requires %s but only %s are available (after %s overhead)", format.HumanBytes2(modelSize), format.HumanBytes2(available), format.HumanBytes2(overhead))
}
}
// Find a free port
port := 0
if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
if l, err := net.ListenTCP("tcp", a); err == nil {
port = l.Addr().(*net.TCPAddr).Port
l.Close()
}
}
if port == 0 {
port = rand.Intn(65535-49152) + 49152
}
c.port = port
// Get the current executable path
exe, err := os.Executable()
if err != nil {
return nil, fmt.Errorf("unable to lookup executable path: %w", err)
}
if eval, err := filepath.EvalSymlinks(exe); err == nil {
exe = eval
}
// Spawn subprocess: ollama runner --mlx-engine --model <name> --port <port>
cmd := exec.Command(exe, "runner", "--mlx-engine", "--model", c.modelName, "--port", strconv.Itoa(port))
cmd.Env = os.Environ()
// Set library path environment variable for MLX libraries
// Linux: LD_LIBRARY_PATH, Windows: PATH
var libPathEnvVar string
switch runtime.GOOS {
case "linux":
libPathEnvVar = "LD_LIBRARY_PATH"
case "windows":
libPathEnvVar = "PATH"
}
if libPathEnvVar != "" {
libraryPaths := []string{ml.LibOllamaPath}
if mlxDirs, err := filepath.Glob(filepath.Join(ml.LibOllamaPath, "mlx_*")); err == nil {
libraryPaths = append(libraryPaths, mlxDirs...)
}
if existingPath, ok := os.LookupEnv(libPathEnvVar); ok {
libraryPaths = append(libraryPaths, filepath.SplitList(existingPath)...)
}
pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
found := false
for i := range cmd.Env {
envName := cmd.Env[i]
if runtime.GOOS == "windows" {
envName = strings.ToUpper(envName)
}
if strings.HasPrefix(envName, libPathEnvVar+"=") {
cmd.Env[i] = libPathEnvVar + "=" + pathEnvVal
found = true
break
}
}
if !found {
cmd.Env = append(cmd.Env, libPathEnvVar+"="+pathEnvVal)
}
slog.Debug("mlx subprocess library path", libPathEnvVar, pathEnvVal)
}
// Point MLX's JIT compiler at our bundled CUDA runtime headers.
// MLX resolves headers via $CUDA_PATH/include/*.h (and checks CUDA_HOME first).
// Always use bundled headers to avoid version mismatches with any
// system-installed CUDA toolkit.
if mlxDirs, err := filepath.Glob(filepath.Join(ml.LibOllamaPath, "mlx_cuda_*")); err == nil {
for _, d := range mlxDirs {
if _, err := os.Stat(filepath.Join(d, "include")); err == nil {
setEnv(cmd, "CUDA_PATH", d)
setEnv(cmd, "CUDA_HOME", d)
slog.Debug("mlx subprocess CUDA headers", "CUDA_PATH", d)
break
}
}
}
c.cmd = cmd
status := llm.NewStatusWriter(os.Stderr)
c.status = status
// os/exec serializes Write calls when shared, which keeps the status writer
// from seeing concurrent stdout/stderr fragments.
cmd.Stdout = status
cmd.Stderr = status
slog.Info("starting mlx runner subprocess", "model", c.modelName, "port", c.port)
if err := cmd.Start(); err != nil {
return nil, fmt.Errorf("failed to start mlx runner: %w", err)
}
// Reap subprocess when it exits
go func() {
c.doneErr = cmd.Wait()
close(c.done)
}()
return nil, nil
}
// ModelPath implements llm.LlamaServer.
func (c *Client) ModelPath() string {
return c.modelName
}
// Pid implements llm.LlamaServer.
func (c *Client) Pid() int {
c.mu.Lock()
defer c.mu.Unlock()
if c.cmd != nil && c.cmd.Process != nil {
return c.cmd.Process.Pid
}
return -1
}
type statusResponse struct {
Status int
Progress int
ContextLength int
Memory uint64
}
// Ping implements llm.LlamaServer.
func (c *Client) Ping(ctx context.Context) error {
reqURL := fmt.Sprintf("http://127.0.0.1:%d/v1/status", c.port)
req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil)
if err != nil {
return err
}
resp, err := c.client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("health check failed: %d", resp.StatusCode)
}
var status statusResponse
if err := json.NewDecoder(resp.Body).Decode(&status); err != nil {
return err
}
c.contextLength.Store(int64(status.ContextLength))
c.memory.Store(status.Memory)
return nil
}
// Tokenize implements llm.LlamaServer.
func (c *Client) Tokenize(ctx context.Context, content string) ([]int, error) {
reqURL := fmt.Sprintf("http://127.0.0.1:%d/v1/tokenize", c.port)
req, err := http.NewRequestWithContext(ctx, "POST", reqURL, strings.NewReader(content))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "text/plain")
resp, err := c.client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
var tokens []int
if err := json.NewDecoder(resp.Body).Decode(&tokens); err != nil {
return nil, err
}
return tokens, nil
}
func (c *Client) currentMemory() uint64 {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
c.Ping(ctx) //nolint:errcheck
return c.memory.Load()
}
// MemorySize implements llm.LlamaServer.
func (c *Client) MemorySize() (total, vram uint64) {
mem := c.currentMemory()
return mem, mem
}
// VRAMByGPU implements llm.LlamaServer.
func (c *Client) VRAMByGPU(id ml.DeviceID) uint64 {
return c.currentMemory()
}
var _ llm.LlamaServer = (*Client)(nil)
// setEnv sets or replaces an environment variable in cmd.Env.
func setEnv(cmd *exec.Cmd, key, value string) {
entry := key + "=" + value
prefix := strings.ToUpper(key + "=")
for i, e := range cmd.Env {
if strings.HasPrefix(strings.ToUpper(e), prefix) {
cmd.Env[i] = entry
return
}
}
cmd.Env = append(cmd.Env, entry)
}

715
x/mlxrunner/dflash.go Normal file
View File

@@ -0,0 +1,715 @@
package mlxrunner
import (
"context"
"fmt"
"log/slog"
"time"
"github.com/ollama/ollama/x/mlxrunner/batch"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model/base"
sampler "github.com/ollama/ollama/x/mlxrunner/sample"
)
type dflashStats struct {
iterations int
drafted int
accepted int
mismatches int
allAccepted int
batched int
serial int
targetDuration time.Duration
draftDuration time.Duration
validateDuration time.Duration
}
type dflashDecodeMode string
const (
dflashDecodeDisabled dflashDecodeMode = ""
dflashDecodeGreedy dflashDecodeMode = "greedy"
dflashDecodeSample dflashDecodeMode = "sample"
)
func (m dflashDecodeMode) enabled() bool {
return m != dflashDecodeDisabled
}
func newDFlashTargetCaches(m base.Model) []cache.Cache {
if cacheFactory, ok := m.(interface{ NewCaches() []cache.Cache }); ok {
return cacheFactory.NewCaches()
}
caches := make([]cache.Cache, m.NumLayers())
for i := range caches {
caches[i] = cache.NewKVCache()
}
return caches
}
func freeCacheSet(caches []cache.Cache) {
for _, c := range caches {
if c != nil {
c.Free()
}
}
}
func (r *Runner) dflashGate(opts sampler.Options) (dflashDecodeMode, string) {
if r.Draft == nil {
return dflashDecodeDisabled, "no_draft"
}
if _, ok := r.Draft.(base.DFlashDraftModel); !ok {
return dflashDecodeDisabled, "draft_not_dflash"
}
if _, ok := r.Model.(base.DFlashTargetModel); !ok {
return dflashDecodeDisabled, "target_not_dflash"
}
if _, ok := r.Model.(base.MTPEmbeddingModel); !ok {
return dflashDecodeDisabled, "target_embeddings_missing"
}
if opts.Logprobs || opts.TopLogprobs > 0 {
return dflashDecodeDisabled, "logprobs_requested"
}
if opts.Temperature > 0 || dflashUsesSamplerHistory(opts) {
return dflashDecodeSample, ""
}
return dflashDecodeGreedy, ""
}
func dflashUsesSamplerHistory(opts sampler.Options) bool {
if opts.RepeatLastN == 0 {
return false
}
repeatPenalty := opts.RepeatPenalty
if repeatPenalty <= 0 {
repeatPenalty = 1
}
return repeatPenalty != 1 || opts.PresencePenalty != 0 || opts.FrequencyPenalty != 0
}
func (r *Runner) runGreedyDFlashDecode(ctx context.Context, request Request, session *cacheSession, targetCaches []cache.Cache, draftCaches []cache.Cache, seed []int32, position *int, started time.Time) error {
target := r.Model.(base.DFlashTargetModel)
draft := r.Draft.(base.DFlashDraftModel)
stats := dflashStats{}
slog.Info("DFlash greedy decode enabled", "block_size", draft.BlockSize(), "target_layers", draft.TargetLayerIDs())
targetForward := func(token *mlx.Array) (*mlx.Array, *mlx.Array) {
hidden, targetHidden := target.ForwardDFlash(&batch.Batch{
InputIDs: token,
SeqOffsets: []int32{int32(*position)},
SeqQueryLens: []int32{int32(token.Dim(1))},
}, targetCaches, draft.TargetLayerIDs())
*position += token.Dim(1)
return hidden, targetHidden
}
t0 := time.Now()
hidden, targetHidden := targetForward(mlx.FromValues(seed, 1, len(seed)))
draft.AppendContext(targetHidden, draftCaches)
current := sampler.Result{Token: greedyTokenFromLogits(r.lastLogits(hidden))}
mlx.Pin(current.Arrays()...)
mlx.Sweep()
mlx.AsyncEval(current.Arrays()...)
stats.targetDuration += time.Since(t0)
defer func() {
mlx.Unpin(current.Arrays()...)
}()
dec := decoder{tokenizer: r.Tokenizer}
final := CompletionResponse{Done: true, PromptEvalCount: len(request.Tokens), DoneReason: 1}
now := started
generated := 0
for generated < request.Options.NumPredict {
if err := ctx.Err(); err != nil {
return err
}
if generated == 0 {
mlx.Eval(current.Arrays()...)
final.PromptEvalDuration = time.Since(now)
now = time.Now()
}
done, err := r.emitMTPToken(ctx, request, session, &dec, current, &final)
if err != nil {
return err
}
if !done {
generated++
}
if done || generated >= request.Options.NumPredict {
break
}
draftCount := min(draft.BlockSize()-1, request.Options.NumPredict-generated)
if draftCount <= 0 {
t0 = time.Now()
hidden, targetHidden := targetForward(mtpTokenInput(current.Token))
draft.AppendContext(targetHidden, draftCaches)
stats.targetDuration += time.Since(t0)
next := sampler.Result{Token: greedyTokenFromLogits(r.lastLogits(hidden))}
mlx.Pin(next.Arrays()...)
old := current
current = next
mlx.Unpin(old.Arrays()...)
mlx.Sweep()
mlx.AsyncEval(current.Arrays()...)
continue
}
stats.iterations++
t0 = time.Now()
draftTokens := r.generateDFlashDrafts(draft, current.Token, draftCaches, draftCount)
mlx.Pin(draftTokens)
mlx.Eval(draftTokens)
stats.draftDuration += time.Since(t0)
stats.drafted += draftCount
t0 = time.Now()
next, accepted, done, err := r.acceptDFlashDrafts(ctx, request, session, &dec, target, draft, targetCaches, draftCaches, position, current, draftTokens, &final, &generated, &stats)
stats.validateDuration += time.Since(t0)
mlx.Unpin(draftTokens)
if err != nil {
return err
}
stats.accepted += accepted
if accepted == draftCount {
stats.allAccepted++
} else {
stats.mismatches++
}
if done || generated >= request.Options.NumPredict {
break
}
mlx.Pin(next.Arrays()...)
old := current
current = next
mlx.Unpin(old.Arrays()...)
mlx.Sweep()
mlx.AsyncEval(current.Arrays()...)
if generated%256 == 0 {
mlx.ClearCache()
}
}
final.EvalCount = generated
final.EvalDuration = time.Since(now)
acceptance := 0.0
if stats.drafted > 0 {
acceptance = float64(stats.accepted) / float64(stats.drafted)
}
avgDraft := 0.0
avgAccepted := 0.0
if stats.iterations > 0 {
avgDraft = float64(stats.drafted) / float64(stats.iterations)
avgAccepted = float64(stats.accepted) / float64(stats.iterations)
}
slog.Info("DFlash decode stats", "mode", "greedy", "generated", generated, "drafted", stats.drafted, "accepted", stats.accepted, "acceptance", acceptance, "iterations", stats.iterations, "avg_draft", avgDraft, "avg_accepted", avgAccepted, "batched", stats.batched, "serial", stats.serial, "mismatches", stats.mismatches, "all_accepted", stats.allAccepted, "max_draft", draft.BlockSize()-1, "block_size", draft.BlockSize(), "target_layers", draft.TargetLayerIDs(), "target_duration", stats.targetDuration, "draft_duration", stats.draftDuration, "validate_duration", stats.validateDuration)
select {
case <-ctx.Done():
return ctx.Err()
case request.Responses <- final:
return nil
}
}
func (r *Runner) runSampleDFlashDecode(ctx context.Context, request Request, session *cacheSession, targetCaches []cache.Cache, draftCaches []cache.Cache, seed []int32, position *int, started time.Time) error {
target := r.Model.(base.DFlashTargetModel)
draft := r.Draft.(base.DFlashDraftModel)
stats := dflashStats{}
slog.Info("DFlash sample decode enabled",
"block_size", draft.BlockSize(),
"target_layers", draft.TargetLayerIDs(),
"temperature", request.SamplerOpts.Temperature,
"top_p", request.SamplerOpts.TopP,
"top_k", request.SamplerOpts.TopK,
"min_p", request.SamplerOpts.MinP,
"repeat_penalty", request.SamplerOpts.RepeatPenalty,
"presence_penalty", request.SamplerOpts.PresencePenalty,
"frequency_penalty", request.SamplerOpts.FrequencyPenalty,
)
targetForward := func(token *mlx.Array) (*mlx.Array, *mlx.Array) {
hidden, targetHidden := target.ForwardDFlash(&batch.Batch{
InputIDs: token,
SeqOffsets: []int32{int32(*position)},
SeqQueryLens: []int32{int32(token.Dim(1))},
}, targetCaches, draft.TargetLayerIDs())
*position += token.Dim(1)
return hidden, targetHidden
}
t0 := time.Now()
hidden, targetHidden := targetForward(mlx.FromValues(seed, 1, len(seed)))
draft.AppendContext(targetHidden, draftCaches)
current := r.Sampler.Sample([]int{pipelineSlot}, r.lastLogits(hidden))
mlx.Pin(current.Arrays()...)
mlx.Sweep()
mlx.AsyncEval(current.Arrays()...)
stats.targetDuration += time.Since(t0)
defer func() {
mlx.Unpin(current.Arrays()...)
}()
dec := decoder{tokenizer: r.Tokenizer}
final := CompletionResponse{Done: true, PromptEvalCount: len(request.Tokens), DoneReason: 1}
now := started
generated := 0
for generated < request.Options.NumPredict {
if err := ctx.Err(); err != nil {
return err
}
if generated == 0 {
mlx.Eval(current.Arrays()...)
final.PromptEvalDuration = time.Since(now)
now = time.Now()
}
done, err := r.emitMTPToken(ctx, request, session, &dec, current, &final)
if err != nil {
return err
}
if !done {
generated++
}
if done || generated >= request.Options.NumPredict {
break
}
draftCount := min(draft.BlockSize()-1, request.Options.NumPredict-generated)
if draftCount <= 0 {
t0 = time.Now()
hidden, targetHidden := targetForward(mtpTokenInput(current.Token))
draft.AppendContext(targetHidden, draftCaches)
stats.targetDuration += time.Since(t0)
next := r.Sampler.Sample([]int{pipelineSlot}, r.lastLogits(hidden))
mlx.Pin(next.Arrays()...)
old := current
current = next
mlx.Unpin(old.Arrays()...)
mlx.Sweep()
mlx.AsyncEval(current.Arrays()...)
continue
}
stats.iterations++
t0 = time.Now()
candidates := r.generateDFlashDraftCandidates(draft, current.Token, draftCaches, draftCount)
var candidateArrays []*mlx.Array
if candidates != nil {
draftCount = candidates.tokens.Dim(1)
candidateArrays = candidates.Arrays()
mlx.Pin(candidateArrays...)
mlx.Sweep()
}
stats.draftDuration += time.Since(t0)
stats.drafted += draftCount
var next sampler.Result
if draftCount == 0 {
t0 = time.Now()
hidden, targetHidden := targetForward(mtpTokenInput(current.Token))
draft.AppendContext(targetHidden, draftCaches)
stats.targetDuration += time.Since(t0)
next = r.Sampler.Sample([]int{pipelineSlot}, r.lastLogits(hidden))
} else {
var accepted int
t0 = time.Now()
next, accepted, done, err = r.acceptSampleDFlashDrafts(ctx, request, session, &dec, target, draft, targetCaches, draftCaches, position, current, candidates, &final, &generated, &stats)
stats.validateDuration += time.Since(t0)
mlx.Unpin(candidateArrays...)
if err != nil {
return err
}
stats.accepted += accepted
if accepted == draftCount {
stats.allAccepted++
} else {
stats.mismatches++
}
if next.Token == nil {
mlx.Sweep()
}
if done || generated >= request.Options.NumPredict {
break
}
}
mlx.Pin(next.Arrays()...)
old := current
current = next
mlx.Unpin(old.Arrays()...)
mlx.Sweep()
mlx.AsyncEval(current.Arrays()...)
if generated%256 == 0 {
mlx.ClearCache()
}
}
final.EvalCount = generated
final.EvalDuration = time.Since(now)
acceptance := 0.0
if stats.drafted > 0 {
acceptance = float64(stats.accepted) / float64(stats.drafted)
}
avgDraft := 0.0
avgAccepted := 0.0
if stats.iterations > 0 {
avgDraft = float64(stats.drafted) / float64(stats.iterations)
avgAccepted = float64(stats.accepted) / float64(stats.iterations)
}
slog.Info("DFlash decode stats", "mode", "sample", "generated", generated, "drafted", stats.drafted, "accepted", stats.accepted, "acceptance", acceptance, "iterations", stats.iterations, "avg_draft", avgDraft, "avg_accepted", avgAccepted, "batched", stats.batched, "serial", stats.serial, "mismatches", stats.mismatches, "all_accepted", stats.allAccepted, "max_draft", draft.BlockSize()-1, "block_size", draft.BlockSize(), "target_layers", draft.TargetLayerIDs(), "target_duration", stats.targetDuration, "draft_duration", stats.draftDuration, "validate_duration", stats.validateDuration)
select {
case <-ctx.Done():
return ctx.Err()
case request.Responses <- final:
return nil
}
}
func (r *Runner) dflashDraftLogits(draft base.DFlashDraftModel, current *mlx.Array, caches []cache.Cache, draftCount int) *mlx.Array {
blockLen := draftCount + 1
values := make([]int32, blockLen)
values[0] = int32(tokenID(current))
for i := 1; i < blockLen; i++ {
values[i] = draft.MaskTokenID()
}
block := mlx.FromValues(values, 1, blockLen)
logits := draft.Draft(block, caches)
return logits.Slice(mlx.Slice(), mlx.Slice(1, blockLen), mlx.Slice())
}
func (r *Runner) generateDFlashDrafts(draft base.DFlashDraftModel, current *mlx.Array, caches []cache.Cache, draftCount int) *mlx.Array {
logits := r.dflashDraftLogits(draft, current, caches, draftCount)
return logits.Argmax(-1, false).AsType(mlx.DTypeInt32)
}
type dflashDraftCandidates struct {
tokens *mlx.Array
dist sampler.Distribution
}
func (c *dflashDraftCandidates) Arrays() []*mlx.Array {
if c == nil {
return nil
}
return append([]*mlx.Array{c.tokens}, c.dist.Arrays()...)
}
func (r *Runner) generateDFlashDraftCandidates(draft base.DFlashDraftModel, current *mlx.Array, caches []cache.Cache, draftCount int) *dflashDraftCandidates {
if draftCount <= 0 {
return nil
}
logits := r.dflashDraftLogits(draft, current, caches, draftCount)
draftTokens := make([]*mlx.Array, 0, draftCount)
draftDists := make([]sampler.Distribution, 0, draftCount)
var prefix *mlx.Array
for i := range draftCount {
rows := logits.Slice(mlx.Slice(), mlx.Slice(0, i+1), mlx.Slice())
dist := r.Sampler.Distribution(pipelineSlot, rows, prefix).SliceRows(i, i+1)
nextToken := mtpTokenVector(r.Sampler.SampleDistribution(pipelineSlot, dist))
nextInput := mtpTokenInput(nextToken)
draftTokens = append(draftTokens, nextInput)
draftDists = append(draftDists, dist)
if prefix == nil {
prefix = nextInput
} else {
prefix = prefix.Concatenate(1, nextInput)
}
}
if len(draftTokens) == 0 {
return nil
}
return &dflashDraftCandidates{
tokens: mlx.Concatenate(draftTokens, 1),
dist: sampler.ConcatenateDistributions(draftDists),
}
}
func (r *Runner) acceptDFlashDrafts(ctx context.Context, request Request, session *cacheSession, dec *decoder, target base.DFlashTargetModel, draft base.DFlashDraftModel, targetCaches []cache.Cache, draftCaches []cache.Cache, position *int, current sampler.Result, draftTokens *mlx.Array, final *CompletionResponse, generated *int, stats *dflashStats) (sampler.Result, int, bool, error) {
specCaches, spec, ok := cache.BeginSpeculation(targetCaches)
if !ok {
stats.serial++
return r.acceptDFlashDraftsSerial(ctx, request, session, dec, target, draft, targetCaches, draftCaches, position, current, draftTokens, final, generated)
}
stats.batched++
return r.acceptDFlashDraftsBatched(ctx, request, session, dec, target, draft, specCaches, spec, draftCaches, position, current, draftTokens, final, generated)
}
func (r *Runner) acceptDFlashDraftsBatched(ctx context.Context, request Request, session *cacheSession, dec *decoder, target base.DFlashTargetModel, draft base.DFlashDraftModel, specCaches []cache.Cache, spec *cache.Speculation, draftCaches []cache.Cache, position *int, current sampler.Result, draftTokens *mlx.Array, final *CompletionResponse, generated *int) (sampler.Result, int, bool, error) {
before := *position
draftCount := draftTokens.Dim(1)
verifyInput := mtpTokenInput(current.Token).Concatenate(1, draftTokens)
hiddenSeq, targetHiddenSeq := target.ForwardDFlash(&batch.Batch{
InputIDs: verifyInput,
SeqOffsets: []int32{int32(before)},
SeqQueryLens: []int32{int32(verifyInput.Dim(1))},
}, specCaches, draft.TargetLayerIDs())
selectedTokens := r.Model.Unembed(hiddenSeq).Argmax(-1, false).AsType(mlx.DTypeInt32)
mlx.Eval(draftTokens, selectedTokens)
draftIDs := draftTokens.Ints()
selectedIDs := selectedTokens.Ints()
if len(selectedIDs) < draftCount+1 {
spec.Commit(0)
return sampler.Result{}, 0, false, fmt.Errorf("dflash validation produced %d tokens for %d draft tokens", len(selectedIDs), draftCount)
}
accepted := 0
for i, id := range draftIDs {
if selectedIDs[i] != id {
break
}
accepted++
if r.Tokenizer.IsEOS(int32(id)) {
break
}
}
commitN := accepted + 1
spec.Commit(0)
done := false
for _, id := range draftIDs[:accepted] {
if *generated >= request.Options.NumPredict {
done = true
break
}
res := sampler.Result{Token: mlx.FromValues([]int32{int32(id)}, 1)}
var err error
done, err = r.emitMTPToken(ctx, request, session, dec, res, final)
if err != nil {
return sampler.Result{}, accepted, done, err
}
if !done {
(*generated)++
}
if done {
break
}
}
spec.Commit(commitN)
*position = before + commitN
draft.AppendContext(targetHiddenSeq.Slice(mlx.Slice(), mlx.Slice(0, commitN), mlx.Slice()), draftCaches)
if done || *generated >= request.Options.NumPredict {
return sampler.Result{}, accepted, true, nil
}
nextIndex := accepted
if nextIndex >= len(selectedIDs) {
nextIndex = len(selectedIDs) - 1
}
return sampler.Result{Token: mlx.FromValues([]int32{int32(selectedIDs[nextIndex])}, 1)}, accepted, false, nil
}
func (r *Runner) acceptDFlashDraftsSerial(ctx context.Context, request Request, session *cacheSession, dec *decoder, target base.DFlashTargetModel, draft base.DFlashDraftModel, targetCaches []cache.Cache, draftCaches []cache.Cache, position *int, current sampler.Result, draftTokens *mlx.Array, final *CompletionResponse, generated *int) (sampler.Result, int, bool, error) {
targetForward := func(token *mlx.Array) *mlx.Array {
hidden, targetHidden := target.ForwardDFlash(&batch.Batch{
InputIDs: token,
SeqOffsets: []int32{int32(*position)},
SeqQueryLens: []int32{int32(token.Dim(1))},
}, targetCaches, draft.TargetLayerIDs())
*position += token.Dim(1)
draft.AppendContext(targetHidden, draftCaches)
return r.lastLogits(hidden)
}
logits := targetForward(mtpTokenInput(current.Token))
accepted := 0
for _, id := range draftTokens.Ints() {
selected := greedyTokenFromLogits(logits)
mlx.Eval(selected)
selectedID := tokenID(selected)
if selectedID != id {
return sampler.Result{Token: mlx.FromValues([]int32{int32(selectedID)}, 1)}, accepted, false, nil
}
res := sampler.Result{Token: mlx.FromValues([]int32{int32(id)}, 1)}
done, err := r.emitMTPToken(ctx, request, session, dec, res, final)
if err != nil {
return sampler.Result{}, accepted, done, err
}
accepted++
if !done {
(*generated)++
}
if done || *generated >= request.Options.NumPredict {
return sampler.Result{}, accepted, true, nil
}
logits = targetForward(mtpTokenInput(res.Token))
}
return sampler.Result{Token: greedyTokenFromLogits(logits)}, accepted, false, nil
}
func (r *Runner) acceptSampleDFlashDrafts(ctx context.Context, request Request, session *cacheSession, dec *decoder, target base.DFlashTargetModel, draft base.DFlashDraftModel, targetCaches []cache.Cache, draftCaches []cache.Cache, position *int, current sampler.Result, candidates *dflashDraftCandidates, final *CompletionResponse, generated *int, stats *dflashStats) (sampler.Result, int, bool, error) {
specCaches, spec, ok := cache.BeginSpeculation(targetCaches)
if !ok {
stats.serial++
return r.acceptSampleDFlashDraftsSerial(ctx, request, session, dec, target, draft, targetCaches, draftCaches, position, current, candidates, final, generated)
}
stats.batched++
return r.acceptSampleDFlashDraftsBatched(ctx, request, session, dec, target, draft, specCaches, spec, draftCaches, position, current, candidates, final, generated)
}
func (r *Runner) acceptSampleDFlashDraftsBatched(ctx context.Context, request Request, session *cacheSession, dec *decoder, target base.DFlashTargetModel, draft base.DFlashDraftModel, specCaches []cache.Cache, spec *cache.Speculation, draftCaches []cache.Cache, position *int, current sampler.Result, candidates *dflashDraftCandidates, final *CompletionResponse, generated *int) (sampler.Result, int, bool, error) {
before := *position
draftCount := candidates.tokens.Dim(1)
verifyInput := mtpTokenInput(current.Token).Concatenate(1, candidates.tokens)
hiddenSeq, targetHiddenSeq := target.ForwardDFlash(&batch.Batch{
InputIDs: verifyInput,
SeqOffsets: []int32{int32(before)},
SeqQueryLens: []int32{int32(verifyInput.Dim(1))},
}, specCaches, draft.TargetLayerIDs())
targetDist := r.Sampler.Distribution(pipelineSlot, r.Model.Unembed(hiddenSeq), candidates.tokens)
draftDist := candidates.dist
acceptedMask := r.mtpSampleAcceptedMask(targetDist.SliceRows(0, draftCount), draftDist, candidates.tokens)
mlx.Eval(candidates.tokens, acceptedMask)
draftIDs := candidates.tokens.Ints()
acceptedFlags := acceptedMask.Ints()
accepted := 0
for _, ok := range acceptedFlags {
if ok == 0 {
break
}
accepted++
}
if accepted > draftCount {
spec.Commit(0)
return sampler.Result{}, 0, false, fmt.Errorf("dflash sample validation accepted %d tokens for %d draft tokens", accepted, draftCount)
}
commitIDs := make([]int32, 0, accepted+1)
done := false
for i, id := range draftIDs[:accepted] {
commitIDs = append(commitIDs, int32(id))
if r.Tokenizer.IsEOS(int32(id)) {
done = true
accepted = i + 1
commitIDs = commitIDs[:accepted]
break
}
}
commitN := accepted + 1
spec.Commit(0)
for _, id := range draftIDs[:accepted] {
if *generated >= request.Options.NumPredict {
done = true
break
}
res := sampler.Result{Token: mlx.FromValues([]int32{int32(id)}, 1)}
var err error
done, err = r.emitMTPToken(ctx, request, session, dec, res, final)
if err != nil {
return sampler.Result{}, accepted, done, err
}
if !done {
(*generated)++
}
if done {
break
}
}
spec.Commit(commitN)
*position = before + commitN
draft.AppendContext(targetHiddenSeq.Slice(mlx.Slice(), mlx.Slice(0, commitN), mlx.Slice()), draftCaches)
if done || *generated >= request.Options.NumPredict {
r.Sampler.Commit(pipelineSlot, commitIDs)
return sampler.Result{}, accepted, true, nil
}
var nextToken *mlx.Array
if accepted == draftCount {
nextToken = r.mtpSampleTokenAt(targetDist, draftCount)
} else {
nextToken = r.mtpSampleResidualToken(targetDist, draftDist, accepted)
}
mlx.Eval(nextToken)
nextID := int32(tokenID(nextToken))
commitIDs = append(commitIDs, nextID)
r.Sampler.Commit(pipelineSlot, commitIDs)
return sampler.Result{Token: nextToken}, accepted, false, nil
}
func (r *Runner) acceptSampleDFlashDraftsSerial(ctx context.Context, request Request, session *cacheSession, dec *decoder, target base.DFlashTargetModel, draft base.DFlashDraftModel, targetCaches []cache.Cache, draftCaches []cache.Cache, position *int, current sampler.Result, candidates *dflashDraftCandidates, final *CompletionResponse, generated *int) (sampler.Result, int, bool, error) {
targetForward := func(token *mlx.Array) *mlx.Array {
hidden, targetHidden := target.ForwardDFlash(&batch.Batch{
InputIDs: token,
SeqOffsets: []int32{int32(*position)},
SeqQueryLens: []int32{int32(token.Dim(1))},
}, targetCaches, draft.TargetLayerIDs())
*position += token.Dim(1)
draft.AppendContext(targetHidden, draftCaches)
return r.lastLogits(hidden)
}
mlx.Eval(candidates.tokens)
draftIDs := candidates.tokens.Ints()
logits := targetForward(mtpTokenInput(current.Token))
accepted := 0
for i, id := range draftIDs {
targetDist := r.Sampler.Distribution(pipelineSlot, logits, nil)
draftDist := candidates.dist.SliceRows(i, i+1)
draftToken := mlx.FromValues([]int32{int32(id)}, 1)
acceptedMask := r.mtpSampleAcceptedMask(targetDist, draftDist, draftToken)
mlx.Eval(acceptedMask)
if acceptedMask.Ints()[0] == 0 {
nextToken := mtpTokenVector(r.Sampler.SampleDistribution(pipelineSlot, targetDist.ResidualAgainst(draftDist)))
mlx.Eval(nextToken)
r.Sampler.Commit(pipelineSlot, []int32{int32(tokenID(nextToken))})
return sampler.Result{Token: nextToken}, accepted, false, nil
}
accepted++
r.Sampler.Commit(pipelineSlot, []int32{int32(id)})
res := sampler.Result{Token: mlx.FromValues([]int32{int32(id)}, 1)}
done, err := r.emitMTPToken(ctx, request, session, dec, res, final)
if err != nil {
return sampler.Result{}, accepted, done, err
}
if !done {
(*generated)++
}
if done || *generated >= request.Options.NumPredict {
return sampler.Result{}, accepted, true, nil
}
logits = targetForward(mtpTokenInput(res.Token))
}
targetDist := r.Sampler.Distribution(pipelineSlot, logits, nil)
nextToken := mtpTokenVector(r.Sampler.SampleDistribution(pipelineSlot, targetDist))
mlx.Eval(nextToken)
r.Sampler.Commit(pipelineSlot, []int32{int32(tokenID(nextToken))})
return sampler.Result{Token: nextToken}, accepted, false, nil
}

13
x/mlxrunner/imports.go Normal file
View File

@@ -0,0 +1,13 @@
package mlxrunner
import (
_ "github.com/ollama/ollama/x/models/dflash"
_ "github.com/ollama/ollama/x/models/gemma3"
_ "github.com/ollama/ollama/x/models/gemma4"
_ "github.com/ollama/ollama/x/models/glm4_moe_lite"
_ "github.com/ollama/ollama/x/models/laguna"
_ "github.com/ollama/ollama/x/models/llama"
_ "github.com/ollama/ollama/x/models/qwen3"
_ "github.com/ollama/ollama/x/models/qwen3_5"
_ "github.com/ollama/ollama/x/models/qwen3_5_moe"
)

3
x/mlxrunner/mlx/.gitignore vendored Normal file
View File

@@ -0,0 +1,3 @@
_deps
build
dist

View File

@@ -0,0 +1,32 @@
cmake_minimum_required(VERSION 3.5)
project(mlx)
if(CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT)
set(CMAKE_INSTALL_PREFIX "${CMAKE_CURRENT_SOURCE_DIR}/dist" CACHE PATH "" FORCE)
endif()
set(MLX_BUILD_GGUF OFF CACHE BOOL "" FORCE)
set(MLX_BUILD_SAFETENSORS ON CACHE BOOL "" FORCE)
set(MLX_C_BUILD_EXAMPLES OFF CACHE BOOL "" FORCE)
set(BUILD_SHARED_LIBS ON CACHE BOOL "" FORCE)
set(CMAKE_INSTALL_RPATH "@loader_path")
include(FetchContent)
# Read MLX-C version from top-level file (shared with imagegen CMakeLists)
file(READ "${CMAKE_SOURCE_DIR}/MLX_C_VERSION" MLX_C_GIT_TAG)
string(STRIP "${MLX_C_GIT_TAG}" MLX_C_GIT_TAG)
FetchContent_Declare(
mlx-c
GIT_REPOSITORY "https://github.com/ml-explore/mlx-c.git"
GIT_TAG ${MLX_C_GIT_TAG}
)
FetchContent_MakeAvailable(mlx-c)
# Sync vendored headers with fetched version
file(GLOB _mlx_c_hdrs "${mlx-c_SOURCE_DIR}/mlx/c/*.h")
file(COPY ${_mlx_c_hdrs} DESTINATION "${CMAKE_CURRENT_SOURCE_DIR}/include/mlx/c/")

99
x/mlxrunner/mlx/act.go Normal file
View File

@@ -0,0 +1,99 @@
package mlx
import "math"
var geluCoeff = float32(math.Sqrt(2 / math.Pi))
// GELUApprox returns 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
// as a fused kernel.
var GELUApprox = Compile1(
"GELUApprox",
func(x *Array) *Array {
// Dtype-matched scalars avoid implicit upcasts on bf16 inputs.
dt := x.DType()
half := FromValue[float32](0.5).AsType(dt)
coeff := FromValue(geluCoeff).AsType(dt)
c := FromValue[float32](0.044715).AsType(dt)
one := FromValue[float32](1.0).AsType(dt)
// x^3 via x*x*x (avoids general Power which is slower).
x3 := x.Multiply(x).Multiply(x)
inner := x.Add(c.Multiply(x3))
tanh := coeff.Multiply(inner).Tanh()
return half.Multiply(x).Multiply(one.Add(tanh))
},
Shapeless(),
)
// SiLU returns a * sigmoid(a) as a fused kernel.
var SiLU = Compile1(
"SiLU",
func(a *Array) *Array {
return a.Multiply(a.Sigmoid())
},
Shapeless(),
)
// SoftplusF32 returns softplus(x) computed in float32 precision and cast back
// to x's original dtype, as a fused kernel. Matches the laguna attention
// output-gate formula: softplus(cast_f32(x)).cast(orig_dtype).
var SoftplusF32 = Compile1(
"SoftplusF32",
func(x *Array) *Array {
dt := x.DType()
zero := FromValue[float32](0)
return Logaddexp(x.AsType(DTypeFloat32), zero).AsType(dt)
},
Shapeless(),
)
// SwiGLU returns silu(gate) * up as a fused kernel.
var SwiGLU = Compile2(
"SwiGLU",
func(gate, up *Array) *Array {
return SiLU(gate).Multiply(up)
},
Shapeless(),
)
// GeGLU returns gelu_approx(gate) * up as a fused kernel. Matches mlx_lm's
// geglu, used by Gemma-family MLP and MoE paths.
var GeGLU = Compile2(
"GeGLU",
func(gate, up *Array) *Array {
return GELUApprox(gate).Multiply(up)
},
Shapeless(),
)
// LogitSoftcap returns tanh(x / cap) * cap as a fused kernel. Matches
// mlx_lm's logit_softcap. cap must have the same dtype as x.
var LogitSoftcap = Compile2(
"LogitSoftcap",
func(x, cap *Array) *Array {
return x.Divide(cap).Tanh().Multiply(cap)
},
Shapeless(),
)
// sigmoidRouterFused traces the DeepSeek-V2 / GLM-MoE aux-loss-free router
// head. Two outputs are returned so the pre-bias sigmoid (used to gather
// per-expert scores after top-k) and the post-bias negation (used as the
// argpartition key for top-k) share a single kernel.
var sigmoidRouterFused = Compile(
"SigmoidRouter",
func(in ...*Array) []*Array {
gates, bias := in[0], in[1]
orig := gates.Sigmoid()
neg := orig.Add(bias).Negative()
return []*Array{orig, neg}
},
Shapeless(),
)
// SigmoidRouter returns (sigmoid(gates), -(sigmoid(gates)+bias)) as a fused
// kernel — the DeepSeek-V2 / GLM-MoE aux-loss-free router head.
func SigmoidRouter(gates, bias *Array) (origScores, negScores *Array) {
out := sigmoidRouterFused(gates, bias)
return out[0], out[1]
}

295
x/mlxrunner/mlx/array.go Normal file
View File

@@ -0,0 +1,295 @@
package mlx
// #include "generated.h"
import "C"
import (
"encoding/binary"
"fmt"
"log/slog"
"reflect"
"sort"
"strings"
"sync"
"sync/atomic"
"unsafe"
"github.com/ollama/ollama/logutil"
)
type Array struct {
ctx C.mlx_array
name string
pinned atomic.Int32
}
var (
arrays []*Array
arraysMu sync.Mutex
)
// constructor utilities
func New(name string) *Array {
t := &Array{name: name}
if tracing {
traceScratch = append(traceScratch, t)
} else {
arraysMu.Lock()
defer arraysMu.Unlock()
arrays = append(arrays, t)
}
return t
}
type scalarTypes interface {
~bool | ~int | ~float32 | ~float64 | ~complex64
}
func FromValue[T scalarTypes](t T) *Array {
tt := New("")
switch v := any(t).(type) {
case bool:
tt.ctx = C.mlx_array_new_bool(C.bool(v))
case int:
tt.ctx = C.mlx_array_new_int(C.int(v))
case float32:
tt.ctx = C.mlx_array_new_float32(C.float(v))
case float64:
tt.ctx = C.mlx_array_new_float64(C.double(v))
case complex64:
tt.ctx = C.mlx_array_new_complex(C.float(real(v)), C.float(imag(v)))
default:
panic("unsupported type")
}
return tt
}
type arrayTypes interface {
~bool | ~uint8 | ~uint16 | ~uint32 | ~uint64 |
~int8 | ~int16 | ~int32 | ~int64 |
~float32 | ~float64 |
~complex64
}
func FromValues[S ~[]E, E arrayTypes](s S, shape ...int) *Array {
if len(shape) == 0 {
panic("shape must be provided for non-scalar tensors")
}
cShape := make([]C.int, len(shape))
for i := range shape {
cShape[i] = C.int(shape[i])
}
var dtype DType
switch reflect.TypeOf(s).Elem().Kind() {
case reflect.Bool:
dtype = DTypeBool
case reflect.Uint8:
dtype = DTypeUint8
case reflect.Uint16:
dtype = DTypeUint16
case reflect.Uint32:
dtype = DTypeUint32
case reflect.Uint64:
dtype = DTypeUint64
case reflect.Int8:
dtype = DTypeInt8
case reflect.Int16:
dtype = DTypeInt16
case reflect.Int32:
dtype = DTypeInt32
case reflect.Int64:
dtype = DTypeInt64
case reflect.Float32:
dtype = DTypeFloat32
case reflect.Float64:
dtype = DTypeFloat64
case reflect.Complex64:
dtype = DTypeComplex64
default:
panic("unsupported type")
}
bts := make([]byte, binary.Size(s))
if _, err := binary.Encode(bts, binary.LittleEndian, s); err != nil {
panic(err)
}
tt := New("")
tt.ctx = C.mlx_array_new_data(unsafe.Pointer(&bts[0]), unsafe.SliceData(cShape), C.int(len(cShape)), C.mlx_dtype(dtype))
return tt
}
func (t *Array) Set(other *Array) {
C.mlx_array_set(&t.ctx, other.ctx)
}
func (t *Array) Clone() *Array {
tt := New(t.name)
C.mlx_array_set(&tt.ctx, t.ctx)
return tt
}
// lifecycle utilities
// Pin marks arrays as in-use so they are retained during Sweep.
func Pin(s ...*Array) {
for _, t := range s {
if t != nil {
t.pinned.Add(1)
}
}
}
// Unpin marks arrays as no longer in-use, allowing Sweep to free them.
func Unpin(s ...*Array) {
for _, t := range s {
if t != nil {
if t.pinned.Add(-1) < 0 {
panic(fmt.Sprintf("mlx.Unpin: negative pin count on array %q", t.name))
}
}
}
}
// Sweep releases all unpinned arrays, primarily intermediate tensors. MLX will truly
// free them when there are no other references, including dependencies in the graph.
func Sweep() {
arraysMu.Lock()
defer arraysMu.Unlock()
n := 0
for _, t := range arrays {
if t.pinned.Load() > 0 && t.Valid() {
arrays[n] = t
n++
} else if t.Valid() {
C.mlx_array_free(t.ctx)
t.ctx.ctx = nil
}
}
arrays = arrays[:n]
}
// misc. utilities
func (t *Array) Valid() bool {
return t.ctx.ctx != nil
}
func (t *Array) String() string {
str := C.mlx_string_new()
defer C.mlx_string_free(str)
C.mlx_array_tostring(&str, t.ctx)
return strings.TrimSpace(C.GoString(C.mlx_string_data(str)))
}
func (t *Array) LogValue() slog.Value {
attrs := []slog.Attr{
slog.String("name", t.name),
slog.Int("pinned", int(t.pinned.Load())),
}
if t.Valid() {
attrs = append(attrs,
slog.Any("dtype", t.DType()),
slog.Any("shape", t.Dims()),
slog.Int("num_bytes", t.NumBytes()),
)
}
return slog.GroupValue(attrs...)
}
// shape utilities
func (t *Array) Size() int {
return int(C.mlx_array_size(t.ctx))
}
func (t *Array) NumBytes() int {
return int(C.mlx_array_nbytes(t.ctx))
}
func (t *Array) NumDims() int {
return int(C.mlx_array_ndim(t.ctx))
}
func (t *Array) Dims() []int {
dims := make([]int, t.NumDims())
for i := range dims {
dims[i] = t.Dim(i)
}
return dims
}
func (t *Array) Dim(dim int) int {
return int(C.mlx_array_dim(t.ctx, C.int(dim)))
}
func (t *Array) DType() DType {
return DType(C.mlx_array_dtype(t.ctx))
}
// data utilities
func (t *Array) Int() int {
var item C.int64_t
C.mlx_array_item_int64(&item, t.ctx)
return int(item)
}
func (t *Array) Float() float64 {
var item C.double
C.mlx_array_item_float64(&item, t.ctx)
return float64(item)
}
func (t *Array) Ints() []int {
if dt := t.DType(); dt != DTypeInt32 {
panic(fmt.Sprintf("mlx: Ints requires DTypeInt32, got %v", dt))
}
ints := make([]int, t.Size())
for i, f := range unsafe.Slice(C.mlx_array_data_int32(t.ctx), len(ints)) {
ints[i] = int(f)
}
return ints
}
func (t *Array) Floats() []float32 {
if dt := t.DType(); dt != DTypeFloat32 {
panic(fmt.Sprintf("mlx: Floats requires DTypeFloat32, got %v", dt))
}
floats := make([]float32, t.Size())
for i, f := range unsafe.Slice(C.mlx_array_data_float32(t.ctx), len(floats)) {
floats[i] = float32(f)
}
return floats
}
func (t *Array) Save(name string) error {
cName := C.CString(name)
defer C.free(unsafe.Pointer(cName))
C.mlx_save(cName, t.ctx)
return nil
}
// LogArrays logs all live arrays, sorted by size
func LogArrays() {
arraysMu.Lock()
defer arraysMu.Unlock()
sort.Slice(arrays, func(i, j int) bool {
return arrays[i].NumBytes() > arrays[j].NumBytes()
})
var total int
for _, t := range arrays {
nb := t.NumBytes()
total += nb
logutil.Trace(fmt.Sprintf("tensor %-60s %5s %5s pinned=%d %v", t.name, t.DType(), PrettyBytes(nb), t.pinned.Load(), t.Dims()))
}
logutil.Trace(fmt.Sprintf("tensors total: %d, size: %s, active: %s", len(arrays), PrettyBytes(total), PrettyBytes(ActiveMemory())))
}

View File

@@ -0,0 +1,76 @@
package mlx
import "testing"
func TestFromValue(t *testing.T) {
withMLXThread(t, func() {
for got, want := range map[*Array]DType{
FromValue(true): DTypeBool,
FromValue(false): DTypeBool,
FromValue(int(7)): DTypeInt32,
FromValue(float32(3.14)): DTypeFloat32,
FromValue(float64(2.71)): DTypeFloat64,
FromValue(complex64(1 + 2i)): DTypeComplex64,
} {
if got.DType() != want {
t.Errorf("%s: want %v, got %v", want, want, got)
}
}
})
}
func TestFromValues(t *testing.T) {
withMLXThread(t, func() {
for got, want := range map[*Array]DType{
FromValues([]bool{true, false, true}, 3): DTypeBool,
FromValues([]uint8{1, 2, 3}, 3): DTypeUint8,
FromValues([]uint16{1, 2, 3}, 3): DTypeUint16,
FromValues([]uint32{1, 2, 3}, 3): DTypeUint32,
FromValues([]uint64{1, 2, 3}, 3): DTypeUint64,
FromValues([]int8{-1, -2, -3}, 3): DTypeInt8,
FromValues([]int16{-1, -2, -3}, 3): DTypeInt16,
FromValues([]int32{-1, -2, -3}, 3): DTypeInt32,
FromValues([]int64{-1, -2, -3}, 3): DTypeInt64,
FromValues([]float32{3.14, 2.71, 1.61}, 3): DTypeFloat32,
FromValues([]float64{3.14, 2.71, 1.61}, 3): DTypeFloat64,
FromValues([]complex64{1 + 2i, 3 + 4i, 5 + 6i}, 3): DTypeComplex64,
} {
if got.DType() != want {
t.Errorf("%s: want %v, got %v", want, want, got)
}
}
})
}
func TestComparisonOpsAndBernoulli(t *testing.T) {
skipIfNoMLX(t)
a := FromValues([]float32{1, 2, 3}, 3)
b := FromValues([]float32{1, 1, 4}, 3)
eq := a.Equal(b).AsType(DTypeInt32)
gt := a.Greater(b).AsType(DTypeInt32)
le := a.LessEqual(b).AsType(DTypeInt32)
bern := Bernoulli(FromValues([]float32{1, 0}, 2)).AsType(DTypeInt32)
Eval(eq, gt, le, bern)
for name, tc := range map[string]struct {
got []int
want []int
}{
"equal": {eq.Ints(), []int{1, 0, 0}},
"greater": {gt.Ints(), []int{0, 1, 0}},
"lessEqual": {le.Ints(), []int{1, 0, 1}},
"bernoulli": {bern.Ints(), []int{1, 0}},
} {
t.Run(name, func(t *testing.T) {
if len(tc.got) != len(tc.want) {
t.Fatalf("got %v, want %v", tc.got, tc.want)
}
for i := range tc.want {
if tc.got[i] != tc.want[i] {
t.Fatalf("got %v, want %v", tc.got, tc.want)
}
}
})
}
}

192
x/mlxrunner/mlx/compile.go Normal file
View File

@@ -0,0 +1,192 @@
package mlx
// #include <stdlib.h>
// #include "generated.h"
//
// extern int closureCallback(mlx_vector_array* res, mlx_vector_array input, void* payload);
// extern void closureDestructor(void* payload);
import "C"
import (
"log/slog"
"runtime/cgo"
"sync"
"unsafe"
)
// CompileFunc is the signature of a function that can be compiled.
type CompileFunc func(inputs ...*Array) []*Array
// CompileOption configures Compile behavior.
type CompileOption func(*compileConfig)
type compileConfig struct {
shapeless bool
}
// Shapeless traces the function once against symbolic shapes so the compiled
// graph accepts any input shape afterwards. Without this option, MLX re-traces
// on each new (shape, dtype) combination and caches each specialization.
func Shapeless() CompileOption {
return func(c *compileConfig) { c.shapeless = true }
}
// Compile returns a compiled version of fn. When called during another
// compile's trace, fn is inlined directly so outer compiles can fuse through
// inner ones.
//
// Compiled functions must not have side effects outside of the function. Do
// not access data other than the arguments passed in (either Go data or MLX
// arrays) unless it is a constant.
func Compile(name string, fn CompileFunc, opts ...CompileOption) CompileFunc {
var cfg compileConfig
for _, o := range opts {
o(&cfg)
}
var closure C.mlx_closure
var once sync.Once
return func(inputs ...*Array) []*Array {
if tracing {
return fn(inputs...)
}
once.Do(func() {
payload := (*cgo.Handle)(C.malloc(C.size_t(unsafe.Sizeof(cgo.Handle(0)))))
*payload = cgo.NewHandle(fn)
src := C.mlx_closure_new_func_payload(
(*[0]byte)(C.closureCallback),
unsafe.Pointer(payload),
(*[0]byte)(C.closureDestructor),
)
defer C.mlx_closure_free(src)
closure = C.mlx_closure_new()
mlxCheck(name+": compile failed", func() C.int {
return C.mlx_compile(&closure, src, C.bool(cfg.shapeless))
})
})
inVec := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(inVec)
for _, in := range inputs {
C.mlx_vector_array_append_value(inVec, in.ctx)
}
outVec := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(outVec)
mlxCheck(name+": closure apply failed", func() C.int {
return C.mlx_closure_apply(&outVec, closure, inVec)
})
n := int(C.mlx_vector_array_size(outVec))
outputs := make([]*Array, n)
for i := range n {
outputs[i] = New(name)
C.mlx_vector_array_get(&outputs[i].ctx, outVec, C.size_t(i))
}
return outputs
}
}
// Compile1 compiles a unary function. See Compile.
func Compile1(name string, fn func(*Array) *Array, opts ...CompileOption) func(*Array) *Array {
cf := Compile(name, func(in ...*Array) []*Array {
return []*Array{fn(in[0])}
}, opts...)
return func(a *Array) *Array {
return cf(a)[0]
}
}
// Compile2 compiles a binary function. See Compile.
func Compile2(name string, fn func(*Array, *Array) *Array, opts ...CompileOption) func(*Array, *Array) *Array {
cf := Compile(name, func(in ...*Array) []*Array {
return []*Array{fn(in[0], in[1])}
}, opts...)
return func(a, b *Array) *Array {
return cf(a, b)[0]
}
}
// Compile3 compiles a ternary function. See Compile.
func Compile3(name string, fn func(*Array, *Array, *Array) *Array, opts ...CompileOption) func(*Array, *Array, *Array) *Array {
cf := Compile(name, func(in ...*Array) []*Array {
return []*Array{fn(in[0], in[1], in[2])}
}, opts...)
return func(a, b, c *Array) *Array {
return cf(a, b, c)[0]
}
}
// tracing is true while a compile callback is running. Since MLX is
// single-threaded at this level a plain Go bool suffices.
var tracing bool
// traceScratch collects arrays created during a compile trace so they can be
// freed as a group when the callback returns.
var traceScratch []*Array
//export closureCallback
func closureCallback(res *C.mlx_vector_array, input C.mlx_vector_array, payload unsafe.Pointer) (rc C.int) {
defer func() {
if r := recover(); r != nil {
slog.Error("mlx closure callback panicked", "panic", r)
rc = 1
}
}()
handle := *(*cgo.Handle)(payload)
fn := handle.Value().(CompileFunc)
// When tracing, we track all of the intermediates that are created and free them separately at the end of
// the process. This will give the effect of a single op - inputs are owned by the original caller (via
// the MLX layer) and outputs are transferred back to MLX to create a new Go side tensor.
if tracing {
panic("mlx: nested compile trace")
}
tracing = true
traceScratch = nil
defer func() {
for _, a := range traceScratch {
if a.pinned.Load() > 0 {
panic("mlx: traced array was pinned during compilation")
}
if a.Valid() {
C.mlx_array_free(a.ctx)
a.ctx.ctx = nil
}
}
tracing = false
traceScratch = nil
}()
n := int(C.mlx_vector_array_size(input))
inputs := make([]*Array, n)
for i := range n {
a := New("")
C.mlx_vector_array_get(&a.ctx, input, C.size_t(i))
inputs[i] = a
}
outputs := fn(inputs...)
var arrPtr *C.mlx_array
if len(outputs) > 0 {
handles := make([]C.mlx_array, len(outputs))
for i, out := range outputs {
handles[i] = out.ctx
}
arrPtr = &handles[0]
}
C.mlx_vector_array_set_data(res, arrPtr, C.size_t(len(outputs)))
return 0
}
//export closureDestructor
func closureDestructor(payload unsafe.Pointer) {
handle := *(*cgo.Handle)(payload)
handle.Delete()
C.free(payload)
}

View File

@@ -0,0 +1,147 @@
package mlx
import (
"testing"
)
func TestCompileFusion(t *testing.T) {
skipIfNoMLX(t)
// Compile fuses the ops inside a function body into a single kernel,
// eliminating intermediate buffers. Use a diamond-shaped graph where
// two branches must be materialized simultaneously without fusion,
// then compare peak memory against the compiled version which fuses
// everything into one kernel with no intermediates.
const n = 1024 * 1024 // 4MB per float32 array
data := make([]float32, n)
for i := range data {
data[i] = float32(i + 1)
}
// Diamond: both a*b and a+b must be live for the final multiply.
// Without fusion: peak includes both intermediates (~8MB extra).
// With fusion: single kernel, no intermediates.
body := func(a, b *Array) *Array {
return a.Multiply(b).Multiply(a.Add(b))
}
a := FromValues(data, n)
b := FromValues(data, n)
Pin(a, b)
defer Unpin(a, b)
// Compiled: ops fused into a single kernel.
EnableCompile()
fn := Compile2("diamond", body, Shapeless())
warm := fn(a, b)
Eval(warm)
Sweep()
ClearCache()
ResetPeakMemory()
y := fn(a, b)
Eval(y)
compiledPeak := PeakMemory()
Sweep()
// Uncompiled: ops evaluated individually, intermediates materialized.
ClearCache()
ResetPeakMemory()
z := body(a, b)
Eval(z)
uncompiledPeak := PeakMemory()
Sweep()
if compiledPeak == 0 && uncompiledPeak == 0 {
t.Skip("peak memory tracking not available")
}
t.Logf("peak memory: compiled=%d uncompiled=%d", compiledPeak, uncompiledPeak)
if compiledPeak >= uncompiledPeak {
t.Fatalf("compilation did not reduce peak memory: compiled=%d uncompiled=%d", compiledPeak, uncompiledPeak)
}
}
func TestCompileNested(t *testing.T) {
skipIfNoMLX(t)
// A compiled function that calls another compiled function should
// produce correct results. The inner function inlines via isTracing()
// during the outer's trace.
inner := Compile1("silu", func(a *Array) *Array {
return a.Multiply(a.Sigmoid())
}, Shapeless())
outer := Compile2("swiglu", func(gate, up *Array) *Array {
return inner(gate).Multiply(up)
}, Shapeless())
gate := FromValues([]float32{0, 1, 2}, 3)
up := FromValues([]float32{1, 1, 1}, 3)
Pin(gate, up)
defer Unpin(gate, up)
y := outer(gate, up)
Eval(y)
// silu(x) = x * sigmoid(x); for x=0 → 0, x=1 → ~0.7311, x=2 → ~1.7616
got := y.Floats()
want := []float32{0, 0.7310586, 1.7615942}
for i, v := range got {
if v-want[i] > 1e-4 || want[i]-v > 1e-4 {
t.Fatalf("got[%d]=%v want %v", i, v, want[i])
}
}
}
func TestCompileCallbackPanicRecovers(t *testing.T) {
skipIfNoMLX(t)
boom := Compile1("boom", func(a *Array) *Array {
panic("intentional test panic")
})
x := FromValues([]float32{1}, 1)
Pin(x)
defer Unpin(x)
defer func() {
r := recover()
if r == nil {
t.Fatal("expected panic from Call, got none")
}
if _, ok := r.(string); !ok {
t.Fatalf("expected string panic, got %T: %v", r, r)
}
}()
boom(x)
}
func TestCompileNoTrackingGrowth(t *testing.T) {
skipIfNoMLX(t)
// Repeated invocations of a compiled kernel should not grow the
// tracked-arrays list — the callback's traceScratch collects
// intermediates during tracing and frees them when the callback returns.
fn := Compile2("mul_add", func(a, b *Array) *Array {
return a.Multiply(b).Add(b)
})
a := FromValues([]float32{1, 2}, 2)
b := FromValues([]float32{3, 4}, 2)
Pin(a, b)
defer Unpin(a, b)
Sweep()
before := len(arrays)
for range 100 {
_ = fn(a, b)
Sweep()
}
after := len(arrays)
if after > before+2 {
t.Fatalf("tracked arrays grew from %d to %d across 100 calls (includes initial trace)", before, after)
}
}

94
x/mlxrunner/mlx/dtype.go Normal file
View File

@@ -0,0 +1,94 @@
package mlx
// #include "generated.h"
import "C"
type DType int
func (t DType) String() string {
switch t {
case DTypeBool:
return "BOOL"
case DTypeUint8:
return "U8"
case DTypeUint16:
return "U16"
case DTypeUint32:
return "U32"
case DTypeUint64:
return "U64"
case DTypeInt8:
return "I8"
case DTypeInt16:
return "I16"
case DTypeInt32:
return "I32"
case DTypeInt64:
return "I64"
case DTypeFloat16:
return "F16"
case DTypeFloat32:
return "F32"
case DTypeFloat64:
return "F64"
case DTypeBFloat16:
return "BF16"
case DTypeComplex64:
return "C64"
default:
return "Unknown"
}
}
func (t *DType) UnmarshalJSON(b []byte) error {
switch string(b) {
case `"BOOL"`:
*t = DTypeBool
case `"U8"`:
*t = DTypeUint8
case `"U16"`:
*t = DTypeUint16
case `"U32"`:
*t = DTypeUint32
case `"U64"`:
*t = DTypeUint64
case `"I8"`:
*t = DTypeInt8
case `"I16"`:
*t = DTypeInt16
case `"I32"`:
*t = DTypeInt32
case `"I64"`:
*t = DTypeInt64
case `"F16"`:
*t = DTypeFloat16
case `"F64"`:
*t = DTypeFloat64
case `"F32"`:
*t = DTypeFloat32
case `"BF16"`:
*t = DTypeBFloat16
case `"C64"`:
*t = DTypeComplex64
default:
return nil
}
return nil
}
const (
DTypeBool DType = C.MLX_BOOL
DTypeUint8 DType = C.MLX_UINT8
DTypeUint16 DType = C.MLX_UINT16
DTypeUint32 DType = C.MLX_UINT32
DTypeUint64 DType = C.MLX_UINT64
DTypeInt8 DType = C.MLX_INT8
DTypeInt16 DType = C.MLX_INT16
DTypeInt32 DType = C.MLX_INT32
DTypeInt64 DType = C.MLX_INT64
DTypeFloat16 DType = C.MLX_FLOAT16
DTypeFloat32 DType = C.MLX_FLOAT32
DTypeFloat64 DType = C.MLX_FLOAT64
DTypeBFloat16 DType = C.MLX_BFLOAT16
DTypeComplex64 DType = C.MLX_COMPLEX64
)

36
x/mlxrunner/mlx/dynamic.c Normal file
View File

@@ -0,0 +1,36 @@
#include "dynamic.h"
#include <stdio.h>
#ifdef _WIN32
#include <windows.h>
#define DLOPEN(path) LoadLibraryA(path)
#define DLCLOSE(handle) FreeLibrary((HMODULE)(handle))
#else
#ifdef __APPLE__
#include <mach-o/dyld.h>
#include <libgen.h>
#endif
#include <dlfcn.h>
#define DLOPEN(path) dlopen(path, RTLD_LAZY | RTLD_GLOBAL)
#define DLCLOSE(handle) dlclose(handle)
#endif
static int mlx_dynamic_open(mlx_dynamic_handle* handle, const char* path) {
handle->ctx = (void*) DLOPEN(path);
if (handle->ctx == NULL) {
return 1;
}
return 0;
}
int mlx_dynamic_load(mlx_dynamic_handle* handle, const char *path) {
return mlx_dynamic_open(handle, path);
}
void mlx_dynamic_unload(mlx_dynamic_handle* handle) {
if (handle->ctx) {
DLCLOSE(handle->ctx);
handle->ctx = NULL;
}
}

253
x/mlxrunner/mlx/dynamic.go Normal file
View File

@@ -0,0 +1,253 @@
package mlx
// #include "dynamic.h"
// #include "generated.h"
// #include <stdlib.h>
import "C"
import (
"fmt"
"io/fs"
"log/slog"
"os"
"path/filepath"
"runtime"
"sort"
"strconv"
"strings"
"unsafe"
)
var initError error
var initLoadError string
var initLoadedPath string
// CheckInit returns any error that occurred during MLX dynamic library initialization.
func CheckInit() error {
if initLoadedPath != "" {
slog.Debug("MLX dynamic library loaded", "path", initLoadedPath)
}
if initError != nil && initLoadError != "" {
slog.Error(initLoadError)
}
return initError
}
// tryLoadFromDir searches a directory for the mlxc shared library and loads it.
func tryLoadFromDir(dir string) bool {
// On Windows, MSVC produces mlxc.dll (no lib prefix)
// On Unix, it's libmlxc.so or libmlxc.dylib
pattern := "libmlxc.*"
if runtime.GOOS == "windows" {
pattern = "mlxc.*"
}
matches, err := fs.Glob(os.DirFS(dir), pattern)
if err != nil || len(matches) == 0 {
return false
}
for _, match := range matches {
path := filepath.Join(dir, match)
cPath := C.CString(path)
defer C.free(unsafe.Pointer(cPath))
var handle C.mlx_dynamic_handle
if C.mlx_dynamic_load(&handle, cPath) != 0 {
initLoadError = fmt.Sprintf("failed to load MLX dynamic library: path=%s", path)
continue
}
if C.mlx_dynamic_load_symbols(handle) != 0 {
initLoadError = fmt.Sprintf("failed to load MLX dynamic library symbols: path=%s", path)
C.mlx_dynamic_unload(&handle)
continue
}
initLoadedPath = path
return true
}
return false
}
// libOllamaRoots returns candidate directories for MLX dynamic libraries.
// Production: exe_dir/lib/ollama (dist tarball) and exe_dir (app bundle).
// Development: build/lib/ollama and build/*/lib/ollama.
func libOllamaRoots() []string {
var roots []string
// Production paths relative to executable
if exe, err := os.Executable(); err == nil {
if eval, err := filepath.EvalSymlinks(exe); err == nil {
exe = eval
}
exeDir := filepath.Dir(exe)
switch runtime.GOOS {
case "darwin":
roots = append(roots, filepath.Join(exeDir, "lib", "ollama"))
roots = append(roots, exeDir) // app bundle: Contents/Resources/
case "linux":
roots = append(roots, filepath.Join(exeDir, "..", "lib", "ollama"))
case "windows":
roots = append(roots, filepath.Join(exeDir, "lib", "ollama"))
}
}
// Development paths: build/lib/ollama and build/*/lib/ollama.
// Reverse-sort and filter the glob results so higher-versioned Metal
// builds (e.g., metal-v4) are tried before lower ones (metal-v3),
// and incompatible variants are skipped. Without this, alphabetical
// order would always pick v3 over v4 in dev builds.
for _, base := range repoBuildDirs() {
roots = append(roots, filepath.Join(base, "lib", "ollama"))
if matches, err := filepath.Glob(filepath.Join(base, "*", "lib", "ollama")); err == nil {
sort.Sort(sort.Reverse(sort.StringSlice(matches)))
for _, m := range matches {
// Extract the build dir name (e.g., "metal-v4" from "build/metal-v4/lib/ollama")
rel, _ := filepath.Rel(base, m)
variant := strings.SplitN(rel, string(filepath.Separator), 2)[0]
if isCompatibleMLXVariant(variant) {
roots = append(roots, m)
}
}
}
}
return roots
}
// repoBuildDirs returns candidate build/ directories relative to cwd and repo root.
func repoBuildDirs() []string {
var dirs []string
if cwd, err := os.Getwd(); err == nil {
dirs = append(dirs, filepath.Join(cwd, "build"))
for dir := cwd; ; {
if _, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil {
if dir != cwd {
dirs = append(dirs, filepath.Join(dir, "build"))
}
break
}
parent := filepath.Dir(dir)
if parent == dir {
break
}
dir = parent
}
}
return dirs
}
// prependLibraryPath prepends dir to the platform's dynamic library search
// path so the linker finds colocated libmlx before any stale copies.
// Called once after successful library load.
func prependLibraryPath(dir string) {
var envVar string
switch runtime.GOOS {
case "darwin":
envVar = "DYLD_LIBRARY_PATH"
case "linux":
envVar = "LD_LIBRARY_PATH"
default:
return
}
if existing := os.Getenv(envVar); existing != "" {
os.Setenv(envVar, dir+string(filepath.ListSeparator)+existing)
} else {
os.Setenv(envVar, dir)
}
}
func init() {
switch runtime.GOOS {
case "darwin", "linux", "windows":
default:
return
}
// OLLAMA_LLM_LIBRARY overrides variant selection (e.g., "mlx_metal_v3").
// When set to an mlx_* value, only that specific subdir is tried.
// The GGML runner ignores mlx_* values (see discover/runner.go).
forcedVariant, _ := os.LookupEnv("OLLAMA_LLM_LIBRARY")
if forcedVariant != "" && !strings.HasPrefix(forcedVariant, "mlx_") {
forcedVariant = "" // not an MLX variant, ignore
}
found := findMLXLibrary(forcedVariant)
if !found {
initError = fmt.Errorf("failed to load MLX dynamic library (searched: %v)", libOllamaRoots())
return
}
prependLibraryPath(filepath.Dir(initLoadedPath))
}
func findMLXLibrary(forcedVariant string) bool {
for _, root := range libOllamaRoots() {
if forcedVariant != "" {
if tryLoadFromDir(filepath.Join(root, forcedVariant)) {
return true
}
} else {
if tryLoadFromMLXSubdirs(root) {
return true
}
if tryLoadFromDir(root) {
return true
}
}
}
return false
}
// tryLoadFromMLXSubdirs globs for mlx_* subdirs within dir, filters out
// incompatible variants, tries the remainder in reverse sorted order (so
// higher-versioned variants are preferred), and returns true on first
// successful load.
func tryLoadFromMLXSubdirs(dir string) bool {
mlxDirs, err := filepath.Glob(filepath.Join(dir, "mlx_*"))
if err != nil || len(mlxDirs) == 0 {
return false
}
// Reverse sort: mlx_metal_v4 before mlx_metal_v3, mlx_cuda_v13 before v12
sort.Sort(sort.Reverse(sort.StringSlice(mlxDirs)))
for _, mlxDir := range mlxDirs {
if !isCompatibleMLXVariant(filepath.Base(mlxDir)) {
slog.Debug("skipping incompatible MLX variant", "dir", mlxDir)
continue
}
if tryLoadFromDir(mlxDir) {
return true
}
}
return false
}
// isCompatibleMLXVariant checks whether an MLX variant directory is
// compatible with the current OS. On macOS, dlopen does NOT enforce
// the deployment target for dynamically loaded libraries, so we must
// check compatibility ourselves to avoid loading Metal 4.x shaders
// on a Metal 3.x driver.
func isCompatibleMLXVariant(name string) bool {
if runtime.GOOS != "darwin" {
return true // non-macOS variants use dlopen failure for filtering
}
// Metal variant naming:
// Production: mlx_metal_v3, mlx_metal_v4
// Dev build: metal-v3, metal-v4
var verStr string
switch {
case strings.HasPrefix(name, "mlx_metal_v"):
verStr = strings.TrimPrefix(name, "mlx_metal_v")
case strings.HasPrefix(name, "metal-v"):
verStr = strings.TrimPrefix(name, "metal-v")
}
if verStr != "" {
metalVer, err := strconv.Atoi(verStr)
if err != nil {
return true // unknown format, try it
}
// Metal 4.x requires macOS 26+
if metalVer >= 4 && macOSMajorVersion() < 26 {
return false
}
}
return true
}

47
x/mlxrunner/mlx/dynamic.h Normal file
View File

@@ -0,0 +1,47 @@
#ifndef MLX_DYNAMIC_H
#define MLX_DYNAMIC_H
#ifdef _WIN32
#include <windows.h>
#define DLSYM(handle, symbol) (void*)GetProcAddress((HMODULE)(handle.ctx), symbol)
#else
#include <dlfcn.h>
#define DLSYM(handle, symbol) dlsym(handle.ctx, symbol)
#endif
#include <stdint.h>
// Provide fallback typedefs for float16_t and bfloat16_t on non-ARM64
// platforms where arm_fp16.h and arm_bf16.h are not available. These are
// only used as function pointer signature placeholders since MLX requires
// Apple Silicon at runtime.
#if !defined(__aarch64__) && !defined(__ARM_FEATURE_FP16_SCALAR_ARITHMETIC)
typedef uint16_t float16_t;
#endif
#if !defined(__aarch64__) && !defined(__ARM_FEATURE_BF16)
typedef uint16_t bfloat16_t;
#endif
// Undef ERROR to avoid conflict with wingdi.h on Windows
#ifdef ERROR
#undef ERROR
#endif
#define MLX_ERROR(fmt, ...) fprintf(stderr, "%s %s - ERROR - %s:%d - " fmt "\n", __DATE__, __TIME__, __FILE__, __LINE__, ##__VA_ARGS__); return 1
#define CHECK(x) if (!(x)) { MLX_ERROR("CHECK failed: " #x); }
#define CHECK_LOAD(handle, x) *(void**)(&x##_) = DLSYM(handle, #x); CHECK(x##_)
// OPTIONAL_LOAD: load symbol if available, leave function pointer NULL otherwise
#define OPTIONAL_LOAD(handle, x) *(void**)(&x##_) = DLSYM(handle, #x)
typedef struct {
void* ctx;
} mlx_dynamic_handle;
int mlx_dynamic_load(
mlx_dynamic_handle* handle,
const char *path);
void mlx_dynamic_unload(
mlx_dynamic_handle* handle);
#endif // MLX_DYNAMIC_H

View File

@@ -0,0 +1,17 @@
package mlx
import (
"strconv"
"strings"
"syscall"
)
func macOSMajorVersion() int {
ver, err := syscall.Sysctl("kern.osproductversion")
if err != nil {
return 0
}
parts := strings.SplitN(ver, ".", 2)
major, _ := strconv.Atoi(parts[0])
return major
}

View File

@@ -0,0 +1,5 @@
//go:build !darwin
package mlx
func macOSMajorVersion() int { return 0 }

47
x/mlxrunner/mlx/fast.go Normal file
View File

@@ -0,0 +1,47 @@
package mlx
// #include "generated.h"
import "C"
import (
"unsafe"
)
func FastScaledDotProductAttention(q, k, v *Array, scale float32, mode string, mask *Array) *Array {
sinks := New("")
cMode := C.CString(mode)
defer C.free(unsafe.Pointer(cMode))
var maskCtx C.mlx_array
if mask != nil {
maskCtx = mask.ctx
} else {
empty := New("")
maskCtx = empty.ctx
}
out := New("FAST_SDPA")
C.mlx_fast_scaled_dot_product_attention(&out.ctx, q.ctx, k.ctx, v.ctx, C.float(scale), cMode, maskCtx, sinks.ctx, DefaultStream().ctx)
return out
}
type LayerNorm struct {
Weight *Array `weight:"weight"`
Bias *Array `weight:"bias"`
}
func (r *LayerNorm) Forward(x *Array, eps float32) *Array {
out := New("FAST_LAYERNORM")
C.mlx_fast_layer_norm(&out.ctx, x.ctx, r.Weight.ctx, r.Bias.ctx, C.float(eps), DefaultStream().ctx)
return out
}
type RMSNorm struct {
Weight *Array `weight:"weight"`
}
func (r *RMSNorm) Forward(x *Array, eps float32) *Array {
out := New("FAST_RMSNORM")
C.mlx_fast_rms_norm(&out.ctx, x.ctx, r.Weight.ctx, C.float(eps), DefaultStream().ctx)
return out
}

View File

@@ -0,0 +1,663 @@
package mlx
// #include <stdlib.h>
// #include "generated.h"
import "C"
import (
"sync"
"unsafe"
)
var (
gatedDeltaMetalKernelOnce sync.Once
gatedDeltaMetalKernel C.mlx_fast_metal_kernel
gatedDeltaMetalDisabled bool
gatedDeltaCUDAKernelOnce sync.Once
gatedDeltaCUDAKernel C.mlx_fast_cuda_kernel
gatedDeltaCUDADisabled bool
)
const gatedDeltaMetalKernelSource = `
auto n = thread_position_in_grid.z;
auto b_idx = n / Hv;
auto hv_idx = n % Hv;
auto hk_idx = hv_idx / (Hv / Hk);
constexpr int n_per_t = Dk / 32;
// q, k: [B, T, Hk, Dk]
auto q_ = q + b_idx * T * Hk * Dk + hk_idx * Dk;
auto k_ = k + b_idx * T * Hk * Dk + hk_idx * Dk;
// v, y: [B, T, Hv, Dv]
auto v_ = v + b_idx * T * Hv * Dv + hv_idx * Dv;
y += b_idx * T * Hv * Dv + hv_idx * Dv;
auto dk_idx = thread_position_in_threadgroup.x;
auto dv_idx = thread_position_in_grid.y;
// state_in, state_out: [B, Hv, Dv, Dk]
auto i_state = state_in + (n * Dv + dv_idx) * Dk;
auto o_state = state_out + (n * Dv + dv_idx) * Dk;
float state[n_per_t];
for (int i = 0; i < n_per_t; ++i) {
auto s_idx = n_per_t * dk_idx + i;
state[i] = static_cast<float>(i_state[s_idx]);
}
// g: [B, T, Hv]
auto g_ = g + b_idx * T * Hv;
auto beta_ = beta + b_idx * T * Hv;
for (int t = 0; t < T; ++t) {
float kv_mem = 0.0f;
for (int i = 0; i < n_per_t; ++i) {
auto s_idx = n_per_t * dk_idx + i;
state[i] = state[i] * g_[hv_idx];
kv_mem += state[i] * k_[s_idx];
}
kv_mem = simd_sum(kv_mem);
auto delta = (v_[dv_idx] - kv_mem) * beta_[hv_idx];
float out = 0.0f;
for (int i = 0; i < n_per_t; ++i) {
auto s_idx = n_per_t * dk_idx + i;
state[i] = state[i] + k_[s_idx] * delta;
out += state[i] * q_[s_idx];
}
out = simd_sum(out);
if (thread_index_in_simdgroup == 0) {
y[dv_idx] = static_cast<InT>(out);
}
q_ += Hk * Dk;
k_ += Hk * Dk;
v_ += Hv * Dv;
y += Hv * Dv;
g_ += Hv;
beta_ += Hv;
}
for (int i = 0; i < n_per_t; ++i) {
auto s_idx = n_per_t * dk_idx + i;
o_state[s_idx] = static_cast<StT>(state[i]);
}
`
const gatedDeltaCUDAKernelSource = `
auto tid_x = threadIdx.x;
auto tid_y = threadIdx.y;
auto grid_y = blockIdx.y * blockDim.y + tid_y;
auto grid_z = blockIdx.z;
int T_val = static_cast<int>(*T);
auto n = grid_z;
auto b_idx = n / Hv;
auto hv_idx = n % Hv;
auto hk_idx = hv_idx / (Hv / Hk);
constexpr int n_per_t = Dk / 32;
// q, k: [B, T, Hk, Dk]
auto q_ = q + b_idx * T_val * Hk * Dk + hk_idx * Dk;
auto k_ = k + b_idx * T_val * Hk * Dk + hk_idx * Dk;
// v, y: [B, T, Hv, Dv]
auto dv_idx = grid_y;
auto v_ = v + b_idx * T_val * Hv * Dv + hv_idx * Dv;
y += b_idx * T_val * Hv * Dv + hv_idx * Dv;
auto dk_idx = tid_x;
// state_in, state_out: [B, Hv, Dv, Dk]
auto i_state = state_in + (n * Dv + dv_idx) * Dk;
auto o_state = state_out + (n * Dv + dv_idx) * Dk;
float state[n_per_t];
for (int i = 0; i < n_per_t; ++i) {
auto s_idx = n_per_t * dk_idx + i;
state[i] = static_cast<float>(i_state[s_idx]);
}
// g: [B, T, Hv]
auto g_ = g + b_idx * T_val * Hv;
auto beta_ = beta + b_idx * T_val * Hv;
for (int t = 0; t < T_val; ++t) {
float kv_mem = 0.0f;
for (int i = 0; i < n_per_t; ++i) {
auto s_idx = n_per_t * dk_idx + i;
state[i] = state[i] * static_cast<float>(g_[hv_idx]);
kv_mem += state[i] * static_cast<float>(k_[s_idx]);
}
// Warp reduction (full warp, 32 threads in x)
for (int offset = 16; offset > 0; offset >>= 1)
kv_mem += __shfl_down_sync(0xffffffff, kv_mem, offset);
kv_mem = __shfl_sync(0xffffffff, kv_mem, 0);
auto delta = (static_cast<float>(v_[dv_idx]) - kv_mem) * static_cast<float>(beta_[hv_idx]);
float out = 0.0f;
for (int i = 0; i < n_per_t; ++i) {
auto s_idx = n_per_t * dk_idx + i;
state[i] = state[i] + static_cast<float>(k_[s_idx]) * delta;
out += state[i] * static_cast<float>(q_[s_idx]);
}
// Warp reduction
for (int offset = 16; offset > 0; offset >>= 1)
out += __shfl_down_sync(0xffffffff, out, offset);
if (tid_x == 0) {
y[dv_idx] = static_cast<InT>(out);
}
q_ += Hk * Dk;
k_ += Hk * Dk;
v_ += Hv * Dv;
y += Hv * Dv;
g_ += Hv;
beta_ += Hv;
}
for (int i = 0; i < n_per_t; ++i) {
auto s_idx = n_per_t * dk_idx + i;
o_state[s_idx] = static_cast<StT>(state[i]);
}
`
func cStringVector(values []string) (C.mlx_vector_string, func(), bool) {
vec := C.mlx_vector_string_new()
ok := true
for _, s := range values {
cs := C.CString(s)
if C.mlx_vector_string_append_value(vec, cs) != 0 {
ok = false
}
C.free(unsafe.Pointer(cs))
if !ok {
break
}
}
cleanup := func() {
C.mlx_vector_string_free(vec)
}
return vec, cleanup, ok
}
func initGatedDeltaMetalKernel() {
inputs, freeInputs, ok := cStringVector([]string{"q", "k", "v", "g", "beta", "state_in", "T"})
if !ok {
gatedDeltaMetalDisabled = true
freeInputs()
return
}
defer freeInputs()
outputs, freeOutputs, ok := cStringVector([]string{"y", "state_out"})
if !ok {
gatedDeltaMetalDisabled = true
freeOutputs()
return
}
defer freeOutputs()
cName := C.CString("gated_delta_step")
defer C.free(unsafe.Pointer(cName))
cSource := C.CString(gatedDeltaMetalKernelSource)
defer C.free(unsafe.Pointer(cSource))
cHeader := C.CString("")
defer C.free(unsafe.Pointer(cHeader))
gatedDeltaMetalKernel = C.mlx_fast_metal_kernel_new(
cName,
inputs,
outputs,
cSource,
cHeader,
C.bool(true),
C.bool(false),
)
}
// gatedDeltaKernel runs a fused Metal kernel for the qwen3.5 recurrent update.
// It returns ok=false on unsupported shapes/devices or kernel setup/apply failure.
func gatedDeltaKernel(q, k, v, g, beta, state *Array) (y, nextState *Array, ok bool) {
if gatedDeltaMetalDisabled {
return nil, nil, false
}
if q == nil || k == nil || v == nil || g == nil || beta == nil || state == nil {
return nil, nil, false
}
qd := q.Dims()
kd := k.Dims()
vd := v.Dims()
gd := g.Dims()
bd := beta.Dims()
sd := state.Dims()
if len(qd) != 4 || len(kd) != 4 || len(vd) != 4 || len(gd) != 3 || len(bd) != 3 || len(sd) != 4 {
return nil, nil, false
}
B, T, Hk, Dk := qd[0], qd[1], qd[2], qd[3]
if T <= 0 || Hk <= 0 || Dk <= 0 || Dk%32 != 0 {
return nil, nil, false
}
if kd[0] != B || kd[1] != T || kd[2] != Hk || kd[3] != Dk {
return nil, nil, false
}
Hv, Dv := vd[2], vd[3]
if vd[0] != B || vd[1] != T || Hv <= 0 || Dv <= 0 || Hv%Hk != 0 {
return nil, nil, false
}
if gd[0] != B || gd[1] != T || gd[2] != Hv {
return nil, nil, false
}
if bd[0] != B || bd[1] != T || bd[2] != Hv {
return nil, nil, false
}
if sd[0] != B || sd[1] != Hv || sd[2] != Dv || sd[3] != Dk {
return nil, nil, false
}
inputDType := q.DType()
stateDType := state.DType()
if k.DType() != inputDType || v.DType() != inputDType || g.DType() != inputDType || beta.DType() != inputDType {
return nil, nil, false
}
gatedDeltaMetalKernelOnce.Do(initGatedDeltaMetalKernel)
if gatedDeltaMetalDisabled {
return nil, nil, false
}
cfg := C.mlx_fast_metal_kernel_config_new()
defer C.mlx_fast_metal_kernel_config_free(cfg)
cInT := C.CString("InT")
defer C.free(unsafe.Pointer(cInT))
if C.mlx_fast_metal_kernel_config_add_template_arg_dtype(cfg, cInT, C.mlx_dtype(inputDType)) != 0 {
gatedDeltaMetalDisabled = true
return nil, nil, false
}
cStT := C.CString("StT")
defer C.free(unsafe.Pointer(cStT))
if C.mlx_fast_metal_kernel_config_add_template_arg_dtype(cfg, cStT, C.mlx_dtype(stateDType)) != 0 {
gatedDeltaMetalDisabled = true
return nil, nil, false
}
for _, tpl := range []struct {
name string
value int
}{
{name: "Dk", value: Dk},
{name: "Dv", value: Dv},
{name: "Hk", value: Hk},
{name: "Hv", value: Hv},
} {
cn := C.CString(tpl.name)
rc := C.mlx_fast_metal_kernel_config_add_template_arg_int(cfg, cn, C.int(tpl.value))
C.free(unsafe.Pointer(cn))
if rc != 0 {
gatedDeltaMetalDisabled = true
return nil, nil, false
}
}
yShape := []C.int{C.int(B), C.int(T), C.int(Hv), C.int(Dv)}
stateShape := []C.int{C.int(B), C.int(Hv), C.int(Dv), C.int(Dk)}
if C.mlx_fast_metal_kernel_config_add_output_arg(cfg, unsafe.SliceData(yShape), C.size_t(len(yShape)), C.mlx_dtype(inputDType)) != 0 {
gatedDeltaMetalDisabled = true
return nil, nil, false
}
if C.mlx_fast_metal_kernel_config_add_output_arg(cfg, unsafe.SliceData(stateShape), C.size_t(len(stateShape)), C.mlx_dtype(stateDType)) != 0 {
gatedDeltaMetalDisabled = true
return nil, nil, false
}
if C.mlx_fast_metal_kernel_config_set_grid(cfg, 32, C.int(Dv), C.int(B*Hv)) != 0 {
gatedDeltaMetalDisabled = true
return nil, nil, false
}
threadY := Dv
if threadY > 4 {
threadY = 4
}
if C.mlx_fast_metal_kernel_config_set_thread_group(cfg, 32, C.int(threadY), 1) != 0 {
gatedDeltaMetalDisabled = true
return nil, nil, false
}
tScalar := FromValue(T)
inputs := []C.mlx_array{
q.ctx,
k.ctx,
v.ctx,
g.ctx,
beta.ctx,
state.ctx,
tScalar.ctx,
}
inVec := C.mlx_vector_array_new_data(unsafe.SliceData(inputs), C.size_t(len(inputs)))
defer C.mlx_vector_array_free(inVec)
outVec := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(outVec)
if C.mlx_fast_metal_kernel_apply(&outVec, gatedDeltaMetalKernel, inVec, cfg, DefaultStream().ctx) != 0 {
gatedDeltaMetalDisabled = true
return nil, nil, false
}
if int(C.mlx_vector_array_size(outVec)) < 2 {
return nil, nil, false
}
y = New("GATED_DELTA_METAL_Y")
nextState = New("GATED_DELTA_METAL_STATE")
C.mlx_vector_array_get(&y.ctx, outVec, 0)
C.mlx_vector_array_get(&nextState.ctx, outVec, 1)
return y, nextState, true
}
func repeatHeadsForGatedDelta(x *Array, repeatFactor int) *Array {
if repeatFactor <= 1 {
return x
}
shape := x.Dims()
x = ExpandDims(x, 3)
x = Tile(x, []int32{1, 1, 1, int32(repeatFactor), 1})
return Reshape(x, int32(shape[0]), int32(shape[1]), int32(shape[2]*repeatFactor), int32(shape[3]))
}
func gatedDeltaFallback(q, k, v, g, beta, state *Array) (y, nextState *Array) {
if q == nil || k == nil || v == nil || g == nil || beta == nil || state == nil {
return nil, nil
}
qd := q.Dims()
kd := k.Dims()
vd := v.Dims()
gd := g.Dims()
bd := beta.Dims()
sd := state.Dims()
if len(qd) != 4 || len(kd) != 4 || len(vd) != 4 || len(gd) != 3 || len(bd) != 3 || len(sd) != 4 {
return nil, nil
}
B, T, Hk, Dk := int32(qd[0]), int32(qd[1]), int32(qd[2]), int32(qd[3])
Hv, Dv := int32(vd[2]), int32(vd[3])
if T <= 0 || Hk <= 0 || Dk <= 0 || Hv <= 0 || Dv <= 0 || Hv%Hk != 0 {
return nil, nil
}
if kd[0] != int(B) || kd[1] != int(T) || kd[2] != int(Hk) || kd[3] != int(Dk) {
return nil, nil
}
if vd[0] != int(B) || vd[1] != int(T) {
return nil, nil
}
if gd[0] != int(B) || gd[1] != int(T) || gd[2] != int(Hv) {
return nil, nil
}
if bd[0] != int(B) || bd[1] != int(T) || bd[2] != int(Hv) {
return nil, nil
}
if sd[0] != int(B) || sd[1] != int(Hv) || sd[2] != int(Dv) || sd[3] != int(Dk) {
return nil, nil
}
repeatFactor := int(Hv / Hk)
q = repeatHeadsForGatedDelta(q, repeatFactor)
k = repeatHeadsForGatedDelta(k, repeatFactor)
nextState = state
if T == 1 {
qt := Squeeze(q, 1)
kt := Squeeze(k, 1)
vt := Squeeze(v, 1)
gt := Squeeze(g, 1)
bt := Squeeze(beta, 1)
nextState = Mul(nextState, ExpandDims(ExpandDims(gt, -1), -1))
kvMem := Sum(Mul(nextState, ExpandDims(kt, 2)), -1, false)
delta := Mul(Sub(vt, kvMem), ExpandDims(bt, -1))
nextState = Add(nextState, Mul(ExpandDims(kt, 2), ExpandDims(delta, -1)))
yt := Sum(Mul(nextState, ExpandDims(qt, 2)), -1, false)
return ExpandDims(yt, 1), nextState
}
outs := make([]*Array, 0, T)
for t := int32(0); t < T; t++ {
qt := Squeeze(SliceStartStop(q, []int32{0, t, 0, 0}, []int32{B, t + 1, Hv, Dk}), 1)
kt := Squeeze(SliceStartStop(k, []int32{0, t, 0, 0}, []int32{B, t + 1, Hv, Dk}), 1)
vt := Squeeze(SliceStartStop(v, []int32{0, t, 0, 0}, []int32{B, t + 1, Hv, Dv}), 1)
gt := Squeeze(SliceStartStop(g, []int32{0, t, 0}, []int32{B, t + 1, Hv}), 1)
bt := Squeeze(SliceStartStop(beta, []int32{0, t, 0}, []int32{B, t + 1, Hv}), 1)
nextState = Mul(nextState, ExpandDims(ExpandDims(gt, -1), -1))
kvMem := Sum(Mul(nextState, ExpandDims(kt, 2)), -1, false)
delta := Mul(Sub(vt, kvMem), ExpandDims(bt, -1))
nextState = Add(nextState, Mul(ExpandDims(kt, 2), ExpandDims(delta, -1)))
yt := Sum(Mul(nextState, ExpandDims(qt, 2)), -1, false)
outs = append(outs, ExpandDims(yt, 1))
}
return Concatenate(outs, 1), nextState
}
func initGatedDeltaCUDAKernel() {
var cudaAvail C.bool
if C.mlx_cuda_is_available(&cudaAvail) != 0 || !bool(cudaAvail) {
gatedDeltaCUDADisabled = true
return
}
inputs, freeInputs, ok := cStringVector([]string{"q", "k", "v", "g", "beta", "state_in", "T"})
if !ok {
gatedDeltaCUDADisabled = true
freeInputs()
return
}
defer freeInputs()
outputs, freeOutputs, ok := cStringVector([]string{"y", "state_out"})
if !ok {
gatedDeltaCUDADisabled = true
freeOutputs()
return
}
defer freeOutputs()
cName := C.CString("gated_delta_step")
defer C.free(unsafe.Pointer(cName))
cSource := C.CString(gatedDeltaCUDAKernelSource)
defer C.free(unsafe.Pointer(cSource))
cHeader := C.CString("")
defer C.free(unsafe.Pointer(cHeader))
gatedDeltaCUDAKernel = C.mlx_fast_cuda_kernel_new(
cName,
inputs,
outputs,
cSource,
cHeader,
C.bool(true),
C.int(0),
)
}
func gatedDeltaCUDAKernelApply(q, k, v, g, beta, state *Array) (y, nextState *Array, ok bool) {
if gatedDeltaCUDADisabled {
return nil, nil, false
}
if q == nil || k == nil || v == nil || g == nil || beta == nil || state == nil {
return nil, nil, false
}
qd := q.Dims()
kd := k.Dims()
vd := v.Dims()
gd := g.Dims()
bd := beta.Dims()
sd := state.Dims()
if len(qd) != 4 || len(kd) != 4 || len(vd) != 4 || len(gd) != 3 || len(bd) != 3 || len(sd) != 4 {
return nil, nil, false
}
B, T, Hk, Dk := qd[0], qd[1], qd[2], qd[3]
if T <= 0 || Hk <= 0 || Dk <= 0 || Dk%32 != 0 {
return nil, nil, false
}
if kd[0] != B || kd[1] != T || kd[2] != Hk || kd[3] != Dk {
return nil, nil, false
}
Hv, Dv := vd[2], vd[3]
if vd[0] != B || vd[1] != T || Hv <= 0 || Dv <= 0 || Hv%Hk != 0 {
return nil, nil, false
}
if gd[0] != B || gd[1] != T || gd[2] != Hv {
return nil, nil, false
}
if bd[0] != B || bd[1] != T || bd[2] != Hv {
return nil, nil, false
}
if sd[0] != B || sd[1] != Hv || sd[2] != Dv || sd[3] != Dk {
return nil, nil, false
}
inputDType := q.DType()
stateDType := state.DType()
if k.DType() != inputDType || v.DType() != inputDType || g.DType() != inputDType || beta.DType() != inputDType {
return nil, nil, false
}
gatedDeltaCUDAKernelOnce.Do(initGatedDeltaCUDAKernel)
if gatedDeltaCUDADisabled {
return nil, nil, false
}
cfg := C.mlx_fast_cuda_kernel_config_new()
defer C.mlx_fast_cuda_kernel_config_free(cfg)
cInT := C.CString("InT")
defer C.free(unsafe.Pointer(cInT))
if C.mlx_fast_cuda_kernel_config_add_template_arg_dtype(cfg, cInT, C.mlx_dtype(inputDType)) != 0 {
gatedDeltaCUDADisabled = true
return nil, nil, false
}
cStT := C.CString("StT")
defer C.free(unsafe.Pointer(cStT))
if C.mlx_fast_cuda_kernel_config_add_template_arg_dtype(cfg, cStT, C.mlx_dtype(stateDType)) != 0 {
gatedDeltaCUDADisabled = true
return nil, nil, false
}
for _, tpl := range []struct {
name string
value int
}{
{name: "Dk", value: Dk},
{name: "Dv", value: Dv},
{name: "Hk", value: Hk},
{name: "Hv", value: Hv},
} {
cn := C.CString(tpl.name)
rc := C.mlx_fast_cuda_kernel_config_add_template_arg_int(cfg, cn, C.int(tpl.value))
C.free(unsafe.Pointer(cn))
if rc != 0 {
gatedDeltaCUDADisabled = true
return nil, nil, false
}
}
yShape := []C.int{C.int(B), C.int(T), C.int(Hv), C.int(Dv)}
stateShape := []C.int{C.int(B), C.int(Hv), C.int(Dv), C.int(Dk)}
if C.mlx_fast_cuda_kernel_config_add_output_arg(cfg, unsafe.SliceData(yShape), C.size_t(len(yShape)), C.mlx_dtype(inputDType)) != 0 {
gatedDeltaCUDADisabled = true
return nil, nil, false
}
if C.mlx_fast_cuda_kernel_config_add_output_arg(cfg, unsafe.SliceData(stateShape), C.size_t(len(stateShape)), C.mlx_dtype(stateDType)) != 0 {
gatedDeltaCUDADisabled = true
return nil, nil, false
}
if C.mlx_fast_cuda_kernel_config_set_grid(cfg, 32, C.int(Dv), C.int(B*Hv)) != 0 {
gatedDeltaCUDADisabled = true
return nil, nil, false
}
threadY := Dv
if threadY > 4 {
threadY = 4
}
if C.mlx_fast_cuda_kernel_config_set_thread_group(cfg, 32, C.int(threadY), 1) != 0 {
gatedDeltaCUDADisabled = true
return nil, nil, false
}
tScalar := FromValue(T)
inputs := []C.mlx_array{
q.ctx,
k.ctx,
v.ctx,
g.ctx,
beta.ctx,
state.ctx,
tScalar.ctx,
}
inVec := C.mlx_vector_array_new_data(unsafe.SliceData(inputs), C.size_t(len(inputs)))
defer C.mlx_vector_array_free(inVec)
outVec := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(outVec)
if C.mlx_fast_cuda_kernel_apply(&outVec, gatedDeltaCUDAKernel, inVec, cfg, DefaultStream().ctx) != 0 {
gatedDeltaCUDADisabled = true
return nil, nil, false
}
if int(C.mlx_vector_array_size(outVec)) < 2 {
return nil, nil, false
}
y = New("GATED_DELTA_CUDA_Y")
nextState = New("GATED_DELTA_CUDA_STATE")
C.mlx_vector_array_get(&y.ctx, outVec, 0)
C.mlx_vector_array_get(&nextState.ctx, outVec, 1)
return y, nextState, true
}
// FastGatedDelta runs the recurrent update operation.
//
// When mask is non-nil, it must be a [B, T] bool tensor identifying real
// (true) vs. padded (false) positions in q/k/v/g/beta. Padded positions
// are substituted with neutral values (q=k=v=beta=0, g=1) so each padded
// kernel iteration is a no-op — state passes through unchanged and the
// final state equals the state after the last real token of each row.
//
// It tries the fused CUDA kernel first, then Metal, then falls back to a
// backend-agnostic MLX implementation with identical inputs/outputs.
func FastGatedDelta(q, k, v, g, beta, state, mask *Array) (y, nextState *Array) {
// TODO: handle this more efficiently with a masked kernel (MLX-LM has one).
if mask != nil {
B := int32(mask.Dim(0))
T := int32(mask.Dim(1))
m4 := Reshape(mask, B, T, 1, 1)
m3 := Reshape(mask, B, T, 1)
zeroQ := FromValue(float32(0)).AsType(q.DType())
zeroK := FromValue(float32(0)).AsType(k.DType())
zeroV := FromValue(float32(0)).AsType(v.DType())
zeroBeta := FromValue(float32(0)).AsType(beta.DType())
oneG := FromValue(float32(1)).AsType(g.DType())
q = Where(m4, q, zeroQ)
k = Where(m4, k, zeroK)
v = Where(m4, v, zeroV)
beta = Where(m3, beta, zeroBeta)
g = Where(m3, g, oneG)
}
if y, nextState, ok := gatedDeltaCUDAKernelApply(q, k, v, g, beta, state); ok {
return y, nextState
}
if y, nextState, ok := gatedDeltaKernel(q, k, v, g, beta, state); ok {
return y, nextState
}
y, nextState = gatedDeltaFallback(q, k, v, g, beta, state)
if y == nil || nextState == nil {
panic("mlx.FastGatedDelta: fallback failed (invalid inputs or unsupported shapes)")
}
return y, nextState
}

3012
x/mlxrunner/mlx/generated.c Normal file

File diff suppressed because it is too large Load Diff

7256
x/mlxrunner/mlx/generated.h Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,17 @@
// This code is auto-generated; DO NOT EDIT.
#include "generated.h"
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
{{ range .Functions }}
{{ .Type }} (*{{ .Name }}_){{ .Parameters }} = NULL;
{{- end }}
int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
{{- range .Functions }}
{{ if .Optional }}OPTIONAL_LOAD{{ else }}CHECK_LOAD{{ end }}(handle, {{ .Name }});
{{- end }}
return 0;
}

View File

@@ -0,0 +1,26 @@
// This code is auto-generated; DO NOT EDIT.
#ifndef MLX_GENERATED_H
#define MLX_GENERATED_H
#include "dynamic.h"
{{ range .Functions }}
#define {{ .Name }} {{ .Name }}_mlx_gen_orig_
{{- end }}
#include "mlx/c/mlx.h"
{{ range .Functions }}
#undef {{ .Name }}
{{- end }}
{{ range .Functions }}
extern {{ .Type }} (*{{ .Name }}_){{ .Parameters }};
{{- end }}
int mlx_dynamic_load_symbols(mlx_dynamic_handle handle);
{{ range .Functions }}
static inline {{ .Type }} {{ .Name }}{{ .Parameters }} {{ "{" }}
return {{ .Name }}_({{ .Args }});
{{ "}" }}
{{- end }}
#endif // MLX_GENERATED_H

View File

@@ -0,0 +1,157 @@
package main
import (
"embed"
"flag"
"fmt"
"os"
"path/filepath"
"slices"
"strings"
"text/template"
tree_sitter "github.com/tree-sitter/go-tree-sitter"
tree_sitter_cpp "github.com/tree-sitter/tree-sitter-cpp/bindings/go"
)
//go:embed *.gotmpl
var fsys embed.FS
// optionalSymbols lists symbols that may not be present in all builds
// (e.g., float16/bfloat16 are unavailable in CUDA builds of MLX).
var optionalSymbols = map[string]bool{
"mlx_array_item_float16": true,
"mlx_array_item_bfloat16": true,
"mlx_array_data_float16": true,
"mlx_array_data_bfloat16": true,
}
type Function struct {
Type,
Name,
Parameters,
Args string
Optional bool
}
func ParseFunction(node *tree_sitter.Node, tc *tree_sitter.TreeCursor, source []byte) Function {
var fn Function
fn.Name = node.ChildByFieldName("declarator").Utf8Text(source)
if params := node.ChildByFieldName("parameters"); params != nil {
fn.Parameters = params.Utf8Text(source)
fn.Args = ParseParameters(params, tc, source)
}
var types []string
for node.Parent() != nil && node.Parent().Kind() != "declaration" {
if node.Parent().Kind() == "pointer_declarator" {
types = append(types, "*")
}
node = node.Parent()
}
for sibling := node.PrevSibling(); sibling != nil; sibling = sibling.PrevSibling() {
types = append(types, sibling.Utf8Text(source))
}
slices.Reverse(types)
fn.Type = strings.Join(types, " ")
return fn
}
func ParseParameters(node *tree_sitter.Node, tc *tree_sitter.TreeCursor, source []byte) string {
var s []string
for _, child := range node.Children(tc) {
if child.IsNamed() {
child := child.ChildByFieldName("declarator")
for child != nil && child.Kind() != "identifier" {
if child.Kind() == "parenthesized_declarator" {
child = child.Child(1)
} else {
child = child.ChildByFieldName("declarator")
}
}
if child != nil {
s = append(s, child.Utf8Text(source))
}
}
}
return strings.Join(s, ", ")
}
func main() {
var output string
flag.StringVar(&output, "output", ".", "Output directory for generated files")
flag.Parse()
parser := tree_sitter.NewParser()
defer parser.Close()
language := tree_sitter.NewLanguage(tree_sitter_cpp.Language())
parser.SetLanguage(language)
query, _ := tree_sitter.NewQuery(language, `(function_declarator declarator: (identifier)) @func`)
defer query.Close()
qc := tree_sitter.NewQueryCursor()
defer qc.Close()
var files []string
for _, arg := range flag.Args() {
matches, err := filepath.Glob(arg)
if err != nil {
fmt.Fprintf(os.Stderr, "Error expanding glob %s: %v\n", arg, err)
continue
}
files = append(files, matches...)
}
var funs []Function
for _, arg := range files {
bts, err := os.ReadFile(arg)
if err != nil {
fmt.Fprintf(os.Stderr, "Error reading file %s: %v\n", arg, err)
continue
}
tree := parser.Parse(bts, nil)
defer tree.Close()
tc := tree.Walk()
defer tc.Close()
matches := qc.Matches(query, tree.RootNode(), bts)
for match := matches.Next(); match != nil; match = matches.Next() {
for _, capture := range match.Captures {
fn := ParseFunction(&capture.Node, tc, bts)
fn.Optional = optionalSymbols[fn.Name]
funs = append(funs, fn)
}
}
}
tmpl, err := template.New("").ParseFS(fsys, "*.gotmpl")
if err != nil {
fmt.Fprintf(os.Stderr, "Error parsing template: %v\n", err)
return
}
for _, tmpl := range tmpl.Templates() {
name := filepath.Join(output, strings.TrimSuffix(tmpl.Name(), ".gotmpl"))
fmt.Println("Generating", name)
f, err := os.Create(name)
if err != nil {
fmt.Fprintf(os.Stderr, "Error creating file %s: %v\n", name, err)
continue
}
defer f.Close()
if err := tmpl.Execute(f, map[string]any{
"Functions": funs,
}); err != nil {
fmt.Fprintf(os.Stderr, "Error executing template %s: %v\n", tmpl.Name(), err)
}
}
}

View File

@@ -0,0 +1,12 @@
# Vendored MLX-C Headers
These header files are vendored from [mlx-c](https://github.com/ml-explore/mlx-c).
The pinned version is in `MLX_C_VERSION` at the repo root.
Headers are automatically refreshed when you run a CMake build:
```shell
cmake --preset 'MLX CUDA 13'
```
See the [MLX Engine](../../../../../../../docs/development.md#mlx-engine-optional) section of the development docs for full build instructions.

View File

@@ -0,0 +1,420 @@
/* Copyright © 2023-2024 Apple Inc. */
#ifndef MLX_ARRAY_H
#define MLX_ARRAY_H
#include "mlx/c/string.h"
#include <float.h>
#include <stdbool.h>
#include <stdint.h>
#include <stdlib.h>
// Complex number support
#ifdef _MSC_VER
#define _CRT_USE_C_COMPLEX_H
#include <complex.h>
typedef _Fcomplex mlx_complex64_t;
#else
#include <complex.h>
typedef float _Complex mlx_complex64_t;
#endif
#include "half.h"
#ifdef __cplusplus
extern "C" {
#endif
/**
* \defgroup mlx_array Array
* MLX N-dimensional array object.
*/
/**@{*/
/**
* A N-dimensional array object.
*/
typedef struct mlx_array_ {
void* ctx;
} mlx_array;
static mlx_array mlx_array_empty;
/**
* Array element type.
*/
typedef enum mlx_dtype_ {
MLX_BOOL,
MLX_UINT8,
MLX_UINT16,
MLX_UINT32,
MLX_UINT64,
MLX_INT8,
MLX_INT16,
MLX_INT32,
MLX_INT64,
MLX_FLOAT16,
MLX_FLOAT32,
MLX_FLOAT64,
MLX_BFLOAT16,
MLX_COMPLEX64,
} mlx_dtype;
/**
* Size of given mlx_dtype datatype in bytes.
*/
size_t mlx_dtype_size(mlx_dtype dtype);
/**
* Get array description.
*/
int mlx_array_tostring(mlx_string* str, const mlx_array arr);
/**
* New empty array.
*/
mlx_array mlx_array_new(void);
/**
* Free an array.
*/
int mlx_array_free(mlx_array arr);
/**
* New array from a bool scalar.
*/
mlx_array mlx_array_new_bool(bool val);
/**
* New array from a int scalar.
*/
mlx_array mlx_array_new_int(int val);
/**
* New array from a float32 scalar.
*/
mlx_array mlx_array_new_float32(float val);
/**
* New array from a float scalar.
* Same as float32.
*/
mlx_array mlx_array_new_float(float val);
/**
* New array from a float64 scalar.
*/
mlx_array mlx_array_new_float64(double val);
/**
* New array from a double scalar.
* Same as float64.
*/
mlx_array mlx_array_new_double(double val);
/**
* New array from a complex scalar.
*/
mlx_array mlx_array_new_complex(float real_val, float imag_val);
/**
* New array from existing buffer.
* @param data A buffer which will be copied.
* @param shape Shape of the array.
* @param dim Number of dimensions (size of `shape`).
* @param dtype Type of array elements.
*/
mlx_array mlx_array_new_data(
const void* data,
const int* shape,
int dim,
mlx_dtype dtype);
/**
* New array from existing buffer.
* @param data A buffer which will be copied.
* @param shape Shape of the array.
* @param dim Number of dimensions (size of `shape`).
* @param dtype Type of array elements.
* @param dtor Callback for when the buffer is no longer needed.
*/
mlx_array mlx_array_new_data_managed(
void* data,
const int* shape,
int dim,
mlx_dtype dtype,
void (*dtor)(void*));
/**
* New array from existing buffer.
* @param data A buffer which will be copied.
* @param shape Shape of the array.
* @param dim Number of dimensions (size of `shape`).
* @param dtype Type of array elements.
* @param payload Payload pointer passed to the `dtor` callback instead of
* `data`.
* @param dtor Callback for when the buffer is no longer needed.
*/
mlx_array mlx_array_new_data_managed_payload(
void* data,
const int* shape,
int dim,
mlx_dtype dtype,
void* payload,
void (*dtor)(void*));
/**
* Set array to provided src array.
*/
int mlx_array_set(mlx_array* arr, const mlx_array src);
/**
* Set array to a bool scalar.
*/
int mlx_array_set_bool(mlx_array* arr, bool val);
/**
* Set array to a int scalar.
*/
int mlx_array_set_int(mlx_array* arr, int val);
/**
* Set array to a float32 scalar.
*/
int mlx_array_set_float32(mlx_array* arr, float val);
/**
* Set array to a float scalar.
*/
int mlx_array_set_float(mlx_array* arr, float val);
/**
* Set array to a float64 scalar.
*/
int mlx_array_set_float64(mlx_array* arr, double val);
/**
* Set array to a double scalar.
*/
int mlx_array_set_double(mlx_array* arr, double val);
/**
* Set array to a complex scalar.
*/
int mlx_array_set_complex(mlx_array* arr, float real_val, float imag_val);
/**
* Set array to specified data and shape.
* @param arr Destination array.
* @param data A buffer which will be copied.
* @param shape Shape of the array.
* @param dim Number of dimensions (size of `shape`).
* @param dtype Type of array elements.
*/
int mlx_array_set_data(
mlx_array* arr,
const void* data,
const int* shape,
int dim,
mlx_dtype dtype);
/**
* The size of the array's datatype in bytes.
*/
size_t mlx_array_itemsize(const mlx_array arr);
/**
* Number of elements in the array.
*/
size_t mlx_array_size(const mlx_array arr);
/**
* The number of bytes in the array.
*/
size_t mlx_array_nbytes(const mlx_array arr);
/**
* The array's dimension.
*/
size_t mlx_array_ndim(const mlx_array arr);
/**
* The shape of the array.
* Returns: a pointer to the sizes of each dimension.
*/
const int* mlx_array_shape(const mlx_array arr);
/**
* The strides of the array.
* Returns: a pointer to the sizes of each dimension.
*/
const size_t* mlx_array_strides(const mlx_array arr);
/**
* The shape of the array in a particular dimension.
*/
int mlx_array_dim(const mlx_array arr, int dim);
/**
* The array element type.
*/
mlx_dtype mlx_array_dtype(const mlx_array arr);
/**
* Evaluate the array.
*/
int mlx_array_eval(mlx_array arr);
/**
* Access the value of a scalar array.
*/
int mlx_array_item_bool(bool* res, const mlx_array arr);
/**
* Access the value of a scalar array.
*/
int mlx_array_item_uint8(uint8_t* res, const mlx_array arr);
/**
* Access the value of a scalar array.
*/
int mlx_array_item_uint16(uint16_t* res, const mlx_array arr);
/**
* Access the value of a scalar array.
*/
int mlx_array_item_uint32(uint32_t* res, const mlx_array arr);
/**
* Access the value of a scalar array.
*/
int mlx_array_item_uint64(uint64_t* res, const mlx_array arr);
/**
* Access the value of a scalar array.
*/
int mlx_array_item_int8(int8_t* res, const mlx_array arr);
/**
* Access the value of a scalar array.
*/
int mlx_array_item_int16(int16_t* res, const mlx_array arr);
/**
* Access the value of a scalar array.
*/
int mlx_array_item_int32(int32_t* res, const mlx_array arr);
/**
* Access the value of a scalar array.
*/
int mlx_array_item_int64(int64_t* res, const mlx_array arr);
/**
* Access the value of a scalar array.
*/
int mlx_array_item_float32(float* res, const mlx_array arr);
/**
* Access the value of a scalar array.
*/
int mlx_array_item_float64(double* res, const mlx_array arr);
/**
* Access the value of a scalar array.
*/
int mlx_array_item_complex64(mlx_complex64_t* res, const mlx_array arr);
#ifdef HAS_FLOAT16
/**
* Access the value of a scalar array.
*/
int mlx_array_item_float16(float16_t* res, const mlx_array arr);
#endif
#ifdef HAS_BFLOAT16
/**
* Access the value of a scalar array.
*/
int mlx_array_item_bfloat16(bfloat16_t* res, const mlx_array arr);
#endif
/**
* Returns a pointer to the array data, cast to `bool*`.
* Array must be evaluated, otherwise returns NULL.
*/
const bool* mlx_array_data_bool(const mlx_array arr);
/**
* Returns a pointer to the array data, cast to `uint8_t*`.
* Array must be evaluated, otherwise returns NULL.
*/
const uint8_t* mlx_array_data_uint8(const mlx_array arr);
/**
* Returns a pointer to the array data, cast to `uint16_t*`.
* Array must be evaluated, otherwise returns NULL.
*/
const uint16_t* mlx_array_data_uint16(const mlx_array arr);
/**
* Returns a pointer to the array data, cast to `uint32_t*`.
* Array must be evaluated, otherwise returns NULL.
*/
const uint32_t* mlx_array_data_uint32(const mlx_array arr);
/**
* Returns a pointer to the array data, cast to `uint64_t*`.
* Array must be evaluated, otherwise returns NULL.
*/
const uint64_t* mlx_array_data_uint64(const mlx_array arr);
/**
* Returns a pointer to the array data, cast to `int8_t*`.
* Array must be evaluated, otherwise returns NULL.
*/
const int8_t* mlx_array_data_int8(const mlx_array arr);
/**
* Returns a pointer to the array data, cast to `int16_t*`.
* Array must be evaluated, otherwise returns NULL.
*/
const int16_t* mlx_array_data_int16(const mlx_array arr);
/**
* Returns a pointer to the array data, cast to `int32_t*`.
* Array must be evaluated, otherwise returns NULL.
*/
const int32_t* mlx_array_data_int32(const mlx_array arr);
/**
* Returns a pointer to the array data, cast to `int64_t*`.
* Array must be evaluated, otherwise returns NULL.
*/
const int64_t* mlx_array_data_int64(const mlx_array arr);
/**
* Returns a pointer to the array data, cast to `float32*`.
* Array must be evaluated, otherwise returns NULL.
*/
const float* mlx_array_data_float32(const mlx_array arr);
/**
* Returns a pointer to the array data, cast to `float64*`.
* Array must be evaluated, otherwise returns NULL.
*/
const double* mlx_array_data_float64(const mlx_array arr);
/**
* Returns a pointer to the array data, cast to `_Complex*`.
* Array must be evaluated, otherwise returns NULL.
*/
const mlx_complex64_t* mlx_array_data_complex64(const mlx_array arr);
#ifdef HAS_FLOAT16
/**
* Returns a pointer to the array data, cast to `float16_t*`.
* Array must be evaluated, otherwise returns NULL.
*/
const float16_t* mlx_array_data_float16(const mlx_array arr);
#endif
#ifdef HAS_BFLOAT16
/**
* Returns a pointer to the array data, cast to `bfloat16_t*`.
* Array must be evaluated, otherwise returns NULL.
*/
const bfloat16_t* mlx_array_data_bfloat16(const mlx_array arr);
#endif
/**
* Check if the array is available.
* Internal function: use at your own risk.
*/
int _mlx_array_is_available(bool* res, const mlx_array arr);
/**
* Wait on the array to be available. After this `_mlx_array_is_available`
* returns `true`. Internal function: use at your own risk.
*/
int _mlx_array_wait(const mlx_array arr);
/**
* Whether the array is contiguous in memory.
* Internal function: use at your own risk.
*/
int _mlx_array_is_contiguous(bool* res, const mlx_array arr);
/**
* Whether the array's rows are contiguous in memory.
* Internal function: use at your own risk.
*/
int _mlx_array_is_row_contiguous(bool* res, const mlx_array arr);
/**
* Whether the array's columns are contiguous in memory.
* Internal function: use at your own risk.
*/
int _mlx_array_is_col_contiguous(bool* res, const mlx_array arr);
/**@}*/
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -0,0 +1,197 @@
/* Copyright © 2023-2024 Apple Inc. */
/* */
/* This file is auto-generated. Do not edit manually. */
/* */
#ifndef MLX_CLOSURE_H
#define MLX_CLOSURE_H
#include "mlx/c/array.h"
#include "mlx/c/map.h"
#include "mlx/c/optional.h"
#include "mlx/c/stream.h"
#include "mlx/c/vector.h"
#ifdef __cplusplus
extern "C" {
#endif
/**
* \defgroup mlx_closure Closures
* MLX closure objects.
*/
/**@{*/
typedef struct mlx_closure_ {
void* ctx;
} mlx_closure;
mlx_closure mlx_closure_new(void);
int mlx_closure_free(mlx_closure cls);
mlx_closure mlx_closure_new_func(
int (*fun)(mlx_vector_array*, const mlx_vector_array));
mlx_closure mlx_closure_new_func_payload(
int (*fun)(mlx_vector_array*, const mlx_vector_array, void*),
void* payload,
void (*dtor)(void*));
int mlx_closure_set(mlx_closure* cls, const mlx_closure src);
int mlx_closure_apply(
mlx_vector_array* res,
mlx_closure cls,
const mlx_vector_array input);
mlx_closure mlx_closure_new_unary(int (*fun)(mlx_array*, const mlx_array));
typedef struct mlx_closure_kwargs_ {
void* ctx;
} mlx_closure_kwargs;
mlx_closure_kwargs mlx_closure_kwargs_new(void);
int mlx_closure_kwargs_free(mlx_closure_kwargs cls);
mlx_closure_kwargs mlx_closure_kwargs_new_func(
int (*fun)(
mlx_vector_array*,
const mlx_vector_array,
const mlx_map_string_to_array));
mlx_closure_kwargs mlx_closure_kwargs_new_func_payload(
int (*fun)(
mlx_vector_array*,
const mlx_vector_array,
const mlx_map_string_to_array,
void*),
void* payload,
void (*dtor)(void*));
int mlx_closure_kwargs_set(
mlx_closure_kwargs* cls,
const mlx_closure_kwargs src);
int mlx_closure_kwargs_apply(
mlx_vector_array* res,
mlx_closure_kwargs cls,
const mlx_vector_array input_0,
const mlx_map_string_to_array input_1);
typedef struct mlx_closure_value_and_grad_ {
void* ctx;
} mlx_closure_value_and_grad;
mlx_closure_value_and_grad mlx_closure_value_and_grad_new(void);
int mlx_closure_value_and_grad_free(mlx_closure_value_and_grad cls);
mlx_closure_value_and_grad mlx_closure_value_and_grad_new_func(
int (*fun)(mlx_vector_array*, mlx_vector_array*, const mlx_vector_array));
mlx_closure_value_and_grad mlx_closure_value_and_grad_new_func_payload(
int (*fun)(
mlx_vector_array*,
mlx_vector_array*,
const mlx_vector_array,
void*),
void* payload,
void (*dtor)(void*));
int mlx_closure_value_and_grad_set(
mlx_closure_value_and_grad* cls,
const mlx_closure_value_and_grad src);
int mlx_closure_value_and_grad_apply(
mlx_vector_array* res_0,
mlx_vector_array* res_1,
mlx_closure_value_and_grad cls,
const mlx_vector_array input);
typedef struct mlx_closure_custom_ {
void* ctx;
} mlx_closure_custom;
mlx_closure_custom mlx_closure_custom_new(void);
int mlx_closure_custom_free(mlx_closure_custom cls);
mlx_closure_custom mlx_closure_custom_new_func(
int (*fun)(
mlx_vector_array*,
const mlx_vector_array,
const mlx_vector_array,
const mlx_vector_array));
mlx_closure_custom mlx_closure_custom_new_func_payload(
int (*fun)(
mlx_vector_array*,
const mlx_vector_array,
const mlx_vector_array,
const mlx_vector_array,
void*),
void* payload,
void (*dtor)(void*));
int mlx_closure_custom_set(
mlx_closure_custom* cls,
const mlx_closure_custom src);
int mlx_closure_custom_apply(
mlx_vector_array* res,
mlx_closure_custom cls,
const mlx_vector_array input_0,
const mlx_vector_array input_1,
const mlx_vector_array input_2);
typedef struct mlx_closure_custom_jvp_ {
void* ctx;
} mlx_closure_custom_jvp;
mlx_closure_custom_jvp mlx_closure_custom_jvp_new(void);
int mlx_closure_custom_jvp_free(mlx_closure_custom_jvp cls);
mlx_closure_custom_jvp mlx_closure_custom_jvp_new_func(
int (*fun)(
mlx_vector_array*,
const mlx_vector_array,
const mlx_vector_array,
const int*,
size_t _num));
mlx_closure_custom_jvp mlx_closure_custom_jvp_new_func_payload(
int (*fun)(
mlx_vector_array*,
const mlx_vector_array,
const mlx_vector_array,
const int*,
size_t _num,
void*),
void* payload,
void (*dtor)(void*));
int mlx_closure_custom_jvp_set(
mlx_closure_custom_jvp* cls,
const mlx_closure_custom_jvp src);
int mlx_closure_custom_jvp_apply(
mlx_vector_array* res,
mlx_closure_custom_jvp cls,
const mlx_vector_array input_0,
const mlx_vector_array input_1,
const int* input_2,
size_t input_2_num);
typedef struct mlx_closure_custom_vmap_ {
void* ctx;
} mlx_closure_custom_vmap;
mlx_closure_custom_vmap mlx_closure_custom_vmap_new(void);
int mlx_closure_custom_vmap_free(mlx_closure_custom_vmap cls);
mlx_closure_custom_vmap mlx_closure_custom_vmap_new_func(
int (*fun)(
mlx_vector_array*,
mlx_vector_int*,
const mlx_vector_array,
const int*,
size_t _num));
mlx_closure_custom_vmap mlx_closure_custom_vmap_new_func_payload(
int (*fun)(
mlx_vector_array*,
mlx_vector_int*,
const mlx_vector_array,
const int*,
size_t _num,
void*),
void* payload,
void (*dtor)(void*));
int mlx_closure_custom_vmap_set(
mlx_closure_custom_vmap* cls,
const mlx_closure_custom_vmap src);
int mlx_closure_custom_vmap_apply(
mlx_vector_array* res_0,
mlx_vector_int* res_1,
mlx_closure_custom_vmap cls,
const mlx_vector_array input_0,
const int* input_1,
size_t input_1_num);
/**@}*/
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -0,0 +1,58 @@
/* Copyright © 2023-2024 Apple Inc. */
/* */
/* This file is auto-generated. Do not edit manually. */
/* */
#ifndef MLX_COMPILE_H
#define MLX_COMPILE_H
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include "mlx/c/array.h"
#include "mlx/c/closure.h"
#include "mlx/c/distributed_group.h"
#include "mlx/c/io_types.h"
#include "mlx/c/map.h"
#include "mlx/c/stream.h"
#include "mlx/c/string.h"
#include "mlx/c/vector.h"
#ifdef __cplusplus
extern "C" {
#endif
/**
* \defgroup compile Compilation operations
*/
/**@{*/
typedef enum mlx_compile_mode_ {
MLX_COMPILE_MODE_DISABLED,
MLX_COMPILE_MODE_NO_SIMPLIFY,
MLX_COMPILE_MODE_NO_FUSE,
MLX_COMPILE_MODE_ENABLED
} mlx_compile_mode;
int mlx_compile(mlx_closure* res, const mlx_closure fun, bool shapeless);
int mlx_detail_compile(
mlx_closure* res,
const mlx_closure fun,
uintptr_t fun_id,
bool shapeless,
const uint64_t* constants,
size_t constants_num);
int mlx_detail_compile_clear_cache(void);
int mlx_detail_compile_erase(uintptr_t fun_id);
int mlx_disable_compile(void);
int mlx_enable_compile(void);
int mlx_set_compile_mode(mlx_compile_mode mode);
/**@}*/
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -0,0 +1,39 @@
/* Copyright © 2023-2024 Apple Inc. */
/* */
/* This file is auto-generated. Do not edit manually. */
/* */
#ifndef MLX_CUDA_H
#define MLX_CUDA_H
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include "mlx/c/array.h"
#include "mlx/c/closure.h"
#include "mlx/c/distributed_group.h"
#include "mlx/c/io_types.h"
#include "mlx/c/map.h"
#include "mlx/c/stream.h"
#include "mlx/c/string.h"
#include "mlx/c/vector.h"
#ifdef __cplusplus
extern "C" {
#endif
/**
* \defgroup cuda Cuda specific operations
*/
/**@{*/
int mlx_cuda_is_available(bool* res);
/**@}*/
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -0,0 +1,154 @@
/* Copyright © 2023-2024 Apple Inc. */
#ifndef MLX_DEVICE_H
#define MLX_DEVICE_H
#include <stdbool.h>
#include <stddef.h>
#include "mlx/c/string.h"
#include "mlx/c/vector.h"
#ifdef __cplusplus
extern "C" {
#endif
/**
* \defgroup mlx_device Device
* MLX device object.
*/
/**@{*/
/**
* A MLX device object.
*/
typedef struct mlx_device_ {
void* ctx;
} mlx_device;
/**
* Device type.
*/
typedef enum mlx_device_type_ { MLX_CPU, MLX_GPU } mlx_device_type;
/**
* Returns a new empty device.
*/
mlx_device mlx_device_new(void);
/**
* Returns a new device of specified `type`, with specified `index`.
*/
mlx_device mlx_device_new_type(mlx_device_type type, int index);
/**
* Free a device.
*/
int mlx_device_free(mlx_device dev);
/**
* Set device to provided src device.
*/
int mlx_device_set(mlx_device* dev, const mlx_device src);
/**
* Get device description.
*/
int mlx_device_tostring(mlx_string* str, mlx_device dev);
/**
* Check if devices are the same.
*/
bool mlx_device_equal(mlx_device lhs, mlx_device rhs);
/**
* Returns the index of the device.
*/
int mlx_device_get_index(int* index, mlx_device dev);
/**
* Returns the type of the device.
*/
int mlx_device_get_type(mlx_device_type* type, mlx_device dev);
/**
* Returns the default MLX device.
*/
int mlx_get_default_device(mlx_device* dev);
/**
* Set the default MLX device.
*/
int mlx_set_default_device(mlx_device dev);
/**
* Check if device is available.
*/
int mlx_device_is_available(bool* avail, mlx_device dev);
/**
* Get the number of available devices for a device type.
*/
int mlx_device_count(int* count, mlx_device_type type);
/**
* A MLX device info object.
* Contains key-value pairs with device properties.
* Keys vary by backend but common keys include:
* - device_name (string): Device name
* - architecture (string): Architecture identifier
* Additional keys may be present depending on the backend.
*/
typedef struct mlx_device_info_ {
void* ctx;
} mlx_device_info;
/**
* Returns a new empty device info object.
*/
mlx_device_info mlx_device_info_new(void);
/**
* Get device information for a device.
*/
int mlx_device_info_get(mlx_device_info* info, mlx_device dev);
/**
* Free a device info object.
*/
int mlx_device_info_free(mlx_device_info info);
/**
* Check if a key exists in the device info.
* Returns 0 on success, 1 on error.
* Sets *exists to true if the key exists, false otherwise.
*/
int mlx_device_info_has_key(
bool* exists,
mlx_device_info info,
const char* key);
/**
* Check if a value is a string type.
* Returns 0 on success, 1 on error.
* Sets *is_string to true if the value is a string, false if it's a size_t.
*/
int mlx_device_info_is_string(
bool* is_string,
mlx_device_info info,
const char* key);
/**
* Get a string value from device info.
* Returns 0 on success, 1 on error, 2 if key not found or wrong type.
*/
int mlx_device_info_get_string(
const char** value,
mlx_device_info info,
const char* key);
/**
* Get a size_t value from device info.
* Returns 0 on success, 1 on error, 2 if key not found or wrong type.
*/
int mlx_device_info_get_size(
size_t* value,
mlx_device_info info,
const char* key);
/**
* Get all keys from device info.
* Returns 0 on success, 1 on error.
*/
int mlx_device_info_get_keys(mlx_vector_string* keys, mlx_device_info info);
/**@}*/
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -0,0 +1,83 @@
/* Copyright © 2023-2024 Apple Inc. */
/* */
/* This file is auto-generated. Do not edit manually. */
/* */
#ifndef MLX_DISTRIBUTED_H
#define MLX_DISTRIBUTED_H
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include "mlx/c/array.h"
#include "mlx/c/closure.h"
#include "mlx/c/distributed_group.h"
#include "mlx/c/io_types.h"
#include "mlx/c/map.h"
#include "mlx/c/stream.h"
#include "mlx/c/string.h"
#include "mlx/c/vector.h"
#ifdef __cplusplus
extern "C" {
#endif
/**
* \defgroup distributed Distributed collectives
*/
/**@{*/
int mlx_distributed_all_gather(
mlx_array* res,
const mlx_array x,
const mlx_distributed_group group /* may be null */,
const mlx_stream S);
int mlx_distributed_all_max(
mlx_array* res,
const mlx_array x,
const mlx_distributed_group group /* may be null */,
const mlx_stream s);
int mlx_distributed_all_min(
mlx_array* res,
const mlx_array x,
const mlx_distributed_group group /* may be null */,
const mlx_stream s);
int mlx_distributed_all_sum(
mlx_array* res,
const mlx_array x,
const mlx_distributed_group group /* may be null */,
const mlx_stream s);
int mlx_distributed_recv(
mlx_array* res,
const int* shape,
size_t shape_num,
mlx_dtype dtype,
int src,
const mlx_distributed_group group /* may be null */,
const mlx_stream s);
int mlx_distributed_recv_like(
mlx_array* res,
const mlx_array x,
int src,
const mlx_distributed_group group /* may be null */,
const mlx_stream s);
int mlx_distributed_send(
mlx_array* res,
const mlx_array x,
int dst,
const mlx_distributed_group group /* may be null */,
const mlx_stream s);
int mlx_distributed_sum_scatter(
mlx_array* res,
const mlx_array x,
const mlx_distributed_group group /* may be null */,
const mlx_stream s);
/**@}*/
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -0,0 +1,74 @@
/* Copyright © 2023-2024 Apple Inc. */
#ifndef MLX_DISTRIBUTED_GROUP_H
#define MLX_DISTRIBUTED_GROUP_H
#include <stdbool.h>
#include "mlx/c/stream.h"
#ifdef __cplusplus
extern "C" {
#endif
/**
* \defgroup mlx_distributed_group MLX distributed
*/
/**@{*/
/**
* A MLX distributed group object.
*/
typedef struct mlx_distributed_group_ {
void* ctx;
} mlx_distributed_group;
/**
* Create an empty group.
*/
mlx_distributed_group mlx_distributed_group_new(void);
/**
* Free the group.
*/
int mlx_distributed_group_free(mlx_distributed_group group);
/**
* Initialize distributed.
*/
int mlx_distributed_init(
mlx_distributed_group* res,
bool strict,
const char* bk /* may be null */);
/**
* Get the rank.
*/
int mlx_distributed_group_rank(mlx_distributed_group group);
/**
* Get the group size.
*/
int mlx_distributed_group_size(mlx_distributed_group group);
/**
* Split the group.
*/
int mlx_distributed_group_split(
mlx_distributed_group* res,
mlx_distributed_group group,
int color,
int key);
/**
* Check if distributed is available.
*/
bool mlx_distributed_is_available(const char* bk /* may be null */);
/**@}*/
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -0,0 +1,41 @@
/* Copyright © 2023-2024 Apple Inc. */
#ifndef MLX_ERROR_H
#define MLX_ERROR_H
#ifdef __cplusplus
extern "C" {
#endif
/**
* \defgroup mlx_error Error management
*/
/**@{*/
typedef void (*mlx_error_handler_func)(const char* msg, void* data);
/**
* Set the error handler.
*/
void mlx_set_error_handler(
mlx_error_handler_func handler,
void* data,
void (*dtor)(void*));
/**
* Throw an error.
*/
void _mlx_error(const char* file, const int line, const char* fmt, ...);
/**
* Throw an error. Macro which passes file name and line number to _mlx_error().
*/
#define mlx_error(...) _mlx_error(__FILE__, __LINE__, __VA_ARGS__)
/**@}*/
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -0,0 +1,75 @@
/* Copyright © 2023-2025 Apple Inc. */
#ifndef MLX_EXPORT_H
#define MLX_EXPORT_H
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include "mlx/c/array.h"
#include "mlx/c/closure.h"
#include "mlx/c/distributed_group.h"
#include "mlx/c/io_types.h"
#include "mlx/c/map.h"
#include "mlx/c/stream.h"
#include "mlx/c/string.h"
#include "mlx/c/vector.h"
#ifdef __cplusplus
extern "C" {
#endif
/**
* \defgroup export Function serialization
*/
/**@{*/
int mlx_export_function(
const char* file,
const mlx_closure fun,
const mlx_vector_array args,
bool shapeless);
int mlx_export_function_kwargs(
const char* file,
const mlx_closure_kwargs fun,
const mlx_vector_array args,
const mlx_map_string_to_array kwargs,
bool shapeless);
typedef struct mlx_function_exporter_ {
void* ctx;
} mlx_function_exporter;
mlx_function_exporter mlx_function_exporter_new(
const char* file,
const mlx_closure fun,
bool shapeless);
int mlx_function_exporter_free(mlx_function_exporter xfunc);
int mlx_function_exporter_apply(
const mlx_function_exporter xfunc,
const mlx_vector_array args);
int mlx_function_exporter_apply_kwargs(
const mlx_function_exporter xfunc,
const mlx_vector_array args,
const mlx_map_string_to_array kwargs);
typedef struct mlx_imported_function_ {
void* ctx;
} mlx_imported_function;
mlx_imported_function mlx_imported_function_new(const char* file);
int mlx_imported_function_free(mlx_imported_function xfunc);
int mlx_imported_function_apply(
mlx_vector_array* res,
const mlx_imported_function xfunc,
const mlx_vector_array args);
int mlx_imported_function_apply_kwargs(
mlx_vector_array* res,
const mlx_imported_function xfunc,
const mlx_vector_array args,
const mlx_map_string_to_array kwargs);
/**@}*/
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -0,0 +1,206 @@
/* Copyright © 2023-2024 Apple Inc. */
/* */
/* This file is auto-generated. Do not edit manually. */
/* */
#ifndef MLX_FAST_H
#define MLX_FAST_H
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include "mlx/c/array.h"
#include "mlx/c/closure.h"
#include "mlx/c/distributed_group.h"
#include "mlx/c/io_types.h"
#include "mlx/c/map.h"
#include "mlx/c/stream.h"
#include "mlx/c/string.h"
#include "mlx/c/vector.h"
#ifdef __cplusplus
extern "C" {
#endif
/**
* \defgroup fast Fast custom operations
*/
/**@{*/
typedef struct mlx_fast_cuda_kernel_config_ {
void* ctx;
} mlx_fast_cuda_kernel_config;
mlx_fast_cuda_kernel_config mlx_fast_cuda_kernel_config_new(void);
void mlx_fast_cuda_kernel_config_free(mlx_fast_cuda_kernel_config cls);
int mlx_fast_cuda_kernel_config_add_output_arg(
mlx_fast_cuda_kernel_config cls,
const int* shape,
size_t size,
mlx_dtype dtype);
int mlx_fast_cuda_kernel_config_set_grid(
mlx_fast_cuda_kernel_config cls,
int grid1,
int grid2,
int grid3);
int mlx_fast_cuda_kernel_config_set_thread_group(
mlx_fast_cuda_kernel_config cls,
int thread1,
int thread2,
int thread3);
int mlx_fast_cuda_kernel_config_set_init_value(
mlx_fast_cuda_kernel_config cls,
float value);
int mlx_fast_cuda_kernel_config_set_verbose(
mlx_fast_cuda_kernel_config cls,
bool verbose);
int mlx_fast_cuda_kernel_config_add_template_arg_dtype(
mlx_fast_cuda_kernel_config cls,
const char* name,
mlx_dtype dtype);
int mlx_fast_cuda_kernel_config_add_template_arg_int(
mlx_fast_cuda_kernel_config cls,
const char* name,
int value);
int mlx_fast_cuda_kernel_config_add_template_arg_bool(
mlx_fast_cuda_kernel_config cls,
const char* name,
bool value);
typedef struct mlx_fast_cuda_kernel_ {
void* ctx;
} mlx_fast_cuda_kernel;
mlx_fast_cuda_kernel mlx_fast_cuda_kernel_new(
const char* name,
const mlx_vector_string input_names,
const mlx_vector_string output_names,
const char* source,
const char* header,
bool ensure_row_contiguous,
int shared_memory);
void mlx_fast_cuda_kernel_free(mlx_fast_cuda_kernel cls);
int mlx_fast_cuda_kernel_apply(
mlx_vector_array* outputs,
mlx_fast_cuda_kernel cls,
const mlx_vector_array inputs,
const mlx_fast_cuda_kernel_config config,
const mlx_stream stream);
int mlx_fast_layer_norm(
mlx_array* res,
const mlx_array x,
const mlx_array weight /* may be null */,
const mlx_array bias /* may be null */,
float eps,
const mlx_stream s);
typedef struct mlx_fast_metal_kernel_config_ {
void* ctx;
} mlx_fast_metal_kernel_config;
mlx_fast_metal_kernel_config mlx_fast_metal_kernel_config_new(void);
void mlx_fast_metal_kernel_config_free(mlx_fast_metal_kernel_config cls);
int mlx_fast_metal_kernel_config_add_output_arg(
mlx_fast_metal_kernel_config cls,
const int* shape,
size_t size,
mlx_dtype dtype);
int mlx_fast_metal_kernel_config_set_grid(
mlx_fast_metal_kernel_config cls,
int grid1,
int grid2,
int grid3);
int mlx_fast_metal_kernel_config_set_thread_group(
mlx_fast_metal_kernel_config cls,
int thread1,
int thread2,
int thread3);
int mlx_fast_metal_kernel_config_set_init_value(
mlx_fast_metal_kernel_config cls,
float value);
int mlx_fast_metal_kernel_config_set_verbose(
mlx_fast_metal_kernel_config cls,
bool verbose);
int mlx_fast_metal_kernel_config_add_template_arg_dtype(
mlx_fast_metal_kernel_config cls,
const char* name,
mlx_dtype dtype);
int mlx_fast_metal_kernel_config_add_template_arg_int(
mlx_fast_metal_kernel_config cls,
const char* name,
int value);
int mlx_fast_metal_kernel_config_add_template_arg_bool(
mlx_fast_metal_kernel_config cls,
const char* name,
bool value);
typedef struct mlx_fast_metal_kernel_ {
void* ctx;
} mlx_fast_metal_kernel;
mlx_fast_metal_kernel mlx_fast_metal_kernel_new(
const char* name,
const mlx_vector_string input_names,
const mlx_vector_string output_names,
const char* source,
const char* header,
bool ensure_row_contiguous,
bool atomic_outputs);
void mlx_fast_metal_kernel_free(mlx_fast_metal_kernel cls);
int mlx_fast_metal_kernel_apply(
mlx_vector_array* outputs,
mlx_fast_metal_kernel cls,
const mlx_vector_array inputs,
const mlx_fast_metal_kernel_config config,
const mlx_stream stream);
int mlx_fast_rms_norm(
mlx_array* res,
const mlx_array x,
const mlx_array weight /* may be null */,
float eps,
const mlx_stream s);
int mlx_fast_rope(
mlx_array* res,
const mlx_array x,
int dims,
bool traditional,
mlx_optional_float base,
float scale,
int offset,
const mlx_array freqs /* may be null */,
const mlx_stream s);
int mlx_fast_rope_dynamic(
mlx_array* res,
const mlx_array x,
int dims,
bool traditional,
mlx_optional_float base,
float scale,
const mlx_array offset,
const mlx_array freqs /* may be null */,
const mlx_stream s);
int mlx_fast_scaled_dot_product_attention(
mlx_array* res,
const mlx_array queries,
const mlx_array keys,
const mlx_array values,
float scale,
const char* mask_mode,
const mlx_array mask_arr /* may be null */,
const mlx_array sinks /* may be null */,
const mlx_stream s);
/**@}*/
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -0,0 +1,158 @@
/* Copyright © 2023-2024 Apple Inc. */
/* */
/* This file is auto-generated. Do not edit manually. */
/* */
#ifndef MLX_FFT_H
#define MLX_FFT_H
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include "mlx/c/array.h"
#include "mlx/c/closure.h"
#include "mlx/c/distributed_group.h"
#include "mlx/c/io_types.h"
#include "mlx/c/map.h"
#include "mlx/c/stream.h"
#include "mlx/c/string.h"
#include "mlx/c/vector.h"
#ifdef __cplusplus
extern "C" {
#endif
/**
* \defgroup fft FFT operations
*/
/**@{*/
typedef enum mlx_fft_norm_ {
MLX_FFT_NORM_BACKWARD,
MLX_FFT_NORM_ORTHO,
MLX_FFT_NORM_FORWARD
} mlx_fft_norm;
int mlx_fft_fft(
mlx_array* res,
const mlx_array a,
int n,
int axis,
mlx_fft_norm norm,
const mlx_stream s);
int mlx_fft_fft2(
mlx_array* res,
const mlx_array a,
const int* n,
size_t n_num,
const int* axes,
size_t axes_num,
mlx_fft_norm norm,
const mlx_stream s);
int mlx_fft_fftfreq(mlx_array* res, int n, double d, const mlx_stream s);
int mlx_fft_fftn(
mlx_array* res,
const mlx_array a,
const int* n,
size_t n_num,
const int* axes,
size_t axes_num,
mlx_fft_norm norm,
const mlx_stream s);
int mlx_fft_fftshift(
mlx_array* res,
const mlx_array a,
const int* axes,
size_t axes_num,
const mlx_stream s);
int mlx_fft_ifft(
mlx_array* res,
const mlx_array a,
int n,
int axis,
mlx_fft_norm norm,
const mlx_stream s);
int mlx_fft_ifft2(
mlx_array* res,
const mlx_array a,
const int* n,
size_t n_num,
const int* axes,
size_t axes_num,
mlx_fft_norm norm,
const mlx_stream s);
int mlx_fft_ifftn(
mlx_array* res,
const mlx_array a,
const int* n,
size_t n_num,
const int* axes,
size_t axes_num,
mlx_fft_norm norm,
const mlx_stream s);
int mlx_fft_ifftshift(
mlx_array* res,
const mlx_array a,
const int* axes,
size_t axes_num,
const mlx_stream s);
int mlx_fft_irfft(
mlx_array* res,
const mlx_array a,
int n,
int axis,
mlx_fft_norm norm,
const mlx_stream s);
int mlx_fft_irfft2(
mlx_array* res,
const mlx_array a,
const int* n,
size_t n_num,
const int* axes,
size_t axes_num,
mlx_fft_norm norm,
const mlx_stream s);
int mlx_fft_irfftn(
mlx_array* res,
const mlx_array a,
const int* n,
size_t n_num,
const int* axes,
size_t axes_num,
mlx_fft_norm norm,
const mlx_stream s);
int mlx_fft_rfft(
mlx_array* res,
const mlx_array a,
int n,
int axis,
mlx_fft_norm norm,
const mlx_stream s);
int mlx_fft_rfft2(
mlx_array* res,
const mlx_array a,
const int* n,
size_t n_num,
const int* axes,
size_t axes_num,
mlx_fft_norm norm,
const mlx_stream s);
int mlx_fft_rfftfreq(mlx_array* res, int n, double d, const mlx_stream s);
int mlx_fft_rfftn(
mlx_array* res,
const mlx_array a,
const int* n,
size_t n_num,
const int* axes,
size_t axes_num,
mlx_fft_norm norm,
const mlx_stream s);
/**@}*/
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -0,0 +1,61 @@
/* Copyright © 2023-2024 Apple Inc. */
/* */
/* This file is auto-generated. Do not edit manually. */
/* */
#ifndef MLX_GRAPH_UTILS_H
#define MLX_GRAPH_UTILS_H
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include "mlx/c/array.h"
#include "mlx/c/closure.h"
#include "mlx/c/distributed_group.h"
#include "mlx/c/io_types.h"
#include "mlx/c/map.h"
#include "mlx/c/stream.h"
#include "mlx/c/string.h"
#include "mlx/c/vector.h"
#ifdef __cplusplus
extern "C" {
#endif
/**
* \defgroup graph_utils Graph Utils
*/
/**@{*/
typedef struct mlx_node_namer_ {
void* ctx;
} mlx_node_namer;
mlx_node_namer mlx_node_namer_new();
int mlx_node_namer_free(mlx_node_namer namer);
int mlx_node_namer_set_name(
mlx_node_namer namer,
const mlx_array arr,
const char* name);
int mlx_node_namer_get_name(
const char** name,
mlx_node_namer namer,
const mlx_array arr);
int mlx_export_to_dot(
FILE* os,
const mlx_node_namer namer,
const mlx_vector_array outputs);
int mlx_print_graph(
FILE* os,
const mlx_node_namer namer,
const mlx_vector_array outputs);
/**@}*/
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -0,0 +1,26 @@
/* Copyright © 2023-2024 Apple Inc. */
#ifndef MLX_HALF_H
#define MLX_HALF_H
#ifdef __cplusplus
extern "C" {
#endif
#if defined(__ARM_FEATURE_FP16_SCALAR_ARITHMETIC) || defined(__aarch64__)
#define HAS_FLOAT16
#include <arm_fp16.h>
typedef __fp16 float16_t;
#endif
#if defined(__ARM_FEATURE_BF16) || defined(__aarch64__)
#define HAS_BFLOAT16
#include <arm_bf16.h>
typedef __bf16 bfloat16_t;
#endif
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -0,0 +1,68 @@
/* Copyright © 2023-2024 Apple Inc. */
/* */
/* This file is auto-generated. Do not edit manually. */
/* */
#ifndef MLX_IO_H
#define MLX_IO_H
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include "mlx/c/array.h"
#include "mlx/c/closure.h"
#include "mlx/c/distributed_group.h"
#include "mlx/c/io_types.h"
#include "mlx/c/map.h"
#include "mlx/c/stream.h"
#include "mlx/c/string.h"
#include "mlx/c/vector.h"
#ifdef __cplusplus
extern "C" {
#endif
/**
* \defgroup io IO operations
*/
/**@{*/
int mlx_load_reader(
mlx_array* res,
mlx_io_reader in_stream,
const mlx_stream s);
int mlx_load(mlx_array* res, const char* file, const mlx_stream s);
int mlx_load_gguf(mlx_io_gguf* gguf, const char* file, const mlx_stream s);
int mlx_load_safetensors_reader(
mlx_map_string_to_array* res_0,
mlx_map_string_to_string* res_1,
mlx_io_reader in_stream,
const mlx_stream s);
int mlx_load_safetensors(
mlx_map_string_to_array* res_0,
mlx_map_string_to_string* res_1,
const char* file,
const mlx_stream s);
int mlx_save_writer(mlx_io_writer out_stream, const mlx_array a);
int mlx_save(const char* file, const mlx_array a);
int mlx_save_gguf(const char* file, mlx_io_gguf gguf);
int mlx_save_safetensors_writer(
mlx_io_writer in_stream,
const mlx_map_string_to_array param,
const mlx_map_string_to_string metadata);
int mlx_save_safetensors(
const char* file,
const mlx_map_string_to_array param,
const mlx_map_string_to_string metadata);
/**@}*/
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -0,0 +1,150 @@
/* Copyright © 2023-2024 Apple Inc. */
#ifndef MLX_IO_TYPES_H
#define MLX_IO_TYPES_H
#include <stdbool.h>
#include "mlx/c/string.h"
#ifdef __cplusplus
extern "C" {
#endif
/**
* \defgroup mlx_io_types IO Types
* MLX IO type objects.
*/
/**@{*/
/**
* A MLX IO reader object.
*/
typedef struct mlx_io_reader_ {
void* ctx;
} mlx_io_reader;
/**
* A MLX IO writer object.
*/
typedef struct mlx_io_writer_ {
void* ctx;
} mlx_io_writer;
/**
* Virtual table for custom IO reader and writer objects.
*/
typedef struct mlx_io_vtable_ {
bool (*is_open)(void*);
bool (*good)(void*);
size_t (*tell)(void*);
void (*seek)(void*, int64_t off, int whence);
void (*read)(void*, char* data, size_t n);
void (*read_at_offset)(void*, char* data, size_t n, size_t off);
void (*write)(void*, const char* data, size_t n);
const char* (*label)(void*);
void (*free)(void*);
} mlx_io_vtable;
/**
* Returns a new custom IO reader.
* `vtable` operates on user descriptor `desc`.
*/
mlx_io_reader mlx_io_reader_new(void* desc, mlx_io_vtable vtable);
/**
* Get IO reader user descriptor.
*/
int mlx_io_reader_descriptor(void** desc_, mlx_io_reader io);
/**
* Get IO reader description.
*/
int mlx_io_reader_tostring(mlx_string* str_, mlx_io_reader io);
/**
* Free IO reader.
*
* Note that MLX arrays are lazily evaluated, so the underlying object may
* be not freed right away. The ``free()`` callback from ``mlx_io_vtable``
* will be called when the underlying object is actually freed.
*/
int mlx_io_reader_free(mlx_io_reader io);
/**
* Returns a new custom IO writer.
* `vtable` operates on user descriptor `desc`.
*/
mlx_io_writer mlx_io_writer_new(void* desc, mlx_io_vtable vtable);
/**
* Get IO writer user descriptor.
*/
int mlx_io_writer_descriptor(void** desc_, mlx_io_writer io);
/**
* Get IO writer description.
*/
int mlx_io_writer_tostring(mlx_string* str_, mlx_io_writer io);
/**
* Free IO writer.
*
* Note that MLX arrays are lazily evaluated, so the underlying object may
* be not freed right away. The ``free()`` callback from ``mlx_io_vtable``
* will be called when the underlying object is actually freed.
*/
int mlx_io_writer_free(mlx_io_writer io);
/**
* A MLX GGUF object.
*/
typedef struct mlx_io_gguf_ {
void* ctx;
} mlx_io_gguf;
mlx_io_gguf mlx_io_gguf_new(void);
int mlx_io_gguf_free(mlx_io_gguf io);
int mlx_io_gguf_get_keys(mlx_vector_string* keys, mlx_io_gguf io);
int mlx_io_gguf_get_array(mlx_array* arr, mlx_io_gguf io, const char* key);
int mlx_io_gguf_get_metadata_array(
mlx_array* arr,
mlx_io_gguf io,
const char* key);
int mlx_io_gguf_get_metadata_string(
mlx_string* str,
mlx_io_gguf io,
const char* key);
int mlx_io_gguf_get_metadata_vector_string(
mlx_vector_string* vstr,
mlx_io_gguf io,
const char* key);
int mlx_io_gguf_has_metadata_array(bool* flag, mlx_io_gguf io, const char* key);
int mlx_io_gguf_has_metadata_string(
bool* flag,
mlx_io_gguf io,
const char* key);
int mlx_io_gguf_has_metadata_vector_string(
bool* flag,
mlx_io_gguf io,
const char* key);
int mlx_io_gguf_set_array(mlx_io_gguf io, const char* key, const mlx_array arr);
int mlx_io_gguf_set_metadata_array(
mlx_io_gguf io,
const char* key,
const mlx_array marr);
int mlx_io_gguf_set_metadata_string(
mlx_io_gguf io,
const char* key,
const char* mstr);
int mlx_io_gguf_set_metadata_vector_string(
mlx_io_gguf io,
const char* key,
const mlx_vector_string mvstr);
/**@}*/
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -0,0 +1,128 @@
/* Copyright © 2023-2024 Apple Inc. */
/* */
/* This file is auto-generated. Do not edit manually. */
/* */
#ifndef MLX_LINALG_H
#define MLX_LINALG_H
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include "mlx/c/array.h"
#include "mlx/c/closure.h"
#include "mlx/c/distributed_group.h"
#include "mlx/c/io_types.h"
#include "mlx/c/map.h"
#include "mlx/c/stream.h"
#include "mlx/c/string.h"
#include "mlx/c/vector.h"
#ifdef __cplusplus
extern "C" {
#endif
/**
* \defgroup linalg Linear algebra operations
*/
/**@{*/
int mlx_linalg_cholesky(
mlx_array* res,
const mlx_array a,
bool upper,
const mlx_stream s);
int mlx_linalg_cholesky_inv(
mlx_array* res,
const mlx_array a,
bool upper,
const mlx_stream s);
int mlx_linalg_cross(
mlx_array* res,
const mlx_array a,
const mlx_array b,
int axis,
const mlx_stream s);
int mlx_linalg_eig(
mlx_array* res_0,
mlx_array* res_1,
const mlx_array a,
const mlx_stream s);
int mlx_linalg_eigh(
mlx_array* res_0,
mlx_array* res_1,
const mlx_array a,
const char* UPLO,
const mlx_stream s);
int mlx_linalg_eigvals(mlx_array* res, const mlx_array a, const mlx_stream s);
int mlx_linalg_eigvalsh(
mlx_array* res,
const mlx_array a,
const char* UPLO,
const mlx_stream s);
int mlx_linalg_inv(mlx_array* res, const mlx_array a, const mlx_stream s);
int mlx_linalg_lu(mlx_vector_array* res, const mlx_array a, const mlx_stream s);
int mlx_linalg_lu_factor(
mlx_array* res_0,
mlx_array* res_1,
const mlx_array a,
const mlx_stream s);
int mlx_linalg_norm(
mlx_array* res,
const mlx_array a,
double ord,
const int* axis /* may be null */,
size_t axis_num,
bool keepdims,
const mlx_stream s);
int mlx_linalg_norm_matrix(
mlx_array* res,
const mlx_array a,
const char* ord,
const int* axis /* may be null */,
size_t axis_num,
bool keepdims,
const mlx_stream s);
int mlx_linalg_norm_l2(
mlx_array* res,
const mlx_array a,
const int* axis /* may be null */,
size_t axis_num,
bool keepdims,
const mlx_stream s);
int mlx_linalg_pinv(mlx_array* res, const mlx_array a, const mlx_stream s);
int mlx_linalg_qr(
mlx_array* res_0,
mlx_array* res_1,
const mlx_array a,
const mlx_stream s);
int mlx_linalg_solve(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const mlx_stream s);
int mlx_linalg_solve_triangular(
mlx_array* res,
const mlx_array a,
const mlx_array b,
bool upper,
const mlx_stream s);
int mlx_linalg_svd(
mlx_vector_array* res,
const mlx_array a,
bool compute_uv,
const mlx_stream s);
int mlx_linalg_tri_inv(
mlx_array* res,
const mlx_array a,
bool upper,
const mlx_stream s);
/**@}*/
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -0,0 +1,149 @@
/* Copyright © 2023-2024 Apple Inc. */
/* */
/* This file is auto-generated. Do not edit manually. */
/* */
#ifndef MLX_MAP_H
#define MLX_MAP_H
#include "mlx/c/array.h"
#include "mlx/c/string.h"
#ifdef __cplusplus
extern "C" {
#endif
/**
* \defgroup mlx_map Maps
* MLX map objects.
*/
/**@{*/
/**
* A string-to-array map
*/
typedef struct mlx_map_string_to_array_ {
void* ctx;
} mlx_map_string_to_array;
/**
* Returns a new empty string-to-array map.
*/
mlx_map_string_to_array mlx_map_string_to_array_new(void);
/**
* Set map to provided src map.
*/
int mlx_map_string_to_array_set(
mlx_map_string_to_array* map,
const mlx_map_string_to_array src);
/**
* Free a string-to-array map.
*/
int mlx_map_string_to_array_free(mlx_map_string_to_array map);
/**
* Insert a new `value` at the specified `key` in the map.
*/
int mlx_map_string_to_array_insert(
mlx_map_string_to_array map,
const char* key,
const mlx_array value);
/**
* Returns the value indexed at the specified `key` in the map.
*/
int mlx_map_string_to_array_get(
mlx_array* value,
const mlx_map_string_to_array map,
const char* key);
/**
* An iterator over a string-to-array map.
*/
typedef struct mlx_map_string_to_array_iterator_ {
void* ctx;
void* map_ctx;
} mlx_map_string_to_array_iterator;
/**
* Returns a new iterator over the given map.
*/
mlx_map_string_to_array_iterator mlx_map_string_to_array_iterator_new(
mlx_map_string_to_array map);
/**
* Free iterator.
*/
int mlx_map_string_to_array_iterator_free(mlx_map_string_to_array_iterator it);
/**
* Increment iterator.
*/
int mlx_map_string_to_array_iterator_next(
const char** key,
mlx_array* value,
mlx_map_string_to_array_iterator it);
/**
* A string-to-string map
*/
typedef struct mlx_map_string_to_string_ {
void* ctx;
} mlx_map_string_to_string;
/**
* Returns a new empty string-to-string map.
*/
mlx_map_string_to_string mlx_map_string_to_string_new(void);
/**
* Set map to provided src map.
*/
int mlx_map_string_to_string_set(
mlx_map_string_to_string* map,
const mlx_map_string_to_string src);
/**
* Free a string-to-string map.
*/
int mlx_map_string_to_string_free(mlx_map_string_to_string map);
/**
* Insert a new `value` at the specified `key` in the map.
*/
int mlx_map_string_to_string_insert(
mlx_map_string_to_string map,
const char* key,
const char* value);
/**
* Returns the value indexed at the specified `key` in the map.
*/
int mlx_map_string_to_string_get(
const char** value,
const mlx_map_string_to_string map,
const char* key);
/**
* An iterator over a string-to-string map.
*/
typedef struct mlx_map_string_to_string_iterator_ {
void* ctx;
void* map_ctx;
} mlx_map_string_to_string_iterator;
/**
* Returns a new iterator over the given map.
*/
mlx_map_string_to_string_iterator mlx_map_string_to_string_iterator_new(
mlx_map_string_to_string map);
/**
* Free iterator.
*/
int mlx_map_string_to_string_iterator_free(
mlx_map_string_to_string_iterator it);
/**
* Increment iterator.
*/
int mlx_map_string_to_string_iterator_next(
const char** key,
const char** value,
mlx_map_string_to_string_iterator it);
/**@}*/
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -0,0 +1,47 @@
/* Copyright © 2023-2024 Apple Inc. */
/* */
/* This file is auto-generated. Do not edit manually. */
/* */
#ifndef MLX_MEMORY_H
#define MLX_MEMORY_H
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include "mlx/c/array.h"
#include "mlx/c/closure.h"
#include "mlx/c/distributed_group.h"
#include "mlx/c/io_types.h"
#include "mlx/c/map.h"
#include "mlx/c/stream.h"
#include "mlx/c/string.h"
#include "mlx/c/vector.h"
#ifdef __cplusplus
extern "C" {
#endif
/**
* \defgroup memory Memory operations
*/
/**@{*/
int mlx_clear_cache(void);
int mlx_get_active_memory(size_t* res);
int mlx_get_cache_memory(size_t* res);
int mlx_get_memory_limit(size_t* res);
int mlx_get_peak_memory(size_t* res);
int mlx_reset_peak_memory(void);
int mlx_set_cache_limit(size_t* res, size_t limit);
int mlx_set_memory_limit(size_t* res, size_t limit);
int mlx_set_wired_limit(size_t* res, size_t limit);
/**@}*/
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -0,0 +1,41 @@
/* Copyright © 2023-2024 Apple Inc. */
/* */
/* This file is auto-generated. Do not edit manually. */
/* */
#ifndef MLX_METAL_H
#define MLX_METAL_H
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include "mlx/c/array.h"
#include "mlx/c/closure.h"
#include "mlx/c/distributed_group.h"
#include "mlx/c/io_types.h"
#include "mlx/c/map.h"
#include "mlx/c/stream.h"
#include "mlx/c/string.h"
#include "mlx/c/vector.h"
#ifdef __cplusplus
extern "C" {
#endif
/**
* \defgroup metal Metal specific operations
*/
/**@{*/
int mlx_metal_is_available(bool* res);
int mlx_metal_start_capture(const char* path);
int mlx_metal_stop_capture(void);
/**@}*/
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -0,0 +1,35 @@
/* Copyright © 2023-2024 Apple Inc. */
#ifndef MLX_ALL_H
#define MLX_ALL_H
#include "mlx/c/array.h"
#include "mlx/c/closure.h"
#include "mlx/c/compile.h"
#include "mlx/c/cuda.h"
#include "mlx/c/device.h"
#include "mlx/c/distributed.h"
#include "mlx/c/distributed_group.h"
#include "mlx/c/error.h"
#include "mlx/c/export.h"
#include "mlx/c/fast.h"
#include "mlx/c/fft.h"
#include "mlx/c/graph_utils.h"
#include "mlx/c/half.h"
#include "mlx/c/io.h"
#include "mlx/c/io_types.h"
#include "mlx/c/linalg.h"
#include "mlx/c/map.h"
#include "mlx/c/memory.h"
#include "mlx/c/metal.h"
#include "mlx/c/ops.h"
#include "mlx/c/optional.h"
#include "mlx/c/random.h"
#include "mlx/c/stream.h"
#include "mlx/c/string.h"
#include "mlx/c/transforms.h"
#include "mlx/c/transforms_impl.h"
#include "mlx/c/vector.h"
#include "mlx/c/version.h"
#endif

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,51 @@
/* Copyright © 2023-2024 Apple Inc. */
#ifndef MLX_OPTIONAL_H
#define MLX_OPTIONAL_H
#include <stdbool.h>
#include "mlx/c/array.h"
#include "mlx/c/string.h"
#ifdef __cplusplus
extern "C" {
#endif
/**
* \defgroup mlx_optional Optionals
* MLX optional scalars.
*/
/**@{*/
/**
* A int optional.
*/
typedef struct mlx_optional_int_ {
int value;
bool has_value;
} mlx_optional_int;
/**
* A float optional.
*/
typedef struct mlx_optional_float_ {
float value;
bool has_value;
} mlx_optional_float;
/**
* A dtype optional.
*/
typedef struct mlx_optional_dtype_ {
mlx_dtype value;
bool has_value;
} mlx_optional_dtype;
/**@}*/
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -0,0 +1,166 @@
/* Copyright © 2023-2024 Apple Inc. */
/* */
/* This file is auto-generated. Do not edit manually. */
/* */
#ifndef MLX_RANDOM_H
#define MLX_RANDOM_H
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include "mlx/c/array.h"
#include "mlx/c/closure.h"
#include "mlx/c/distributed_group.h"
#include "mlx/c/io_types.h"
#include "mlx/c/map.h"
#include "mlx/c/stream.h"
#include "mlx/c/string.h"
#include "mlx/c/vector.h"
#ifdef __cplusplus
extern "C" {
#endif
/**
* \defgroup random Random number operations
*/
/**@{*/
int mlx_random_bernoulli(
mlx_array* res,
const mlx_array p,
const int* shape,
size_t shape_num,
const mlx_array key /* may be null */,
const mlx_stream s);
int mlx_random_bits(
mlx_array* res,
const int* shape,
size_t shape_num,
int width,
const mlx_array key /* may be null */,
const mlx_stream s);
int mlx_random_categorical_shape(
mlx_array* res,
const mlx_array logits,
int axis,
const int* shape,
size_t shape_num,
const mlx_array key /* may be null */,
const mlx_stream s);
int mlx_random_categorical_num_samples(
mlx_array* res,
const mlx_array logits_,
int axis,
int num_samples,
const mlx_array key /* may be null */,
const mlx_stream s);
int mlx_random_categorical(
mlx_array* res,
const mlx_array logits,
int axis,
const mlx_array key /* may be null */,
const mlx_stream s);
int mlx_random_gumbel(
mlx_array* res,
const int* shape,
size_t shape_num,
mlx_dtype dtype,
const mlx_array key /* may be null */,
const mlx_stream s);
int mlx_random_key(mlx_array* res, uint64_t seed);
int mlx_random_laplace(
mlx_array* res,
const int* shape,
size_t shape_num,
mlx_dtype dtype,
float loc,
float scale,
const mlx_array key /* may be null */,
const mlx_stream s);
int mlx_random_multivariate_normal(
mlx_array* res,
const mlx_array mean,
const mlx_array cov,
const int* shape,
size_t shape_num,
mlx_dtype dtype,
const mlx_array key /* may be null */,
const mlx_stream s);
int mlx_random_normal_broadcast(
mlx_array* res,
const int* shape,
size_t shape_num,
mlx_dtype dtype,
const mlx_array loc /* may be null */,
const mlx_array scale /* may be null */,
const mlx_array key /* may be null */,
const mlx_stream s);
int mlx_random_normal(
mlx_array* res,
const int* shape,
size_t shape_num,
mlx_dtype dtype,
float loc,
float scale,
const mlx_array key /* may be null */,
const mlx_stream s);
int mlx_random_permutation(
mlx_array* res,
const mlx_array x,
int axis,
const mlx_array key /* may be null */,
const mlx_stream s);
int mlx_random_permutation_arange(
mlx_array* res,
int x,
const mlx_array key /* may be null */,
const mlx_stream s);
int mlx_random_randint(
mlx_array* res,
const mlx_array low,
const mlx_array high,
const int* shape,
size_t shape_num,
mlx_dtype dtype,
const mlx_array key /* may be null */,
const mlx_stream s);
int mlx_random_seed(uint64_t seed);
int mlx_random_split_num(
mlx_array* res,
const mlx_array key,
int num,
const mlx_stream s);
int mlx_random_split(
mlx_array* res_0,
mlx_array* res_1,
const mlx_array key,
const mlx_stream s);
int mlx_random_truncated_normal(
mlx_array* res,
const mlx_array lower,
const mlx_array upper,
const int* shape,
size_t shape_num,
mlx_dtype dtype,
const mlx_array key /* may be null */,
const mlx_stream s);
int mlx_random_uniform(
mlx_array* res,
const mlx_array low,
const mlx_array high,
const int* shape,
size_t shape_num,
mlx_dtype dtype,
const mlx_array key /* may be null */,
const mlx_stream s);
/**@}*/
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -0,0 +1,88 @@
/* Copyright © 2023-2024 Apple Inc. */
#ifndef MLX_STREAM_H
#define MLX_STREAM_H
#include <stdbool.h>
#include "mlx/c/device.h"
#ifdef __cplusplus
extern "C" {
#endif
/**
* \defgroup mlx_stream Stream
* MLX stream object.
*/
/**@{*/
/**
* A MLX stream object.
*/
typedef struct mlx_stream_ {
void* ctx;
} mlx_stream;
/**
* Returns a new empty stream.
*/
mlx_stream mlx_stream_new(void);
/**
* Returns a new stream on a device.
*/
mlx_stream mlx_stream_new_device(mlx_device dev);
/**
* Set stream to provided src stream.
*/
int mlx_stream_set(mlx_stream* stream, const mlx_stream src);
/**
* Free a stream.
*/
int mlx_stream_free(mlx_stream stream);
/**
* Get stream description.
*/
int mlx_stream_tostring(mlx_string* str, mlx_stream stream);
/**
* Check if streams are the same.
*/
bool mlx_stream_equal(mlx_stream lhs, mlx_stream rhs);
/**
* Return the device of the stream.
*/
int mlx_stream_get_device(mlx_device* dev, mlx_stream stream);
/**
* Return the index of the stream.
*/
int mlx_stream_get_index(int* index, mlx_stream stream);
/**
* Synchronize with the provided stream.
*/
int mlx_synchronize(mlx_stream stream);
/**
* Returns the default stream on the given device.
*/
int mlx_get_default_stream(mlx_stream* stream, mlx_device dev);
/**
* Set default stream.
*/
int mlx_set_default_stream(mlx_stream stream);
/**
* Returns the current default CPU stream.
*/
mlx_stream mlx_default_cpu_stream_new(void);
/**
* Returns the current default GPU stream.
*/
mlx_stream mlx_default_gpu_stream_new(void);
/**@}*/
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -0,0 +1,55 @@
/* Copyright © 2023-2024 Apple Inc. */
#ifndef MLX_STRING_H
#define MLX_STRING_H
#ifdef __cplusplus
extern "C" {
#endif
/**
* \defgroup mlx_string String
* MLX string object.
*/
/**@{*/
/**
* A MLX string object.
*/
typedef struct mlx_string_ {
void* ctx;
} mlx_string;
/**
* Returns a new empty string.
*/
mlx_string mlx_string_new(void);
/**
* Returns a new string, copying contents from `str`, which must end with `\0`.
*/
mlx_string mlx_string_new_data(const char* str);
/**
* Set string to src string.
*/
int mlx_string_set(mlx_string* str, const mlx_string src);
/**
* Returns a pointer to the string contents.
* The pointer is valid for the life duration of the string.
*/
const char* mlx_string_data(mlx_string str);
/**
* Free string.
*/
int mlx_string_free(mlx_string str);
/**@}*/
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -0,0 +1,68 @@
/* Copyright © 2023-2024 Apple Inc. */
/* */
/* This file is auto-generated. Do not edit manually. */
/* */
#ifndef MLX_TRANSFORMS_H
#define MLX_TRANSFORMS_H
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include "mlx/c/array.h"
#include "mlx/c/closure.h"
#include "mlx/c/distributed_group.h"
#include "mlx/c/io_types.h"
#include "mlx/c/map.h"
#include "mlx/c/stream.h"
#include "mlx/c/string.h"
#include "mlx/c/vector.h"
#ifdef __cplusplus
extern "C" {
#endif
/**
* \defgroup transforms Transform operations
*/
/**@{*/
int mlx_async_eval(const mlx_vector_array outputs);
int mlx_checkpoint(mlx_closure* res, const mlx_closure fun);
int mlx_custom_function(
mlx_closure* res,
const mlx_closure fun,
const mlx_closure_custom fun_vjp /* may be null */,
const mlx_closure_custom_jvp fun_jvp /* may be null */,
const mlx_closure_custom_vmap fun_vmap /* may be null */);
int mlx_custom_vjp(
mlx_closure* res,
const mlx_closure fun,
const mlx_closure_custom fun_vjp);
int mlx_eval(const mlx_vector_array outputs);
int mlx_jvp(
mlx_vector_array* res_0,
mlx_vector_array* res_1,
const mlx_closure fun,
const mlx_vector_array primals,
const mlx_vector_array tangents);
int mlx_value_and_grad(
mlx_closure_value_and_grad* res,
const mlx_closure fun,
const int* argnums,
size_t argnums_num);
int mlx_vjp(
mlx_vector_array* res_0,
mlx_vector_array* res_1,
const mlx_closure fun,
const mlx_vector_array primals,
const mlx_vector_array cotangents);
/**@}*/
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -0,0 +1,54 @@
/* Copyright © 2023-2024 Apple Inc. */
/* */
/* This file is auto-generated. Do not edit manually. */
/* */
#ifndef MLX_TRANSFORMS_IMPL_H
#define MLX_TRANSFORMS_IMPL_H
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include "mlx/c/array.h"
#include "mlx/c/closure.h"
#include "mlx/c/distributed_group.h"
#include "mlx/c/io_types.h"
#include "mlx/c/map.h"
#include "mlx/c/stream.h"
#include "mlx/c/string.h"
#include "mlx/c/vector.h"
#ifdef __cplusplus
extern "C" {
#endif
/**
* \defgroup transforms_impl Implementation detail operations
*/
/**@{*/
int mlx_detail_vmap_replace(
mlx_vector_array* res,
const mlx_vector_array inputs,
const mlx_vector_array s_inputs,
const mlx_vector_array s_outputs,
const int* in_axes,
size_t in_axes_num,
const int* out_axes,
size_t out_axes_num);
int mlx_detail_vmap_trace(
mlx_vector_array* res_0,
mlx_vector_array* res_1,
const mlx_closure fun,
const mlx_vector_array inputs,
const int* in_axes,
size_t in_axes_num);
/**@}*/
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -0,0 +1,133 @@
/* Copyright © 2023-2024 Apple Inc. */
/* */
/* This file is auto-generated. Do not edit manually. */
/* */
#ifndef MLX_VECTOR_H
#define MLX_VECTOR_H
#include "mlx/c/array.h"
#include "mlx/c/string.h"
#ifdef __cplusplus
extern "C" {
#endif
/**
* \defgroup mlx_vector Vectors
* MLX vector objects.
*/
/**@{*/
/**
* A vector of array.
*/
typedef struct mlx_vector_array_ {
void* ctx;
} mlx_vector_array;
mlx_vector_array mlx_vector_array_new(void);
int mlx_vector_array_set(mlx_vector_array* vec, const mlx_vector_array src);
int mlx_vector_array_free(mlx_vector_array vec);
mlx_vector_array mlx_vector_array_new_data(const mlx_array* data, size_t size);
mlx_vector_array mlx_vector_array_new_value(const mlx_array val);
int mlx_vector_array_set_data(
mlx_vector_array* vec,
const mlx_array* data,
size_t size);
int mlx_vector_array_set_value(mlx_vector_array* vec, const mlx_array val);
int mlx_vector_array_append_data(
mlx_vector_array vec,
const mlx_array* data,
size_t size);
int mlx_vector_array_append_value(mlx_vector_array vec, const mlx_array val);
size_t mlx_vector_array_size(mlx_vector_array vec);
int mlx_vector_array_get(
mlx_array* res,
const mlx_vector_array vec,
size_t idx);
/**
* A vector of vector_array.
*/
typedef struct mlx_vector_vector_array_ {
void* ctx;
} mlx_vector_vector_array;
mlx_vector_vector_array mlx_vector_vector_array_new(void);
int mlx_vector_vector_array_set(
mlx_vector_vector_array* vec,
const mlx_vector_vector_array src);
int mlx_vector_vector_array_free(mlx_vector_vector_array vec);
mlx_vector_vector_array mlx_vector_vector_array_new_data(
const mlx_vector_array* data,
size_t size);
mlx_vector_vector_array mlx_vector_vector_array_new_value(
const mlx_vector_array val);
int mlx_vector_vector_array_set_data(
mlx_vector_vector_array* vec,
const mlx_vector_array* data,
size_t size);
int mlx_vector_vector_array_set_value(
mlx_vector_vector_array* vec,
const mlx_vector_array val);
int mlx_vector_vector_array_append_data(
mlx_vector_vector_array vec,
const mlx_vector_array* data,
size_t size);
int mlx_vector_vector_array_append_value(
mlx_vector_vector_array vec,
const mlx_vector_array val);
size_t mlx_vector_vector_array_size(mlx_vector_vector_array vec);
int mlx_vector_vector_array_get(
mlx_vector_array* res,
const mlx_vector_vector_array vec,
size_t idx);
/**
* A vector of int.
*/
typedef struct mlx_vector_int_ {
void* ctx;
} mlx_vector_int;
mlx_vector_int mlx_vector_int_new(void);
int mlx_vector_int_set(mlx_vector_int* vec, const mlx_vector_int src);
int mlx_vector_int_free(mlx_vector_int vec);
mlx_vector_int mlx_vector_int_new_data(int* data, size_t size);
mlx_vector_int mlx_vector_int_new_value(int val);
int mlx_vector_int_set_data(mlx_vector_int* vec, int* data, size_t size);
int mlx_vector_int_set_value(mlx_vector_int* vec, int val);
int mlx_vector_int_append_data(mlx_vector_int vec, int* data, size_t size);
int mlx_vector_int_append_value(mlx_vector_int vec, int val);
size_t mlx_vector_int_size(mlx_vector_int vec);
int mlx_vector_int_get(int* res, const mlx_vector_int vec, size_t idx);
/**
* A vector of string.
*/
typedef struct mlx_vector_string_ {
void* ctx;
} mlx_vector_string;
mlx_vector_string mlx_vector_string_new(void);
int mlx_vector_string_set(mlx_vector_string* vec, const mlx_vector_string src);
int mlx_vector_string_free(mlx_vector_string vec);
mlx_vector_string mlx_vector_string_new_data(const char** data, size_t size);
mlx_vector_string mlx_vector_string_new_value(const char* val);
int mlx_vector_string_set_data(
mlx_vector_string* vec,
const char** data,
size_t size);
int mlx_vector_string_set_value(mlx_vector_string* vec, const char* val);
int mlx_vector_string_append_data(
mlx_vector_string vec,
const char** data,
size_t size);
int mlx_vector_string_append_value(mlx_vector_string vec, const char* val);
size_t mlx_vector_string_size(mlx_vector_string vec);
int mlx_vector_string_get(char** res, const mlx_vector_string vec, size_t idx);
/**@}*/
#ifdef __cplusplus
}
#endif
#endif

View File

@@ -0,0 +1,18 @@
/* Copyright © 2023-2024 Apple Inc. */
#ifndef MLX_VERSION_H
#define MLX_VERSION_H
#include "mlx/c/string.h"
#ifdef __cplusplus
extern "C" {
#endif
int mlx_version(mlx_string* str_);
#ifdef __cplusplus
}
#endif
#endif

164
x/mlxrunner/mlx/io.go Normal file
View File

@@ -0,0 +1,164 @@
package mlx
// #include "generated.h"
import "C"
import (
"fmt"
"iter"
"runtime"
"sort"
"unsafe"
)
// SafetensorsFile represents a loaded safetensors file.
type SafetensorsFile struct {
arrays C.mlx_map_string_to_array
metadata C.mlx_map_string_to_string
}
func loadSafetensorsStream() C.mlx_stream {
if runtime.GOOS == "darwin" {
return C.mlx_default_cpu_stream_new()
}
return C.mlx_default_gpu_stream_new()
}
// LoadSafetensorsNative loads a safetensors file using MLX's native loader.
func LoadSafetensorsNative(path string) (*SafetensorsFile, error) {
var arrays C.mlx_map_string_to_array
var metadata C.mlx_map_string_to_string
cPath := C.CString(path)
defer C.free(unsafe.Pointer(cPath))
stream := loadSafetensorsStream()
defer C.mlx_stream_free(stream)
if C.mlx_load_safetensors(&arrays, &metadata, cPath, stream) != 0 {
return nil, fmt.Errorf("failed to load safetensors: %s", path)
}
return &SafetensorsFile{arrays: arrays, metadata: metadata}, nil
}
// Get retrieves a tensor by name.
func (s *SafetensorsFile) Get(name string) *Array {
cName := C.CString(name)
defer C.free(unsafe.Pointer(cName))
value := C.mlx_array_new()
if C.mlx_map_string_to_array_get(&value, s.arrays, cName) != 0 {
return nil
}
if value.ctx == nil {
return nil
}
arr := New(name)
arr.ctx = value
return arr
}
// GetMetadata retrieves a metadata value by key.
func (s *SafetensorsFile) GetMetadata(key string) string {
cKey := C.CString(key)
defer C.free(unsafe.Pointer(cKey))
var cValue *C.char
if C.mlx_map_string_to_string_get(&cValue, s.metadata, cKey) != 0 {
return ""
}
return C.GoString(cValue)
}
// Free releases the loaded safetensors maps.
func (s *SafetensorsFile) Free() {
if s == nil {
return
}
C.mlx_map_string_to_array_free(s.arrays)
C.mlx_map_string_to_string_free(s.metadata)
}
func Load(path string) iter.Seq2[string, *Array] {
return func(yield func(string, *Array) bool) {
sf, err := LoadSafetensorsNative(path)
if err != nil {
return
}
defer sf.Free()
it := C.mlx_map_string_to_array_iterator_new(sf.arrays)
defer C.mlx_map_string_to_array_iterator_free(it)
for {
var key *C.char
value := C.mlx_array_new()
if C.mlx_map_string_to_array_iterator_next(&key, &value, it) != 0 {
break
}
name := C.GoString(key)
arr := New(name)
arr.ctx = value
if !yield(name, arr) {
break
}
}
}
}
// SaveSafetensors saves arrays to a safetensors file without metadata.
func SaveSafetensors(path string, arrays map[string]*Array) error {
return SaveSafetensorsWithMetadata(path, arrays, nil)
}
// SaveSafetensorsWithMetadata saves arrays to a safetensors file with metadata.
func SaveSafetensorsWithMetadata(path string, arrays map[string]*Array, metadata map[string]string) error {
cPath := C.CString(path)
defer C.free(unsafe.Pointer(cPath))
cArrays := C.mlx_map_string_to_array_new()
defer C.mlx_map_string_to_array_free(cArrays)
arrayNames := make([]string, 0, len(arrays))
for name, arr := range arrays {
if arr == nil {
continue
}
arrayNames = append(arrayNames, name)
}
sort.Strings(arrayNames)
for _, name := range arrayNames {
arr := arrays[name]
cName := C.CString(name)
C.mlx_map_string_to_array_insert(cArrays, cName, arr.ctx)
C.free(unsafe.Pointer(cName))
}
cMetadata := C.mlx_map_string_to_string_new()
defer C.mlx_map_string_to_string_free(cMetadata)
metadataKeys := make([]string, 0, len(metadata))
for key := range metadata {
metadataKeys = append(metadataKeys, key)
}
sort.Strings(metadataKeys)
for _, key := range metadataKeys {
value := metadata[key]
cKey := C.CString(key)
cValue := C.CString(value)
C.mlx_map_string_to_string_insert(cMetadata, cKey, cValue)
C.free(unsafe.Pointer(cKey))
C.free(unsafe.Pointer(cValue))
}
if C.mlx_save_safetensors(cPath, cArrays, cMetadata) != 0 {
return fmt.Errorf("failed to save safetensors: %s", path)
}
return nil
}

89
x/mlxrunner/mlx/memory.go Normal file
View File

@@ -0,0 +1,89 @@
package mlx
// #include "generated.h"
import "C"
import (
"fmt"
"log/slog"
"strconv"
)
func (b Byte) String() string {
return strconv.FormatInt(int64(b), 10) + " B"
}
func (b KibiByte) String() string {
return strconv.FormatFloat(float64(b)/(1<<10), 'f', 2, 64) + " KiB"
}
func (b MebiByte) String() string {
return strconv.FormatFloat(float64(b)/(1<<(2*10)), 'f', 2, 64) + " MiB"
}
func (b GibiByte) String() string {
return strconv.FormatFloat(float64(b)/(1<<(3*10)), 'f', 2, 64) + " GiB"
}
func (b TebiByte) String() string {
return strconv.FormatFloat(float64(b)/(1<<(4*10)), 'f', 2, 64) + " TiB"
}
func PrettyBytes(n int) fmt.Stringer {
switch {
case n < 1<<10:
return Byte(n)
case n < 1<<(2*10):
return KibiByte(n)
case n < 1<<(3*10):
return MebiByte(n)
case n < 1<<(4*10):
return GibiByte(n)
default:
return TebiByte(n)
}
}
func ActiveMemory() int {
var active C.size_t
C.mlx_get_active_memory(&active)
return int(active)
}
func CacheMemory() int {
var cache C.size_t
C.mlx_get_cache_memory(&cache)
return int(cache)
}
func PeakMemory() int {
var peak C.size_t
C.mlx_get_peak_memory(&peak)
return int(peak)
}
func ResetPeakMemory() {
C.mlx_reset_peak_memory()
}
type Memory struct{}
func (Memory) LogValue() slog.Value {
return slog.GroupValue(
slog.Any("active", PrettyBytes(ActiveMemory())),
slog.Any("cache", PrettyBytes(CacheMemory())),
slog.Any("peak", PrettyBytes(PeakMemory())),
)
}
type (
Byte int
KibiByte int
MebiByte int
GibiByte int
TebiByte int
)
func ClearCache() {
C.mlx_clear_cache()
}

107
x/mlxrunner/mlx/mlx.go Normal file
View File

@@ -0,0 +1,107 @@
package mlx
//go:generate go run generator/main.go -output=. ./include/mlx/c/*.h
// #cgo CXXFLAGS: -std=c++17
// #cgo CPPFLAGS: -I${SRCDIR}/include
// #cgo LDFLAGS: -lstdc++
// #cgo darwin LDFLAGS: -framework Foundation -framework Metal -framework Accelerate
// #include "generated.h"
// #include <string.h>
//
// static __thread char _mlx_last_error_msg[1024] = {0};
// static __thread int _mlx_last_error_flag = 0;
//
// static void _mlx_capture_error_handler(const char* msg, void* data) {
// (void)data;
// strncpy(_mlx_last_error_msg, msg, sizeof(_mlx_last_error_msg) - 1);
// _mlx_last_error_msg[sizeof(_mlx_last_error_msg) - 1] = '\0';
// _mlx_last_error_flag = 1;
// }
//
// static void mlx_install_capture_handler(void) {
// if (mlx_set_error_handler_) {
// mlx_set_error_handler_(_mlx_capture_error_handler, NULL, NULL);
// }
// }
//
// static void mlx_clear_last_error(void) {
// _mlx_last_error_flag = 0;
// _mlx_last_error_msg[0] = '\0';
// }
//
// static const char* mlx_get_last_error(void) {
// return _mlx_last_error_flag ? _mlx_last_error_msg : "";
// }
import "C"
import "runtime"
func init() {
// Replace the default exit(-1) error handler with one that captures
// the error message so we can surface it in Go.
C.mlx_install_capture_handler()
}
// Version returns the MLX core library version string.
func Version() string {
str := C.mlx_string_new()
defer C.mlx_string_free(str)
C.mlx_version(&str)
return C.GoString(C.mlx_string_data(str))
}
// mlxCheck locks the goroutine to its OS thread, clears the captured error
// state, calls fn, and panics with the captured message if fn returns non-zero.
// The thread lock ensures the thread-local error state is read from the same
// thread that executed the call.
func mlxCheck(fallback string, fn func() C.int) {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
C.mlx_clear_last_error()
if fn() != 0 {
msg := C.GoString(C.mlx_get_last_error())
if msg == "" {
msg = fallback
}
panic("mlx: " + msg)
}
}
func doEval(outputs []*Array, async bool) {
if len(outputs) == 0 {
return
}
vector := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(vector)
for _, output := range outputs {
if output != nil && output.Valid() {
C.mlx_vector_array_append_value(vector, output.ctx)
}
}
mlxCheck("eval failed", func() C.int {
if async {
return C.mlx_async_eval(vector)
}
return C.mlx_eval(vector)
})
}
func AsyncEval(outputs ...*Array) {
doEval(outputs, true)
}
func Eval(outputs ...*Array) {
doEval(outputs, false)
}
// MetalIsAvailable returns true if a Metal GPU is available.
func MetalIsAvailable() bool {
var available C._Bool
C.mlx_metal_is_available(&available)
return bool(available)
}

36
x/mlxrunner/mlx/nn.go Normal file
View File

@@ -0,0 +1,36 @@
package mlx
type Linear struct {
Weight *Array `weight:"weight"`
Bias *Array `weight:"bias"`
}
// Forward computes the linear transformation: x @ Weight.T + Bias
func (m *Linear) Forward(x *Array) *Array {
w := m.Weight.Transpose(1, 0)
if m.Bias.Valid() {
return m.Bias.Addmm(x, w, 1.0, 1.0)
}
return x.Matmul(w)
}
func (m *Linear) Gather(x, lhs, rhs *Array, sorted bool) *Array {
w := m.Weight.Transpose(0, 2, 1)
// TODO: bias
return x.GatherMM(w, lhs, rhs, sorted)
}
type Embedding struct {
Weight *Array `weight:"weight"`
}
func (e *Embedding) Forward(indices *Array) *Array {
return e.Weight.TakeAxis(indices, 0)
}
func (e *Embedding) AsLinear() Linear {
return Linear{
Weight: e.Weight,
}
}

300
x/mlxrunner/mlx/ops.go Normal file
View File

@@ -0,0 +1,300 @@
package mlx
// #include "generated.h"
import "C"
import (
"unsafe"
)
func (t *Array) Abs() *Array {
out := New("ABS")
C.mlx_abs(&out.ctx, t.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Add(other *Array) *Array {
out := New("ADD")
C.mlx_add(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Addmm(a, b *Array, alpha, beta float32) *Array {
out := New("ADDMM")
C.mlx_addmm(&out.ctx, t.ctx, a.ctx, b.ctx, C.float(alpha), C.float(beta), DefaultStream().ctx)
return out
}
func (t *Array) Argmax(axis int, keepDims bool) *Array {
out := New("ARGMAX")
C.mlx_argmax_axis(&out.ctx, t.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
return out
}
func (t *Array) ArgpartitionAxis(kth int, axis int) *Array {
out := New("ARGPARTITION")
C.mlx_argpartition_axis(&out.ctx, t.ctx, C.int(kth), C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) ArgsortAxis(axis int) *Array {
out := New("ARGSORT_AXIS")
C.mlx_argsort_axis(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) AsType(dtype DType) *Array {
out := New("AS_TYPE")
C.mlx_astype(&out.ctx, t.ctx, C.mlx_dtype(dtype), DefaultStream().ctx)
return out
}
func (t *Array) AsStrided(shape []int, strides []int, offset int) *Array {
cShape := make([]C.int, len(shape))
for i, s := range shape {
cShape[i] = C.int(s)
}
cStrides := make([]C.int64_t, len(strides))
for i, s := range strides {
cStrides[i] = C.int64_t(s)
}
out := New("AS_STRIDED")
C.mlx_as_strided(
&out.ctx, t.ctx,
unsafe.SliceData(cShape), C.size_t(len(shape)),
unsafe.SliceData(cStrides), C.size_t(len(strides)),
C.size_t(offset),
DefaultStream().ctx,
)
return out
}
func (t *Array) Concatenate(axis int, others ...*Array) *Array {
if len(others) == 0 {
return t.Clone()
}
vector := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(vector)
s := append([]*Array{t}, others...)
for _, other := range s {
C.mlx_vector_array_append_value(vector, other.ctx)
}
out := New("CONCATENATE")
C.mlx_concatenate_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) Cumsum(axis int, reverse, inclusive bool) *Array {
out := New("CUMSUM")
C.mlx_cumsum(&out.ctx, t.ctx, C.int(axis), C.bool(reverse), C.bool(inclusive), DefaultStream().ctx)
return out
}
func (t *Array) Divide(other *Array) *Array {
out := New("DIVIDE")
C.mlx_divide(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) ExpandDims(axis int) *Array {
out := New("EXPAND_DIMS")
C.mlx_expand_dims(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) Flatten(startAxis, endAxis int) *Array {
out := New("FLATTEN")
C.mlx_flatten(&out.ctx, t.ctx, C.int(startAxis), C.int(endAxis), DefaultStream().ctx)
return out
}
func (t *Array) FloorDivide(other *Array) *Array {
out := New("FLOOR_DIVIDE")
C.mlx_floor_divide(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) GatherMM(other, lhs, rhs *Array, sorted bool) *Array {
if lhs == nil {
lhs = New("")
}
if rhs == nil {
rhs = New("")
}
out := New("GATHER_MM")
C.mlx_gather_mm(&out.ctx, t.ctx, other.ctx, lhs.ctx, rhs.ctx, C.bool(sorted), DefaultStream().ctx)
return out
}
func (t *Array) LogsumexpAxis(axis int, keepDims bool) *Array {
out := New("LOGSUMEXP_AXIS")
C.mlx_logsumexp_axis(&out.ctx, t.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
return out
}
func (t *Array) Equal(other *Array) *Array {
out := New("EQUAL")
C.mlx_equal(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Greater(other *Array) *Array {
out := New("GREATER")
C.mlx_greater(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Less(other *Array) *Array {
out := New("LESS")
C.mlx_less(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) LessEqual(other *Array) *Array {
out := New("LESS_EQUAL")
C.mlx_less_equal(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) MaxAxis(axis int, keepDims bool) *Array {
out := New("MAX_AXIS")
C.mlx_max_axis(&out.ctx, t.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
return out
}
func (t *Array) Matmul(other *Array) *Array {
out := New("MATMUL")
C.mlx_matmul(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Multiply(other *Array) *Array {
out := New("MULTIPLY")
C.mlx_multiply(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Negative() *Array {
out := New("NEGATIVE")
C.mlx_negative(&out.ctx, t.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Power(exponent *Array) *Array {
out := New("POWER")
C.mlx_power(&out.ctx, t.ctx, exponent.ctx, DefaultStream().ctx)
return out
}
func (t *Array) PutAlongAxis(indices, values *Array, axis int) *Array {
out := New("PUT_ALONG_AXIS")
C.mlx_put_along_axis(&out.ctx, t.ctx, indices.ctx, values.ctx, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) ScatterAddAxis(indices, values *Array, axis int) *Array {
out := New("SCATTER_ADD_AXIS")
C.mlx_scatter_add_axis(&out.ctx, t.ctx, indices.ctx, values.ctx, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) Reshape(axes ...int) *Array {
cAxes := make([]C.int, len(axes))
for i := range axes {
cAxes[i] = C.int(axes[i])
}
out := New("RESHAPE")
C.mlx_reshape(&out.ctx, t.ctx, unsafe.SliceData(cAxes), C.size_t(len(cAxes)), DefaultStream().ctx)
return out
}
func (t *Array) Sigmoid() *Array {
out := New("SIGMOID")
C.mlx_sigmoid(&out.ctx, t.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Sqrt() *Array {
out := New("SQRT")
C.mlx_sqrt(&out.ctx, t.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Squeeze(axis int) *Array {
out := New("SQUEEZE")
C.mlx_squeeze_axis(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) StackAxis(axis int, others ...*Array) *Array {
vectorData := make([]C.mlx_array, len(others)+1)
vectorData[0] = t.ctx
for i := range others {
vectorData[i+1] = others[i].ctx
}
vector := C.mlx_vector_array_new_data(unsafe.SliceData(vectorData), C.size_t(len(vectorData)))
defer C.mlx_vector_array_free(vector)
out := New("STACK_AXIS")
C.mlx_stack_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) Subtract(other *Array) *Array {
out := New("SUBTRACT")
C.mlx_subtract(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
return out
}
func (t *Array) SumAxis(axis int, keepDims bool) *Array {
out := New("SUM_AXIS")
C.mlx_sum_axis(&out.ctx, t.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
return out
}
func (t *Array) TakeAxis(indices *Array, axis int) *Array {
out := New("TAKE_AXIS")
C.mlx_take_axis(&out.ctx, t.ctx, indices.ctx, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) TakeAlongAxis(indices *Array, axis int) *Array {
out := New("TAKE_ALONG_AXIS")
C.mlx_take_along_axis(&out.ctx, t.ctx, indices.ctx, C.int(axis), DefaultStream().ctx)
return out
}
func (t *Array) Tanh() *Array {
out := New("TANH")
C.mlx_tanh(&out.ctx, t.ctx, DefaultStream().ctx)
return out
}
func (t *Array) Transpose(axes ...int) *Array {
cAxes := make([]C.int, len(axes))
for i, axis := range axes {
cAxes[i] = C.int(axis)
}
out := New("TRANSPOSE")
C.mlx_transpose_axes(&out.ctx, t.ctx, unsafe.SliceData(cAxes), C.size_t(len(cAxes)), DefaultStream().ctx)
return out
}
func Zeros(dtype DType, shape ...int) *Array {
cAxes := make([]C.int, len(shape))
for i := range shape {
cAxes[i] = C.int(shape[i])
}
t := New("ZEROS")
C.mlx_zeros(&t.ctx, unsafe.SliceData(cAxes), C.size_t(len(cAxes)), C.mlx_dtype(dtype), DefaultStream().ctx)
return t
}

View File

@@ -0,0 +1,666 @@
package mlx
// #include "generated.h"
import "C"
import (
"reflect"
"unsafe"
)
// Quantization operations
func Quantize(w *Array, groupSize, bits int, mode string) (weights, scales, biases *Array) {
cMode := C.CString(mode)
defer C.free(unsafe.Pointer(cMode))
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
res := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(res)
var globalScale C.mlx_array
C.mlx_quantize(&res, w.ctx, optGroupSize, optBits, cMode, globalScale, DefaultStream().ctx)
vecSize := int(C.mlx_vector_array_size(res))
w0 := New("QUANTIZE_W")
C.mlx_vector_array_get(&w0.ctx, res, 0)
w1 := New("QUANTIZE_S")
C.mlx_vector_array_get(&w1.ctx, res, 1)
if vecSize >= 3 {
w2 := New("QUANTIZE_B")
C.mlx_vector_array_get(&w2.ctx, res, 2)
return w0, w1, w2
}
return w0, w1, nil
}
func FromFP8(x *Array, dtype DType) *Array {
out := New("FROM_FP8")
C.mlx_from_fp8(&out.ctx, x.ctx, C.mlx_dtype(dtype), DefaultStream().ctx)
return out
}
func ToFP8(x *Array) *Array {
out := New("TO_FP8")
C.mlx_to_fp8(&out.ctx, x.ctx, DefaultStream().ctx)
return out
}
func Dequantize(w, scales, biases *Array, groupSize, bits int, mode string) *Array {
cMode := C.CString(mode)
defer C.free(unsafe.Pointer(cMode))
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
optDtype := C.mlx_optional_dtype{has_value: false}
var b C.mlx_array
if biases != nil {
b = biases.ctx
}
out := New("DEQUANTIZE")
var globalScale C.mlx_array
C.mlx_dequantize(&out.ctx, w.ctx, scales.ctx, b, optGroupSize, optBits, cMode, globalScale, optDtype, DefaultStream().ctx)
return out
}
func QuantizedMatmul(x, w, scales, biases *Array, transpose bool, groupSize, bits int, mode string) *Array {
cMode := C.CString(mode)
defer C.free(unsafe.Pointer(cMode))
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
var b C.mlx_array
if biases != nil {
b = biases.ctx
}
out := New("QUANTIZED_MATMUL")
C.mlx_quantized_matmul(&out.ctx, x.ctx, w.ctx, scales.ctx, b, C.bool(transpose), optGroupSize, optBits, cMode, DefaultStream().ctx)
return out
}
func GatherQMM(x, w, scales *Array, biases, lhsIndices, rhsIndices *Array, transpose bool, groupSize, bits int, mode string, sortedIndices bool) *Array {
cMode := C.CString(mode)
defer C.free(unsafe.Pointer(cMode))
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
var b, lhs, rhs C.mlx_array
if biases != nil {
b = biases.ctx
}
if lhsIndices != nil {
lhs = lhsIndices.ctx
}
if rhsIndices != nil {
rhs = rhsIndices.ctx
}
out := New("GATHER_QMM")
C.mlx_gather_qmm(&out.ctx, x.ctx, w.ctx, scales.ctx, b, lhs, rhs, C.bool(transpose), optGroupSize, optBits, cMode, C.bool(sortedIndices), DefaultStream().ctx)
return out
}
// Missing tensor ops
func Tile(a *Array, reps []int32) *Array {
cReps := make([]C.int, len(reps))
for i, r := range reps {
cReps[i] = C.int(r)
}
out := New("TILE")
C.mlx_tile(&out.ctx, a.ctx, unsafe.SliceData(cReps), C.size_t(len(reps)), DefaultStream().ctx)
return out
}
func Tri(n, m int32, k int) *Array {
out := New("TRI")
C.mlx_tri(&out.ctx, C.int(n), C.int(m), C.int(k), C.mlx_dtype(DTypeFloat32), DefaultStream().ctx)
return out
}
func Where(condition, a, b *Array) *Array {
out := New("WHERE")
C.mlx_where(&out.ctx, condition.ctx, a.ctx, b.ctx, DefaultStream().ctx)
return out
}
func Conv1d(x, weight *Array, bias *Array, stride, padding, dilation, groups int32) *Array {
out := New("CONV1D")
C.mlx_conv1d(
&out.ctx,
x.ctx,
weight.ctx,
C.int(stride),
C.int(padding),
C.int(dilation),
C.int(groups),
DefaultStream().ctx,
)
if bias != nil && bias.Valid() {
out = Add(out, bias)
}
return out
}
func Contiguous(a *Array, allowColMajor bool) *Array {
out := New("CONTIGUOUS")
C.mlx_contiguous(&out.ctx, a.ctx, C.bool(allowColMajor), DefaultStream().ctx)
return out
}
// Conv2d performs 2D convolution: x [N,H,W,C_in], weight [C_out,kH,kW,C_in].
// MLX uses NHWC layout.
func Conv2d(x, weight *Array, strideH, strideW, padH, padW, dilationH, dilationW, groups int32) *Array {
out := New("CONV2D")
C.mlx_conv2d(
&out.ctx,
x.ctx,
weight.ctx,
C.int(strideH), C.int(strideW),
C.int(padH), C.int(padW),
C.int(dilationH), C.int(dilationW),
C.int(groups),
DefaultStream().ctx,
)
return out
}
// Pad pads array a along the given axes with specified low/high pad sizes.
// mode should be "constant", "edge", or "reflect".
func Pad(a *Array, axes []int, lowPad, highPad []int, padValue *Array, mode string) *Array {
cAxes := make([]C.int, len(axes))
cLow := make([]C.int, len(lowPad))
cHigh := make([]C.int, len(highPad))
for i := range axes {
cAxes[i] = C.int(axes[i])
cLow[i] = C.int(lowPad[i])
cHigh[i] = C.int(highPad[i])
}
cMode := C.CString(mode)
defer C.free(unsafe.Pointer(cMode))
out := New("PAD")
C.mlx_pad(
&out.ctx,
a.ctx,
unsafe.SliceData(cAxes), C.size_t(len(cAxes)),
unsafe.SliceData(cLow), C.size_t(len(cLow)),
unsafe.SliceData(cHigh), C.size_t(len(cHigh)),
padValue.ctx,
cMode,
DefaultStream().ctx,
)
return out
}
// PadConstant pads with zeros along the given axes.
func PadConstant(a *Array, axes []int, lowPad, highPad []int) *Array {
zero := NewScalarArray(float32(0))
return Pad(a, axes, lowPad, highPad, zero, "constant")
}
func DepthwiseConv1d(x, weight *Array, bias *Array) *Array {
groups := int32(x.Dim(x.NumDims() - 1))
return Conv1d(x, weight, bias, 1, 0, 1, groups)
}
// Maximum returns element-wise maximum of two arrays.
func Maximum(a, b *Array) *Array {
out := New("MAXIMUM")
C.mlx_maximum(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
return out
}
// Minimum returns element-wise minimum of two arrays.
func Minimum(a, b *Array) *Array {
out := New("MINIMUM")
C.mlx_minimum(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
return out
}
// Softplus computes log(1 + exp(x)) using logaddexp for numerical stability.
func Softplus(a *Array) *Array {
return Logaddexp(a, Zeros(a.DType(), a.Dims()...))
}
// ReLU computes max(0, x).
func ReLU(a *Array) *Array {
return Maximum(a, NewScalarArray(float32(0)))
}
// GLU applies Gated Linear Unit: splits x along last dim into two halves,
// returns first * sigmoid(second).
func GLU(a *Array) *Array {
lastDim := a.NumDims() - 1
halfSize := a.Dim(lastDim) / 2
first := SliceStartStop(a,
make([]int32, lastDim+1), // all zeros for start
appendDims(a, lastDim, int32(halfSize)),
)
second := SliceStartStop(a,
appendDimsStart(a, lastDim, int32(halfSize)),
appendDims(a, lastDim, int32(a.Dim(lastDim))),
)
return first.Multiply(second.Sigmoid())
}
// helper: builds stop array for SliceStartStop where the target axis = val
func appendDims(a *Array, targetAxis int, val int32) []int32 {
n := a.NumDims()
out := make([]int32, n)
for i := range n {
if i == targetAxis {
out[i] = val
} else {
out[i] = int32(a.Dim(i))
}
}
return out
}
// helper: builds start array for SliceStartStop where the target axis = val
func appendDimsStart(a *Array, targetAxis int, val int32) []int32 {
n := a.NumDims()
out := make([]int32, n)
for i := range n {
if i == targetAxis {
out[i] = val
}
}
return out
}
// Clamp clamps array values to [min, max].
func Clamp(a *Array, minVal, maxVal float32) *Array {
return Minimum(Maximum(a, NewScalarArray(minVal)), NewScalarArray(maxVal))
}
// Convenience wrappers (function-style for the model code)
func Stack(arrays []*Array, axis int) *Array {
vectorData := make([]C.mlx_array, len(arrays))
for i := range arrays {
vectorData[i] = arrays[i].ctx
}
vector := C.mlx_vector_array_new_data(unsafe.SliceData(vectorData), C.size_t(len(vectorData)))
defer C.mlx_vector_array_free(vector)
out := New("STACK")
C.mlx_stack_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx)
return out
}
func Neg(a *Array) *Array {
return a.Negative()
}
func Sum(a *Array, axis int, keepDims bool) *Array {
return a.SumAxis(axis, keepDims)
}
func Argsort(a *Array, axis int) *Array {
return a.ArgsortAxis(axis)
}
func Take(a *Array, indices *Array, axis int) *Array {
return a.TakeAxis(indices, axis)
}
func RSqrt(a *Array) *Array {
out := New("RSQRT")
C.mlx_rsqrt(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
func Mean(a *Array, axis int, keepDims bool) *Array {
out := New("MEAN_AXIS")
C.mlx_mean_axis(&out.ctx, a.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
return out
}
func Argpartition(a *Array, kth int, axis int) *Array {
return a.ArgpartitionAxis(kth, axis)
}
func TakeAlongAxis(a, indices *Array, axis int) *Array {
return a.TakeAlongAxis(indices, axis)
}
// Function-style wrappers matching imagegen API
func Add(a, b *Array) *Array {
return a.Add(b)
}
func Sub(a, b *Array) *Array {
return a.Subtract(b)
}
func Mul(a, b *Array) *Array {
return a.Multiply(b)
}
func Div(a, b *Array) *Array {
return a.Divide(b)
}
func Matmul(a, b *Array) *Array {
return a.Matmul(b)
}
func Reshape(a *Array, shape ...int32) *Array {
axes := make([]int, len(shape))
for i, s := range shape {
axes[i] = int(s)
}
return a.Reshape(axes...)
}
func Transpose(a *Array, axes ...int) *Array {
return a.Transpose(axes...)
}
func ExpandDims(a *Array, axis int) *Array {
return a.ExpandDims(axis)
}
func Squeeze(a *Array, axis int) *Array {
return a.Squeeze(axis)
}
func Flatten(a *Array) *Array {
return a.Flatten(0, -1)
}
func Concatenate(arrays []*Array, axis int) *Array {
if len(arrays) == 0 {
return nil
}
if len(arrays) == 1 {
return arrays[0].Clone()
}
return arrays[0].Concatenate(axis, arrays[1:]...)
}
func SliceStartStop(a *Array, start, stop []int32) *Array {
n := len(start)
cStart := make([]C.int, n)
cStop := make([]C.int, n)
cStrides := make([]C.int, n)
for i := 0; i < n; i++ {
cStart[i] = C.int(start[i])
cStop[i] = C.int(stop[i])
cStrides[i] = 1
}
out := New("SLICE")
C.mlx_slice(&out.ctx, a.ctx, unsafe.SliceData(cStart), C.size_t(n), unsafe.SliceData(cStop), C.size_t(n), unsafe.SliceData(cStrides), C.size_t(n), DefaultStream().ctx)
return out
}
func GatherMM(a, b *Array, lhsIndices, rhsIndices *Array, sortedIndices bool) *Array {
if lhsIndices == nil {
lhsIndices = New("")
}
if rhsIndices == nil {
rhsIndices = New("")
}
return a.GatherMM(b, lhsIndices, rhsIndices, sortedIndices)
}
// RoPEWithBase applies rotary position embeddings to x. offsets is an
// int32 array of shape [B] giving each batch row's starting position;
// the kernel applies positions offsets[b] + 0..T-1 per row.
func RoPEWithBase(x *Array, dims int, traditional bool, base, scale float32, offsets *Array) *Array {
return RoPEWithFreqs(x, dims, traditional, base, scale, offsets, nil)
}
// RoPEWithFreqs applies RoPE with optional custom frequencies.
// When freqs is non-nil, it is used instead of computing from base.
// Note: MLX takes reciprocal(freqs) internally to get inv_freq, so pass
// the actual frequencies (base^(2i/dim)), not the inverse frequencies.
func RoPEWithFreqs(x *Array, dims int, traditional bool, base, scale float32, offsets *Array, freqs *Array) *Array {
var freqsCtx C.mlx_array
var optBase C.mlx_optional_float
if freqs != nil {
freqsCtx = freqs.ctx
optBase = C.mlx_optional_float{has_value: C.bool(false)}
} else {
empty := New("")
freqsCtx = empty.ctx
optBase = C.mlx_optional_float{
value: C.float(base),
has_value: C.bool(func() bool { return base != 0 }()),
}
}
out := New("FAST_ROPE")
C.mlx_fast_rope_dynamic(
&out.ctx,
x.ctx,
C.int(dims),
C.bool(traditional),
optBase,
C.float(scale),
offsets.ctx,
freqsCtx,
DefaultStream().ctx,
)
return out
}
func Sigmoid(a *Array) *Array {
return a.Sigmoid()
}
func Exp(a *Array) *Array {
out := New("EXP")
C.mlx_exp(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
func Log(a *Array) *Array {
out := New("LOG")
C.mlx_log(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
func Sin(a *Array) *Array {
out := New("SIN")
C.mlx_sin(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
func Cos(a *Array) *Array {
out := New("COS")
C.mlx_cos(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
func Clip(a, aMin, aMax *Array) *Array {
out := New("CLIP")
C.mlx_clip(&out.ctx, a.ctx, aMin.ctx, aMax.ctx, DefaultStream().ctx)
return out
}
func Logaddexp(a, b *Array) *Array {
out := New("LOGADDEXP")
C.mlx_logaddexp(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
return out
}
func SoftmaxAxis(a *Array, axis int, precise bool) *Array {
out := New("SOFTMAX_AXIS")
C.mlx_softmax_axis(&out.ctx, a.ctx, C.int(axis), C.bool(precise), DefaultStream().ctx)
return out
}
func LayerNormFn(x, weight, bias *Array, eps float32) *Array {
out := New("FAST_LAYERNORM")
var w, b C.mlx_array
if weight != nil {
w = weight.ctx
}
if bias != nil {
b = bias.ctx
}
C.mlx_fast_layer_norm(&out.ctx, x.ctx, w, b, C.float(eps), DefaultStream().ctx)
return out
}
func RMSNormFn(x, weight *Array, eps float32) *Array {
out := New("FAST_RMSNORM")
var w C.mlx_array
if weight != nil {
w = weight.ctx
}
C.mlx_fast_rms_norm(&out.ctx, x.ctx, w, C.float(eps), DefaultStream().ctx)
return out
}
func AddMM(c, a, b *Array, alpha, beta float32) *Array {
return c.Addmm(a, b, alpha, beta)
}
// Scalar helpers
// scalarWithDtype creates a scalar array matching the dtype of a.
// Matching dtype is important for graph fusion and avoiding implicit casts.
func scalarWithDtype(s float32, a *Array) C.mlx_array {
f32 := C.mlx_array_new_float(C.float(s))
dtype := a.DType()
if dtype == DTypeFloat32 {
return f32
}
casted := C.mlx_array_new()
C.mlx_astype(&casted, f32, C.mlx_dtype(dtype), DefaultStream().ctx)
C.mlx_array_free(f32)
return casted
}
func AddScalar(a *Array, s float32) *Array {
scalar := scalarWithDtype(s, a)
out := New("ADD_SCALAR")
C.mlx_add(&out.ctx, a.ctx, scalar, DefaultStream().ctx)
C.mlx_array_free(scalar)
return out
}
func MulScalar(a *Array, s float32) *Array {
scalar := scalarWithDtype(s, a)
out := New("MUL_SCALAR")
C.mlx_multiply(&out.ctx, a.ctx, scalar, DefaultStream().ctx)
C.mlx_array_free(scalar)
return out
}
func DivScalar(a *Array, s float32) *Array {
scalar := scalarWithDtype(s, a)
out := New("DIV_SCALAR")
C.mlx_divide(&out.ctx, a.ctx, scalar, DefaultStream().ctx)
C.mlx_array_free(scalar)
return out
}
func FloorDivideScalar(a *Array, s int32) *Array {
scalar := FromValue(int(s))
return a.FloorDivide(scalar)
}
// Array constructors
func NewArrayInt32(data []int32, shape []int32) *Array {
cShape := make([]C.int, len(shape))
for i, s := range shape {
cShape[i] = C.int(s)
}
out := New("NEW_ARRAY_INT32")
out.ctx = C.mlx_array_new_data(unsafe.Pointer(&data[0]), unsafe.SliceData(cShape), C.int(len(shape)), C.mlx_dtype(DTypeInt32))
return out
}
func NewScalarArray(value float32) *Array {
out := New("SCALAR")
out.ctx = C.mlx_array_new_float32(C.float(value))
return out
}
func ZerosF32(shape []int32) *Array {
return Zeros(DTypeFloat32, func() []int {
ints := make([]int, len(shape))
for i, s := range shape {
ints[i] = int(s)
}
return ints
}()...)
}
// Utility
func Collect(v any) []*Array {
var arrays []*Array
seen := make(map[uintptr]bool)
collect(reflect.ValueOf(v), &arrays, seen)
return arrays
}
func collect(v reflect.Value, arrays *[]*Array, seen map[uintptr]bool) {
if !v.IsValid() {
return
}
if v.Kind() == reflect.Ptr {
if v.IsNil() {
return
}
ptr := v.Pointer()
if seen[ptr] {
return
}
seen[ptr] = true
if arr, ok := v.Interface().(*Array); ok {
if arr != nil && arr.Valid() {
*arrays = append(*arrays, arr)
}
return
}
collect(v.Elem(), arrays, seen)
return
}
switch v.Kind() {
case reflect.Struct:
// Check if this struct IS an Array (not a pointer to one)
if arr, ok := v.Addr().Interface().(*Array); ok {
if arr != nil && arr.Valid() {
*arrays = append(*arrays, arr)
}
return
}
for i := 0; i < v.NumField(); i++ {
field := v.Field(i)
if field.CanInterface() {
collect(field, arrays, seen)
}
}
case reflect.Slice:
for i := 0; i < v.Len(); i++ {
collect(v.Index(i), arrays, seen)
}
case reflect.Map:
for _, key := range v.MapKeys() {
collect(v.MapIndex(key), arrays, seen)
}
case reflect.Interface:
if !v.IsNil() {
collect(v.Elem(), arrays, seen)
}
}
}
func EnableCompile() {
C.mlx_enable_compile()
}
func DisableCompile() {
C.mlx_disable_compile()
}

44
x/mlxrunner/mlx/random.go Normal file
View File

@@ -0,0 +1,44 @@
package mlx
// #include "generated.h"
import "C"
import "unsafe"
func RandomKey(seed uint64) *Array {
out := New("RANDOM_KEY")
C.mlx_random_key(&out.ctx, C.uint64_t(seed))
return out
}
func (t *Array) Categorical(axis int) *Array {
return t.CategoricalWithKey(axis, nil)
}
func (t *Array) CategoricalWithKey(axis int, key *Array) *Array {
if key == nil {
key = New("")
}
out := New("")
C.mlx_random_categorical(&out.ctx, t.ctx, C.int(axis), key.ctx, DefaultStream().ctx)
return out
}
func Bernoulli(p *Array) *Array {
return BernoulliWithKey(p, nil)
}
func BernoulliWithKey(p *Array, key *Array) *Array {
dims := p.Dims()
shape := make([]C.int, len(dims))
for i, d := range dims {
shape[i] = C.int(d)
}
if key == nil {
key = New("")
}
out := New("BERNOULLI")
C.mlx_random_bernoulli(&out.ctx, p.ctx, unsafe.SliceData(shape), C.size_t(len(shape)), key.ctx, DefaultStream().ctx)
return out
}

100
x/mlxrunner/mlx/slice.go Normal file
View File

@@ -0,0 +1,100 @@
package mlx
// #include "generated.h"
import "C"
import (
"math"
"unsafe"
)
// End is a sentinel value meaning "to the end of the dimension",
// equivalent to an omitted stop in Python (e.g. a[i:]).
const End = math.MaxInt32
type slice struct {
args []int
}
func Slice(args ...int) slice {
return slice{args: args}
}
func resolve(val, dim int) C.int {
if val == End {
return C.int(dim)
}
if val < 0 {
return C.int(dim + val)
}
return C.int(val)
}
func makeSlices(dims []int, slices ...slice) (starts, stops, strides []C.int) {
if len(slices) != len(dims) {
panic("number of slice arguments must match number of tensor dimensions")
}
args := [3][]C.int{
make([]C.int, len(slices)),
make([]C.int, len(slices)),
make([]C.int, len(slices)),
}
for i, s := range slices {
dim := dims[i]
switch len(s.args) {
case 0:
// slice[:]
args[0][i] = C.int(0)
args[1][i] = C.int(dim)
args[2][i] = C.int(1)
case 1:
// slice[i]
start := resolve(s.args[0], dim)
args[0][i] = start
args[1][i] = start + 1
args[2][i] = C.int(1)
case 2:
// slice[i:j]
args[0][i] = resolve(s.args[0], dim)
args[1][i] = resolve(s.args[1], dim)
args[2][i] = C.int(1)
case 3:
// slice[i:j:k]
args[0][i] = resolve(s.args[0], dim)
args[1][i] = resolve(s.args[1], dim)
args[2][i] = C.int(s.args[2])
default:
panic("invalid slice arguments")
}
}
return args[0], args[1], args[2]
}
func (t *Array) Slice(slices ...slice) *Array {
starts, stops, strides := makeSlices(t.Dims(), slices...)
out := New("SLICE")
C.mlx_slice(
&out.ctx, t.ctx,
unsafe.SliceData(starts), C.size_t(len(starts)),
unsafe.SliceData(stops), C.size_t(len(stops)),
unsafe.SliceData(strides), C.size_t(len(strides)),
DefaultStream().ctx,
)
return out
}
func (t *Array) SliceUpdate(other *Array, slices ...slice) *Array {
starts, stops, strides := makeSlices(t.Dims(), slices...)
out := New("SLICE_UPDATE")
C.mlx_slice_update(
&out.ctx, t.ctx, other.ctx,
unsafe.SliceData(starts), C.size_t(len(starts)),
unsafe.SliceData(stops), C.size_t(len(stops)),
unsafe.SliceData(strides), C.size_t(len(strides)),
DefaultStream().ctx,
)
return out
}

79
x/mlxrunner/mlx/stream.go Normal file
View File

@@ -0,0 +1,79 @@
package mlx
// #include "generated.h"
import "C"
import "log/slog"
type Device struct {
ctx C.mlx_device
}
func (d Device) LogValue() slog.Value {
str := C.mlx_string_new()
defer C.mlx_string_free(str)
C.mlx_device_tostring(&str, d.ctx)
return slog.StringValue(C.GoString(C.mlx_string_data(str)))
}
var (
defaultDevice Device
defaultDeviceSet bool
defaultStream Stream
defaultStreamSet bool
)
func resetDefaultStreamCache() {
defaultDeviceSet = false
defaultStreamSet = false
}
func DefaultDevice() Device {
if !defaultDeviceSet {
d := C.mlx_device_new()
C.mlx_get_default_device(&d)
defaultDevice = Device{d}
defaultDeviceSet = true
}
return defaultDevice
}
// GPUIsAvailable returns true if a GPU device is available.
func GPUIsAvailable() bool {
dev := C.mlx_device_new_type(C.MLX_GPU, 0)
defer C.mlx_device_free(dev)
var avail C.bool
C.mlx_device_is_available(&avail, dev)
return bool(avail)
}
// SetDefaultDeviceGPU sets the default MLX device to GPU.
func SetDefaultDeviceGPU() {
dev := C.mlx_device_new_type(C.MLX_GPU, 0)
C.mlx_set_default_device(dev)
C.mlx_device_free(dev)
resetDefaultStreamCache()
}
type Stream struct {
ctx C.mlx_stream
}
func (s Stream) LogValue() slog.Value {
str := C.mlx_string_new()
defer C.mlx_string_free(str)
C.mlx_stream_tostring(&str, s.ctx)
return slog.StringValue(C.GoString(C.mlx_string_data(str)))
}
func DefaultStream() Stream {
if !defaultStreamSet {
s := C.mlx_stream_new()
C.mlx_get_default_stream(&s, DefaultDevice().ctx)
defaultStream = Stream{s}
defaultStreamSet = true
}
return defaultStream
}

View File

@@ -0,0 +1,104 @@
package mlx
import (
"context"
"runtime"
"sync"
"testing"
"github.com/ollama/ollama/x/internal/mlxthread"
)
func skipIfNoMLX(t *testing.T) {
t.Helper()
if err := CheckInit(); err != nil {
t.Skipf("MLX not available: %v", err)
}
}
func startMLXThread(t *testing.T) *mlxthread.Thread {
t.Helper()
thread, err := mlxthread.Start("mlx-test", func() error {
if err := CheckInit(); err != nil {
return err
}
if GPUIsAvailable() {
SetDefaultDeviceGPU()
}
return nil
})
if err != nil {
t.Skipf("MLX not available: %v", err)
}
return thread
}
func stopMLXThread(t *testing.T, thread *mlxthread.Thread) {
t.Helper()
if err := thread.Stop(context.Background(), func() {
Sweep()
ClearCache()
resetDefaultStreamCache()
}); err != nil {
t.Fatal(err)
}
}
func withMLXThread(t *testing.T, fn func()) {
t.Helper()
thread := startMLXThread(t)
defer stopMLXThread(t, thread)
if err := thread.Do(context.Background(), func() error {
fn()
return nil
}); err != nil {
t.Fatal(err)
}
}
func TestThreadedMLXOperations(t *testing.T) {
thread := startMLXThread(t)
defer stopMLXThread(t, thread)
oldProcs := runtime.GOMAXPROCS(8)
defer runtime.GOMAXPROCS(oldProcs)
const goroutines = 8
const iterations = 8
var wg sync.WaitGroup
errCh := make(chan error, goroutines)
for range goroutines {
wg.Add(1)
go func() {
defer wg.Done()
for range iterations {
if err := thread.Do(context.Background(), func() error {
a := FromValues([]float32{1, 2, 3, 4}, 2, 2)
b := Matmul(a, a)
AsyncEval(b)
Eval(b)
Sweep()
ClearCache()
return nil
}); err != nil {
errCh <- err
return
}
}
}()
}
wg.Wait()
close(errCh)
for err := range errCh {
t.Fatal(err)
}
}

View File

@@ -0,0 +1,202 @@
package base
import (
"encoding/json"
"fmt"
"log/slog"
"sync"
"github.com/ollama/ollama/x/mlxrunner/batch"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
"github.com/ollama/ollama/x/tokenizer"
)
// Model is the interface that model implementations must satisfy.
type Model interface {
Forward(b *batch.Batch, cache []cache.Cache) *mlx.Array
Unembed(x *mlx.Array) *mlx.Array
NumLayers() int
Tokenizer() *tokenizer.Tokenizer
MaxContextLength() int
// LoadWeights receives all tensors loaded from the manifest and assigns
// them to model fields. Model-specific logic (MLA absorption, expert
// stacking, quantized layer creation) happens here.
LoadWeights(tensors map[string]*mlx.Array) error
}
// DraftModel is an auxiliary model stored alongside a target model.
type DraftModel interface {
LoadWeights(tensors map[string]*mlx.Array) error
}
// MTPDefaults holds model-provided draft-token defaults for speculative
// decoding. Environment settings in the runner may override these values.
type MTPDefaults struct {
InitialDraftTokens int
MaxDraftTokens int
Enabled bool
}
// MTPDefaultsProvider lets a model provide MTP policy defaults from its own
// config without teaching the runner model-specific shape heuristics.
type MTPDefaultsProvider interface {
MTPDraftDefaults(sample bool) MTPDefaults
}
// MTPDraftModel is a draft model capable of Gemma-style multi-token
// prediction from target token embeddings, target hidden states, and target KV.
type MTPDraftModel interface {
Draft(inputEmbeds *mlx.Array, position int32, caches []cache.Cache) (logits, hidden *mlx.Array)
}
// MTPEmbeddingModel exposes the target token embedding path used by MTP drafts.
type MTPEmbeddingModel interface {
TokenEmbeddings(inputIDs *mlx.Array) *mlx.Array
}
// DFlashTargetModel exposes target-layer hidden states for DFlash drafts.
type DFlashTargetModel interface {
ForwardDFlash(b *batch.Batch, caches []cache.Cache, layerIDs []int) (hidden, targetHidden *mlx.Array)
}
// DFlashDraftModel is a block-diffusion speculative draft model.
type DFlashDraftModel interface {
DraftModel
TargetLayerIDs() []int
BlockSize() int
MaskTokenID() int32
NewCaches() []cache.Cache
AppendContext(targetHidden *mlx.Array, caches []cache.Cache)
Draft(inputIDs *mlx.Array, caches []cache.Cache) *mlx.Array
}
var (
mu sync.Mutex
registry = make(map[string]func(root *model.Root) (Model, error))
draftRegistry = make(map[string]func(root *model.Root, target Model) (DraftModel, error))
)
// Register registers a model constructor by architecture name.
// Called from init() in model packages. Panics on duplicate registration.
func Register(arch string, fn func(root *model.Root) (Model, error)) {
mu.Lock()
defer mu.Unlock()
if _, exists := registry[arch]; exists {
panic(fmt.Sprintf("model architecture %q already registered", arch))
}
registry[arch] = fn
}
// RegisterDraft registers a draft model constructor by architecture name.
func RegisterDraft(arch string, fn func(root *model.Root, target Model) (DraftModel, error)) {
mu.Lock()
defer mu.Unlock()
if _, exists := draftRegistry[arch]; exists {
panic(fmt.Sprintf("draft model architecture %q already registered", arch))
}
draftRegistry[arch] = fn
}
// New reads config.json from the manifest, detects the architecture, looks up
// the registered constructor, and calls it to create the model (with config
// parsed and struct created, but weights not yet loaded).
func New(root *model.Root) (Model, error) {
configData, err := root.Manifest.ReadConfig("config.json")
if err != nil {
return nil, fmt.Errorf("failed to read config.json: %w", err)
}
var archConfig struct {
Architectures []string `json:"architectures"`
}
if err := json.Unmarshal(configData, &archConfig); err != nil {
return nil, fmt.Errorf("failed to parse config.json: %w", err)
}
if len(archConfig.Architectures) == 0 {
return nil, fmt.Errorf("no architectures found in config.json")
}
arch := archConfig.Architectures[0]
slog.Info("Model architecture", "arch", arch)
mu.Lock()
fn, ok := registry[arch]
mu.Unlock()
if !ok {
return nil, fmt.Errorf("unsupported architecture: %s", arch)
}
return fn(root)
}
// NewDraft constructs the draft model described by the manifest config, if any.
func NewDraft(root *model.Root, target Model) (DraftModel, error) {
if root == nil || root.Draft == nil {
return nil, nil
}
configPath := root.Draft.Config
if configPath == "" {
configPath = "draft/config.json"
}
configData, err := root.Manifest.ReadConfig(configPath)
if err != nil {
return nil, fmt.Errorf("failed to read %s: %w", configPath, err)
}
var archConfig struct {
Architectures []string `json:"architectures"`
ModelType string `json:"model_type"`
}
if err := json.Unmarshal(configData, &archConfig); err != nil {
return nil, fmt.Errorf("failed to parse %s: %w", configPath, err)
}
arch := root.Draft.Architecture
if arch == "" && len(archConfig.Architectures) > 0 {
arch = archConfig.Architectures[0]
}
if arch == "" {
arch = archConfig.ModelType
}
if arch == "" {
return nil, fmt.Errorf("no draft architecture found in %s", configPath)
}
slog.Info("Draft model architecture", "arch", arch)
mu.Lock()
fn, ok := draftRegistry[arch]
mu.Unlock()
if !ok {
return nil, fmt.Errorf("unsupported draft architecture: %s", arch)
}
return fn(root, target)
}
// Weights returns a function that loads model weights, then pins all
// arrays reachable from the model struct and sweeps everything else.
func Weights(m Model) func(map[string]*mlx.Array) error {
return func(tensors map[string]*mlx.Array) error {
if err := m.LoadWeights(tensors); err != nil {
return err
}
collected := mlx.Collect(m)
for _, arr := range collected {
mlx.Pin(arr)
}
mlx.Sweep()
mlx.Eval(collected...)
return nil
}
}

View File

@@ -0,0 +1,42 @@
package model
import (
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/models/nn"
)
// MakeEmbeddingLayer constructs an embedding layer from a tensor map.
//
// For quantized tensors (path.weight + path.weight_scale), it returns a
// QuantizedEmbedding using the same quant metadata path that linear layers use.
// For non-quantized tensors, it returns a standard dense embedding.
func MakeEmbeddingLayer(
tensors map[string]*mlx.Array,
path string,
defaultGroupSize, defaultBits int,
defaultMode string,
tensorQuant map[string]*TensorQuantInfo,
) nn.EmbeddingLayer {
w := tensors[path+".weight"]
if w == nil {
return nil
}
scales := tensors[path+".weight_scale"]
if scales != nil {
qbiases := tensors[path+".weight_qbias"]
groupSize, bits, mode := ResolveLinearQuantParams(
defaultGroupSize,
defaultBits,
defaultMode,
tensorQuant,
path+".weight",
w,
scales,
)
return nn.NewQuantizedEmbedding(w, scales, qbiases, groupSize, bits, mode)
}
return nn.NewEmbedding(w)
}

View File

@@ -0,0 +1,78 @@
package model
import (
"testing"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/models/nn"
)
func skipIfNoMLX(t *testing.T) {
t.Helper()
if err := mlx.CheckInit(); err != nil {
t.Skipf("MLX not available: %v", err)
}
}
func TestMakeEmbeddingLayerDense(t *testing.T) {
skipIfNoMLX(t)
weight := mlx.FromValues([]float32{
1, 2, 3, 4,
5, 6, 7, 8,
}, 2, 4).AsType(mlx.DTypeBFloat16)
emb := MakeEmbeddingLayer(map[string]*mlx.Array{
"model.embed_tokens.weight": weight,
}, "model.embed_tokens", 0, 0, "", nil)
dense, ok := emb.(*nn.Embedding)
if !ok {
t.Fatalf("embedding type = %T, want *nn.Embedding", emb)
}
if dense.Weight.DType() != mlx.DTypeBFloat16 {
t.Fatalf("embedding dtype = %v, want %v", dense.Weight.DType(), mlx.DTypeBFloat16)
}
if _, ok := emb.AsLinear().(*nn.Linear); !ok {
t.Fatalf("AsLinear type = %T, want *nn.Linear", emb.AsLinear())
}
}
func TestMakeEmbeddingLayerQuantized(t *testing.T) {
skipIfNoMLX(t)
denseWeight := mlx.FromValues(func() []float32 {
out := make([]float32, 2*64)
for i := range out {
out[i] = float32(i%17) / 8
}
return out
}(), 2, 64).AsType(mlx.DTypeBFloat16)
qw, scales, qbiases := mlx.Quantize(denseWeight, 64, 4, "affine")
mlx.Eval(qw, scales, qbiases)
emb := MakeEmbeddingLayer(map[string]*mlx.Array{
"model.embed_tokens.weight": qw,
"model.embed_tokens.weight_scale": scales,
"model.embed_tokens.weight_qbias": qbiases,
}, "model.embed_tokens", 64, 4, "affine", nil)
qemb, ok := emb.(*nn.QuantizedEmbedding)
if !ok {
t.Fatalf("embedding type = %T, want *nn.QuantizedEmbedding", emb)
}
if qemb.GroupSize != 64 || qemb.Bits != 4 || qemb.Mode != "affine" {
t.Fatalf("quant params = (%d, %d, %q), want (64, 4, %q)", qemb.GroupSize, qemb.Bits, qemb.Mode, "affine")
}
indices := mlx.FromValues([]int32{1, 0}, 2)
out := emb.Forward(indices)
mlx.Eval(out)
if dims := out.Dims(); len(dims) != 2 || dims[0] != 2 || dims[1] != 64 {
t.Fatalf("embedding output dims = %v, want [2 64]", dims)
}
if _, ok := emb.AsLinear().(*nn.QuantizedLinear); !ok {
t.Fatalf("AsLinear type = %T, want *nn.QuantizedLinear", emb.AsLinear())
}
}

View File

@@ -0,0 +1,99 @@
package model
import (
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/models/nn"
)
// LinearFactory builds linear layers using shared tensor maps and quant defaults.
type LinearFactory struct {
tensors map[string]*mlx.Array
defaultGroupSize int
defaultBits int
defaultMode string
tensorQuant map[string]*TensorQuantInfo
}
// NewLinearFactory creates a reusable constructor for model linear layers.
func NewLinearFactory(
tensors map[string]*mlx.Array,
defaultGroupSize, defaultBits int,
defaultMode string,
tensorQuant map[string]*TensorQuantInfo,
) LinearFactory {
return LinearFactory{
tensors: tensors,
defaultGroupSize: defaultGroupSize,
defaultBits: defaultBits,
defaultMode: defaultMode,
tensorQuant: tensorQuant,
}
}
// Make constructs a linear layer at path.
func (f LinearFactory) Make(path string) nn.LinearLayer {
return MakeLinearLayer(
f.tensors,
path,
f.defaultGroupSize,
f.defaultBits,
f.defaultMode,
f.tensorQuant,
)
}
// MakeLinearLayer constructs a linear layer from a tensor map.
//
// For quantized tensors (path.weight + path.weight_scale), it resolves per-tensor
// quant params via TensorQuant metadata (with shape-based affine fallback).
// For non-quantized tensors, it returns a standard nn.Linear.
func MakeLinearLayer(
tensors map[string]*mlx.Array,
path string,
defaultGroupSize, defaultBits int,
defaultMode string,
tensorQuant map[string]*TensorQuantInfo,
) nn.LinearLayer {
w := tensors[path+".weight"]
if w == nil {
return nil
}
scales := tensors[path+".weight_scale"]
if scales != nil {
qbiases := tensors[path+".weight_qbias"]
bias := tensors[path+".bias"]
groupSize, bits, mode := ResolveLinearQuantParams(
defaultGroupSize,
defaultBits,
defaultMode,
tensorQuant,
path+".weight",
w,
scales,
)
// Check for per-tensor global scale (NVIDIA double-scale nvfp4).
// NVIDIA ModelOpt stores this as "weight_scale_2"; our import
// pipeline maps it to "weight.global_scale".
globalScale := tensors[path+".weight.global_scale"]
if globalScale == nil {
globalScale = tensors[path+".weight_scale_2"]
}
return &nn.QuantizedLinear{
Weight: w,
Scales: scales,
QBiases: qbiases,
Bias: bias,
GlobalScale: globalScale,
GroupSize: groupSize,
Bits: bits,
Mode: mode,
}
}
bias := tensors[path+".bias"]
return nn.NewLinear(w, bias)
}

132
x/mlxrunner/model/quant.go Normal file
View File

@@ -0,0 +1,132 @@
package model
import (
"strings"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
// QuantizationParams returns default groupSize, bits, and mode for a quantization type.
func QuantizationParams(quantization string) (groupSize, bits int, mode string) {
switch strings.ToUpper(quantization) {
case "NVFP4":
return 16, 4, "nvfp4"
case "MXFP4":
return 32, 4, "mxfp4"
case "FP4", "Q4", "INT4":
return 64, 4, "affine"
case "MXFP8":
return 32, 8, "mxfp8"
case "FP8", "Q8", "INT8":
return 64, 8, "affine"
case "":
return 0, 0, ""
default:
return 32, 8, "affine"
}
}
// TensorQuantParams resolves quant params for a tensor using per-tensor metadata
// when available, otherwise falling back to the provided model defaults.
func TensorQuantParams(
defaultGroupSize, defaultBits int,
defaultMode string,
tensorQuant map[string]*TensorQuantInfo,
tensorName string,
) (groupSize, bits int, mode string, fromTensor bool) {
if tensorQuant != nil {
if tq := tensorQuant[tensorName]; tq != nil {
groupSize, bits, mode = QuantizationParams(tq.QuantType)
if tq.GroupSize > 0 {
groupSize = tq.GroupSize
}
return groupSize, bits, mode, true
}
}
return defaultGroupSize, defaultBits, defaultMode, false
}
// ResolveLinearQuantParams resolves quantization params for a quantized linear
// tensor, preferring per-tensor metadata and falling back to shape-based
// inference for affine packed tensors.
func ResolveLinearQuantParams(
defaultGroupSize, defaultBits int,
defaultMode string,
tensorQuant map[string]*TensorQuantInfo,
tensorName string,
weight, scales *mlx.Array,
) (groupSize, bits int, mode string) {
groupSize, bits, mode, fromTensor := TensorQuantParams(
defaultGroupSize,
defaultBits,
defaultMode,
tensorQuant,
tensorName,
)
if mode == "affine" {
if inferredGroupSize, inferredBits, ok := InferAffineQuantParamsFromShapes(weight, scales, bits); ok {
if !fromTensor || groupSize == 0 || bits == 0 {
groupSize = inferredGroupSize
bits = inferredBits
}
}
}
return groupSize, bits, mode
}
// InferAffineQuantParamsFromShapes infers (groupSize,bits) for affine quantized
// tensors from packed weight and scale shapes.
func InferAffineQuantParamsFromShapes(weight, scales *mlx.Array, hintBits int) (groupSize, bits int, ok bool) {
if weight == nil || scales == nil {
return 0, 0, false
}
weightShape := weight.Dims()
scaleShape := scales.Dims()
if len(weightShape) == 0 || len(scaleShape) == 0 {
return 0, 0, false
}
weightCols := weightShape[len(weightShape)-1]
scalesCols := scaleShape[len(scaleShape)-1]
if weightCols <= 0 || scalesCols <= 0 {
return 0, 0, false
}
groupSize4 := weightCols * 8 / scalesCols
groupSize8 := weightCols * 4 / scalesCols
switch {
case groupSize4 == 32:
return 32, 4, true
case groupSize8 == 64:
return 64, 8, true
case groupSize4 == 64 && groupSize8 == 32:
if hintBits == 8 {
return 32, 8, true
}
if hintBits == 4 {
return 64, 4, true
}
}
if isCommonGroupSize(groupSize4) && !isCommonGroupSize(groupSize8) {
return groupSize4, 4, true
}
if isCommonGroupSize(groupSize8) && !isCommonGroupSize(groupSize4) {
return groupSize8, 8, true
}
return 0, 0, false
}
func isCommonGroupSize(v int) bool {
switch v {
case 16, 32, 64, 128:
return true
default:
return false
}
}

299
x/mlxrunner/model/root.go Normal file
View File

@@ -0,0 +1,299 @@
package model
import (
"encoding/binary"
"encoding/json"
"fmt"
"io"
"os"
"sort"
"strconv"
"strings"
modeltypes "github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/imagegen/manifest"
)
// TensorQuantInfo describes per-tensor quantization metadata.
type TensorQuantInfo struct {
QuantType string
GroupSize int
}
// Root wraps a ModelManifest with pre-scanned quantization metadata.
type Root struct {
Manifest *manifest.ModelManifest
Draft *modeltypes.Draft
// Backwards-compatible model-level quant metadata (first tensor blob).
quantType string
groupSize int
// Per-tensor quantization metadata.
tensorQuant map[string]*TensorQuantInfo
}
// Open loads a manifest for the given model name and scans tensor blobs for
// quantization metadata.
func Open(modelName string) (*Root, error) {
m, err := manifest.LoadManifest(modelName)
if err != nil {
return nil, err
}
root := &Root{
Manifest: m,
tensorQuant: make(map[string]*TensorQuantInfo),
}
root.Draft = readDraftConfig(m)
for _, layer := range m.GetTensorLayers("") {
blobPath := m.BlobPath(layer.Digest)
infos, blobQuantType, blobGroupSize, err := readBlobTensorQuantInfo(blobPath)
if err != nil {
continue
}
for name, info := range infos {
root.tensorQuant[name] = info
}
if root.quantType == "" && blobQuantType != "" {
root.quantType = strings.ToUpper(blobQuantType)
root.groupSize = blobGroupSize
if root.groupSize == 0 {
root.groupSize = defaultGroupSize(root.quantType)
}
}
}
return root, nil
}
func readDraftConfig(m *manifest.ModelManifest) *modeltypes.Draft {
if m == nil || m.Manifest == nil || m.Manifest.Config.Digest == "" {
return nil
}
data, err := os.ReadFile(m.BlobPath(m.Manifest.Config.Digest))
if err != nil {
return nil
}
var cfg modeltypes.ConfigV2
if err := json.Unmarshal(data, &cfg); err != nil {
return nil
}
if cfg.Draft != nil {
return cfg.Draft
}
if m.GetConfigLayer("draft/config.json") != nil {
return &modeltypes.Draft{
ModelFormat: "safetensors",
TensorPrefix: "draft.",
Config: "draft/config.json",
}
}
return nil
}
// Close is a no-op for now (future: release resources).
func (r *Root) Close() {}
// QuantType returns the quantization type detected from the first tensor blob metadata.
func (r *Root) QuantType() string { return r.quantType }
// GroupSize returns the quantization group size detected from the first tensor blob metadata.
func (r *Root) GroupSize() int { return r.groupSize }
// TensorQuant returns per-tensor quantization metadata if available.
func (r *Root) TensorQuant(name string) *TensorQuantInfo {
if r == nil {
return nil
}
return r.tensorQuant[name]
}
// AllTensorQuant returns a copy of the per-tensor quantization metadata.
func (r *Root) AllTensorQuant() map[string]*TensorQuantInfo {
out := make(map[string]*TensorQuantInfo, len(r.tensorQuant))
for k, v := range r.tensorQuant {
if v == nil {
continue
}
copy := *v
out[k] = &copy
}
return out
}
func defaultGroupSize(quantType string) int {
groupSize, _, _ := QuantizationParams(quantType)
return groupSize
}
func readBlobTensorQuantInfo(path string) (map[string]*TensorQuantInfo, string, int, error) {
f, err := os.Open(path)
if err != nil {
return nil, "", 0, err
}
defer f.Close()
var headerSize uint64
if err := binary.Read(f, binary.LittleEndian, &headerSize); err != nil {
return nil, "", 0, err
}
if headerSize > 100*1024*1024 {
return nil, "", 0, fmt.Errorf("header too large: %d", headerSize)
}
data := make([]byte, headerSize)
if _, err := io.ReadFull(f, data); err != nil {
return nil, "", 0, err
}
var header map[string]json.RawMessage
if err := json.Unmarshal(data, &header); err != nil {
return nil, "", 0, err
}
globalQuantType, globalGroupSize := parseGlobalQuantMetadata(header)
globalQuantType = strings.ToUpper(globalQuantType)
// Parse full metadata for per-tensor quant info
var metaMap map[string]string
if metaRaw, ok := header["__metadata__"]; ok {
json.Unmarshal(metaRaw, &metaMap)
}
mainNames := mainTensorNames(header)
infos := make(map[string]*TensorQuantInfo)
for _, name := range mainNames {
if _, ok := header[name+".scale"]; !ok {
continue
}
quantType := globalQuantType
groupSize := globalGroupSize
// Check per-tensor metadata (e.g. from packed expert blobs with mixed precision)
if metaMap != nil {
if qt, ok := metaMap[name+".quant_type"]; ok && qt != "" {
quantType = strings.ToUpper(qt)
}
if gs, ok := metaMap[name+".group_size"]; ok && gs != "" {
if v, err := strconv.Atoi(gs); err == nil {
groupSize = v
}
}
}
inferredType, inferredGroup := inferQuantTypeFromShapes(header, name, quantType)
if quantType == "" {
quantType = inferredType
}
if groupSize == 0 {
groupSize = inferredGroup
}
if quantType == "" {
continue
}
if groupSize == 0 {
groupSize = defaultGroupSize(quantType)
}
infos[name] = &TensorQuantInfo{QuantType: quantType, GroupSize: groupSize}
}
return infos, globalQuantType, globalGroupSize, nil
}
func parseGlobalQuantMetadata(header map[string]json.RawMessage) (quantType string, groupSize int) {
metaRaw, ok := header["__metadata__"]
if !ok {
return "", 0
}
var meta map[string]string
if err := json.Unmarshal(metaRaw, &meta); err != nil {
return "", 0
}
quantType = meta["quant_type"]
if gs := meta["group_size"]; gs != "" {
groupSize, _ = strconv.Atoi(gs)
}
return quantType, groupSize
}
func mainTensorNames(header map[string]json.RawMessage) []string {
names := make([]string, 0, len(header))
for name := range header {
if name == "__metadata__" || strings.HasSuffix(name, ".scale") || strings.HasSuffix(name, ".bias") {
continue
}
names = append(names, name)
}
sort.Strings(names)
return names
}
func inferQuantTypeFromShapes(header map[string]json.RawMessage, tensorName string, hintQuantType string) (string, int) {
type tensorShape struct {
Shape []int64 `json:"shape"`
}
mainRaw, ok := header[tensorName]
if !ok {
return "", 0
}
scaleRaw, ok := header[tensorName+".scale"]
if !ok {
return "", 0
}
var mainInfo tensorShape
if err := json.Unmarshal(mainRaw, &mainInfo); err != nil || len(mainInfo.Shape) == 0 {
return "", 0
}
var scaleInfo tensorShape
if err := json.Unmarshal(scaleRaw, &scaleInfo); err != nil || len(scaleInfo.Shape) == 0 {
return "", 0
}
weightCols := int(mainInfo.Shape[len(mainInfo.Shape)-1])
scalesCols := int(scaleInfo.Shape[len(scaleInfo.Shape)-1])
if weightCols <= 0 || scalesCols <= 0 {
return "", 0
}
groupSize4 := weightCols * 8 / scalesCols
groupSize8 := weightCols * 4 / scalesCols
switch {
case groupSize4 == 32:
return "INT4", 32
case groupSize8 == 64:
return "INT8", 64
case groupSize4 == 64 && groupSize8 == 32:
h := strings.ToUpper(hintQuantType)
if strings.Contains(h, "8") {
return "INT8", 32
}
if strings.Contains(h, "4") {
return "INT4", 64
}
}
if isCommonGroupSize(groupSize4) && !isCommonGroupSize(groupSize8) {
return "INT4", groupSize4
}
if isCommonGroupSize(groupSize8) && !isCommonGroupSize(groupSize4) {
return "INT8", groupSize8
}
return "", 0
}

952
x/mlxrunner/mtp.go Normal file
View File

@@ -0,0 +1,952 @@
package mlxrunner
import (
"context"
"fmt"
"log/slog"
"os"
"strconv"
"strings"
"time"
"github.com/ollama/ollama/x/mlxrunner/batch"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model/base"
sampler "github.com/ollama/ollama/x/mlxrunner/sample"
)
const (
mtpDefaultInitialDraftTokens = 4
mtpDefaultMaxDraftTokens = 16
)
type mtpDraftSchedule string
const (
mtpDraftScheduleHeuristic mtpDraftSchedule = "heuristic"
mtpDraftScheduleConstant mtpDraftSchedule = "constant"
)
type mtpStats struct {
iterations int
drafted int
accepted int
mismatches int
allAccepted int
batched int
serial int
compared int
batchSerialMismatches int
maxDraft int
targetDuration time.Duration
draftDuration time.Duration
validateDuration time.Duration
}
type mtpOptions struct {
initialDraftTokens int
maxDraftTokens int
draftSchedule mtpDraftSchedule
serialValidate bool
compareSerialValidate bool
}
func (r *Runner) mtpDefaults(sample bool) base.MTPDefaults {
defaults := base.MTPDefaults{
InitialDraftTokens: mtpDefaultInitialDraftTokens,
MaxDraftTokens: mtpDefaultMaxDraftTokens,
Enabled: true,
}
if p, ok := r.Model.(base.MTPDefaultsProvider); ok {
defaults = p.MTPDraftDefaults(sample)
}
if defaults.InitialDraftTokens <= 0 {
defaults.InitialDraftTokens = mtpDefaultInitialDraftTokens
}
if defaults.MaxDraftTokens <= 0 {
defaults.MaxDraftTokens = mtpDefaultMaxDraftTokens
}
return defaults
}
func (r *Runner) loadMTPOptions(sample bool) mtpOptions {
defaults := r.mtpDefaults(sample)
opts := mtpOptions{
initialDraftTokens: defaults.InitialDraftTokens,
maxDraftTokens: defaults.MaxDraftTokens,
draftSchedule: mtpDraftScheduleConstant,
}
if v := positiveEnvInt("OLLAMA_MLX_MTP_MAX_DRAFT_TOKENS"); v > 0 {
opts.maxDraftTokens = v
}
if v := positiveEnvInt("OLLAMA_MLX_MTP_INITIAL_DRAFT_TOKENS"); v > 0 {
opts.initialDraftTokens = v
}
if opts.initialDraftTokens > opts.maxDraftTokens {
opts.initialDraftTokens = opts.maxDraftTokens
}
if b, err := strconv.ParseBool(os.Getenv("OLLAMA_MLX_MTP_SERIAL_VALIDATE")); err == nil {
opts.serialValidate = b
}
if b, err := strconv.ParseBool(os.Getenv("OLLAMA_MLX_MTP_COMPARE_SERIAL_VALIDATE")); err == nil {
opts.compareSerialValidate = b
}
switch schedule := strings.ToLower(strings.TrimSpace(os.Getenv("OLLAMA_MLX_MTP_DRAFT_SCHEDULE"))); schedule {
case "", string(mtpDraftScheduleConstant):
opts.draftSchedule = mtpDraftScheduleConstant
case string(mtpDraftScheduleHeuristic):
opts.draftSchedule = mtpDraftScheduleHeuristic
default:
slog.Warn("invalid MTP env setting", "key", "OLLAMA_MLX_MTP_DRAFT_SCHEDULE", "value", schedule)
}
return opts
}
func positiveEnvInt(key string) int {
raw := os.Getenv(key)
if raw == "" {
return 0
}
v, err := strconv.Atoi(raw)
if err != nil || v <= 0 {
slog.Warn("invalid MTP env setting", "key", key, "value", raw)
return 0
}
return v
}
func (r *Runner) useGreedyMTP(opts sampler.Options) bool {
if r.Draft == nil {
return false
}
if _, ok := r.Draft.(base.MTPDraftModel); !ok {
return false
}
if _, ok := r.Model.(base.MTPEmbeddingModel); !ok {
return false
}
if !r.mtpDefaults(false).Enabled {
return false
}
if opts.Logprobs || opts.TopLogprobs > 0 {
return false
}
if opts.Temperature != 0 {
return false
}
repeatPenaltyNeutral := opts.RepeatPenalty <= 0 || opts.RepeatPenalty == 1
topPNeutral := opts.TopP <= 0 || opts.TopP >= 1
topKNeutral := opts.TopK <= 0
return repeatPenaltyNeutral && opts.PresencePenalty == 0 && opts.FrequencyPenalty == 0 && topPNeutral && topKNeutral && opts.MinP == 0
}
func (r *Runner) useSampleMTP(opts sampler.Options) bool {
if serial, err := strconv.ParseBool(os.Getenv("OLLAMA_MLX_MTP_SERIAL_VALIDATE")); err == nil && serial {
return false
}
if compare, err := strconv.ParseBool(os.Getenv("OLLAMA_MLX_MTP_COMPARE_SERIAL_VALIDATE")); err == nil && compare {
return false
}
if r.Draft == nil {
return false
}
if _, ok := r.Draft.(base.MTPDraftModel); !ok {
return false
}
if _, ok := r.Model.(base.MTPEmbeddingModel); !ok {
return false
}
if !r.mtpDefaults(true).Enabled {
return false
}
if opts.Logprobs || opts.TopLogprobs > 0 {
return false
}
return opts.Temperature != 0
}
func (r *Runner) runGreedyMTPDecode(ctx context.Context, request Request, session *cacheSession, caches []cache.Cache, seed []int32, position *int, started time.Time) error {
targetEmbeddings := r.Model.(base.MTPEmbeddingModel)
draft := r.Draft.(base.MTPDraftModel)
mtpOpts := r.loadMTPOptions(false)
stats := mtpStats{maxDraft: mtpOpts.initialDraftTokens}
draftLimit := mtpOpts.initialDraftTokens
slog.Info("MTP greedy decode enabled", "initial_draft_tokens", mtpOpts.initialDraftTokens, "max_draft_tokens", mtpOpts.maxDraftTokens, "draft_schedule", mtpOpts.draftSchedule, "serial_validate", mtpOpts.serialValidate, "compare_serial_validate", mtpOpts.compareSerialValidate)
targetForward := func(token *mlx.Array) *mlx.Array {
fwd := r.Model.Forward(&batch.Batch{
InputIDs: token,
SeqOffsets: []int32{int32(*position)},
SeqQueryLens: []int32{int32(token.Dim(1))},
}, caches)
*position += token.Dim(1)
return fwd
}
hidden := targetForward(mlx.FromValues(seed, 1, len(seed)))
current := sampler.Result{Token: greedyTokenFromLogits(r.lastLogits(hidden))}
mlx.Pin(current.Arrays()...)
mlx.Sweep()
mlx.AsyncEval(current.Arrays()...)
defer func() {
mlx.Unpin(current.Arrays()...)
}()
dec := decoder{tokenizer: r.Tokenizer}
final := CompletionResponse{Done: true, PromptEvalCount: len(request.Tokens), DoneReason: 1}
now := started
generated := 0
for generated < request.Options.NumPredict {
if err := ctx.Err(); err != nil {
return err
}
t0 := time.Now()
hidden = targetForward(current.Token.ExpandDims(-1))
baseLogits := r.lastLogits(hidden)
stats.targetDuration += time.Since(t0)
if generated == 0 {
mlx.Eval(current.Arrays()...)
final.PromptEvalDuration = time.Since(now)
now = time.Now()
}
done, err := r.emitMTPToken(ctx, request, session, &dec, current, &final)
if err != nil {
return err
}
if !done {
generated++
}
if done || generated >= request.Options.NumPredict {
break
}
stats.iterations++
maxDraft := min(draftLimit, request.Options.NumPredict-generated)
t0 = time.Now()
draftTokens := r.generateMTPDrafts(draft, targetEmbeddings, current.Token, hidden, caches, int32(*position-1), maxDraft)
draftCount := 0
if draftTokens != nil {
draftCount = draftTokens.Dim(1)
mlx.Pin(baseLogits, draftTokens)
mlx.Eval(draftTokens)
mlx.Sweep()
}
stats.draftDuration += time.Since(t0)
stats.drafted += draftCount
var next sampler.Result
if draftCount == 0 {
next = sampler.Result{Token: greedyTokenFromLogits(baseLogits)}
} else {
var accepted int
t0 = time.Now()
next, accepted, done, err = r.acceptMTPDrafts(ctx, request, session, &dec, caches, position, baseLogits, draftTokens, &final, &generated, &stats, mtpOpts)
stats.validateDuration += time.Since(t0)
mlx.Unpin(baseLogits, draftTokens)
if err != nil {
return err
}
stats.accepted += accepted
switch {
case mtpOpts.draftSchedule == mtpDraftScheduleConstant:
case accepted == draftCount:
stats.allAccepted++
draftLimit = min(mtpOpts.maxDraftTokens, draftLimit+2)
default:
stats.mismatches++
draftLimit = max(1, draftLimit-1)
}
if mtpOpts.draftSchedule == mtpDraftScheduleConstant {
if accepted == draftCount {
stats.allAccepted++
} else {
stats.mismatches++
}
}
stats.maxDraft = max(stats.maxDraft, draftLimit)
if next.Token == nil {
mlx.Sweep()
}
if done || generated >= request.Options.NumPredict {
break
}
}
mlx.Pin(next.Arrays()...)
old := current
current = next
mlx.Unpin(old.Arrays()...)
mlx.Sweep()
mlx.AsyncEval(current.Arrays()...)
if generated%256 == 0 {
mlx.ClearCache()
}
}
final.EvalCount = generated
final.EvalDuration = time.Since(now)
acceptance := 0.0
if stats.drafted > 0 {
acceptance = float64(stats.accepted) / float64(stats.drafted)
}
avgDraft := 0.0
avgAccepted := 0.0
if stats.iterations > 0 {
avgDraft = float64(stats.drafted) / float64(stats.iterations)
avgAccepted = float64(stats.accepted) / float64(stats.iterations)
}
slog.Info("MTP decode stats", "generated", generated, "drafted", stats.drafted, "accepted", stats.accepted, "acceptance", acceptance, "iterations", stats.iterations, "avg_draft", avgDraft, "avg_accepted", avgAccepted, "batched", stats.batched, "serial", stats.serial, "compared", stats.compared, "batch_serial_mismatches", stats.batchSerialMismatches, "mismatches", stats.mismatches, "all_accepted", stats.allAccepted, "max_draft", stats.maxDraft, "draft_schedule", mtpOpts.draftSchedule, "target_duration", stats.targetDuration, "draft_duration", stats.draftDuration, "validate_duration", stats.validateDuration)
select {
case <-ctx.Done():
return ctx.Err()
case request.Responses <- final:
return nil
}
}
func (r *Runner) runSampleMTPDecode(ctx context.Context, request Request, session *cacheSession, caches []cache.Cache, seed []int32, position *int, started time.Time) error {
targetEmbeddings := r.Model.(base.MTPEmbeddingModel)
draft := r.Draft.(base.MTPDraftModel)
mtpOpts := r.loadMTPOptions(true)
stats := mtpStats{maxDraft: mtpOpts.initialDraftTokens}
draftLimit := mtpOpts.initialDraftTokens
slog.Info("MTP sample decode enabled", "initial_draft_tokens", mtpOpts.initialDraftTokens, "max_draft_tokens", mtpOpts.maxDraftTokens, "draft_schedule", mtpOpts.draftSchedule, "serial_validate", mtpOpts.serialValidate)
targetForward := func(token *mlx.Array) *mlx.Array {
fwd := r.Model.Forward(&batch.Batch{
InputIDs: token,
SeqOffsets: []int32{int32(*position)},
SeqQueryLens: []int32{int32(token.Dim(1))},
}, caches)
*position += token.Dim(1)
return fwd
}
hidden := targetForward(mlx.FromValues(seed, 1, len(seed)))
current := r.Sampler.Sample([]int{pipelineSlot}, r.lastLogits(hidden))
mlx.Pin(current.Arrays()...)
mlx.Sweep()
mlx.AsyncEval(current.Arrays()...)
defer func() {
mlx.Unpin(current.Arrays()...)
}()
dec := decoder{tokenizer: r.Tokenizer}
final := CompletionResponse{Done: true, PromptEvalCount: len(request.Tokens), DoneReason: 1}
now := started
generated := 0
for generated < request.Options.NumPredict {
if err := ctx.Err(); err != nil {
return err
}
t0 := time.Now()
hidden = targetForward(mtpTokenInput(current.Token))
baseLogits := r.lastLogits(hidden)
stats.targetDuration += time.Since(t0)
if generated == 0 {
mlx.Eval(current.Arrays()...)
final.PromptEvalDuration = time.Since(now)
now = time.Now()
}
done, err := r.emitMTPToken(ctx, request, session, &dec, current, &final)
if err != nil {
return err
}
if !done {
generated++
}
if done || generated >= request.Options.NumPredict {
break
}
stats.iterations++
maxDraft := min(draftLimit, request.Options.NumPredict-generated)
t0 = time.Now()
candidates := r.generateMTPDraftCandidates(draft, targetEmbeddings, current.Token, hidden, caches, int32(*position-1), maxDraft)
draftCount := 0
var candidateArrays []*mlx.Array
if candidates != nil {
draftCount = candidates.tokens.Dim(1)
candidateArrays = append([]*mlx.Array{baseLogits}, candidates.Arrays()...)
mlx.Pin(candidateArrays...)
mlx.Sweep()
}
stats.draftDuration += time.Since(t0)
stats.drafted += draftCount
var next sampler.Result
if draftCount == 0 {
next = r.Sampler.Sample([]int{pipelineSlot}, baseLogits)
} else {
var accepted int
t0 = time.Now()
next, accepted, done, err = r.acceptSampleMTPDrafts(ctx, request, session, &dec, caches, position, baseLogits, candidates, &final, &generated, &stats)
stats.validateDuration += time.Since(t0)
mlx.Unpin(candidateArrays...)
if err != nil {
return err
}
stats.accepted += accepted
switch {
case mtpOpts.draftSchedule == mtpDraftScheduleConstant:
case accepted == draftCount:
stats.allAccepted++
draftLimit = min(mtpOpts.maxDraftTokens, draftLimit+2)
default:
stats.mismatches++
draftLimit = max(1, draftLimit-1)
}
if mtpOpts.draftSchedule == mtpDraftScheduleConstant {
if accepted == draftCount {
stats.allAccepted++
} else {
stats.mismatches++
}
}
stats.maxDraft = max(stats.maxDraft, draftLimit)
if next.Token == nil {
mlx.Sweep()
}
if done || generated >= request.Options.NumPredict {
break
}
}
mlx.Pin(next.Arrays()...)
old := current
current = next
mlx.Unpin(old.Arrays()...)
mlx.Sweep()
mlx.AsyncEval(current.Arrays()...)
if generated%256 == 0 {
mlx.ClearCache()
}
}
final.EvalCount = generated
final.EvalDuration = time.Since(now)
acceptance := 0.0
if stats.drafted > 0 {
acceptance = float64(stats.accepted) / float64(stats.drafted)
}
avgDraft := 0.0
avgAccepted := 0.0
if stats.iterations > 0 {
avgDraft = float64(stats.drafted) / float64(stats.iterations)
avgAccepted = float64(stats.accepted) / float64(stats.iterations)
}
slog.Info("MTP decode stats", "mode", "sample", "generated", generated, "drafted", stats.drafted, "accepted", stats.accepted, "acceptance", acceptance, "iterations", stats.iterations, "avg_draft", avgDraft, "avg_accepted", avgAccepted, "batched", stats.batched, "serial", stats.serial, "mismatches", stats.mismatches, "all_accepted", stats.allAccepted, "max_draft", stats.maxDraft, "draft_schedule", mtpOpts.draftSchedule, "target_duration", stats.targetDuration, "draft_duration", stats.draftDuration, "validate_duration", stats.validateDuration)
select {
case <-ctx.Done():
return ctx.Err()
case request.Responses <- final:
return nil
}
}
type mtpDraftCandidates struct {
tokens *mlx.Array
// dist is the proposal distribution used to sample each drafted token.
dist sampler.Distribution
}
func (c *mtpDraftCandidates) Arrays() []*mlx.Array {
if c == nil {
return nil
}
return append([]*mlx.Array{c.tokens}, c.dist.Arrays()...)
}
func (r *Runner) generateMTPDrafts(draft base.MTPDraftModel, target base.MTPEmbeddingModel, token *mlx.Array, hidden *mlx.Array, caches []cache.Cache, position int32, maxDraft int) *mlx.Array {
if maxDraft <= 0 {
return nil
}
lastToken := token.ExpandDims(-1)
lastHidden := hidden
draftTokens := make([]*mlx.Array, 0, maxDraft)
// Gemma4 assistant MTP is trained as "single-position" drafting:
// keep the RoPE/cache position anchored at the last target-seen token
// while the proposed token and projected hidden state advance.
for range maxDraft {
tokenEmbedding := target.TokenEmbeddings(lastToken)
inputs := tokenEmbedding.Concatenate(-1, lastHidden)
logits, projected := draft.Draft(inputs, position, caches)
stepLogits := r.lastLogitsFromLogits(logits)
nextToken := greedyTokenFromLogits(stepLogits)
lastToken = nextToken.ExpandDims(-1)
lastHidden = projected
draftTokens = append(draftTokens, lastToken)
}
if len(draftTokens) == 0 {
return nil
}
return mlx.Concatenate(draftTokens, 1)
}
func (r *Runner) generateMTPDraftCandidates(draft base.MTPDraftModel, target base.MTPEmbeddingModel, token *mlx.Array, hidden *mlx.Array, caches []cache.Cache, position int32, maxDraft int) *mtpDraftCandidates {
if maxDraft <= 0 {
return nil
}
lastToken := mtpTokenInput(token)
lastHidden := hidden
draftTokens := make([]*mlx.Array, 0, maxDraft)
draftDists := make([]sampler.Distribution, 0, maxDraft)
var prefix *mlx.Array
// Gemma4 assistant MTP is trained as "single-position" drafting:
// keep the RoPE/cache position anchored at the last target-seen token
// while the proposed token and projected hidden state advance.
for range maxDraft {
tokenEmbedding := target.TokenEmbeddings(lastToken)
inputs := tokenEmbedding.Concatenate(-1, lastHidden)
logits, projected := draft.Draft(inputs, position, caches)
stepLogits := r.lastLogitsFromLogits(logits)
dist := r.Sampler.Distribution(pipelineSlot, stepLogits, prefix)
nextToken := r.Sampler.SampleDistribution(pipelineSlot, dist)
lastToken = mtpTokenInput(nextToken)
lastHidden = projected
draftTokens = append(draftTokens, lastToken)
draftDists = append(draftDists, dist)
if prefix == nil {
prefix = lastToken
} else {
prefix = prefix.Concatenate(1, lastToken)
}
}
if len(draftTokens) == 0 {
return nil
}
return &mtpDraftCandidates{
tokens: mlx.Concatenate(draftTokens, 1),
dist: sampler.ConcatenateDistributions(draftDists),
}
}
func (r *Runner) acceptMTPDrafts(ctx context.Context, request Request, session *cacheSession, dec *decoder, caches []cache.Cache, position *int, baseLogits *mlx.Array, draftTokens *mlx.Array, final *CompletionResponse, generated *int, stats *mtpStats, opts mtpOptions) (sampler.Result, int, bool, error) {
if opts.serialValidate {
stats.serial++
return r.acceptMTPDraftsSerial(ctx, request, session, dec, caches, position, baseLogits, draftTokens, final, generated)
}
specCaches, spec, ok := cache.BeginSpeculation(caches)
if ok {
stats.batched++
return r.acceptMTPDraftsBatched(ctx, request, session, dec, caches, specCaches, spec, position, baseLogits, draftTokens, final, generated, stats, opts)
}
stats.serial++
return r.acceptMTPDraftsSerial(ctx, request, session, dec, caches, position, baseLogits, draftTokens, final, generated)
}
func (r *Runner) acceptMTPDraftsBatched(ctx context.Context, request Request, session *cacheSession, dec *decoder, liveCaches []cache.Cache, caches []cache.Cache, spec *cache.Speculation, position *int, baseLogits *mlx.Array, draftTokens *mlx.Array, final *CompletionResponse, generated *int, stats *mtpStats, opts mtpOptions) (sampler.Result, int, bool, error) {
before := *position
draftCount := draftTokens.Dim(1)
hiddenSeq := r.Model.Forward(&batch.Batch{
InputIDs: draftTokens,
SeqOffsets: []int32{int32(before)},
SeqQueryLens: []int32{int32(draftCount)},
}, caches)
accepted := 0
var next sampler.Result
done := false
selectedTokens := r.mtpValidationTokens(baseLogits, hiddenSeq)
mlx.Eval(draftTokens, selectedTokens)
draftIDs := draftTokens.Ints()
selectedIDs := selectedTokens.Ints()
if len(selectedIDs) < draftCount+1 {
return sampler.Result{}, accepted, false, fmt.Errorf("mtp validation produced %d tokens for %d draft tokens", len(selectedIDs), draftCount)
}
for i, id := range draftIDs {
if selectedIDs[i] != id {
next = sampler.Result{Token: mtpTokenAt(selectedTokens, i)}
break
}
accepted++
if r.Tokenizer.IsEOS(int32(id)) {
done = true
break
}
}
if opts.compareSerialValidate {
spec.Commit(0)
r.compareMTPBatchedWithSerial(ctx, liveCaches, before, baseLogits, hiddenSeq, draftIDs, selectedIDs, accepted, draftCount, stats)
}
spec.Commit(accepted)
*position = before + accepted
for _, id := range draftIDs[:accepted] {
if *generated >= request.Options.NumPredict {
done = true
break
}
res := sampler.Result{Token: mlx.FromValues([]int32{int32(id)}, 1)}
var err error
done, err = r.emitMTPToken(ctx, request, session, dec, res, final)
if err != nil {
return sampler.Result{}, accepted, done, err
}
if !done {
(*generated)++
}
if done {
break
}
}
if done || *generated >= request.Options.NumPredict {
return sampler.Result{}, accepted, true, nil
}
if next.Token == nil {
next = sampler.Result{Token: mtpTokenAt(selectedTokens, draftCount)}
}
return next, accepted, false, nil
}
func (r *Runner) acceptSampleMTPDrafts(ctx context.Context, request Request, session *cacheSession, dec *decoder, caches []cache.Cache, position *int, baseLogits *mlx.Array, candidates *mtpDraftCandidates, final *CompletionResponse, generated *int, stats *mtpStats) (sampler.Result, int, bool, error) {
specCaches, spec, ok := cache.BeginSpeculation(caches)
if !ok {
stats.serial++
return r.Sampler.Sample([]int{pipelineSlot}, baseLogits), 0, false, nil
}
stats.batched++
before := *position
draftCount := candidates.tokens.Dim(1)
hiddenSeq := r.Model.Forward(&batch.Batch{
InputIDs: candidates.tokens,
SeqOffsets: []int32{int32(before)},
SeqQueryLens: []int32{int32(draftCount)},
}, specCaches)
targetDist := r.Sampler.Distribution(pipelineSlot, r.mtpValidationLogits(baseLogits, hiddenSeq), candidates.tokens)
draftDist := candidates.dist
acceptedMask := r.mtpSampleAcceptedMask(targetDist.SliceRows(0, draftCount), draftDist, candidates.tokens)
mlx.Eval(candidates.tokens, acceptedMask)
draftIDs := candidates.tokens.Ints()
acceptedFlags := acceptedMask.Ints()
accepted := 0
for _, ok := range acceptedFlags {
if ok == 0 {
break
}
accepted++
}
if accepted > draftCount {
return sampler.Result{}, 0, false, fmt.Errorf("mtp sample validation accepted %d tokens for %d draft tokens", accepted, draftCount)
}
commitIDs := make([]int32, 0, accepted+1)
done := false
for i, id := range draftIDs[:accepted] {
commitIDs = append(commitIDs, int32(id))
if r.Tokenizer.IsEOS(int32(id)) {
done = true
accepted = i + 1
commitIDs = commitIDs[:accepted]
break
}
}
spec.Commit(accepted)
*position = before + accepted
for _, id := range draftIDs[:accepted] {
if *generated >= request.Options.NumPredict {
done = true
break
}
res := sampler.Result{Token: mlx.FromValues([]int32{int32(id)}, 1)}
var err error
done, err = r.emitMTPToken(ctx, request, session, dec, res, final)
if err != nil {
return sampler.Result{}, accepted, done, err
}
if !done {
(*generated)++
}
if done {
break
}
}
if done || *generated >= request.Options.NumPredict {
r.Sampler.Commit(pipelineSlot, commitIDs)
return sampler.Result{}, accepted, true, nil
}
var nextToken *mlx.Array
if accepted == draftCount {
nextToken = r.mtpSampleTokenAt(targetDist, draftCount)
} else {
nextToken = r.mtpSampleResidualToken(targetDist, draftDist, accepted)
}
mlx.Eval(nextToken)
nextID := int32(tokenID(nextToken))
commitIDs = append(commitIDs, nextID)
r.Sampler.Commit(pipelineSlot, commitIDs)
return sampler.Result{Token: nextToken}, accepted, false, nil
}
func (r *Runner) mtpSampleAcceptedMask(targetDist, draftDist sampler.Distribution, draftTokens *mlx.Array) *mlx.Array {
p := targetDist.Prob(draftTokens)
q := draftDist.Prob(draftTokens)
acceptP := mlx.Minimum(p.Divide(q), mlx.FromValue(float32(1)))
return r.Sampler.Bernoulli(pipelineSlot, acceptP).AsType(mlx.DTypeInt32)
}
func (r *Runner) mtpSampleTokenAt(dist sampler.Distribution, index int) *mlx.Array {
return mtpTokenVector(r.Sampler.SampleDistribution(pipelineSlot, dist.SliceRows(index, index+1)))
}
func (r *Runner) mtpSampleResidualToken(targetDist, draftDist sampler.Distribution, index int) *mlx.Array {
residual := targetDist.SliceRows(index, index+1).ResidualAgainst(draftDist.SliceRows(index, index+1))
return mtpTokenVector(r.Sampler.SampleDistribution(pipelineSlot, residual))
}
func mtpTokenInput(token *mlx.Array) *mlx.Array {
switch token.NumDims() {
case 0:
return token.Reshape(1, 1)
case 1:
return token.ExpandDims(-1)
case 2:
return token
default:
panic(fmt.Sprintf("mtp token must be rank 0, 1, or 2, got rank %d", token.NumDims()))
}
}
func mtpTokenVector(token *mlx.Array) *mlx.Array {
switch token.NumDims() {
case 0:
return token.Reshape(1)
case 1:
return token
default:
panic(fmt.Sprintf("mtp sampled token must be rank 0 or 1, got rank %d", token.NumDims()))
}
}
func (r *Runner) compareMTPBatchedWithSerial(ctx context.Context, caches []cache.Cache, before int, baseLogits, hiddenSeq *mlx.Array, draftIDs, selectedIDs []int, accepted, draftCount int, stats *mtpStats) {
serialCaches, ok := cache.BeginIsolatedSpeculation(caches)
if !ok {
return
}
compareCount := accepted + 1
if accepted == draftCount {
// Include the target bonus token when every draft was accepted.
compareCount = draftCount + 1
}
serialLogits := baseLogits
for i := range compareCount {
if err := ctx.Err(); err != nil {
return
}
if i >= len(selectedIDs) {
return
}
batchedLogits := baseLogits
if i > 0 {
batchedLogits = r.targetLogitsAt(hiddenSeq, i-1)
}
batchedToken := greedyTokenFromLogits(batchedLogits)
serialToken := greedyTokenFromLogits(serialLogits)
mlx.Eval(batchedToken, serialToken)
batchedID := tokenID(batchedToken)
vectorizedID := selectedIDs[i]
serialID := tokenID(serialToken)
stats.compared++
if vectorizedID != serialID {
firstMismatch := stats.batchSerialMismatches == 0
stats.batchSerialMismatches++
if !firstMismatch {
return
}
draftID := -1
if i < draftCount {
draftID = draftIDs[i]
}
batchedTop := top2FromLogits(batchedLogits)
serialTop := top2FromLogits(serialLogits)
slog.Warn("MTP batched validation differs from serial validation",
"position", before+i,
"draft", draftID,
"batched", vectorizedID,
"batched_slice", batchedID,
"serial", serialID,
"batched_slice_top1", batchedTop.firstToken,
"batched_slice_top2", batchedTop.secondToken,
"batched_slice_margin", batchedTop.margin,
"serial_top1", serialTop.firstToken,
"serial_top2", serialTop.secondToken,
"serial_margin", serialTop.margin,
)
return
}
if i >= draftCount || i >= accepted {
return
}
hidden := r.Model.Forward(&batch.Batch{
InputIDs: mlx.FromValues([]int32{int32(draftIDs[i])}, 1, 1),
SeqOffsets: []int32{int32(before + i)},
SeqQueryLens: []int32{1},
}, serialCaches)
serialLogits = r.lastLogits(hidden)
}
}
type mtpTop2 struct {
firstToken int
secondToken int
margin float64
}
func top2FromLogits(logits *mlx.Array) mtpTop2 {
indices := logits.Negative().ArgsortAxis(-1).Slice(mlx.Slice(), mlx.Slice(0, 2))
indices32 := indices.AsType(mlx.DTypeInt32)
values := logits.TakeAlongAxis(indices, -1).AsType(mlx.DTypeFloat32)
mlx.Eval(indices32, values)
tokenIDs := indices32.Ints()
logitValues := values.Floats()
if len(tokenIDs) < 2 || len(logitValues) < 2 {
return mtpTop2{}
}
return mtpTop2{
firstToken: tokenIDs[0],
secondToken: tokenIDs[1],
margin: float64(logitValues[0] - logitValues[1]),
}
}
func (r *Runner) acceptMTPDraftsSerial(ctx context.Context, request Request, session *cacheSession, dec *decoder, caches []cache.Cache, position *int, baseLogits *mlx.Array, draftTokens *mlx.Array, final *CompletionResponse, generated *int) (sampler.Result, int, bool, error) {
logits := baseLogits
accepted := 0
draftIDs := draftTokens.Ints()
for _, id := range draftIDs {
selected := greedyTokenFromLogits(logits)
mlx.Eval(selected)
selectedID := tokenID(selected)
if selectedID != id {
return sampler.Result{Token: mlx.FromValues([]int32{int32(selectedID)}, 1)}, accepted, false, nil
}
hidden := r.Model.Forward(&batch.Batch{
InputIDs: mlx.FromValues([]int32{int32(id)}, 1, 1),
SeqOffsets: []int32{int32(*position)},
SeqQueryLens: []int32{1},
}, caches)
(*position)++
accepted++
res := sampler.Result{Token: mlx.FromValues([]int32{int32(id)}, 1)}
done, err := r.emitMTPToken(ctx, request, session, dec, res, final)
if err != nil {
return sampler.Result{}, accepted, done, err
}
if !done {
(*generated)++
}
if done || *generated >= request.Options.NumPredict {
return sampler.Result{}, accepted, true, nil
}
logits = r.lastLogits(hidden)
}
return sampler.Result{Token: greedyTokenFromLogits(logits)}, accepted, false, nil
}
func (r *Runner) emitMTPToken(ctx context.Context, request Request, session *cacheSession, dec *decoder, res sampler.Result, final *CompletionResponse) (bool, error) {
output := int32(tokenID(res.Token))
session.outputs = append(session.outputs, output)
if r.Tokenizer.IsEOS(output) {
final.DoneReason = 0
return true, nil
}
if resp, ok := dec.decode(res); ok {
select {
case <-ctx.Done():
return false, ctx.Err()
case request.Responses <- resp:
}
}
return false, nil
}
func (r *Runner) lastLogits(hidden *mlx.Array) *mlx.Array {
logits := r.Model.Unembed(hidden)
return r.lastLogitsFromLogits(logits)
}
func (r *Runner) targetLogitsAt(hiddenSeq *mlx.Array, index int) *mlx.Array {
hidden := hiddenSeq.Slice(mlx.Slice(), mlx.Slice(index), mlx.Slice())
return r.lastLogits(hidden)
}
func (r *Runner) mtpValidationTokens(baseLogits, hiddenSeq *mlx.Array) *mlx.Array {
return greedyTokenFromLogits(r.mtpValidationLogits(baseLogits, hiddenSeq))
}
func (r *Runner) mtpValidationLogits(baseLogits, hiddenSeq *mlx.Array) *mlx.Array {
seqLogits := r.Model.Unembed(hiddenSeq)
return baseLogits.ExpandDims(1).Concatenate(1, seqLogits)
}
func mtpTokenAt(tokens *mlx.Array, index int) *mlx.Array {
return tokens.Slice(mlx.Slice(), mlx.Slice(index)).Squeeze(0)
}
func (r *Runner) lastLogitsFromLogits(logits *mlx.Array) *mlx.Array {
return logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)
}
func greedyTokenFromLogits(logits *mlx.Array) *mlx.Array {
return logits.Argmax(-1, false).AsType(mlx.DTypeInt32)
}
func tokenID(token *mlx.Array) int {
if token == nil {
return -1
}
if token.DType() == mlx.DTypeInt32 {
ids := token.Ints()
if len(ids) > 0 {
return ids[0]
}
}
return token.Int()
}

434
x/mlxrunner/pipeline.go Normal file
View File

@@ -0,0 +1,434 @@
package mlxrunner
import (
"bytes"
"context"
"errors"
"fmt"
"log/slog"
"sort"
"time"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/x/mlxrunner/batch"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model/base"
sampler "github.com/ollama/ollama/x/mlxrunner/sample"
"github.com/ollama/ollama/x/tokenizer"
)
func prefillChunkSize() int {
return 2 << 10
}
// Prepare tokenizes the prompt and validates it against the model's
// context length. It is safe to call from any goroutine. On success it
// populates request.Tokens and adjusts request.Options.NumPredict.
func (r *Runner) Prepare(request *Request) error {
if r.Model == nil {
return errors.New("model not loaded")
}
tokens := r.Tokenizer.Encode(request.Prompt, r.Tokenizer.AddBOS())
if len(tokens) == 0 {
return errors.New("empty prompt")
}
if len(tokens) >= r.contextLength {
return fmt.Errorf("input length (%d tokens) exceeds the model's maximum context length (%d tokens)", len(tokens), r.contextLength)
}
// Cap generation to stay within the model's context length
maxGenerate := r.contextLength - len(tokens)
if request.Options.NumPredict <= 0 {
request.Options.NumPredict = maxGenerate
} else {
request.Options.NumPredict = min(request.Options.NumPredict, maxGenerate)
}
request.Tokens = tokens
return nil
}
// The runner serializes requests today so we just use a fixed slot ID.
const pipelineSlot = 0
func (r *Runner) TextGenerationPipeline(ctx context.Context, request Request) error {
mlx.ResetPeakMemory()
var sample, nextSample sampler.Result
defer func() {
r.Sampler.Remove(pipelineSlot)
mlx.Unpin(sample.Arrays()...)
mlx.Unpin(nextSample.Arrays()...)
mlx.Sweep()
mlx.ClearCache()
if slog.Default().Enabled(context.TODO(), logutil.LevelTrace) {
mlx.LogArrays()
r.cache.dumpTree()
}
slog.Info("peak memory", "size", mlx.PrettyBytes(mlx.PeakMemory()))
}()
inputs := request.Tokens
session := r.cache.begin(r.Model, inputs)
defer session.close()
caches := session.caches
tokens := session.remaining
prefillChunk := prefillChunkSize()
dflashMode, dflashDisabledReason := r.dflashGate(request.SamplerOpts)
dflashEnabled := dflashMode.enabled()
var dflashDraft base.DFlashDraftModel
var dflashTarget base.DFlashTargetModel
var dflashCaches []cache.Cache
var dflashSession *cacheSession
if dflashEnabled {
dflashDraft = r.Draft.(base.DFlashDraftModel)
dflashTarget = r.Model.(base.DFlashTargetModel)
targetCachedPrefix := len(inputs) - len(tokens)
dflashSession = r.dflashCache.beginWithFactoryLimit(inputs, dflashDraft.NewCaches, "DFlash draft", targetCachedPrefix, false)
dflashCaches = dflashSession.caches
defer func() {
dflashSession.outputs = append([]int32(nil), session.outputs...)
dflashSession.close()
}()
} else if _, ok := r.Draft.(base.DFlashDraftModel); ok {
slog.Info("DFlash decode disabled",
"reason", dflashDisabledReason,
"temperature", request.SamplerOpts.Temperature,
"top_p", request.SamplerOpts.TopP,
"top_k", request.SamplerOpts.TopK,
"min_p", request.SamplerOpts.MinP,
"repeat_penalty", request.SamplerOpts.RepeatPenalty,
"presence_penalty", request.SamplerOpts.PresencePenalty,
"frequency_penalty", request.SamplerOpts.FrequencyPenalty,
"logprobs", request.SamplerOpts.Logprobs,
"top_logprobs", request.SamplerOpts.TopLogprobs,
)
}
requestPipelineSnapshots := func(s *cacheSession) {
if s == nil {
return
}
// Request periodic snapshots during prefill and near the end of the
// prompt so that long prompts can be partially restored and
// thinking/generation can be retried without full reprocessing.
const snapshotInterval = 8192
for offset := snapshotInterval; offset < len(inputs); offset += snapshotInterval {
s.requestSnapshot(offset)
}
const preThinking = 4
if end := len(inputs) - preThinking; end > 0 {
s.requestSnapshot(end)
}
}
requestPipelineSnapshots(session)
requestPipelineSnapshots(dflashSession)
nextSnapshotOffset := func() int {
next := session.nextPendingSnapshot()
if dflashSession != nil {
if offset := dflashSession.nextPendingSnapshot(); offset > 0 && (next == 0 || offset < next) {
next = offset
}
}
return next
}
snapshotReadySessions := func(position int) {
if snapOffset := session.nextPendingSnapshot(); snapOffset > 0 && position >= snapOffset {
session.snapshot()
}
if dflashSession != nil {
if snapOffset := dflashSession.nextPendingSnapshot(); snapOffset > 0 && position >= snapOffset {
dflashSession.snapshot()
}
}
}
materializeCaches := func(cacheSets ...[]cache.Cache) {
if len(cacheSets) == 0 {
cacheSets = [][]cache.Cache{caches}
}
state := make([]*mlx.Array, 0, 2*len(caches))
for _, set := range cacheSets {
for _, c := range set {
if c == nil {
continue
}
state = append(state, c.State()...)
}
}
if len(state) == 0 {
return
}
mlx.Eval(state...)
}
if dflashEnabled {
targetCachedPrefix := len(inputs) - len(tokens)
dflashCachedPrefix := len(inputs) - len(dflashSession.remaining)
if targetCachedPrefix > dflashCachedPrefix {
t0 := time.Now()
rebuildCaches := newDFlashTargetCaches(r.Model)
rebuildProcessed := 0
for targetCachedPrefix-rebuildProcessed > 0 {
if err := ctx.Err(); err != nil {
freeCacheSet(rebuildCaches)
return err
}
n := min(prefillChunk, targetCachedPrefix-rebuildProcessed)
if snapOffset := dflashSession.nextPendingSnapshot(); snapOffset > rebuildProcessed && snapOffset < rebuildProcessed+n {
n = snapOffset - rebuildProcessed
}
start, end := rebuildProcessed, rebuildProcessed+n
b := &batch.Batch{
InputIDs: mlx.FromValues(inputs[start:end], 1, n),
SeqOffsets: []int32{int32(start)},
SeqQueryLens: []int32{int32(n)},
}
_, targetHidden := dflashTarget.ForwardDFlash(b, rebuildCaches, dflashDraft.TargetLayerIDs())
if end > dflashCachedPrefix {
appendHidden := targetHidden
if start < dflashCachedPrefix {
appendHidden = targetHidden.Slice(mlx.Slice(), mlx.Slice(dflashCachedPrefix-start, n), mlx.Slice())
}
dflashDraft.AppendContext(appendHidden, dflashCaches)
}
mlx.Sweep()
materializeCaches(rebuildCaches, dflashCaches)
rebuildProcessed = end
if snapOffset := dflashSession.nextPendingSnapshot(); snapOffset > 0 && rebuildProcessed >= snapOffset {
dflashSession.snapshot()
}
mlx.ClearCache()
}
freeCacheSet(rebuildCaches)
slog.Info("DFlash draft cache rebuild",
"target_cached", targetCachedPrefix,
"draft_cached", dflashCachedPrefix,
"rebuilt", targetCachedPrefix-dflashCachedPrefix,
"draft_offset", r.dflashCache.minCacheOffset(),
"duration", time.Since(t0),
)
} else {
slog.Info("DFlash draft cache restored",
"target_cached", targetCachedPrefix,
"draft_cached", dflashCachedPrefix,
"draft_offset", r.dflashCache.minCacheOffset(),
)
}
}
now := time.Now()
total, processed := len(tokens), 0
position := len(inputs) - len(tokens)
for total-processed > 1 {
if err := ctx.Err(); err != nil {
return err
}
n := min(prefillChunk, total-processed-1)
// If there's a pending snapshot, split the batch so we can
// capture it at the exact offset.
if snapOffset := nextSnapshotOffset(); snapOffset > 0 {
tokensUntilSnapshot := snapOffset - position
if tokensUntilSnapshot > 0 && tokensUntilSnapshot < n {
n = tokensUntilSnapshot
}
}
b := &batch.Batch{
InputIDs: mlx.FromValues(tokens[processed:processed+n], 1, n),
SeqOffsets: []int32{int32(position)},
SeqQueryLens: []int32{int32(n)},
}
if dflashEnabled {
_, targetHidden := dflashTarget.ForwardDFlash(b, caches, dflashDraft.TargetLayerIDs())
dflashDraft.AppendContext(targetHidden, dflashCaches)
} else {
r.Model.Forward(b, caches)
}
mlx.Sweep()
if dflashEnabled {
materializeCaches(caches, dflashCaches)
} else {
materializeCaches()
}
processed += n
position += n
slog.Info("Prompt processing progress", "processed", processed, "total", total)
logutil.TraceContext(ctx, "mlx prompt forward", "processed", processed, "total", total, "tokens", n, "memory", mlx.Memory{})
// Create snapshot if we've reached a pending offset.
snapshotReadySessions(position)
mlx.ClearCache()
}
// Register the sampler after prefill completes.
r.Sampler.Add(pipelineSlot, request.SamplerOpts, inputs)
if dflashMode == dflashDecodeGreedy {
return r.runGreedyDFlashDecode(ctx, request, session, caches, dflashCaches, tokens[processed:], &position, now)
}
if dflashMode == dflashDecodeSample {
return r.runSampleDFlashDecode(ctx, request, session, caches, dflashCaches, tokens[processed:], &position, now)
}
if r.useGreedyMTP(request.SamplerOpts) {
return r.runGreedyMTPDecode(ctx, request, session, caches, tokens[processed:], &position, now)
}
if r.useSampleMTP(request.SamplerOpts) {
return r.runSampleMTPDecode(ctx, request, session, caches, tokens[processed:], &position, now)
}
step := func(token *mlx.Array) sampler.Result {
fwd := r.Model.Forward(&batch.Batch{
InputIDs: token,
SeqOffsets: []int32{int32(position)},
SeqQueryLens: []int32{int32(token.Dim(1))},
}, caches)
position += token.Dim(1)
logits := r.Model.Unembed(fwd)
logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)
sample := r.Sampler.Sample([]int{pipelineSlot}, logits)
mlx.Pin(sample.Arrays()...)
mlx.Sweep()
mlx.AsyncEval(sample.Arrays()...)
return sample
}
sample = step(mlx.FromValues(tokens[processed:], 1, total-processed))
logutil.TraceContext(ctx, "mlx decode seed", "tokens", total-processed, "memory", mlx.Memory{})
dec := decoder{
tokenizer: r.Tokenizer,
wantLogprobs: request.SamplerOpts.Logprobs,
wantTopLogprobs: request.SamplerOpts.TopLogprobs,
}
final := CompletionResponse{Done: true, PromptEvalCount: len(inputs), EvalCount: request.Options.NumPredict, DoneReason: 1}
for i := range request.Options.NumPredict {
if err := ctx.Err(); err != nil {
return err
}
nextSample = step(sample.Token.ExpandDims(-1))
if i == 0 {
mlx.Eval(sample.Arrays()...)
final.PromptEvalDuration = time.Since(now)
now = time.Now()
}
output := int32(sample.Token.Int())
session.outputs = append(session.outputs, output)
if i == 0 {
logutil.TraceContext(ctx, "mlx decode first token", "memory", mlx.Memory{})
}
if r.Tokenizer.IsEOS(output) {
final.DoneReason = 0
final.EvalCount = i
break
}
if resp, ok := dec.decode(sample); ok {
select {
case <-ctx.Done():
return ctx.Err()
case request.Responses <- resp:
}
}
mlx.Unpin(sample.Arrays()...)
sample, nextSample = nextSample, sampler.Result{}
if i%256 == 0 {
mlx.ClearCache()
}
}
final.EvalDuration = time.Since(now)
select {
case <-ctx.Done():
return ctx.Err()
case request.Responses <- final:
return nil
}
}
// decoder serializes sampled tokens into response chunks, holding bytes
// whose UTF-8 sequence hasn't completed yet and the logprobs that belong
// with those bytes so Content and Logprobs stay aligned when a chunk does
// flush.
type decoder struct {
tokenizer *tokenizer.Tokenizer
buf bytes.Buffer
logprobs []llm.Logprob
wantLogprobs bool
wantTopLogprobs int
}
func (d *decoder) decode(res sampler.Result) (CompletionResponse, bool) {
output := int32(res.Token.Int())
d.buf.WriteString(d.tokenizer.Decode([]int32{output}))
d.logprobs = append(d.logprobs, buildLogprob(res, d.wantLogprobs, d.wantTopLogprobs, d.tokenizer.Decode)...)
content := flushValidUTF8Prefix(&d.buf)
if content == "" {
return CompletionResponse{}, false
}
resp := CompletionResponse{Content: content, Logprobs: d.logprobs}
d.logprobs = nil
return resp, true
}
// buildLogprob converts the sampler's logprob tensors into the wire-format
// llm.Logprob entries the caller wants. The sampler populates its logprob
// tensors whenever any registered slot requested them, so the caller must
// gate emission on its own request config (wantLogprobs / wantTopLogprobs)
// rather than on whether the tensors happen to be non-nil.
func buildLogprob(sample sampler.Result, wantLogprobs bool, wantTopLogprobs int, decode func([]int32) string) []llm.Logprob {
if !wantLogprobs || sample.Logprob == nil {
return nil
}
tok := func(id int32) string { return decode([]int32{id}) }
out := llm.Logprob{
TokenLogprob: llm.TokenLogprob{
Token: tok(int32(sample.Token.Int())),
Logprob: float64(sample.Logprob.Floats()[0]),
},
}
if wantTopLogprobs > 0 && sample.TopTokens != nil {
ids := sample.TopTokens.Ints()
vals := sample.TopLogprobs.Floats()
pairs := make([]llm.TokenLogprob, len(ids))
for i, id := range ids {
pairs[i] = llm.TokenLogprob{
Token: tok(int32(id)),
Logprob: float64(vals[i]),
}
}
// The sampler emits the top maxK across registered slots via
// Argpartition, which leaves entries unsorted.
sort.Slice(pairs, func(i, j int) bool {
return pairs[i].Logprob > pairs[j].Logprob
})
if wantTopLogprobs < len(pairs) {
pairs = pairs[:wantTopLogprobs]
}
out.TopLogprobs = pairs
}
return []llm.Logprob{out}
}

210
x/mlxrunner/runner.go Normal file
View File

@@ -0,0 +1,210 @@
package mlxrunner
import (
"context"
"errors"
"log/slog"
"net"
"net/http"
"strings"
"golang.org/x/sync/errgroup"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/x/internal/mlxthread"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
"github.com/ollama/ollama/x/mlxrunner/model/base"
"github.com/ollama/ollama/x/mlxrunner/sample"
"github.com/ollama/ollama/x/tokenizer"
)
// Request is a short-lived struct that carries a completion request through
// a channel from the HTTP handler to the runner goroutine. The ctx field
// must travel with the request so that cancellation propagates across the
// channel boundary.
type Request struct {
CompletionRequest
Responses chan CompletionResponse
Pipeline func(context.Context, Request) error
Ctx context.Context //nolint:containedctx
Tokens []int32
SamplerOpts sample.Options
}
type Runner struct {
Model base.Model
Draft base.DraftModel
Tokenizer *tokenizer.Tokenizer
Requests chan Request
Sampler *sample.Sampler
cache kvCache
dflashCache kvCache
contextLength int
mlxThread *mlxthread.Thread
}
func (r *Runner) Load(modelName string) error {
root, err := model.Open(modelName)
if err != nil {
return err
}
defer root.Close()
m, err := base.New(root)
if err != nil {
return err
}
// Load all tensor blobs from manifest
tensors, err := loadTensorsFromManifest(root)
if err != nil {
return err
}
// Assign weights to model (model-specific logic). Target and draft weights
// must be loaded before sweeping so tensors from a combined manifest are
// not discarded before the draft model can retain them.
if err := m.LoadWeights(tensors); err != nil {
return err
}
r.Draft = nil
draft, err := base.NewDraft(root, m)
if err != nil {
return err
}
if draft != nil {
if err := draft.LoadWeights(tensors); err != nil {
return err
}
r.Draft = draft
}
collected := mlx.Collect(m)
if draft != nil {
draftArrays := mlx.Collect(draft)
collected = append(collected, draftArrays...)
if root.Draft != nil {
slog.Info("Loaded draft model", "tensor_prefix", root.Draft.TensorPrefix, "config", root.Draft.Config, "arrays", len(draftArrays))
} else {
slog.Info("Loaded draft model", "arrays", len(draftArrays))
}
}
for _, arr := range collected {
mlx.Pin(arr)
}
mlx.Sweep()
mlx.Eval(collected...)
r.Model = m
r.Tokenizer = m.Tokenizer()
r.contextLength = m.MaxContextLength()
r.Sampler = sample.New(r.contextLength)
mlx.EnableCompile()
return nil
}
// loadTensorsFromManifest loads all tensor blobs from the manifest into a
// flat map, deduplicating by digest and remapping safetensors key suffixes.
//
// Uses a two-phase approach: first loads all raw tensors, then remaps
// .bias → _qbias with complete knowledge of which base names have .scale
// entries. This avoids a race condition where Go map iteration order could
// cause .bias to be processed before .scale within the same blob.
func loadTensorsFromManifest(root *model.Root) (map[string]*mlx.Array, error) {
// Phase 1: Load all tensors raw from all blobs
rawTensors := make(map[string]*mlx.Array)
seen := make(map[string]bool)
for _, layer := range root.Manifest.GetTensorLayers("") {
if seen[layer.Digest] {
continue
}
seen[layer.Digest] = true
blobPath := root.Manifest.BlobPath(layer.Digest)
for name, arr := range mlx.Load(blobPath) {
rawTensors[name] = arr
}
}
// Phase 2: Identify all base names that have .scale tensors and remap them
scaleBaseNames := make(map[string]bool)
allTensors := make(map[string]*mlx.Array, len(rawTensors))
for name, arr := range rawTensors {
if strings.HasSuffix(name, ".scale") {
baseName := strings.TrimSuffix(name, ".scale")
allTensors[baseName+"_scale"] = arr
scaleBaseNames[baseName] = true
}
}
// Phase 3: Process remaining tensors with complete scale knowledge
for name, arr := range rawTensors {
if strings.HasSuffix(name, ".scale") {
continue // already handled
}
if strings.HasSuffix(name, ".bias") && !strings.HasSuffix(name, ".weight_qbias") {
baseName := strings.TrimSuffix(name, ".bias")
if scaleBaseNames[baseName] {
allTensors[baseName+"_qbias"] = arr
} else {
allTensors[name] = arr
}
} else {
allTensors[name] = arr
}
}
slog.Info("Loaded tensors from manifest", "count", len(allTensors))
return allTensors, nil
}
func (r *Runner) Run(host, port string, mux http.Handler) error {
g, ctx := errgroup.WithContext(context.Background())
g.Go(func() error {
for {
select {
case <-ctx.Done():
return nil
case request := <-r.Requests:
err := r.runRequest(request)
if err != nil {
slog.Info("Request terminated", "error", err)
var statusErr api.StatusError
if !errors.As(err, &statusErr) {
statusErr = api.StatusError{
StatusCode: http.StatusInternalServerError,
ErrorMessage: err.Error(),
}
}
select {
case request.Responses <- CompletionResponse{Error: &statusErr}:
case <-request.Ctx.Done():
}
}
close(request.Responses)
}
}
})
g.Go(func() error {
slog.Info("Starting HTTP server", "host", host, "port", port)
return http.ListenAndServe(net.JoinHostPort(host, port), mux)
})
return g.Wait()
}
func (r *Runner) runRequest(request Request) error {
if r.mlxThread == nil {
return request.Pipeline(request.Ctx, request)
}
return r.mlxThread.Do(request.Ctx, func() error {
return request.Pipeline(request.Ctx, request)
})
}

View File

@@ -0,0 +1,300 @@
//go:build mlx
package sample
import (
"math"
"sort"
"testing"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
// logprobEntry is the (token id, logprob) pair returned by the sampler's
// top-K extraction, used after the test-side descending sort.
type logprobEntry struct {
id int
logprob float64
}
// runSampleLogprobs drives Sample on a fresh Sampler configured for logprobs
// and returns the greedily-sampled token id, its logprob, and the top-K
// entries sorted descending by logprob. Logits must be a [vocab]-shaped
// slice; the helper reshapes it to [1, vocab] before calling the sampler.
func runSampleLogprobs(t *testing.T, logits []float32, topK int) (int, float64, []logprobEntry) {
t.Helper()
s := New(128)
defer func() {
s.Free()
mlx.Sweep()
}()
s.Add(0, Options{Logprobs: true, TopLogprobs: topK}, nil)
tensor := mlx.FromValues(logits, 1, len(logits))
res := s.Sample([]int{0}, tensor)
mlx.Pin(res.Arrays()...)
defer mlx.Unpin(res.Arrays()...)
mlx.Sweep()
mlx.Eval(res.Arrays()...)
selected := res.Token.Int()
selLP := float64(res.Logprob.Floats()[0])
var top []logprobEntry
if topK > 0 && res.TopTokens != nil {
ids := res.TopTokens.Ints()
vals := res.TopLogprobs.Floats()
top = make([]logprobEntry, len(ids))
for i, id := range ids {
top[i] = logprobEntry{id: id, logprob: float64(vals[i])}
}
sort.Slice(top, func(i, j int) bool { return top[i].logprob > top[j].logprob })
}
return selected, selLP, top
}
func TestSampleLogprobsBasic(t *testing.T) {
skipIfNoMLX(t)
tests := []struct {
name string
logits []float32
topK int
wantSelectedID int
wantTopLen int
}{
{
name: "single token without top logprobs",
logits: []float32{1.0, 0.5, 0.3, 0.1},
topK: 0,
wantSelectedID: 0,
wantTopLen: 0,
},
{
name: "single token with top logprobs",
logits: []float32{1.0, 0.5, 0.3, 0.1},
topK: 3,
wantSelectedID: 0,
wantTopLen: 3,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
selected, _, top := runSampleLogprobs(t, tt.logits, tt.topK)
if selected != tt.wantSelectedID {
t.Errorf("selected = %d, want %d", selected, tt.wantSelectedID)
}
if len(top) != tt.wantTopLen {
t.Errorf("top-K length = %d, want %d", len(top), tt.wantTopLen)
}
})
}
}
func TestSampleLogprobsNumericalStability(t *testing.T) {
skipIfNoMLX(t)
logits := []float32{1000.0, 999.0, 998.0}
_, selLP, top := runSampleLogprobs(t, logits, 3)
if math.IsInf(selLP, 0) || math.IsNaN(selLP) {
t.Errorf("selected logprob is not finite: %f", selLP)
}
for i, e := range top {
if math.IsInf(e.logprob, 0) || math.IsNaN(e.logprob) {
t.Errorf("top[%d] logprob is not finite: %f", i, e.logprob)
}
}
for i := 1; i < len(top); i++ {
if top[i].logprob > top[i-1].logprob {
t.Errorf("top logprobs not descending: %f > %f", top[i].logprob, top[i-1].logprob)
}
}
}
func TestSampleLogprobsProbabilityCorrectness(t *testing.T) {
skipIfNoMLX(t)
tests := []struct {
name string
logits []float32
}{
{"uniform", []float32{1.0, 1.0, 1.0, 1.0}},
{"different", []float32{2.0, 1.0, 0.5, 0.1}},
{"negative", []float32{-1.0, -2.0, -3.0, -4.0}},
{"mixed", []float32{5.0, -5.0, 0.0, 2.5}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
selected, selLP, top := runSampleLogprobs(t, tt.logits, len(tt.logits))
if selLP > 0 {
t.Errorf("selected logprob should be <= 0, got %f", selLP)
}
for i, e := range top {
if e.logprob > 0 {
t.Errorf("top[%d] logprob should be <= 0, got %f", i, e.logprob)
}
}
if tt.name == "uniform" {
want := 1.0 / float64(len(tt.logits))
got := math.Exp(selLP)
if math.Abs(got-want) > 1e-6 {
t.Errorf("uniform logits: selected prob = %f, want %f", got, want)
}
}
for i := 1; i < len(top); i++ {
if top[i].logprob > top[i-1].logprob {
t.Errorf("top logprobs not descending at %d: %f > %f",
i, top[i].logprob, top[i-1].logprob)
}
}
found := false
for _, e := range top {
if e.id == selected {
found = true
if math.Abs(e.logprob-selLP) > 1e-6 {
t.Errorf("selected logprob mismatch: selLP=%f top=%f", selLP, e.logprob)
}
break
}
}
if !found {
t.Errorf("selected token %d not present in top-K", selected)
}
})
}
}
func TestSampleLogprobsSoftmaxCorrectness(t *testing.T) {
skipIfNoMLX(t)
tests := []struct {
name string
logits []float32
}{
{"small vocabulary", []float32{1.0, 2.0, 3.0}},
{"large differences", []float32{10.0, 0.0, -10.0}},
{"all equal", []float32{5.0, 5.0, 5.0, 5.0, 5.0}},
{"very large values", []float32{500.0, 499.0, 498.0}},
{"very small values", []float32{-500.0, -499.0, -498.0}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, _, top := runSampleLogprobs(t, tt.logits, len(tt.logits))
if len(top) != len(tt.logits) {
t.Fatalf("top-K length = %d, want %d", len(top), len(tt.logits))
}
var sum float64
for _, e := range top {
p := math.Exp(e.logprob)
if p < 0 || p > 1 {
t.Errorf("token %d: probability %f out of [0,1]", e.id, p)
}
sum += p
}
if math.Abs(sum-1.0) > 1e-5 {
t.Errorf("probabilities sum = %f, want 1.0", sum)
}
})
}
}
func TestSampleLogprobsSelectedTokenCorrectness(t *testing.T) {
skipIfNoMLX(t)
logits := []float32{3.0, 1.0, 2.0, 0.5}
maxIdx := 0
for i, v := range logits[1:] {
if v > logits[maxIdx] {
maxIdx = i + 1
}
}
selected, selLP, top := runSampleLogprobs(t, logits, len(logits))
if selected != maxIdx {
t.Errorf("selected = %d, want argmax %d", selected, maxIdx)
}
if top[0].id != maxIdx {
t.Errorf("top[0].id = %d, want argmax %d", top[0].id, maxIdx)
}
if math.Abs(top[0].logprob-selLP) > 1e-6 {
t.Errorf("top[0].logprob = %f, want selected %f", top[0].logprob, selLP)
}
}
// TestBatchedLogprobsPerRow verifies that per-row logprobs in a batched
// sample call match the per-slot reference. The numerically-stable softmax
// must reduce along the last axis only, not over the whole batch.
func TestBatchedLogprobsPerRow(t *testing.T) {
skipIfNoMLX(t)
rowA := []float32{2, 1, 0}
rowB := []float32{0, 5, 0}
_, wantA, _ := runSampleLogprobs(t, rowA, 0)
_, wantB, _ := runSampleLogprobs(t, rowB, 0)
s := New(128)
t.Cleanup(func() {
s.Free()
mlx.Sweep()
})
s.Add(1, Options{Logprobs: true}, nil)
s.Add(2, Options{Logprobs: true}, nil)
logits := mlx.FromValues(append(append([]float32{}, rowA...), rowB...), 2, 3)
res := s.Sample([]int{1, 2}, logits)
mlx.Pin(res.Arrays()...)
t.Cleanup(func() { mlx.Unpin(res.Arrays()...) })
mlx.Eval(res.Arrays()...)
got := res.Logprob.Floats()
if len(got) != 2 {
t.Fatalf("Logprob length = %d, want 2", len(got))
}
if math.Abs(float64(got[0])-wantA) > 1e-5 {
t.Errorf("row 0 logprob = %f, want %f (per-slot reference)", got[0], wantA)
}
if math.Abs(float64(got[1])-wantB) > 1e-5 {
t.Errorf("row 1 logprob = %f, want %f (per-slot reference)", got[1], wantB)
}
}
func TestSampleLogprobsTopKOrdering(t *testing.T) {
skipIfNoMLX(t)
// Logits chosen so argmax order differs from index order.
logits := []float32{2.0, 5.0, 1.0, 4.0, 3.0}
wantOrder := []int{1, 3, 4, 0, 2}
_, _, top := runSampleLogprobs(t, logits, len(logits))
if len(top) != len(wantOrder) {
t.Fatalf("top-K length = %d, want %d", len(top), len(wantOrder))
}
for i, e := range top {
if e.id != wantOrder[i] {
t.Errorf("top[%d].id = %d, want %d", i, e.id, wantOrder[i])
}
}
for i := 1; i < len(top); i++ {
if top[i].logprob > top[i-1].logprob {
t.Errorf("top[%d].logprob (%f) > top[%d].logprob (%f)",
i, top[i].logprob, i-1, top[i-1].logprob)
}
}
}

View File

@@ -0,0 +1,897 @@
package sample
import (
"fmt"
"math"
"slices"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
type Options struct {
Temperature float32
TopP float32
MinP float32
TopK int
RepeatLastN int
RepeatPenalty float32
PresencePenalty float32
FrequencyPenalty float32
Seed int
UseSeed bool
// Logprobs causes Sample to populate Result.Logprob with the selected
// token's log-probability. TopLogprobs (when > 0) adds top-K pairs.
Logprobs bool
TopLogprobs int
}
// Result bundles the outputs of one decode step. Logprob/TopTokens/
// TopLogprobs are populated whenever any registered slot has Logprobs
// (respectively TopLogprobs>0). Consumers need to filter by their
// per-slot Options.
type Result struct {
Token *mlx.Array // sampled token ids, shape [B]
Logprob *mlx.Array // sampled-token logprobs, shape [B,1]; nil unless any registered slot has Logprobs
TopTokens *mlx.Array // top-K token ids, shape [B,maxK]; nil unless any registered slot has TopLogprobs>0
TopLogprobs *mlx.Array // top-K logprobs, shape [B,maxK]; same
}
// Arrays returns the tensor fields as a slice so callers can drive the mlx
// lifecycle verbs (Pin, Unpin, Eval, AsyncEval) over the whole group. Unset
// fields stay nil; the mlx helpers skip them.
func (r Result) Arrays() []*mlx.Array {
return []*mlx.Array{r.Token, r.Logprob, r.TopTokens, r.TopLogprobs}
}
// Distribution is the filtered probability distribution used by the sampler.
// When IDs is nil, Probs is dense over the vocabulary. When IDs is set, Probs
// is sparse over the token ids in IDs, preserving GPU residency for the
// top-k-first path used by normal and speculative sampling.
type Distribution struct {
IDs *mlx.Array // sparse token ids, shape [B,K]; nil for dense distributions
Probs *mlx.Array // probabilities, shape [B,K] or [B,V]
}
// Arrays returns the tensor fields for mlx lifecycle management.
func (d Distribution) Arrays() []*mlx.Array {
return []*mlx.Array{d.IDs, d.Probs}
}
// Rows returns the number of rows in the distribution.
func (d Distribution) Rows() int {
if d.Probs == nil {
return 0
}
return d.Probs.Dim(0)
}
// SliceRows returns a row slice while preserving sparse/dense layout.
func (d Distribution) SliceRows(start, stop int) Distribution {
out := Distribution{Probs: d.Probs.Slice(mlx.Slice(start, stop), mlx.Slice())}
if d.IDs != nil {
out.IDs = d.IDs.Slice(mlx.Slice(start, stop), mlx.Slice())
}
return out
}
// SampleWithKey draws one token per row using key when supplied.
func (d Distribution) SampleWithKey(key *mlx.Array) *mlx.Array {
choice := logitsFromProbs(d.Probs).CategoricalWithKey(-1, key).AsType(mlx.DTypeInt32)
if d.IDs == nil {
return choice
}
return d.IDs.TakeAlongAxis(choice.ExpandDims(-1), -1).Squeeze(-1).AsType(mlx.DTypeInt32)
}
// Prob returns the probability assigned to one token per row.
func (d Distribution) Prob(tokens *mlx.Array) *mlx.Array {
switch tokens.NumDims() {
case 2:
if tokens.Dim(0) == 1 {
tokens = tokens.Squeeze(0)
} else if tokens.Dim(1) == 1 {
tokens = tokens.Squeeze(1)
}
case 0:
tokens = tokens.Reshape(1)
}
return d.ProbsForIDs(tokens.ExpandDims(-1)).Squeeze(-1)
}
// ProbsForIDs returns probabilities for each requested token id. ids must be
// rank-2 [B,N], matching the distribution rows.
func (d Distribution) ProbsForIDs(ids *mlx.Array) *mlx.Array {
if d.IDs == nil {
return d.Probs.TakeAlongAxis(ids, -1)
}
eq := d.IDs.ExpandDims(-1).Equal(ids.ExpandDims(1))
values := mlx.Where(eq, d.Probs.ExpandDims(-1), mlx.FromValue(float32(0)))
return values.SumAxis(1, false)
}
// ResidualAgainst returns the Leviathan/Chen rejection distribution
// proportional to max(target - draft, 0). Sparse target distributions stay
// sparse over the target support; tokens outside target support have zero mass.
func (d Distribution) ResidualAgainst(draft Distribution) Distribution {
if d.IDs != nil {
diff := d.Probs.Subtract(draft.ProbsForIDs(d.IDs))
return Distribution{IDs: d.IDs, Probs: normalizeProbs(mlx.Maximum(diff, mlx.FromValue(float32(0))))}
}
if draft.IDs != nil {
panic("sample.Distribution.ResidualAgainst: dense target with sparse draft is unsupported")
}
diff := d.Probs.Subtract(draft.Probs)
return Distribution{Probs: normalizeProbs(mlx.Maximum(diff, mlx.FromValue(float32(0))))}
}
// LogProbs returns dense log-probabilities, scattering sparse distributions
// into a full-vocabulary tensor when needed.
func (d Distribution) LogProbs(vocab int) *mlx.Array {
logProbs := logitsFromProbs(d.Probs)
if d.IDs == nil {
return logProbs
}
out := mlx.AddScalar(mlx.Zeros(mlx.DTypeFloat32, d.Probs.Dim(0), vocab), float32(math.Inf(-1)))
return out.PutAlongAxis(d.IDs, logProbs, -1)
}
// ConcatenateDistributions concatenates distribution rows. All inputs must use
// the same sparse/dense layout.
func ConcatenateDistributions(dists []Distribution) Distribution {
if len(dists) == 0 {
return Distribution{}
}
probs := make([]*mlx.Array, 0, len(dists))
ids := make([]*mlx.Array, 0, len(dists))
sparse := dists[0].IDs != nil
for _, d := range dists {
if (d.IDs != nil) != sparse {
panic("sample.ConcatenateDistributions: mixed sparse and dense distributions")
}
probs = append(probs, d.Probs)
if sparse {
ids = append(ids, d.IDs)
}
}
out := Distribution{Probs: mlx.Concatenate(probs, 0)}
if sparse {
out.IDs = mlx.Concatenate(ids, 0)
}
return out
}
// Sampler is a batched, slot-based sampler. Sequences are registered with
// Add and released with Remove. Each Sample call takes a subset of
// registered slots (in any order) with their [B,V] logits, samples one
// token per row, and appends it to that slot's ring-buffer history. Slots
// not named in a given call are untouched.
type Sampler struct {
slots []*slotState
byID map[int]*slotState
// history is the pooled ring-buffer storage, [B, W] int32. Row i
// belongs to slots[i]; W is max(RepeatLastN) across penalty slots.
// Allocated on the first penalty slot, rebuilt only in Add/Remove.
history *mlx.Array
// allSameOpts: every registered slot shares Options. When true the
// canonical shared value is s.slots[0].opts.
allSameOpts bool
// anyLogprobs / maxTopLogprobs: compute-for-all output config.
// Sample populates Logprob (and Top* when maxTopLogprobs>0) whenever
// any registered slot requests them, even if that slot isn't in the
// current call.
anyLogprobs bool
maxTopLogprobs int
// numCtx is the runner's context window; normalize uses it to
// resolve the repeat_last_n == -1 sentinel.
numCtx int
}
type slotState struct {
opts Options
historyLen int
randomCounter uint64
}
type slotCtx struct {
opts Options
history *mlx.Array // 2D [B, W] when penalties are configured; nil otherwise
}
// New constructs an empty sampler with no registered slots. numCtx is
// the runner's context window and must be positive.
func New(numCtx int) *Sampler {
return &Sampler{
byID: make(map[int]*slotState),
allSameOpts: true,
numCtx: numCtx,
}
}
// historyWidth returns the column count of the pooled history tensor,
// or 0 when no penalty slot has forced it to be allocated.
func (s *Sampler) historyWidth() int {
if s.history == nil {
return 0
}
return s.history.Dim(1)
}
func (o Options) usesHistory() bool {
// RepeatLastN == 0 disables the penalty ring per the repeat_last_n API
// contract (0 = disabled), overriding any penalty coefficients.
if o.RepeatLastN == 0 {
return false
}
return o.RepeatPenalty != 1 || o.PresencePenalty != 0 || o.FrequencyPenalty != 0
}
func (o Options) normalize(numCtx int) Options {
if o.RepeatPenalty <= 0 {
o.RepeatPenalty = 1
}
// Resolve the repeat_last_n == -1 sentinel ("-1 = num_ctx") against
// the caller's context window.
if o.RepeatLastN < 0 {
o.RepeatLastN = numCtx
}
if !o.usesHistory() {
// Zero the ring capacity so slots that differ only in a spurious
// RepeatLastN still batch together and don't inflate pool width.
o.RepeatLastN = 0
}
if o.Seed < 0 {
o.UseSeed = false
}
if !o.UseSeed {
// Keep unseeded callers on the same batching path even when a
// meaningless Seed value is present in an Options literal.
o.Seed = 0
}
return o
}
// Add registers a sequence under seqID. The last RepeatLastN entries of
// priorTokens seed the ring buffer.
func (s *Sampler) Add(seqID int, opts Options, priorTokens []int32) {
if _, dup := s.byID[seqID]; dup {
panic(fmt.Sprintf("sample.Sampler.Add: seqID %d already registered", seqID))
}
opts = opts.normalize(s.numCtx)
slot := &slotState{
opts: opts,
}
// Grow the pool to hold this slot's row. The pool is lazy — the first
// penalty slot allocates it — and thereafter every registered slot
// gets a row (rows for non-penalty slots are zero and never read).
// Invariant: s.history is pinned whenever non-nil.
if s.history != nil || opts.usesHistory() {
targetWidth := max(opts.RepeatLastN, s.historyWidth())
newRow := makeHistoryRow(priorTokens, opts.RepeatLastN, targetWidth)
var pool *mlx.Array
switch {
case s.history == nil && len(s.slots) == 0:
pool = newRow
case s.history == nil:
// First penalty slot with non-penalty slots already registered;
// seed zero rows so s.slots and pool row indices stay aligned.
zeros := mlx.Zeros(mlx.DTypeInt32, len(s.slots), targetWidth)
pool = zeros.Concatenate(0, newRow)
case targetWidth > s.historyWidth():
pad := mlx.Zeros(mlx.DTypeInt32, s.history.Dim(0), targetWidth-s.historyWidth())
pool = s.history.Concatenate(1, pad).Concatenate(0, newRow)
default:
pool = s.history.Concatenate(0, newRow)
}
mlx.Pin(pool)
mlx.Unpin(s.history)
s.history = pool
if opts.usesHistory() {
// Cap on seed so the next write's ring position
// (historyLen % RepeatLastN) lands at 0, overwriting the
// oldest entry when the ring was filled from priors.
slot.historyLen = min(len(priorTokens), opts.RepeatLastN)
}
}
s.slots = append(s.slots, slot)
s.byID[seqID] = slot
s.recomputeInvariants()
}
// makeHistoryRow builds a [1, width] int32 row with the last repeatLastN
// entries of priorTokens packed into [0, min(len, repeatLastN)), zeros
// elsewhere.
func makeHistoryRow(priorTokens []int32, repeatLastN, width int) *mlx.Array {
take := min(len(priorTokens), repeatLastN)
if take <= 0 {
return mlx.Zeros(mlx.DTypeInt32, 1, width)
}
row := make([]int32, width)
copy(row, priorTokens[len(priorTokens)-take:])
return mlx.NewArrayInt32(row, []int32{1, int32(width)})
}
// recomputeInvariants refreshes allSameOpts and anyLogprobs/maxTopLogprobs
// from s.slots. Called at the end of Add and Remove.
func (s *Sampler) recomputeInvariants() {
if len(s.slots) == 0 {
s.allSameOpts = true
s.anyLogprobs = false
s.maxTopLogprobs = 0
return
}
first := s.slots[0].opts
s.allSameOpts = true
s.anyLogprobs = false
s.maxTopLogprobs = 0
for _, slot := range s.slots {
if slot.opts != first {
s.allSameOpts = false
}
if slot.opts.Logprobs {
s.anyLogprobs = true
if slot.opts.TopLogprobs > s.maxTopLogprobs {
s.maxTopLogprobs = slot.opts.TopLogprobs
}
}
}
}
// Remove releases the slot. The pool tensor is rebuilt to drop the row.
func (s *Sampler) Remove(seqID int) {
slot, ok := s.byID[seqID]
if !ok {
return
}
delete(s.byID, seqID)
row := slices.Index(s.slots, slot)
s.slots = slices.Delete(s.slots, row, row+1)
s.recomputeInvariants()
if s.history == nil {
return
}
n := s.history.Dim(0)
var newHistory *mlx.Array
switch {
case n == 1:
newHistory = nil
case row == 0:
newHistory = s.history.Slice(mlx.Slice(1, n), mlx.Slice())
case row == n-1:
newHistory = s.history.Slice(mlx.Slice(0, row), mlx.Slice())
default:
before := s.history.Slice(mlx.Slice(0, row), mlx.Slice())
after := s.history.Slice(mlx.Slice(row+1, n), mlx.Slice())
newHistory = before.Concatenate(0, after)
}
mlx.Pin(newHistory)
mlx.Unpin(s.history)
s.history = newHistory
}
// Free releases the pooled history tensor and resets the sampler to the
// New-equivalent state so it may be reused.
func (s *Sampler) Free() {
mlx.Unpin(s.history)
*s = Sampler{
byID: make(map[int]*slotState),
allSameOpts: true,
numCtx: s.numCtx,
}
}
// Sample draws one token per row of logits ([B,V]); seqIDs[i] names the
// slot whose logits live at row i. Each sampled token is appended to its
// slot's ring. Slots not named in seqIDs are untouched.
func (s *Sampler) Sample(seqIDs []int, logits *mlx.Array) Result {
if len(seqIDs) == 0 {
return Result{}
}
slots := make([]*slotState, len(seqIDs))
for i, id := range seqIDs {
slot, ok := s.byID[id]
if !ok {
panic(fmt.Sprintf("sample.Sampler.Sample: seqID %d not registered", id))
}
slots[i] = slot
}
var token *mlx.Array
if opts0, ok := s.canBatch(slots); ok {
token = s.sampleTokensUniform(slots, opts0, logits)
} else {
token = s.sampleTokensSerial(slots, logits)
}
res := Result{Token: token}
if s.anyLogprobs {
// Log-softmax over original logits so every row holds a truthful
// value (compute-for-all; consumers filter per-slot). Subtract
// max first for numerical stability in the logsumexp.
lp := logits.AsType(mlx.DTypeFloat32)
lp = lp.Subtract(lp.MaxAxis(-1, true))
lp = lp.Subtract(lp.LogsumexpAxis(-1, true))
res.Logprob = lp.TakeAlongAxis(token.ExpandDims(-1), -1)
if s.maxTopLogprobs > 0 {
k := s.maxTopLogprobs
if vocab := lp.Dim(lp.NumDims() - 1); k > vocab {
k = vocab
}
// Argpartition on the negated values places the K largest
// (unsorted) in positions [0:K].
idx := lp.Negative().ArgpartitionAxis(k-1, -1).Slice(mlx.Slice(), mlx.Slice(0, k))
res.TopTokens = idx.AsType(mlx.DTypeInt32)
res.TopLogprobs = lp.TakeAlongAxis(idx, -1)
}
}
return res
}
// Distribution applies this slot's sampling transforms to logits without
// mutating sampler state. Row i is built as if draftTokens[:i] had already
// been appended to the slot history. logits must be [R,V] or [1,R,V].
func (s *Sampler) Distribution(seqID int, logits *mlx.Array, draftTokens *mlx.Array) Distribution {
slot, logits, draftTokens := s.speculativeInputs("Distribution", seqID, logits, draftTokens)
rows := logits.Dim(0)
var hist *mlx.Array
if slot.opts.usesHistory() {
if s.history == nil {
panic(fmt.Sprintf("sample.Sampler.Distribution: seqID %d has no history", seqID))
}
if slot.historyLen < slot.opts.RepeatLastN {
return s.speculativeDistributionSerial(slot, logits, draftTokens)
}
hist = s.speculativeHistory(slot, draftTokens, rows)
}
return slot.distribution(&slotCtx{opts: slot.opts, history: hist}, logits)
}
// SpeculativeScores applies this slot's sampling transforms to logits without
// mutating sampler state and returns dense log-probability scores for sampled
// decoding. Greedy decoding returns the penalty-adjusted logits.
func (s *Sampler) SpeculativeScores(seqID int, logits *mlx.Array, draftTokens *mlx.Array) *mlx.Array {
slot, logits, draftTokens := s.speculativeInputs("SpeculativeScores", seqID, logits, draftTokens)
rows := logits.Dim(0)
var hist *mlx.Array
if slot.opts.usesHistory() {
if s.history == nil {
panic(fmt.Sprintf("sample.Sampler.SpeculativeScores: seqID %d has no history", seqID))
}
if slot.historyLen < slot.opts.RepeatLastN {
return s.speculativeScoresSerial(slot, logits, draftTokens)
}
hist = s.speculativeHistory(slot, draftTokens, rows)
}
return slot.speculativeScores(&slotCtx{opts: slot.opts, history: hist}, logits)
}
// SampleDistribution draws from a precomputed distribution while advancing
// seqID's deterministic RNG stream when a seed is configured.
func (s *Sampler) SampleDistribution(seqID int, dist Distribution) *mlx.Array {
slot := s.mustSlot("SampleDistribution", seqID)
return dist.SampleWithKey(slot.nextRandomKey())
}
// Bernoulli samples boolean outcomes while advancing seqID's deterministic RNG
// stream when a seed is configured.
func (s *Sampler) Bernoulli(seqID int, p *mlx.Array) *mlx.Array {
slot := s.mustSlot("Bernoulli", seqID)
return mlx.BernoulliWithKey(p, slot.nextRandomKey())
}
func (s *Sampler) mustSlot(caller string, seqID int) *slotState {
slot, ok := s.byID[seqID]
if !ok {
panic(fmt.Sprintf("sample.Sampler.%s: seqID %d not registered", caller, seqID))
}
return slot
}
func (s *Sampler) speculativeInputs(caller string, seqID int, logits *mlx.Array, draftTokens *mlx.Array) (*slotState, *mlx.Array, *mlx.Array) {
slot := s.mustSlot(caller, seqID)
if logits.NumDims() == 3 {
if logits.Dim(0) != 1 {
panic(fmt.Sprintf("sample.Sampler.%s: only batch size 1 is supported", caller))
}
logits = logits.Squeeze(0)
}
if logits.NumDims() != 2 {
panic(fmt.Sprintf("sample.Sampler.%s: logits must be rank 2 or 3, got rank %d", caller, logits.NumDims()))
}
if draftTokens != nil && draftTokens.NumDims() == 1 {
draftTokens = draftTokens.ExpandDims(0)
}
return slot, logits, draftTokens
}
// Commit appends already-selected tokens to seqID's repeat-penalty history.
// It is used after speculative sampling once the accepted continuation is
// known. Normal Sample calls continue to mutate history themselves.
func (s *Sampler) Commit(seqID int, tokens []int32) {
if len(tokens) == 0 {
return
}
slot, ok := s.byID[seqID]
if !ok {
panic(fmt.Sprintf("sample.Sampler.Commit: seqID %d not registered", seqID))
}
if !slot.opts.usesHistory() {
return
}
if s.history == nil {
panic(fmt.Sprintf("sample.Sampler.Commit: seqID %d has no history", seqID))
}
row := slices.Index(s.slots, slot)
width := s.historyWidth()
take := min(len(tokens), slot.opts.RepeatLastN)
startLen := slot.historyLen + len(tokens) - take
writeTokens := tokens[len(tokens)-take:]
flatOffsets := make([]int32, take)
for i := range take {
ringPos := (startLen + i) % slot.opts.RepeatLastN
flatOffsets[i] = int32(row*width + ringPos)
}
flatIdx := mlx.NewArrayInt32(flatOffsets, []int32{int32(take), 1})
values := mlx.NewArrayInt32(writeTokens, []int32{int32(take), 1})
flatHist := s.history.Reshape(s.history.Dim(0)*width, 1)
s.history.Set(flatHist.PutAlongAxis(flatIdx, values, 0).Reshape(s.history.Dim(0), width))
slot.historyLen += len(tokens)
}
func (s *Sampler) speculativeDistributionSerial(slot *slotState, logits *mlx.Array, draftTokens *mlx.Array) Distribution {
rows := logits.Dim(0)
draftCount := 0
if draftTokens != nil {
draftCount = draftTokens.Dim(1)
}
row := slices.Index(s.slots, slot)
baseFill := min(slot.historyLen, slot.opts.RepeatLastN)
var base *mlx.Array
if baseFill > 0 {
base = s.history.Slice(mlx.Slice(row, row+1), mlx.Slice(0, baseFill))
}
dists := make([]Distribution, 0, rows)
for i := range rows {
rowLogits := logits.Slice(mlx.Slice(i, i+1), mlx.Slice())
hist := base
prefixLen := min(i, draftCount)
if prefixLen > 0 {
prefix := draftTokens.Slice(mlx.Slice(), mlx.Slice(0, prefixLen))
if hist == nil {
hist = prefix
} else {
hist = hist.Concatenate(1, prefix)
}
if hist.Dim(1) > slot.opts.RepeatLastN {
hist = hist.Slice(mlx.Slice(), mlx.Slice(hist.Dim(1)-slot.opts.RepeatLastN, mlx.End))
}
}
dists = append(dists, slot.distribution(&slotCtx{opts: slot.opts, history: hist}, rowLogits))
}
return ConcatenateDistributions(dists)
}
func (s *Sampler) speculativeScoresSerial(slot *slotState, logits *mlx.Array, draftTokens *mlx.Array) *mlx.Array {
return s.speculativeDistributionSerial(slot, logits, draftTokens).LogProbs(logits.Dim(logits.NumDims() - 1))
}
func (s *Sampler) speculativeHistory(slot *slotState, draftTokens *mlx.Array, rows int) *mlx.Array {
row := slices.Index(s.slots, slot)
width := slot.opts.RepeatLastN
base := s.history.Slice(mlx.Slice(row, row+1), mlx.Slice(0, width))
base = mlx.Tile(base, []int32{int32(rows), 1})
next := slot.historyLen % width
draftCount := 0
if draftTokens != nil {
draftCount = draftTokens.Dim(1)
}
if draftCount == 0 {
return base
}
sourceIdx := make([]int32, rows*width)
writeMask := make([]bool, rows*width)
for i := range rows {
prefixLen := min(i, draftCount)
for j := range prefixLen {
pos := (next + j) % width
sourceIdx[i*width+pos] = int32(j)
writeMask[i*width+pos] = true
}
}
draftRows := mlx.Tile(draftTokens, []int32{int32(rows), 1})
idx := mlx.NewArrayInt32(sourceIdx, []int32{int32(rows), int32(width)})
mask := mlx.FromValues(writeMask, rows, width)
values := draftRows.TakeAlongAxis(idx, 1)
return mlx.Where(mask, values, base)
}
func (slot *slotState) speculativeScores(ctx *slotCtx, logits *mlx.Array) *mlx.Array {
if slot.opts.Temperature == 0 {
return slot.baseScores(ctx, logits)
}
return slot.distribution(ctx, logits).LogProbs(logits.Dim(logits.NumDims() - 1))
}
// canBatch reports whether the call can take the uniform batched path.
// All slots must share Options; when penalties are active the call must
// additionally cover every registered slot in registration order with a
// full ring, because the uniform path indexes the pool positionally.
func (s *Sampler) canBatch(slots []*slotState) (Options, bool) {
if !s.allSameOpts {
return Options{}, false
}
// slots is non-empty (Sample guards) and every slot is registered,
// so s.slots[0].opts is the canonical shared value.
shared := s.slots[0].opts
// TODO(pdevine): Before using multi-slot batching with seeded stochastic sampling,
// make sure each row gets its own per-slot random key instead of sharing
// slots[0]'s key through one batched categorical op.
if !shared.usesHistory() {
return shared, true
}
if len(slots) != len(s.slots) {
return Options{}, false
}
for i, slot := range slots {
if s.slots[i] != slot || slot.historyLen < shared.RepeatLastN {
return Options{}, false
}
}
return shared, true
}
// sampleTokensUniform runs one fused sampling pass over the whole batch.
// Reached only when canBatch is true, which lets the pool be used in place
// with a single PutAlongAxis write-back and no gather.
func (s *Sampler) sampleTokensUniform(slots []*slotState, opts Options, logits *mlx.Array) *mlx.Array {
B := len(slots)
var hist *mlx.Array
if opts.usesHistory() {
hist = s.history
if s.historyWidth() > opts.RepeatLastN {
hist = hist.Slice(mlx.Slice(), mlx.Slice(0, opts.RepeatLastN))
}
}
ctx := &slotCtx{opts: opts, history: hist}
token := slots[0].sample(ctx, logits)
if opts.UseSeed && opts.Temperature != 0 {
// TODO: This only keeps counters aligned; it does not give each slot
// an independent key for the batched draw.
for _, slot := range slots[1:] {
slot.randomCounter++
}
}
if !opts.usesHistory() {
return token
}
writeIdxData := make([]int32, B)
for i, slot := range slots {
writeIdxData[i] = int32(slot.historyLen % opts.RepeatLastN)
slot.historyLen++
}
writeIdx := mlx.NewArrayInt32(writeIdxData, []int32{int32(B), 1})
s.history.Set(s.history.PutAlongAxis(writeIdx, token.ExpandDims(-1), 1))
return token
}
// sampleTokensSerial samples each slot against its own row of logits.
func (s *Sampler) sampleTokensSerial(slots []*slotState, logits *mlx.Array) *mlx.Array {
perSlotTokens := make([]*mlx.Array, len(slots))
rowOf := make(map[*slotState]int, len(s.slots))
for i, slot := range s.slots {
rowOf[slot] = i
}
for i, slot := range slots {
row := logits.Slice(mlx.Slice(i, i+1), mlx.Slice())
var hist *mlx.Array
if slot.opts.usesHistory() && slot.historyLen > 0 && s.history != nil {
poolRow := rowOf[slot]
fill := min(slot.historyLen, slot.opts.RepeatLastN)
hist = s.history.Slice(
mlx.Slice(poolRow, poolRow+1),
mlx.Slice(0, fill),
)
}
ctx := &slotCtx{opts: slot.opts, history: hist}
perSlotTokens[i] = slot.sample(ctx, row)
}
token := mlx.Concatenate(perSlotTokens, 0)
if s.history != nil {
// For each writing slot collect its flat (row-major) pool offset
// and the call-order position of its token. One PutAlongAxis on a
// flat view of the pool scatters all writes in a single op.
flatOffsets := make([]int32, 0, len(slots))
tokenPos := make([]int32, 0, len(slots))
for i, slot := range slots {
if !slot.opts.usesHistory() {
continue
}
ringPos := slot.historyLen % slot.opts.RepeatLastN
flatOffsets = append(flatOffsets, int32(rowOf[slot]*s.historyWidth()+ringPos))
tokenPos = append(tokenPos, int32(i))
slot.historyLen++
}
if len(flatOffsets) > 0 {
m := len(flatOffsets)
flatIdx := mlx.NewArrayInt32(flatOffsets, []int32{int32(m), 1})
writingTokens := token
if m != len(slots) {
tokenPosIdx := mlx.NewArrayInt32(tokenPos, []int32{int32(m)})
writingTokens = token.TakeAxis(tokenPosIdx, 0)
}
flatHist := s.history.Reshape(s.history.Dim(0)*s.historyWidth(), 1)
s.history.Set(flatHist.PutAlongAxis(flatIdx, writingTokens.ExpandDims(-1), 0).Reshape(s.history.Dim(0), s.historyWidth()))
}
}
return token
}
func (slot *slotState) sample(ctx *slotCtx, logits *mlx.Array) *mlx.Array {
if slot.opts.Temperature == 0 {
return slot.baseScores(ctx, logits).Argmax(-1, false).AsType(mlx.DTypeInt32)
}
return slot.distribution(ctx, logits).SampleWithKey(slot.nextRandomKey())
}
func (slot *slotState) nextRandomKey() *mlx.Array {
if !slot.opts.UseSeed {
return nil
}
seed := mixSeed(uint64(slot.opts.Seed), slot.randomCounter)
slot.randomCounter++
return mlx.RandomKey(seed)
}
const (
// SplitMix64 constants used to decorrelate nearby (seed, counter) pairs.
splitMix64Weyl = 0x9e3779b97f4a7c15
splitMix64Mul1 = 0xbf58476d1ce4e5b9
splitMix64Mul2 = 0x94d049bb133111eb
splitMix64Shift1 = 30
splitMix64Shift2 = 27
splitMix64FinalShift = 31
)
func mixSeed(seed, counter uint64) uint64 {
z := seed + splitMix64Weyl*(counter+1)
z = (z ^ (z >> splitMix64Shift1)) * splitMix64Mul1
z = (z ^ (z >> splitMix64Shift2)) * splitMix64Mul2
return z ^ (z >> splitMix64FinalShift)
}
func (slot *slotState) baseScores(ctx *slotCtx, logits *mlx.Array) *mlx.Array {
scores := logits
if slot.opts.usesHistory() {
scores = penalty(ctx, scores)
}
return scores
}
func (slot *slotState) distribution(ctx *slotCtx, logits *mlx.Array) Distribution {
scores := slot.baseScores(ctx, logits)
if slot.opts.Temperature <= 0 {
ids := scores.Argmax(-1, false).AsType(mlx.DTypeInt32).ExpandDims(-1)
probs := mlx.AddScalar(ids.AsType(mlx.DTypeFloat32).Multiply(mlx.FromValue(float32(0))), 1)
return Distribution{IDs: ids, Probs: probs}
}
vocab := scores.Dim(scores.NumDims() - 1)
if slot.opts.TopK > 0 && slot.opts.TopK < vocab {
return sparseDistribution(ctx.opts, scores)
}
return denseDistribution(ctx.opts, scores)
}
func sparseDistribution(opts Options, scores *mlx.Array) Distribution {
ids := scores.Negative().ArgpartitionAxis(opts.TopK-1, -1).Slice(mlx.Slice(), mlx.Slice(0, opts.TopK)).AsType(mlx.DTypeInt32)
topScores := scores.TakeAlongAxis(ids, -1).AsType(mlx.DTypeFloat32)
probs := mlx.SoftmaxAxis(mlx.DivScalar(topScores, opts.Temperature), -1, true)
probs = applyTopPProbs(probs, opts.TopP)
probs = applyMinPProbs(probs, opts.MinP)
return Distribution{IDs: ids, Probs: normalizeProbs(probs)}
}
func denseDistribution(opts Options, scores *mlx.Array) Distribution {
probs := mlx.SoftmaxAxis(mlx.DivScalar(scores.AsType(mlx.DTypeFloat32), opts.Temperature), -1, true)
probs = applyTopPProbs(probs, opts.TopP)
probs = applyMinPProbs(probs, opts.MinP)
return Distribution{Probs: normalizeProbs(probs)}
}
func applyTopPProbs(probs *mlx.Array, topP float32) *mlx.Array {
if topP <= 0 || topP >= 1 {
return probs
}
order := probs.Negative().ArgsortAxis(-1)
sorted := probs.TakeAlongAxis(order, -1)
prevCumProbs := sorted.Cumsum(-1, false, true).Subtract(sorted)
keep := prevCumProbs.Less(mlx.FromValue(topP))
filtered := mlx.Where(keep, sorted, mlx.FromValue(float32(0)))
return mlx.Zeros(probs.DType(), probs.Dims()...).PutAlongAxis(order, filtered, -1)
}
func applyMinPProbs(probs *mlx.Array, minP float32) *mlx.Array {
if minP <= 0 || minP > 1 {
return probs
}
threshold := mlx.MulScalar(probs.MaxAxis(-1, true), minP)
return mlx.Where(probs.Less(threshold), mlx.FromValue(float32(0)), probs)
}
func normalizeProbs(probs *mlx.Array) *mlx.Array {
sum := mlx.Maximum(probs.SumAxis(-1, true), mlx.FromValue(float32(1e-20)))
return probs.Divide(sum)
}
func logitsFromProbs(probs *mlx.Array) *mlx.Array {
positive := mlx.Maximum(probs, mlx.FromValue(float32(1e-20)))
logits := mlx.Log(positive)
return mlx.Where(probs.LessEqual(mlx.FromValue(float32(0))), mlx.FromValue(float32(math.Inf(-1))), logits)
}
func penalty(ctx *slotCtx, scores *mlx.Array) *mlx.Array {
tokenIndices := ctx.history
if tokenIndices == nil {
return scores
}
if ctx.opts.RepeatPenalty != 1 || ctx.opts.PresencePenalty != 0 {
adjusted := scores.TakeAlongAxis(tokenIndices, -1)
if ctx.opts.RepeatPenalty != 1 {
factor := mlx.Where(
adjusted.Less(mlx.FromValue(float32(0))),
mlx.FromValue(ctx.opts.RepeatPenalty),
mlx.FromValue(1/ctx.opts.RepeatPenalty),
)
adjusted = adjusted.Multiply(factor)
}
if ctx.opts.PresencePenalty != 0 {
adjusted = mlx.AddScalar(adjusted, -ctx.opts.PresencePenalty)
}
scores = scores.PutAlongAxis(tokenIndices, adjusted, -1)
}
if ctx.opts.FrequencyPenalty != 0 {
scores = scores.ScatterAddAxis(tokenIndices, mlx.FromValue(-ctx.opts.FrequencyPenalty), -1)
}
return scores
}

View File

@@ -0,0 +1,492 @@
//go:build mlx
package sample
import (
"math"
"slices"
"testing"
"github.com/ollama/ollama/x/mlxrunner/mlx"
)
func skipIfNoMLX(t *testing.T) {
t.Helper()
if err := mlx.CheckInit(); err != nil {
t.Skipf("MLX not available: %v", err)
}
}
// slotLogits builds a [1, V] logits tensor for a single-slot Sample call.
func slotLogits(values []float32) *mlx.Array {
return mlx.FromValues(values, 1, len(values))
}
// batchLogits stacks per-row float32 slices of equal length into a [B, V]
// logits tensor.
func batchLogits(rows ...[]float32) *mlx.Array {
v := len(rows[0])
flat := make([]float32, 0, len(rows)*v)
for _, r := range rows {
if len(r) != v {
panic("batchLogits: rows must share vocab size")
}
flat = append(flat, r...)
}
return mlx.FromValues(flat, len(rows), v)
}
// sampleOne runs Sample on a freshly-added single slot and returns the
// sampled token id. Used both for the single-slot options table and as the
// reference oracle for the batched-equivalence test.
func sampleOne(t *testing.T, opts Options, priorTokens []int32, values []float32) int {
t.Helper()
s := New(128)
t.Cleanup(func() {
s.Free()
mlx.Sweep()
})
s.Add(0, opts, priorTokens)
got := s.Sample([]int{0}, slotLogits(values)).Token
mlx.Eval(got)
return got.Int()
}
// logOf returns log(p) as a float32 so tests can build logits that softmax to
// a chosen probability distribution.
func logOf(p float64) float32 { return float32(math.Log(p)) }
// TestSampleSingleSlotOptions pins the per-slot behavior of each Options
// knob against a concrete expected token. Expected values are worked out by
// hand from the math of each transform, not from a second call into the
// sampler — so a regression in any single transform shows up here.
func TestSampleSingleSlotOptions(t *testing.T) {
skipIfNoMLX(t)
cases := []struct {
name string
opts Options
priors []int32
logits []float32
want int
}{
{
name: "presence penalty",
opts: Options{RepeatLastN: 1, PresencePenalty: 6},
priors: []int32{1},
logits: []float32{0, 5, 4},
want: 2, // token 1: 5 - 6 = -1, argmax shifts to 2
},
{
name: "repeat penalty on positive logits",
opts: Options{RepeatLastN: 1, RepeatPenalty: 2},
priors: []int32{1},
logits: []float32{0, 5, 4},
want: 2, // token 1 positive → divided: 5/2 = 2.5, argmax shifts to 2
},
{
name: "repeat penalty on negative logits",
opts: Options{RepeatLastN: 1, RepeatPenalty: 4},
priors: []int32{1},
logits: []float32{-5, -1, -3},
want: 2, // token 1 negative → multiplied: -1*4 = -4, argmax shifts to 2
},
{
name: "frequency penalty",
opts: Options{RepeatLastN: 4, FrequencyPenalty: 2},
priors: []int32{1, 1},
logits: []float32{0, 5, 4},
want: 2, // 5 - 2*count(1)=2*2=4 → 1, argmax shifts to 2
},
{
name: "top-k",
opts: Options{Temperature: 1, TopK: 1},
logits: []float32{1, 5, 4},
want: 1, // only argmax survives → deterministic even with temperature
},
{
name: "top-p",
opts: Options{Temperature: 1, TopP: 0.4},
logits: []float32{logOf(0.5), logOf(0.3), logOf(0.2)},
want: 0, // exclusive cumsum below 0.4 keeps only token 0
},
{
name: "min-p",
opts: Options{Temperature: 1, MinP: 0.7},
logits: []float32{logOf(0.5), logOf(0.3), logOf(0.2)},
want: 0, // threshold 0.5*0.7=0.35 drops all but the top token
},
{
name: "RepeatLastN=0 disables penalties",
opts: Options{RepeatLastN: 0, RepeatPenalty: 2, PresencePenalty: 10},
priors: []int32{1},
logits: []float32{0, 5, 4},
want: 1, // 0 = disabled per API contract, argmax unchanged
},
{
name: "RepeatLastN=-1 resolves to num_ctx",
opts: Options{RepeatLastN: -1, PresencePenalty: 6},
priors: []int32{1},
logits: []float32{0, 5, 4},
want: 2, // -1 → num_ctx (128); penalty applies, argmax shifts
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
if got := sampleOne(t, tc.opts, tc.priors, tc.logits); got != tc.want {
t.Errorf("got %d, want %d", got, tc.want)
}
})
}
}
func TestDistributionAppliesTopKBeforeTopP(t *testing.T) {
skipIfNoMLX(t)
s := New(128)
t.Cleanup(func() {
s.Free()
mlx.Sweep()
})
s.Add(0, Options{Temperature: 1, TopK: 2, TopP: 0.7}, nil)
dist := s.Distribution(0, slotLogits([]float32{logOf(0.6), logOf(0.2), logOf(0.2)}), nil)
mlx.Eval(dist.Arrays()...)
ids := dist.IDs.Ints()
probs := dist.Probs.Floats()
if len(ids) != 2 || len(probs) != 2 {
t.Fatalf("support = ids %v probs %v, want 2 sparse entries", ids, probs)
}
foundTop := false
for i, id := range ids {
switch id {
case 0:
foundTop = true
if math.Abs(float64(probs[i]-1)) > 1e-5 {
t.Fatalf("top token prob = %v, want 1; ids=%v probs=%v", probs[i], ids, probs)
}
default:
if math.Abs(float64(probs[i])) > 1e-5 {
t.Fatalf("non-top token %d prob = %v, want 0; ids=%v probs=%v", id, probs[i], ids, probs)
}
}
}
if !foundTop {
t.Fatalf("top-k support %v did not include token 0", ids)
}
}
func TestDistributionResidualUsesTargetSupport(t *testing.T) {
skipIfNoMLX(t)
target := Distribution{
IDs: mlx.NewArrayInt32([]int32{2, 5}, []int32{1, 2}),
Probs: mlx.FromValues([]float32{0.7, 0.3}, 1, 2),
}
draft := Distribution{
IDs: mlx.NewArrayInt32([]int32{2, 4}, []int32{1, 2}),
Probs: mlx.FromValues([]float32{0.2, 0.8}, 1, 2),
}
residual := target.ResidualAgainst(draft)
mlx.Eval(residual.Arrays()...)
ids := residual.IDs.Ints()
probs := residual.Probs.Floats()
want := map[int]float64{2: 0.625, 5: 0.375}
if len(ids) != 2 || len(probs) != 2 {
t.Fatalf("residual = ids %v probs %v, want 2 sparse entries", ids, probs)
}
for i, id := range ids {
w, ok := want[id]
if !ok {
t.Fatalf("residual includes token %d outside target support: ids=%v probs=%v", id, ids, probs)
}
if math.Abs(float64(probs[i])-w) > 1e-5 {
t.Fatalf("residual token %d prob = %v, want %v; ids=%v probs=%v", id, probs[i], w, ids, probs)
}
}
}
func TestSeededSamplingIsReproducible(t *testing.T) {
skipIfNoMLX(t)
seededSequence := func(seed int) []int {
s := New(128)
t.Cleanup(func() {
s.Free()
mlx.Sweep()
})
s.Add(0, Options{Temperature: 1, TopK: 4, Seed: seed, UseSeed: true}, nil)
logits := slotLogits([]float32{0, 0, 0, 0})
out := make([]int, 32)
for i := range out {
token := s.Sample([]int{0}, logits).Token
mlx.Eval(token)
out[i] = token.Int()
}
return out
}
a := seededSequence(1234)
b := seededSequence(1234)
if !slices.Equal(a, b) {
t.Fatalf("same seed produced different sequences:\n%v\n%v", a, b)
}
c := seededSequence(5678)
if slices.Equal(a, c) {
t.Fatalf("different seeds produced the same sequence: %v", a)
}
}
func TestSeededBernoulliIsReproducible(t *testing.T) {
skipIfNoMLX(t)
seededMask := func() []int {
s := New(128)
t.Cleanup(func() {
s.Free()
mlx.Sweep()
})
s.Add(0, Options{Seed: 99, UseSeed: true}, nil)
mask := s.Bernoulli(0, mlx.FromValues([]float32{0.5, 0.5, 0.5, 0.5, 0.5, 0.5}, 6)).AsType(mlx.DTypeInt32)
mlx.Eval(mask)
return mask.Ints()
}
a := seededMask()
b := seededMask()
if !slices.Equal(a, b) {
t.Fatalf("same seed produced different bernoulli masks:\n%v\n%v", a, b)
}
}
// TestSampleHistoryWindow verifies that penalty history respects the
// RepeatLastN window: priors longer than RepeatLastN are trimmed on Add,
// and once the ring wraps, tokens that rotate out no longer contribute
// to penalties.
func TestSampleHistoryWindow(t *testing.T) {
skipIfNoMLX(t)
s := New(128)
t.Cleanup(func() {
s.Free()
mlx.Sweep()
})
// RepeatLastN=2 with priors {1, 2, 3}: makeHistoryRow keeps only
// {2, 3}. Token 1 was trimmed — its penalty is NOT active.
s.Add(0, Options{RepeatLastN: 2, PresencePenalty: 10}, []int32{1, 2, 3})
// Step 1: logits favor token 1 (trimmed). If the trim were broken it
// would be penalized and the argmax would move.
step1 := s.Sample([]int{0}, slotLogits([]float32{0, 5, 0, 0, 0})).Token
mlx.Eval(step1)
if got := step1.Int(); got != 1 {
t.Fatalf("step 1 = %d, want 1 (token 1 trimmed from priors)", got)
}
// After step 1 the ring holds {1, 3}; token 2 has rotated out.
// Step 2: logits favor token 2 (rotated out). If the ring wrap were
// wrong, token 2 would still be penalized.
step2 := s.Sample([]int{0}, slotLogits([]float32{0, 0, 5, 0, 0})).Token
mlx.Eval(step2)
if got := step2.Int(); got != 2 {
t.Fatalf("step 2 = %d, want 2 (token 2 rotated out of ring)", got)
}
}
func TestSpeculativeScoresUsesDraftHistoryWithoutCommit(t *testing.T) {
skipIfNoMLX(t)
s := New(128)
t.Cleanup(func() {
s.Free()
mlx.Sweep()
})
s.Add(0, Options{RepeatLastN: 2, RepeatPenalty: 10}, []int32{1, 2})
draftTokens := mlx.NewArrayInt32([]int32{3, 4}, []int32{1, 2})
scores := s.SpeculativeScores(0, batchLogits(
[]float32{0, 9, 9, 8, 0}, // history {1,2}; token 3 wins
[]float32{0, 0, 9, 9, 8}, // history {2,3}; token 4 wins
[]float32{0, 0, 9, 9, 8}, // history {3,4}; token 2 wins
), draftTokens)
tokens := scores.Argmax(-1, false).AsType(mlx.DTypeInt32)
mlx.Eval(tokens)
if got, want := tokens.Ints(), []int{3, 4, 2}; len(got) != len(want) {
t.Fatalf("tokens = %v, want %v", got, want)
} else {
for i := range want {
if got[i] != want[i] {
t.Fatalf("tokens = %v, want %v", got, want)
}
}
}
if s.byID[0].historyLen != 2 {
t.Fatalf("historyLen = %d, want 2", s.byID[0].historyLen)
}
}
func TestCommitBatchesRingWrites(t *testing.T) {
skipIfNoMLX(t)
s := New(128)
t.Cleanup(func() {
s.Free()
mlx.Sweep()
})
s.Add(0, Options{RepeatLastN: 4, RepeatPenalty: 1.1}, []int32{10, 11, 12})
s.Commit(0, []int32{20, 21, 22})
s.Commit(0, []int32{30, 31, 32, 33, 34})
mlx.Eval(s.history)
got := s.history.Ints()
want := []int{32, 33, 34, 31}
for i := range want {
if got[i] != want[i] {
t.Fatalf("history = %v, want %v", got, want)
}
}
if s.byID[0].historyLen != 11 {
t.Fatalf("historyLen = %d, want 11", s.byID[0].historyLen)
}
}
// TestBatchSamplingPreservesPerSlotBehavior is the core equivalence test:
// for every representative dispatch branch (uniform, serial on mixed opts,
// serial on partial ring, subset/out-of-order), a batched Sample call must
// produce the same token per row as running the same slot alone.
func TestBatchSamplingPreservesPerSlotBehavior(t *testing.T) {
skipIfNoMLX(t)
type slot struct {
id int
opts Options
priors []int32
}
cases := []struct {
name string
slots []slot
sample []int
rows [][]float32
}{
{
name: "uniform",
slots: []slot{
{10, Options{RepeatLastN: 2, PresencePenalty: 5}, []int32{1, 2}},
{20, Options{RepeatLastN: 2, PresencePenalty: 5}, []int32{0, 2}},
},
sample: []int{10, 20},
rows: [][]float32{{0, 5, 4}, {3, 0, 0}},
},
{
name: "serial — mixed opts",
slots: []slot{
{1, Options{RepeatLastN: 1, RepeatPenalty: 2}, []int32{1}},
{2, Options{Temperature: 1, TopK: 1}, nil},
},
sample: []int{1, 2},
rows: [][]float32{{0, 5, 4, 1}, {2, 1, 5, 3}},
},
{
name: "serial — partial ring",
slots: []slot{
{1, Options{RepeatLastN: 4, PresencePenalty: 5}, []int32{1, 1, 1, 1}},
{2, Options{RepeatLastN: 4, PresencePenalty: 5}, []int32{2}},
},
sample: []int{1, 2},
rows: [][]float32{{0, 5, 4}, {0, 4, 5}},
},
{
name: "subset out-of-order",
slots: []slot{
{10, Options{RepeatLastN: 2, PresencePenalty: 10}, []int32{1, 1}},
{20, Options{RepeatLastN: 2, PresencePenalty: 10}, []int32{2, 2}},
{30, Options{RepeatLastN: 2, PresencePenalty: 10}, []int32{3, 3}},
},
sample: []int{30, 10},
rows: [][]float32{{5, 5, 5, 0, 5, 5}, {5, 0, 5, 5, 0, 5}},
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
// Per-slot reference for each sampled seq.
want := make([]int, len(tc.sample))
for i, id := range tc.sample {
var spec slot
for _, s := range tc.slots {
if s.id == id {
spec = s
break
}
}
want[i] = sampleOne(t, spec.opts, spec.priors, tc.rows[i])
}
// Batched call.
s := New(128)
t.Cleanup(func() {
s.Free()
mlx.Sweep()
})
for _, spec := range tc.slots {
s.Add(spec.id, spec.opts, spec.priors)
}
res := s.Sample(tc.sample, batchLogits(tc.rows...))
mlx.Eval(res.Token)
got := res.Token.Ints()
for i, id := range tc.sample {
if got[i] != want[i] {
t.Errorf("seq %d: batched = %d, per-slot = %d", id, got[i], want[i])
}
}
})
}
}
// TestRemoveDoesNotLeakHistory: after Remove, a newly-added slot at the
// recycled row must start from its own priors only — no carryover from
// the removed slot's history.
func TestRemoveDoesNotLeakHistory(t *testing.T) {
skipIfNoMLX(t)
opts := Options{RepeatLastN: 1, PresencePenalty: 10}
s := New(128)
t.Cleanup(func() {
s.Free()
mlx.Sweep()
})
s.Add(1, opts, []int32{1})
s.Add(2, opts, []int32{2})
s.Remove(1)
s.Add(3, opts, []int32{0})
// Slot 2 retains history {2}; slot 3 retains history {0}. With
// equal logits and PresencePenalty=10 the argmax drops to the first
// unpenalized token.
res := s.Sample([]int{2, 3}, batchLogits(
[]float32{3, 3, 0},
[]float32{3, 3, 0},
))
mlx.Eval(res.Token)
tokens := res.Token.Ints()
if tokens[0] != 0 {
t.Errorf("slot 2 = %d, want 0 (token 2 penalized)", tokens[0])
}
if tokens[1] != 1 {
t.Errorf("slot 3 = %d, want 1 (token 0 penalized, no slot-1 carryover)", tokens[1])
}
}

247
x/mlxrunner/server.go Normal file
View File

@@ -0,0 +1,247 @@
package mlxrunner
import (
"bytes"
"context"
"encoding/json"
"flag"
"fmt"
"io"
"log/slog"
"net/http"
"os"
"strconv"
"time"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/x/internal/mlxthread"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/sample"
)
func Execute(args []string) error {
slog.SetDefault(logutil.NewLogger(os.Stderr, envconfig.LogLevel()))
var (
modelName string
port int
)
flagSet := flag.NewFlagSet("mlxrunner", flag.ExitOnError)
flagSet.StringVar(&modelName, "model", "", "Model name")
flagSet.IntVar(&port, "port", 0, "Port to listen on")
_ = flagSet.Bool("verbose", false, "Enable debug logging")
flagSet.Parse(args)
worker, err := mlxthread.Start("mlxrunner", func() error {
if err := mlx.CheckInit(); err != nil {
return fmt.Errorf("MLX not available: %w", err)
}
if mlx.GPUIsAvailable() {
mlx.SetDefaultDeviceGPU()
slog.Info("MLX engine initialized", "MLX version", mlx.Version(), "device", "gpu")
} else {
slog.Info("MLX engine initialized", "MLX version", mlx.Version(), "device", "cpu")
}
return nil
})
if err != nil {
return err
}
defer worker.Stop(context.Background(), func() {
mlx.Sweep()
mlx.ClearCache()
})
runnerCtx, cancelRunner := context.WithCancel(context.Background())
defer cancelRunner()
runner := Runner{
Requests: make(chan Request),
mlxThread: worker,
}
if err := worker.Do(context.Background(), func() error {
return runner.Load(modelName)
}); err != nil {
return err
}
readMemory := func() (uint64, error) {
return uint64(mlx.ActiveMemory() + mlx.CacheMemory()), nil
}
initialMemory, err := mlxthread.Call(context.Background(), worker, readMemory)
if err != nil {
return err
}
memoryCache := newStatusMemoryCache(
runnerCtx,
initialMemory,
time.Now(),
statusMemoryRefreshWait,
func() (uint64, error) {
return mlxthread.Call(runnerCtx, worker, readMemory)
},
)
mux := http.NewServeMux()
mux.HandleFunc("GET /v1/status", func(w http.ResponseWriter, r *http.Request) {
if err := json.NewEncoder(w).Encode(statusResponse{
Status: 0,
Progress: 100,
ContextLength: runner.contextLength,
Memory: memoryCache.Memory(),
}); err != nil {
slog.Error("Failed to encode response", "error", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
})
mux.HandleFunc("/v1/models", func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case "POST":
fallthrough
case "GET":
if err := json.NewEncoder(w).Encode(map[string]any{
"Success": true,
}); err != nil {
slog.Error("Failed to encode response", "error", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
case "DELETE":
// TODO: cleanup model and cache
}
})
mux.HandleFunc("POST /v1/completions", func(w http.ResponseWriter, r *http.Request) {
request := Request{Responses: make(chan CompletionResponse)}
if err := json.NewDecoder(r.Body).Decode(&request.CompletionRequest); err != nil {
slog.Error("Failed to decode request", "error", err)
http.Error(w, "Bad Request", http.StatusBadRequest)
return
}
request.Pipeline = runner.TextGenerationPipeline
request.SamplerOpts = sample.Options{
Temperature: request.Options.Temperature,
TopP: request.Options.TopP,
MinP: request.Options.MinP,
TopK: request.Options.TopK,
RepeatLastN: request.Options.RepeatLastN,
RepeatPenalty: request.Options.RepeatPenalty,
PresencePenalty: request.Options.PresencePenalty,
FrequencyPenalty: request.Options.FrequencyPenalty,
Seed: request.Options.Seed,
UseSeed: request.Options.Seed >= 0,
Logprobs: request.Logprobs,
TopLogprobs: request.TopLogprobs,
}
if err := runner.Prepare(&request); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
var cancel context.CancelFunc
request.Ctx, cancel = context.WithCancel(r.Context())
defer cancel()
select {
case <-r.Context().Done():
return
case runner.Requests <- request:
}
w.Header().Set("Content-Type", "application/jsonl")
w.WriteHeader(http.StatusOK)
enc := json.NewEncoder(w)
for {
select {
case <-r.Context().Done():
return
case response, ok := <-request.Responses:
if !ok {
return
}
if err := enc.Encode(response); err != nil {
slog.Error("Failed to encode response", "error", err)
return
}
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
}
}
})
mux.HandleFunc("POST /v1/tokenize", func(w http.ResponseWriter, r *http.Request) {
var b bytes.Buffer
if _, err := io.Copy(&b, r.Body); err != nil {
slog.Error("Failed to read request body", "error", err)
http.Error(w, "Bad Request", http.StatusBadRequest)
return
}
tokens := runner.Tokenizer.Encode(b.String(), runner.Tokenizer.AddBOS())
if err := json.NewEncoder(w).Encode(tokens); err != nil {
slog.Error("Failed to encode response", "error", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
})
for source, target := range map[string]string{
"GET /health": "/v1/status",
"POST /load": "/v1/models",
"POST /completion": "/v1/completions",
} {
mux.Handle(source, http.RedirectHandler(target, http.StatusPermanentRedirect))
}
return runner.Run("127.0.0.1", strconv.Itoa(port), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
recorder := &statusRecorder{ResponseWriter: w, code: http.StatusOK}
t := time.Now()
mux.ServeHTTP(recorder, r)
var level slog.Level
switch {
case recorder.code >= 500:
level = slog.LevelError
case recorder.code >= 400:
level = slog.LevelWarn
case recorder.code >= 300:
return
}
slog.Log(r.Context(), level, "ServeHTTP", "method", r.Method, "path", r.URL.Path, "took", time.Since(t), "status", recorder.Status())
}))
}
type statusRecorder struct {
http.ResponseWriter
code int
}
func (w *statusRecorder) WriteHeader(code int) {
w.code = code
w.ResponseWriter.WriteHeader(code)
}
func (w *statusRecorder) Status() string {
return strconv.Itoa(w.code) + " " + http.StatusText(w.code)
}
func (w *statusRecorder) Flush() {
if f, ok := w.ResponseWriter.(http.Flusher); ok {
f.Flush()
}
}

View File

@@ -0,0 +1,109 @@
package mlxrunner
import (
"context"
"log/slog"
"sync"
"time"
)
const statusMemoryRefreshWait = 50 * time.Millisecond
type statusMemoryRefreshFunc func() (uint64, error)
// statusMemoryCache keeps health checks from depending synchronously on the
// serialized MLX worker while still refreshing memory telemetry opportunistically.
type statusMemoryCache struct {
done <-chan struct{}
wait time.Duration
refresh statusMemoryRefreshFunc
mu sync.Mutex
memory uint64
refreshedAt time.Time
inFlight chan struct{}
}
func newStatusMemoryCache(ctx context.Context, memory uint64, refreshedAt time.Time, wait time.Duration, refresh statusMemoryRefreshFunc) *statusMemoryCache {
return &statusMemoryCache{
done: ctx.Done(),
wait: wait,
refresh: refresh,
memory: memory,
refreshedAt: refreshedAt,
}
}
func (c *statusMemoryCache) Memory() uint64 {
done := c.startRefresh()
if c.wait <= 0 {
<-done
memory, _ := c.snapshot()
return memory
}
timer := time.NewTimer(c.wait)
defer timer.Stop()
select {
case <-done:
case <-timer.C:
memory, refreshedAt := c.snapshot()
if refreshedAt.IsZero() {
slog.Debug("using cached MLX memory status before first refresh")
} else {
slog.Debug("using cached MLX memory status", "stale", time.Since(refreshedAt))
}
return memory
case <-c.done:
}
memory, _ := c.snapshot()
return memory
}
func (c *statusMemoryCache) startRefresh() chan struct{} {
c.mu.Lock()
if c.inFlight != nil {
done := c.inFlight
c.mu.Unlock()
return done
}
refreshDone := make(chan struct{})
c.inFlight = refreshDone
refresh := c.refresh
lifecycleDone := c.done
c.mu.Unlock()
go func() {
memory, err := refresh()
now := time.Now()
c.mu.Lock()
defer c.mu.Unlock()
defer close(refreshDone)
if err != nil {
select {
case <-lifecycleDone:
default:
slog.Debug("failed to refresh MLX memory status", "error", err)
}
c.inFlight = nil
return
}
c.memory = memory
c.refreshedAt = now
c.inFlight = nil
}()
return refreshDone
}
func (c *statusMemoryCache) snapshot() (uint64, time.Time) {
c.mu.Lock()
defer c.mu.Unlock()
return c.memory, c.refreshedAt
}

View File

@@ -0,0 +1,246 @@
package mlxrunner
import (
"context"
"errors"
"sync"
"sync/atomic"
"testing"
"time"
)
func TestStatusMemoryCacheWaitsForFastRefresh(t *testing.T) {
var calls atomic.Int32
cache := newStatusMemoryCache(context.Background(), 7, time.Now().Add(-time.Minute), time.Second, func() (uint64, error) {
calls.Add(1)
return 42, nil
})
if got := cache.Memory(); got != 42 {
t.Fatalf("got memory %d, want 42", got)
}
if got := calls.Load(); got != 1 {
t.Fatalf("refresh calls = %d, want 1", got)
}
}
func TestStatusMemoryCacheSupportsBlockingWait(t *testing.T) {
cache := newStatusMemoryCache(context.Background(), 7, time.Now().Add(-time.Minute), 0, func() (uint64, error) {
return 42, nil
})
if got := cache.Memory(); got != 42 {
t.Fatalf("got memory %d, want 42", got)
}
}
func TestStatusMemoryCacheReturnsCachedValueAndRefreshesLater(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
started := make(chan struct{})
release := make(chan struct{})
var calls atomic.Int32
cache := newStatusMemoryCache(ctx, 7, time.Now().Add(-time.Minute), time.Millisecond, func() (uint64, error) {
if calls.Add(1) == 1 {
close(started)
}
select {
case <-release:
return 42, nil
case <-ctx.Done():
return 0, ctx.Err()
}
})
start := time.Now()
if got := cache.Memory(); got != 7 {
t.Fatalf("got memory %d, want cached value 7", got)
}
if elapsed := time.Since(start); elapsed > time.Second {
t.Fatalf("cached memory lookup took too long: %s", elapsed)
}
waitForRefreshStart(t, started)
close(release)
waitForCachedMemory(t, cache, 42)
if got := calls.Load(); got != 1 {
t.Fatalf("refresh calls = %d, want 1", got)
}
}
func TestStatusMemoryCacheReturnsCachedValueBeforeFirstRefresh(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
started := make(chan struct{})
release := make(chan struct{})
cache := newStatusMemoryCache(ctx, 7, time.Time{}, time.Millisecond, func() (uint64, error) {
close(started)
select {
case <-release:
return 42, nil
case <-ctx.Done():
return 0, ctx.Err()
}
})
if got := cache.Memory(); got != 7 {
t.Fatalf("got memory %d, want cached value 7", got)
}
waitForRefreshStart(t, started)
close(release)
waitForCachedMemory(t, cache, 42)
}
func TestStatusMemoryCacheKeepsCachedValueWhenRefreshFails(t *testing.T) {
var calls atomic.Int32
cache := newStatusMemoryCache(context.Background(), 7, time.Now().Add(-time.Minute), time.Second, func() (uint64, error) {
calls.Add(1)
return 0, errors.New("refresh failed")
})
if got := cache.Memory(); got != 7 {
t.Fatalf("got memory %d, want cached value 7", got)
}
if got := calls.Load(); got != 1 {
t.Fatalf("refresh calls = %d, want 1", got)
}
}
func TestStatusMemoryCacheReturnsCachedValueWhenContextDone(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
started := make(chan struct{})
release := make(chan struct{})
cache := newStatusMemoryCache(ctx, 7, time.Now().Add(-time.Minute), time.Second, func() (uint64, error) {
close(started)
<-release
return 0, ctx.Err()
})
cancel()
if got := cache.Memory(); got != 7 {
t.Fatalf("got memory %d, want cached value 7", got)
}
waitForRefreshStart(t, started)
close(release)
waitForInflightRefresh(t, cache)
}
func TestStatusMemoryCacheAllowsRefreshAfterFailure(t *testing.T) {
var calls atomic.Int32
cache := newStatusMemoryCache(context.Background(), 7, time.Now().Add(-time.Minute), time.Second, func() (uint64, error) {
if calls.Add(1) == 1 {
return 0, errors.New("refresh failed")
}
return 42, nil
})
if got := cache.Memory(); got != 7 {
t.Fatalf("got memory %d, want cached value 7", got)
}
if got := cache.Memory(); got != 42 {
t.Fatalf("got memory %d after retry, want 42", got)
}
if got := calls.Load(); got != 2 {
t.Fatalf("refresh calls = %d, want 2", got)
}
}
func TestStatusMemoryCacheAllowsOneInflightRefresh(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
started := make(chan struct{})
release := make(chan struct{})
var calls atomic.Int32
cache := newStatusMemoryCache(ctx, 11, time.Now().Add(-time.Minute), time.Millisecond, func() (uint64, error) {
if calls.Add(1) == 1 {
close(started)
}
select {
case <-release:
return 99, nil
case <-ctx.Done():
return 0, ctx.Err()
}
})
const goroutines = 8
var wg sync.WaitGroup
errCh := make(chan string, goroutines)
for range goroutines {
wg.Add(1)
go func() {
defer wg.Done()
if got := cache.Memory(); got != 11 {
errCh <- "got non-cached memory value"
}
}()
}
wg.Wait()
close(errCh)
for err := range errCh {
t.Fatal(err)
}
waitForRefreshStart(t, started)
if got := calls.Load(); got != 1 {
t.Fatalf("refresh calls = %d, want 1", got)
}
close(release)
waitForCachedMemory(t, cache, 99)
}
func waitForRefreshStart(t *testing.T, started <-chan struct{}) {
t.Helper()
select {
case <-started:
case <-time.After(time.Second):
t.Fatal("timeout waiting for refresh to start")
}
}
func waitForCachedMemory(t *testing.T, cache *statusMemoryCache, want uint64) {
t.Helper()
deadline := time.After(time.Second)
for {
got, _ := cache.snapshot()
if got == want {
return
}
select {
case <-deadline:
t.Fatalf("cached memory = %d, want %d", got, want)
case <-time.After(time.Millisecond):
}
}
}
func waitForInflightRefresh(t *testing.T, cache *statusMemoryCache) {
t.Helper()
deadline := time.After(time.Second)
for {
cache.mu.Lock()
inFlight := cache.inFlight
cache.mu.Unlock()
if inFlight == nil {
return
}
select {
case <-deadline:
t.Fatal("timeout waiting for refresh to finish")
case <-time.After(time.Millisecond):
}
}
}

View File

@@ -0,0 +1,47 @@
package mlxrunner
import (
"bytes"
"unicode/utf8"
)
// flushValidUTF8Prefix returns and consumes the longest valid UTF-8 prefix
// currently buffered, leaving any incomplete trailing bytes in place.
func flushValidUTF8Prefix(b *bytes.Buffer) string {
data := b.Bytes()
if len(data) == 0 {
return ""
}
prefix := validUTF8PrefixLen(data)
if prefix == 0 {
return ""
}
text := string(data[:prefix])
b.Next(prefix)
return text
}
func validUTF8PrefixLen(data []byte) int {
i := 0
prefix := 0
for i < len(data) {
r, size := utf8.DecodeRune(data[i:])
if r == utf8.RuneError && size == 1 {
if !utf8.FullRune(data[i:]) {
break
}
// Invalid UTF-8 byte; consume one byte to guarantee forward progress.
i++
prefix = i
continue
}
i += size
prefix = i
}
return prefix
}

View File

@@ -0,0 +1,46 @@
package mlxrunner
import (
"bytes"
"testing"
)
func TestFlushValidUTF8Prefix_PreservesIncompleteRune(t *testing.T) {
var b bytes.Buffer
b.Write([]byte{0xE3, 0x81})
if got := flushValidUTF8Prefix(&b); got != "" {
t.Fatalf("first flush = %q, want empty", got)
}
b.Write([]byte{0x93, 0xE3})
if got := flushValidUTF8Prefix(&b); got != "こ" {
t.Fatalf("second flush = %q, want %q", got, "こ")
}
if got := b.Bytes(); !bytes.Equal(got, []byte{0xE3}) {
t.Fatalf("buffer after second flush = %v, want %v", got, []byte{0xE3})
}
b.Write([]byte{0x82, 0x93})
if got := flushValidUTF8Prefix(&b); got != "ん" {
t.Fatalf("third flush = %q, want %q", got, "ん")
}
if b.Len() != 0 {
t.Fatalf("buffer not empty after third flush: %d", b.Len())
}
}
func TestFlushValidUTF8Prefix_ValidText(t *testing.T) {
var b bytes.Buffer
b.WriteString("hello 世界")
if got := flushValidUTF8Prefix(&b); got != "hello 世界" {
t.Fatalf("flush = %q, want %q", got, "hello 世界")
}
if b.Len() != 0 {
t.Fatalf("buffer not empty after flush: %d", b.Len())
}
}