ollama source for Momentry Core verification
This commit is contained in:
42
x/mlxrunner/batch/batch.go
Normal file
42
x/mlxrunner/batch/batch.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package batch
|
||||
|
||||
import "github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
|
||||
// Batch is the per-forward-pass input handed to a model.
|
||||
type Batch struct {
|
||||
// InputIDs is the input token IDs for this forward pass, shape (B, L).
|
||||
InputIDs *mlx.Array
|
||||
|
||||
// SeqOffsets gives each row's current position within its sequence —
|
||||
// where the chunk in InputIDs starts. Length equals the batch dimension
|
||||
// of InputIDs.
|
||||
SeqOffsets []int32
|
||||
|
||||
// SeqQueryLens is each row's real query length in this forward. Values
|
||||
// less than L mean the row's tail is padding that must be masked out.
|
||||
// Length equals the batch dimension of InputIDs.
|
||||
SeqQueryLens []int32
|
||||
|
||||
// Memo is per-forward memoization used to cache results, such as masks,
|
||||
// which are often the same across layers.
|
||||
Memo Memo
|
||||
}
|
||||
|
||||
type Memo struct {
|
||||
entries map[any]any
|
||||
}
|
||||
|
||||
// Get returns the memoized value for key and true if present, or nil
|
||||
// and false otherwise.
|
||||
func (m *Memo) Get(key any) (any, bool) {
|
||||
v, ok := m.entries[key]
|
||||
return v, ok
|
||||
}
|
||||
|
||||
// Put stores value under key, allocating on first use.
|
||||
func (m *Memo) Put(key, value any) {
|
||||
if m.entries == nil {
|
||||
m.entries = map[any]any{}
|
||||
}
|
||||
m.entries[key] = value
|
||||
}
|
||||
632
x/mlxrunner/cache.go
Normal file
632
x/mlxrunner/cache.go
Normal file
@@ -0,0 +1,632 @@
|
||||
// cache.go manages a shared KV cache across conversations using a compressed
|
||||
// prefix trie. Each trie node stores a token sequence (edge) and optional
|
||||
// per-layer snapshots that can be paged in/out of the live MLX cache arrays.
|
||||
//
|
||||
// Key properties:
|
||||
// - Only one path through the trie is "active" (backed by live MLX arrays)
|
||||
// at a time. Switching paths pages out the frontier node and pages in the
|
||||
// new path.
|
||||
// - Snapshots are only captured at the frontier (end) of the active path.
|
||||
// Intermediate node snapshots come from split prefill.
|
||||
// - All cache layers must stay at the same token offset.
|
||||
// - Sibling edges must not share a common token prefix (compressed trie
|
||||
// invariant).
|
||||
// - begin() always re-evaluates at least one token so the pipeline can seed
|
||||
// generation, even on a full prefix match.
|
||||
|
||||
package mlxrunner
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||
)
|
||||
|
||||
const maxPagedOutBytes int64 = 8 << 30 // 8 GiB eviction threshold for paged-out snapshot memory
|
||||
|
||||
type kvCache struct {
|
||||
root *trieNode // root of the prefix trie
|
||||
activePath []*trieNode // current root→leaf path with live MLX arrays
|
||||
caches []cache.Cache
|
||||
pagedOutBytes int64 // total bytes in paged-out snapshots across the trie
|
||||
}
|
||||
|
||||
type cacheFactoryFunc func() []cache.Cache
|
||||
|
||||
// pendingSnapshot is a snapshot scheduled to be taken during prefill.
|
||||
type pendingSnapshot struct {
|
||||
offset int
|
||||
user bool
|
||||
}
|
||||
|
||||
// cacheSession manages caches for a single pipeline run.
|
||||
// Callers should append generated tokens to outputs and
|
||||
// defer close to save the cache state.
|
||||
type cacheSession struct {
|
||||
cache *kvCache
|
||||
inputs []int32
|
||||
outputs []int32
|
||||
|
||||
caches []cache.Cache
|
||||
remaining []int32
|
||||
|
||||
// pendingSnapshots lists offsets where snapshots should be captured
|
||||
// during prefill, sorted by offset. Entries are consumed as the
|
||||
// cache advances past them.
|
||||
pendingSnapshots []pendingSnapshot
|
||||
}
|
||||
|
||||
func (c *kvCache) ensureCachesWithFactory(newCaches cacheFactoryFunc) {
|
||||
if len(c.caches) != 0 {
|
||||
return
|
||||
}
|
||||
c.caches = newCaches()
|
||||
}
|
||||
|
||||
func newModelCaches(m base.Model) []cache.Cache {
|
||||
if cacheFactory, ok := m.(interface{ NewCaches() []cache.Cache }); ok {
|
||||
return cacheFactory.NewCaches()
|
||||
}
|
||||
caches := make([]cache.Cache, m.NumLayers())
|
||||
for i := range caches {
|
||||
caches[i] = cache.NewKVCache()
|
||||
}
|
||||
return caches
|
||||
}
|
||||
|
||||
func (c *kvCache) ensureRoot() {
|
||||
if c.root == nil {
|
||||
c.root = &trieNode{
|
||||
lastUsed: time.Now(),
|
||||
}
|
||||
c.activePath = []*trieNode{c.root}
|
||||
}
|
||||
}
|
||||
|
||||
// begin prepares caches for a new request. It finds the nearest
|
||||
// matching cache or creates new caches if none match.
|
||||
func (c *kvCache) begin(m base.Model, inputs []int32) *cacheSession {
|
||||
return c.beginWithFactory(inputs, func() []cache.Cache { return newModelCaches(m) }, "")
|
||||
}
|
||||
|
||||
func (c *kvCache) beginWithFactory(inputs []int32, newCaches cacheFactoryFunc, logPrefix string) *cacheSession {
|
||||
return c.beginWithFactoryLimit(inputs, newCaches, logPrefix, -1, true)
|
||||
}
|
||||
|
||||
func (c *kvCache) beginWithFactoryLimit(inputs []int32, newCaches cacheFactoryFunc, logPrefix string, maxCachedPrefix int, keepSeedToken bool) *cacheSession {
|
||||
c.ensureCachesWithFactory(newCaches)
|
||||
c.ensureRoot()
|
||||
|
||||
matchPath, matched := findBestMatch(c.root, inputs)
|
||||
originalMatched := matched
|
||||
if maxCachedPrefix >= 0 {
|
||||
maxCachedPrefix = min(maxCachedPrefix, len(inputs))
|
||||
if matched > maxCachedPrefix {
|
||||
matchPath, matched = findBestMatch(c.root, inputs[:maxCachedPrefix])
|
||||
}
|
||||
}
|
||||
|
||||
// Always keep at least one token to re-evaluate so the
|
||||
// pipeline can seed token generation from it.
|
||||
if keepSeedToken && matched == len(inputs) && matched > 0 {
|
||||
matchPath, matched = findBestMatch(c.root, inputs[:len(inputs)-1])
|
||||
}
|
||||
|
||||
// Switch to the matched path, paging in/out as needed.
|
||||
c.switchToPath(matchPath, matched)
|
||||
|
||||
// switchToPath aligns caches to a common offset
|
||||
prefix := c.minCacheOffset()
|
||||
remaining := inputs[prefix:]
|
||||
|
||||
session := &cacheSession{
|
||||
cache: c,
|
||||
inputs: inputs,
|
||||
caches: c.caches,
|
||||
remaining: remaining,
|
||||
}
|
||||
|
||||
// Schedule a snapshot at the branch point during prefill so future
|
||||
// requests diverging here can restore instead of re-evaluating.
|
||||
if prefix < matched {
|
||||
session.pendingSnapshots = append(session.pendingSnapshots, pendingSnapshot{offset: matched, user: false})
|
||||
}
|
||||
|
||||
msg := "cache hit"
|
||||
if prefix == 0 {
|
||||
msg = "cache miss"
|
||||
}
|
||||
if logPrefix != "" {
|
||||
msg = logPrefix + " " + msg
|
||||
}
|
||||
slog.Info(msg, "total", len(inputs), "matched", originalMatched, "cached", prefix, "left", len(remaining))
|
||||
|
||||
return session
|
||||
}
|
||||
|
||||
// switchToPath transitions from the current active path to a new path,
|
||||
// paging out diverging segments and paging in the new path.
|
||||
func (c *kvCache) switchToPath(newPath []*trieNode, matched int) {
|
||||
defer c.enforceEvictionPolicy()
|
||||
|
||||
// Find common ancestor index.
|
||||
commonLen := 0
|
||||
for commonLen < len(c.activePath) && commonLen < len(newPath) {
|
||||
if c.activePath[commonLen] != newPath[commonLen] {
|
||||
break
|
||||
}
|
||||
commonLen++
|
||||
}
|
||||
|
||||
ancestorOffset := 0
|
||||
if commonLen > 0 {
|
||||
ancestorOffset = c.activePath[commonLen-1].endOffset
|
||||
}
|
||||
|
||||
var pageOutCount, pageInCount int
|
||||
|
||||
// Page out the leaf of the old path. Only the leaf's live cache
|
||||
// state is correct — intermediate nodes already have snapshots
|
||||
// captured during their creation (splitNode + prefill). Snapshotting
|
||||
// non-leaf nodes here would produce wrong results for non-rewindable
|
||||
// caches (e.g. RecurrentCache) whose state reflects the leaf, not
|
||||
// the intermediate boundary.
|
||||
leaf := len(c.activePath) - 1
|
||||
leafDiverges := leaf >= commonLen
|
||||
leafNeedsRewind := matched < c.activePath[leaf].endOffset
|
||||
if leafDiverges || leafNeedsRewind {
|
||||
node := c.activePath[leaf]
|
||||
if !node.hasAllSnapshots() {
|
||||
fromOffset := node.startOffset()
|
||||
snaps := make([]cache.Snapshot, len(c.caches))
|
||||
for j, kv := range c.caches {
|
||||
if kv == nil {
|
||||
continue
|
||||
}
|
||||
snaps[j] = kv.Snapshot(fromOffset)
|
||||
}
|
||||
node.setSnapshots(snaps, &c.pagedOutBytes)
|
||||
pageOutCount++
|
||||
logutil.Trace(fmt.Sprintf("page out: [%d, %d)", fromOffset, node.endOffset))
|
||||
}
|
||||
}
|
||||
|
||||
// Rewind each cache to the target offset or free it. When matched
|
||||
// falls within the ancestor's range (same-path case), we rewind
|
||||
// directly to the match point. Otherwise we rewind to the ancestor
|
||||
// and let page-in bring us forward to matched.
|
||||
rewindTarget := min(ancestorOffset, matched)
|
||||
for _, kv := range c.caches {
|
||||
if kv == nil {
|
||||
continue
|
||||
}
|
||||
if !kv.Restore(nil, rewindTarget) {
|
||||
kv.Free()
|
||||
}
|
||||
}
|
||||
|
||||
// Page in — walk the full new path, restoring from snapshots.
|
||||
// Freed caches naturally pick up the first available snapshot.
|
||||
// Caches already past a node skip it via offset check.
|
||||
pageIn:
|
||||
for _, node := range newPath {
|
||||
if !node.hasSnapshots() {
|
||||
continue
|
||||
}
|
||||
nodeTarget := min(node.endOffset, matched)
|
||||
for j, kv := range c.caches {
|
||||
if kv == nil {
|
||||
continue
|
||||
}
|
||||
if j >= len(node.snapshots) || node.snapshots[j] == nil {
|
||||
continue
|
||||
}
|
||||
if kv.Offset() >= nodeTarget {
|
||||
continue
|
||||
}
|
||||
if !kv.Restore(node.snapshots[j], nodeTarget) {
|
||||
// Restore failed — stop page-in and let alignment
|
||||
// bring all caches to a consistent offset.
|
||||
break pageIn
|
||||
}
|
||||
}
|
||||
if node.endOffset > ancestorOffset {
|
||||
pageInCount++
|
||||
logutil.Trace(fmt.Sprintf("page in: [%d, %d)", node.startOffset(), nodeTarget))
|
||||
}
|
||||
}
|
||||
|
||||
// Align all caches to the minimum offset.
|
||||
c.activePath = newPath
|
||||
minOff := c.minCacheOffset()
|
||||
for _, kv := range c.caches {
|
||||
if kv != nil && kv.Offset() != minOff {
|
||||
if !kv.Restore(nil, minOff) {
|
||||
slog.Warn("failed to restore cache, freeing all caches", "offset", minOff)
|
||||
c.freeAll()
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
for i := len(c.activePath) - 1; i >= 0; i-- {
|
||||
if c.activePath[i].endOffset <= minOff {
|
||||
c.activePath = c.activePath[:i+1]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Update last-used time on only the final used node. For recurrent
|
||||
// caches we don't need the intermediate snapshots and for KV caches
|
||||
// we can reslice the data out of merged edges.
|
||||
if len(c.activePath) > 0 {
|
||||
c.activePath[len(c.activePath)-1].lastUsed = time.Now()
|
||||
}
|
||||
|
||||
if pageOutCount > 0 || pageInCount > 0 {
|
||||
slog.Debug("switching cache path", "page_out", pageOutCount, "page_in", pageInCount)
|
||||
}
|
||||
}
|
||||
|
||||
// requestSnapshot schedules a user snapshot at the given absolute token
|
||||
// offset. The snapshot will be captured during prefill when the cache
|
||||
// reaches this offset.
|
||||
func (s *cacheSession) requestSnapshot(offset int) {
|
||||
baseOffset := len(s.inputs) - len(s.remaining)
|
||||
if offset <= baseOffset || offset > len(s.inputs) {
|
||||
return
|
||||
}
|
||||
// Deduplicate: if this offset already exists, upgrade to user.
|
||||
for i := range s.pendingSnapshots {
|
||||
if s.pendingSnapshots[i].offset == offset {
|
||||
s.pendingSnapshots[i].user = true
|
||||
return
|
||||
}
|
||||
}
|
||||
s.pendingSnapshots = append(s.pendingSnapshots, pendingSnapshot{offset: offset, user: true})
|
||||
slices.SortFunc(s.pendingSnapshots, func(a, b pendingSnapshot) int {
|
||||
return a.offset - b.offset
|
||||
})
|
||||
}
|
||||
|
||||
// nextPendingSnapshot returns the offset of the next pending snapshot,
|
||||
// or 0 if there are none.
|
||||
func (s *cacheSession) nextPendingSnapshot() int {
|
||||
if len(s.pendingSnapshots) == 0 {
|
||||
return 0
|
||||
}
|
||||
return s.pendingSnapshots[0].offset
|
||||
}
|
||||
|
||||
// snapshot creates a snapshot at the current cache position. It determines
|
||||
// whether this is a user snapshot by consuming pending entries whose offset
|
||||
// has been reached.
|
||||
func (s *cacheSession) snapshot() {
|
||||
c := s.cache
|
||||
cacheOffset := c.minCacheOffset()
|
||||
if cacheOffset <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Consume pending snapshots up to the current offset and derive
|
||||
// the user flag from them.
|
||||
user := false
|
||||
for len(s.pendingSnapshots) > 0 && cacheOffset >= s.pendingSnapshots[0].offset {
|
||||
if s.pendingSnapshots[0].user {
|
||||
user = true
|
||||
}
|
||||
s.pendingSnapshots = s.pendingSnapshots[1:]
|
||||
}
|
||||
|
||||
// The last node in activePath is the frontier where caches are advancing.
|
||||
// cacheOffset is always >= its endOffset: begin() restores caches to this
|
||||
// boundary and prefill advances monotonically forward.
|
||||
frontier := c.activePath[len(c.activePath)-1]
|
||||
|
||||
// If the frontier already ends at cacheOffset, just ensure it has snapshots.
|
||||
if frontier.endOffset == cacheOffset {
|
||||
if user {
|
||||
frontier.user = true
|
||||
}
|
||||
if !frontier.hasAllSnapshots() {
|
||||
s.attachSnapshots(frontier, cacheOffset)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if frontier.endOffset > cacheOffset {
|
||||
slog.Warn("snapshot skipped: cacheOffset is behind frontier", "cacheOffset", cacheOffset, "frontierEndOffset", frontier.endOffset)
|
||||
return
|
||||
}
|
||||
|
||||
// Advance the trie to cacheOffset — find or create a node there.
|
||||
edgeTokens := append(s.inputs, s.outputs...)[frontier.endOffset:cacheOffset]
|
||||
frontier = c.advancePath(frontier, edgeTokens, cacheOffset)
|
||||
|
||||
// Attach fresh snapshots from the live caches. Always use fresh
|
||||
// snapshots even if the node already has some (e.g. from splitNode's
|
||||
// Cache.Split which may be incomplete for non-splittable caches
|
||||
// like RecurrentCache).
|
||||
if user {
|
||||
frontier.user = true
|
||||
}
|
||||
s.attachSnapshots(frontier, cacheOffset)
|
||||
}
|
||||
|
||||
// advancePath advances the active path from the current frontier by matching
|
||||
// tokens against existing trie children, splitting partial matches, and
|
||||
// appending any remaining tokens as new nodes. Returns the new frontier.
|
||||
func (c *kvCache) advancePath(frontier *trieNode, tokens []int32, endOffset int) *trieNode {
|
||||
// Check if existing children already cover some or all of tokens.
|
||||
// tokens may span multiple trie nodes when extending a previous run's
|
||||
// leaf and this snapshot now overlaps that same range.
|
||||
matchPath, matched := findBestMatch(frontier, tokens)
|
||||
// matchPath[0] is frontier itself; the rest are newly traversed nodes.
|
||||
remaining := tokens[matched:]
|
||||
|
||||
// Check for a partial match within the last node's edge — if so, split it.
|
||||
if len(matchPath) > 1 {
|
||||
lastNode := matchPath[len(matchPath)-1]
|
||||
matchedInEdge := frontier.endOffset + matched - lastNode.startOffset()
|
||||
if matchedInEdge > 0 && matchedInEdge < len(lastNode.tokens) {
|
||||
matchPath[len(matchPath)-1] = splitNode(lastNode, matchedInEdge, c.caches, &c.pagedOutBytes)
|
||||
}
|
||||
}
|
||||
|
||||
// Append traversed nodes (excluding frontier) to the active path.
|
||||
c.activePath = append(c.activePath, matchPath[1:]...)
|
||||
dest := matchPath[len(matchPath)-1]
|
||||
|
||||
if len(remaining) > 0 {
|
||||
// Drop non-user snapshots so appendTokens can extend in-place
|
||||
// rather than creating a new child node.
|
||||
if len(dest.children) == 0 && !dest.user {
|
||||
dest.setSnapshots(nil, &c.pagedOutBytes)
|
||||
}
|
||||
newDest := dest.appendTokens(c.root, remaining, endOffset)
|
||||
if newDest != dest {
|
||||
c.activePath = append(c.activePath, newDest)
|
||||
}
|
||||
dest = newDest
|
||||
}
|
||||
return dest
|
||||
}
|
||||
|
||||
// attachSnapshots attaches cache snapshots to a trie node at the given offset.
|
||||
// The node must be on the active path (and thus protected from eviction;
|
||||
// lastUsed is updated in close()). All non-nil caches must be at the same
|
||||
// offset (cacheOffset); a mismatch indicates a bug in the caller.
|
||||
func (s *cacheSession) attachSnapshots(node *trieNode, cacheOffset int) {
|
||||
c := s.cache
|
||||
|
||||
if c.activePath[len(c.activePath)-1] != node {
|
||||
slog.Warn("attachSnapshots skipped: node is not the active frontier", "nodeEndOffset", node.endOffset)
|
||||
return
|
||||
}
|
||||
|
||||
snaps := make([]cache.Snapshot, len(c.caches))
|
||||
for i, kv := range c.caches {
|
||||
if kv != nil {
|
||||
if kv.Offset() != cacheOffset {
|
||||
panic(fmt.Sprintf("attachSnapshots: cache offset mismatch layer %d: expected %d, got %d", i, cacheOffset, kv.Offset()))
|
||||
}
|
||||
snaps[i] = kv.Snapshot(node.startOffset())
|
||||
}
|
||||
}
|
||||
node.setSnapshots(snaps, &c.pagedOutBytes)
|
||||
node.lastUsed = time.Now()
|
||||
slog.Debug("created snapshot", "offset", cacheOffset)
|
||||
c.enforceEvictionPolicy()
|
||||
}
|
||||
|
||||
// freeAll releases all cache layers.
|
||||
func (c *kvCache) freeAll() {
|
||||
for _, kv := range c.caches {
|
||||
if kv != nil {
|
||||
kv.Free()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *kvCache) minCacheOffset() int {
|
||||
offset := 0
|
||||
found := false
|
||||
for _, kv := range c.caches {
|
||||
if kv == nil {
|
||||
continue
|
||||
}
|
||||
if off := kv.Offset(); !found || off < offset {
|
||||
offset = off
|
||||
found = true
|
||||
}
|
||||
}
|
||||
return offset
|
||||
}
|
||||
|
||||
// close saves the token state if the forward pass ran.
|
||||
func (s *cacheSession) close() {
|
||||
offset := s.cache.minCacheOffset()
|
||||
if offset <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
arrays := make([]*mlx.Array, 0, 2*len(s.caches))
|
||||
for _, kv := range s.caches {
|
||||
if kv == nil {
|
||||
continue
|
||||
}
|
||||
arrays = append(arrays, kv.State()...)
|
||||
}
|
||||
|
||||
// Ensure that if we have run the forward pass and set the metadata
|
||||
// that we also actually have the data.
|
||||
mlx.AsyncEval(arrays...)
|
||||
|
||||
// Advance the trie frontier with any newly generated tokens.
|
||||
c := s.cache
|
||||
if len(c.activePath) > 0 {
|
||||
frontier := c.activePath[len(c.activePath)-1]
|
||||
stored := append(s.inputs, s.outputs...)
|
||||
|
||||
if offset > frontier.endOffset {
|
||||
newTokens := stored[frontier.endOffset:offset]
|
||||
c.advancePath(frontier, newTokens, offset)
|
||||
}
|
||||
c.activePath[len(c.activePath)-1].lastUsed = time.Now()
|
||||
}
|
||||
}
|
||||
|
||||
// enforceEvictionPolicy evicts eligible nodes until paged-out memory is within limits.
|
||||
func (c *kvCache) enforceEvictionPolicy() {
|
||||
if c.pagedOutBytes <= maxPagedOutBytes {
|
||||
return
|
||||
}
|
||||
|
||||
activeSet := make(map[*trieNode]bool, len(c.activePath))
|
||||
for _, n := range c.activePath {
|
||||
activeSet[n] = true
|
||||
}
|
||||
|
||||
for c.pagedOutBytes > maxPagedOutBytes {
|
||||
var best *trieNode
|
||||
walkNodes(c.root, func(n *trieNode) bool {
|
||||
if n == c.root || activeSet[n] || len(n.children) > 1 {
|
||||
return true
|
||||
}
|
||||
// Evict: oldest, then deepest, then largest.
|
||||
if best == nil || cmp.Or(
|
||||
n.lastUsed.Compare(best.lastUsed),
|
||||
cmp.Compare(best.endOffset, n.endOffset),
|
||||
cmp.Compare(best.snapshotBytes(), n.snapshotBytes()),
|
||||
) < 0 {
|
||||
best = n
|
||||
}
|
||||
return true
|
||||
})
|
||||
if best == nil {
|
||||
break
|
||||
}
|
||||
c.evictNode(best)
|
||||
}
|
||||
}
|
||||
|
||||
// evictNode evicts a single node from the trie, freeing its snapshot memory.
|
||||
func (c *kvCache) evictNode(node *trieNode) {
|
||||
if len(node.children) == 0 {
|
||||
// Leaf: remove entirely.
|
||||
slog.Debug("evicting leaf", "offset", node.startOffset(), "tokens", len(node.tokens), "freed", mlx.PrettyBytes(int(node.snapshotBytes())))
|
||||
removeNode(node, &c.pagedOutBytes)
|
||||
} else if len(node.children) == 1 {
|
||||
// Interior node with one child: merge with child.
|
||||
before := c.pagedOutBytes
|
||||
tokens := len(node.tokens)
|
||||
mergeWithChild(node, c.caches, &c.pagedOutBytes)
|
||||
slog.Debug("evicting interior node", "offset", node.startOffset(), "tokens", tokens, "freed", mlx.PrettyBytes(int(before-c.pagedOutBytes)))
|
||||
} else {
|
||||
panic("evictNode called on multi-child branch point")
|
||||
}
|
||||
}
|
||||
|
||||
func (c *kvCache) dumpTree() {
|
||||
// Summary stats
|
||||
var cacheBytes int
|
||||
for _, kv := range c.caches {
|
||||
if kv == nil {
|
||||
continue
|
||||
}
|
||||
for _, a := range kv.State() {
|
||||
if a != nil {
|
||||
cacheBytes += a.NumBytes()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build active path set for marking.
|
||||
active := make(map[*trieNode]bool, len(c.activePath))
|
||||
for _, n := range c.activePath {
|
||||
active[n] = true
|
||||
}
|
||||
|
||||
var nodeCount, snapshotCount int
|
||||
var pagedBytes int64
|
||||
var lines []string
|
||||
var dump func(n *trieNode, prefix string, isLast bool)
|
||||
dump = func(n *trieNode, prefix string, isLast bool) {
|
||||
if n == nil {
|
||||
return
|
||||
}
|
||||
nodeCount++
|
||||
|
||||
// Build connector
|
||||
var connector string
|
||||
if n.parent == nil {
|
||||
connector = ""
|
||||
} else if isLast {
|
||||
connector = prefix + "`-- "
|
||||
} else {
|
||||
connector = prefix + "|-- "
|
||||
}
|
||||
|
||||
// Node label
|
||||
nodeBytes := n.snapshotBytes()
|
||||
pagedBytes += nodeBytes
|
||||
|
||||
label := fmt.Sprintf("[%d,%d) %dt", n.startOffset(), n.endOffset, len(n.tokens))
|
||||
if nodeBytes > 0 {
|
||||
label += " " + mlx.PrettyBytes(int(nodeBytes)).String()
|
||||
}
|
||||
if !n.lastUsed.IsZero() {
|
||||
label += fmt.Sprintf(" %s ago", time.Since(n.lastUsed).Truncate(time.Millisecond))
|
||||
}
|
||||
var flags []string
|
||||
if n.user {
|
||||
flags = append(flags, "user")
|
||||
}
|
||||
if n.hasAllSnapshots() {
|
||||
snapshotCount++
|
||||
flags = append(flags, "snap")
|
||||
}
|
||||
if active[n] {
|
||||
flags = append(flags, "active")
|
||||
}
|
||||
if len(flags) > 0 {
|
||||
label += " (" + flags[0]
|
||||
for _, f := range flags[1:] {
|
||||
label += ", " + f
|
||||
}
|
||||
label += ")"
|
||||
}
|
||||
lines = append(lines, connector+label)
|
||||
|
||||
// Recurse children
|
||||
childPrefix := prefix
|
||||
if n.parent != nil {
|
||||
if isLast {
|
||||
childPrefix += " "
|
||||
} else {
|
||||
childPrefix += "| "
|
||||
}
|
||||
}
|
||||
for i, child := range n.children {
|
||||
dump(child, childPrefix, i == len(n.children)-1)
|
||||
}
|
||||
}
|
||||
dump(c.root, "", true)
|
||||
|
||||
offset := c.minCacheOffset()
|
||||
logutil.Trace(fmt.Sprintf("kv cache active_tokens: %d, active_size: %s, paged_out: %s, trie: nodes=%d, snapshots=%d",
|
||||
offset, mlx.PrettyBytes(cacheBytes), mlx.PrettyBytes(int(pagedBytes)), nodeCount, snapshotCount))
|
||||
for i, l := range lines {
|
||||
if i == 0 {
|
||||
logutil.Trace("cache trie: " + l)
|
||||
} else {
|
||||
logutil.Trace(" " + l)
|
||||
}
|
||||
}
|
||||
}
|
||||
846
x/mlxrunner/cache/cache.go
vendored
Normal file
846
x/mlxrunner/cache/cache.go
vendored
Normal file
@@ -0,0 +1,846 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/x/mlxrunner/batch"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/models/nn"
|
||||
)
|
||||
|
||||
// Cache is common state management shared by every cache kind. Writers
|
||||
// live on the specific caches
|
||||
type Cache interface {
|
||||
// State returns the cache-owned state roots that should be kept/evaluated.
|
||||
State() []*mlx.Array
|
||||
Free()
|
||||
Offset() int
|
||||
|
||||
// Snapshot copies cache state from fromOffset to current offset into
|
||||
// pinned VRAM arrays. The active cache is unchanged.
|
||||
Snapshot(fromOffset int) Snapshot
|
||||
|
||||
// Restore brings the cache to target. If snapshot is nil, rewinds
|
||||
// using the cache's own live state. Returns false if the target is
|
||||
// unreachable (e.g. target > current offset, or negative).
|
||||
Restore(snapshot Snapshot, target int) bool
|
||||
|
||||
// Merge combines two sequential snapshots [a,b) and [b,c) into [a,c).
|
||||
// Takes ownership of both inputs.
|
||||
Merge(parent, child Snapshot) Snapshot
|
||||
|
||||
// Split divides a snapshot [a,c) at offset b into [a,b) and [b,c).
|
||||
// Takes ownership of the input. Cache types that cannot split
|
||||
// (e.g. recurrent) return (nil, snapshot).
|
||||
Split(snapshot Snapshot, at int) (parent, child Snapshot)
|
||||
}
|
||||
|
||||
// Snapshot is paged-out cache state that can be restored later.
|
||||
type Snapshot interface {
|
||||
// Size returns the byte size of the paged-out data (in VRAM).
|
||||
Size() int
|
||||
// Close unpins the snapshot's arrays so they can be freed by Sweep.
|
||||
Close()
|
||||
}
|
||||
|
||||
// Attention is the contract for caches that back attention layers
|
||||
// (KVCache, RotatingKVCache).
|
||||
type Attention interface {
|
||||
Cache
|
||||
|
||||
// Update appends (k, v) and returns an opaque nn.KVHistory for
|
||||
// this layer's SDPA.
|
||||
Update(b *batch.Batch, keys, values *mlx.Array) *nn.KVHistory
|
||||
}
|
||||
|
||||
// Viewer exposes a read-only attention history for a cache.
|
||||
type Viewer interface {
|
||||
View(b *batch.Batch) *nn.KVHistory
|
||||
}
|
||||
|
||||
type speculativeCommitter interface {
|
||||
Cache
|
||||
commit(n int)
|
||||
}
|
||||
|
||||
// Speculation is an isolated cache transaction for speculative target
|
||||
// validation. Updates record generated K/V without mutating the live caches;
|
||||
// Commit appends only the accepted prefix to the live caches.
|
||||
type Speculation struct {
|
||||
layers []speculativeCommitter
|
||||
}
|
||||
|
||||
// BeginSpeculation returns cache wrappers suitable for a speculative target
|
||||
// forward. The returned caches must only be used for that forward.
|
||||
func BeginSpeculation(caches []Cache) ([]Cache, *Speculation, bool) {
|
||||
specCaches := make([]Cache, len(caches))
|
||||
layers := make([]speculativeCommitter, len(caches))
|
||||
|
||||
for i, c := range caches {
|
||||
switch c := c.(type) {
|
||||
case nil:
|
||||
case *RotatingKVCache:
|
||||
sc := newSpeculativeRotatingKVCache(c)
|
||||
specCaches[i] = sc
|
||||
layers[i] = sc
|
||||
case *KVCache:
|
||||
sc := newSpeculativeKVCache(c)
|
||||
specCaches[i] = sc
|
||||
layers[i] = sc
|
||||
case *RecurrentCache:
|
||||
sc := newSpeculativeRecurrentCache(c)
|
||||
specCaches[i] = sc
|
||||
layers[i] = sc
|
||||
default:
|
||||
return nil, nil, false
|
||||
}
|
||||
}
|
||||
|
||||
return specCaches, &Speculation{layers: layers}, true
|
||||
}
|
||||
|
||||
// BeginIsolatedSpeculation returns cache wrappers that never mutate live cache
|
||||
// state. It is intended for correctness instrumentation, not the hot path.
|
||||
func BeginIsolatedSpeculation(caches []Cache) ([]Cache, bool) {
|
||||
specCaches := make([]Cache, len(caches))
|
||||
|
||||
for i, c := range caches {
|
||||
switch c := c.(type) {
|
||||
case nil:
|
||||
case *RotatingKVCache:
|
||||
specCaches[i] = newSpeculativeRotatingKVCache(c)
|
||||
case *KVCache:
|
||||
specCaches[i] = newIsolatedKVCache(c)
|
||||
case *RecurrentCache:
|
||||
specCaches[i] = newSpeculativeRecurrentCache(c)
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
return specCaches, true
|
||||
}
|
||||
|
||||
// Commit appends the accepted prefix from the speculative forward to the live
|
||||
// caches. The target bonus token is intentionally not committed.
|
||||
func (s *Speculation) Commit(n int) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
for _, layer := range s.layers {
|
||||
if layer != nil {
|
||||
layer.commit(n)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type KVCache struct {
|
||||
keys, values *mlx.Array
|
||||
offset int
|
||||
step int
|
||||
}
|
||||
|
||||
func NewKVCache() *KVCache {
|
||||
return &KVCache{step: 256}
|
||||
}
|
||||
|
||||
// Assumes B = 1; heterogeneous batches are not supported.
|
||||
func (c *KVCache) Update(_ *batch.Batch, keys, values *mlx.Array) *nn.KVHistory {
|
||||
newK, newV := c.appendKV(keys, values)
|
||||
return nn.NewKVHistory(newK, newV, nil)
|
||||
}
|
||||
|
||||
// View returns the current cache contents as attention history without writing.
|
||||
func (c *KVCache) View(_ *batch.Batch) *nn.KVHistory {
|
||||
state := c.State()
|
||||
if len(state) < 2 {
|
||||
return nil
|
||||
}
|
||||
return nn.NewKVHistory(state[0], state[1], nil)
|
||||
}
|
||||
|
||||
// appendKV is the raw write path shared by Update and Restore.
|
||||
func (c *KVCache) appendKV(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
B, H, L, Dk, Dv := keys.Dim(0), keys.Dim(1), keys.Dim(2), keys.Dim(3), values.Dim(3)
|
||||
|
||||
prev := c.offset
|
||||
|
||||
// Grow buffer if needed
|
||||
if c.keys == nil || (prev+L) > c.keys.Dim(2) {
|
||||
steps := (c.step + L - 1) / c.step
|
||||
newKeys := mlx.Zeros(keys.DType(), B, H, steps*c.step, Dk)
|
||||
newValues := mlx.Zeros(values.DType(), B, H, steps*c.step, Dv)
|
||||
|
||||
if c.keys != nil {
|
||||
if prev%c.step != 0 {
|
||||
c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, prev), mlx.Slice()))
|
||||
c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, prev), mlx.Slice()))
|
||||
}
|
||||
c.keys.Set(c.keys.Concatenate(2, newKeys))
|
||||
c.values.Set(c.values.Concatenate(2, newValues))
|
||||
} else {
|
||||
c.keys, c.values = newKeys, newValues
|
||||
mlx.Pin(c.keys, c.values)
|
||||
}
|
||||
}
|
||||
|
||||
c.offset += L
|
||||
c.keys.Set(c.keys.SliceUpdate(keys, mlx.Slice(), mlx.Slice(), mlx.Slice(prev, c.offset), mlx.Slice()))
|
||||
c.values.Set(c.values.SliceUpdate(values, mlx.Slice(), mlx.Slice(), mlx.Slice(prev, c.offset), mlx.Slice()))
|
||||
|
||||
return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
|
||||
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())
|
||||
}
|
||||
|
||||
func (c *KVCache) State() []*mlx.Array {
|
||||
if c.keys == nil || c.values == nil {
|
||||
return nil
|
||||
}
|
||||
return []*mlx.Array{
|
||||
c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
|
||||
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
|
||||
}
|
||||
}
|
||||
|
||||
// kvSnapshot holds paged-out KV data for a range [fromOffset, toOffset).
|
||||
type kvSnapshot struct {
|
||||
keys, values *mlx.Array
|
||||
fromOffset, toOffset int
|
||||
}
|
||||
|
||||
func (s *kvSnapshot) Size() int { return s.keys.NumBytes() + s.values.NumBytes() }
|
||||
func (s *kvSnapshot) Close() { mlx.Unpin(s.keys, s.values) }
|
||||
|
||||
func (c *KVCache) Snapshot(fromOffset int) Snapshot {
|
||||
if c.keys == nil || c.offset <= fromOffset {
|
||||
return nil
|
||||
}
|
||||
from := max(0, fromOffset)
|
||||
to := c.offset
|
||||
|
||||
kSlice := c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(from, to), mlx.Slice())
|
||||
vSlice := c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(from, to), mlx.Slice())
|
||||
kCopy := mlx.Contiguous(kSlice, false)
|
||||
vCopy := mlx.Contiguous(vSlice, false)
|
||||
mlx.Pin(kCopy, vCopy)
|
||||
mlx.AsyncEval(kCopy, vCopy)
|
||||
|
||||
return &kvSnapshot{
|
||||
keys: kCopy,
|
||||
values: vCopy,
|
||||
fromOffset: from,
|
||||
toOffset: to,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *KVCache) Restore(snapshot Snapshot, target int) bool {
|
||||
if target < 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
if snapshot == nil {
|
||||
if target > c.offset {
|
||||
return false
|
||||
}
|
||||
c.offset = target
|
||||
return true
|
||||
}
|
||||
|
||||
snap := snapshot.(*kvSnapshot)
|
||||
|
||||
if target > snap.toOffset || c.offset < snap.fromOffset {
|
||||
return false
|
||||
}
|
||||
|
||||
// Rewind to snapshot start, then feed snapshot.
|
||||
c.offset = snap.fromOffset
|
||||
c.appendKV(snap.keys, snap.values)
|
||||
|
||||
// Clamp to target if needed (target may be less than full snapshot).
|
||||
if target < c.offset {
|
||||
c.offset = target
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *KVCache) Merge(parent, child Snapshot) Snapshot {
|
||||
if parent == nil || child == nil {
|
||||
if parent != nil {
|
||||
parent.Close()
|
||||
}
|
||||
if child != nil {
|
||||
child.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
p := parent.(*kvSnapshot)
|
||||
ch := child.(*kvSnapshot)
|
||||
|
||||
mk := p.keys.Concatenate(2, ch.keys)
|
||||
mv := p.values.Concatenate(2, ch.values)
|
||||
mlx.Pin(mk, mv)
|
||||
mlx.AsyncEval(mk, mv)
|
||||
|
||||
p.Close()
|
||||
ch.Close()
|
||||
|
||||
return &kvSnapshot{
|
||||
keys: mk,
|
||||
values: mv,
|
||||
fromOffset: p.fromOffset,
|
||||
toOffset: ch.toOffset,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *KVCache) Split(snapshot Snapshot, at int) (Snapshot, Snapshot) {
|
||||
if snapshot == nil {
|
||||
return nil, nil
|
||||
}
|
||||
snap := snapshot.(*kvSnapshot)
|
||||
splitIdx := at - snap.fromOffset
|
||||
seqLen := snap.toOffset - snap.fromOffset
|
||||
if splitIdx <= 0 {
|
||||
return nil, snapshot
|
||||
}
|
||||
if splitIdx >= seqLen {
|
||||
return snapshot, nil
|
||||
}
|
||||
|
||||
pk := mlx.Contiguous(snap.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, splitIdx), mlx.Slice()), false)
|
||||
pv := mlx.Contiguous(snap.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, splitIdx), mlx.Slice()), false)
|
||||
ck := mlx.Contiguous(snap.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(splitIdx, seqLen), mlx.Slice()), false)
|
||||
cv := mlx.Contiguous(snap.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(splitIdx, seqLen), mlx.Slice()), false)
|
||||
mlx.Pin(pk, pv, ck, cv)
|
||||
mlx.AsyncEval(pk, pv, ck, cv)
|
||||
|
||||
snap.Close()
|
||||
|
||||
p := &kvSnapshot{
|
||||
keys: pk,
|
||||
values: pv,
|
||||
fromOffset: snap.fromOffset,
|
||||
toOffset: at,
|
||||
}
|
||||
ch := &kvSnapshot{
|
||||
keys: ck,
|
||||
values: cv,
|
||||
fromOffset: at,
|
||||
toOffset: snap.toOffset,
|
||||
}
|
||||
return p, ch
|
||||
}
|
||||
|
||||
func (c *KVCache) Free() {
|
||||
mlx.Unpin(c.keys, c.values)
|
||||
c.keys, c.values = nil, nil
|
||||
c.offset = 0
|
||||
}
|
||||
|
||||
func (c *KVCache) Offset() int { return c.offset }
|
||||
|
||||
type speculativeBase struct {
|
||||
offset int
|
||||
}
|
||||
|
||||
func (s *speculativeBase) Free() {}
|
||||
func (s *speculativeBase) Offset() int { return s.offset }
|
||||
func (s *speculativeBase) Snapshot(int) Snapshot { return nil }
|
||||
func (s *speculativeBase) Restore(Snapshot, int) bool { return false }
|
||||
func (s *speculativeBase) Merge(parent, child Snapshot) Snapshot { return nil }
|
||||
func (s *speculativeBase) Split(snapshot Snapshot, at int) (Snapshot, Snapshot) {
|
||||
return nil, snapshot
|
||||
}
|
||||
|
||||
type speculativeKVCache struct {
|
||||
speculativeBase
|
||||
target *KVCache
|
||||
start int
|
||||
end int
|
||||
}
|
||||
|
||||
func newSpeculativeKVCache(target *KVCache) *speculativeKVCache {
|
||||
return &speculativeKVCache{
|
||||
speculativeBase: speculativeBase{offset: target.Offset()},
|
||||
target: target,
|
||||
start: target.Offset(),
|
||||
end: target.Offset(),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *speculativeKVCache) Update(b *batch.Batch, keys, values *mlx.Array) *nn.KVHistory {
|
||||
history := c.target.Update(b, keys, values)
|
||||
c.offset = c.target.Offset()
|
||||
c.end = c.target.Offset()
|
||||
return history
|
||||
}
|
||||
|
||||
func (c *speculativeKVCache) State() []*mlx.Array {
|
||||
return c.target.State()
|
||||
}
|
||||
|
||||
func (c *speculativeKVCache) commit(n int) {
|
||||
target := max(c.start, c.start+n)
|
||||
if target > c.end {
|
||||
target = c.end
|
||||
}
|
||||
c.target.offset = target
|
||||
c.offset = target
|
||||
}
|
||||
|
||||
type isolatedKVCache struct {
|
||||
speculativeBase
|
||||
target *KVCache
|
||||
keys, values *mlx.Array
|
||||
}
|
||||
|
||||
func newIsolatedKVCache(target *KVCache) *isolatedKVCache {
|
||||
return &isolatedKVCache{
|
||||
speculativeBase: speculativeBase{offset: target.Offset()},
|
||||
target: target,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *isolatedKVCache) Update(_ *batch.Batch, keys, values *mlx.Array) *nn.KVHistory {
|
||||
c.keys = concatKV(c.keys, keys)
|
||||
c.values = concatKV(c.values, values)
|
||||
c.offset += keys.Dim(2)
|
||||
|
||||
state := c.target.State()
|
||||
if len(state) < 2 {
|
||||
return nn.NewKVHistory(c.keys, c.values, nil)
|
||||
}
|
||||
return nn.NewKVHistory(state[0].Concatenate(2, c.keys), state[1].Concatenate(2, c.values), nil)
|
||||
}
|
||||
|
||||
func (c *isolatedKVCache) State() []*mlx.Array {
|
||||
if c.keys == nil || c.values == nil {
|
||||
return c.target.State()
|
||||
}
|
||||
state := c.target.State()
|
||||
if len(state) < 2 {
|
||||
return []*mlx.Array{c.keys, c.values}
|
||||
}
|
||||
return []*mlx.Array{
|
||||
state[0].Concatenate(2, c.keys),
|
||||
state[1].Concatenate(2, c.values),
|
||||
}
|
||||
}
|
||||
|
||||
// RotatingKVCache implements sliding window attention with bounded memory
|
||||
type RotatingKVCache struct {
|
||||
maxSize int
|
||||
idx int
|
||||
|
||||
*KVCache
|
||||
}
|
||||
|
||||
func NewRotatingKVCache(maxSize int) *RotatingKVCache {
|
||||
return &RotatingKVCache{maxSize: maxSize, KVCache: NewKVCache()}
|
||||
}
|
||||
|
||||
// Assumes B = 1; heterogeneous batches are not supported.
|
||||
func (c *RotatingKVCache) Update(b *batch.Batch, keys, values *mlx.Array) *nn.KVHistory {
|
||||
newK, newV := c.appendKV(keys, values)
|
||||
return nn.NewKVHistory(newK, newV, rotatingApplier{
|
||||
b: b,
|
||||
K: newK.Dim(2),
|
||||
L: keys.Dim(2),
|
||||
window: c.maxSize,
|
||||
ringIdx: c.idx,
|
||||
dtype: keys.DType(),
|
||||
})
|
||||
}
|
||||
|
||||
// View returns the current rotating cache contents in logical order for
|
||||
// assistant KV sharing.
|
||||
func (c *RotatingKVCache) View(_ *batch.Batch) *nn.KVHistory {
|
||||
k, v := c.logicalTail(c.maxSize - 1)
|
||||
if k == nil || v == nil {
|
||||
return nil
|
||||
}
|
||||
return nn.NewKVHistory(k, v, nil)
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) logicalTail(keep int) (*mlx.Array, *mlx.Array) {
|
||||
state := c.State()
|
||||
if len(state) < 2 || keep <= 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
keys, values := state[0], state[1]
|
||||
K := keys.Dim(2)
|
||||
if K == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
keep = min(keep, K)
|
||||
if K > c.maxSize || c.offset < c.maxSize {
|
||||
start := K - keep
|
||||
return keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(start, K), mlx.Slice()),
|
||||
values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(start, K), mlx.Slice())
|
||||
}
|
||||
|
||||
oldest := c.idx % K
|
||||
var logicalK, logicalV *mlx.Array
|
||||
if oldest == 0 {
|
||||
logicalK, logicalV = keys, values
|
||||
} else {
|
||||
tailK := keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(oldest, K), mlx.Slice())
|
||||
tailV := values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(oldest, K), mlx.Slice())
|
||||
headK := keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, oldest), mlx.Slice())
|
||||
headV := values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, oldest), mlx.Slice())
|
||||
logicalK = tailK.Concatenate(2, headK)
|
||||
logicalV = tailV.Concatenate(2, headV)
|
||||
}
|
||||
|
||||
start := K - keep
|
||||
return logicalK.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(start, K), mlx.Slice()),
|
||||
logicalV.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(start, K), mlx.Slice())
|
||||
}
|
||||
|
||||
type speculativeRotatingKVCache struct {
|
||||
speculativeBase
|
||||
target *RotatingKVCache
|
||||
keys, values *mlx.Array
|
||||
}
|
||||
|
||||
func newSpeculativeRotatingKVCache(target *RotatingKVCache) *speculativeRotatingKVCache {
|
||||
return &speculativeRotatingKVCache{
|
||||
speculativeBase: speculativeBase{offset: target.Offset()},
|
||||
target: target,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *speculativeRotatingKVCache) Update(b *batch.Batch, keys, values *mlx.Array) *nn.KVHistory {
|
||||
c.keys = concatKV(c.keys, keys)
|
||||
c.values = concatKV(c.values, values)
|
||||
c.offset += keys.Dim(2)
|
||||
|
||||
oldK, oldV := c.target.logicalTail(c.target.maxSize - 1)
|
||||
histK, histV := c.keys, c.values
|
||||
if oldK != nil && oldV != nil {
|
||||
histK = oldK.Concatenate(2, c.keys)
|
||||
histV = oldV.Concatenate(2, c.values)
|
||||
}
|
||||
|
||||
return nn.NewKVHistory(histK, histV, logicalSlidingApplier{
|
||||
b: b,
|
||||
K: histK.Dim(2),
|
||||
window: c.target.maxSize,
|
||||
dtype: keys.DType(),
|
||||
})
|
||||
}
|
||||
|
||||
func (c *speculativeRotatingKVCache) State() []*mlx.Array {
|
||||
if c.keys == nil || c.values == nil {
|
||||
return c.target.State()
|
||||
}
|
||||
oldK, oldV := c.target.logicalTail(c.target.maxSize - 1)
|
||||
if oldK == nil || oldV == nil {
|
||||
return []*mlx.Array{c.keys, c.values}
|
||||
}
|
||||
return []*mlx.Array{oldK.Concatenate(2, c.keys), oldV.Concatenate(2, c.values)}
|
||||
}
|
||||
|
||||
func (c *speculativeRotatingKVCache) commit(n int) {
|
||||
if c.keys == nil || c.values == nil || n <= 0 {
|
||||
return
|
||||
}
|
||||
n = min(n, c.keys.Dim(2))
|
||||
c.target.appendKV(prefixKV(c.keys, n), prefixKV(c.values, n))
|
||||
}
|
||||
|
||||
type logicalSlidingApplier struct {
|
||||
b *batch.Batch
|
||||
K int
|
||||
window int
|
||||
dtype mlx.DType
|
||||
}
|
||||
|
||||
func (a logicalSlidingApplier) ApplyMask(logical nn.AttentionMask) nn.AttentionMask {
|
||||
return logical.Intersect(nn.SlidingWindowMask(a.b, a.K, a.window, a.dtype))
|
||||
}
|
||||
|
||||
func concatKV(prev, next *mlx.Array) *mlx.Array {
|
||||
if prev == nil {
|
||||
return next
|
||||
}
|
||||
return prev.Concatenate(2, next)
|
||||
}
|
||||
|
||||
func prefixKV(a *mlx.Array, n int) *mlx.Array {
|
||||
return a.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, n), mlx.Slice())
|
||||
}
|
||||
|
||||
// appendKV is the raw write path shared by Update and Restore —
|
||||
// routes to concat for prefill (L > 1) and update for decode.
|
||||
func (c *RotatingKVCache) appendKV(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
if keys.Dim(2) > 1 {
|
||||
return c.concat(keys, values)
|
||||
}
|
||||
return c.update(keys, values)
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) concat(keys, values *mlx.Array) (newK *mlx.Array, newV *mlx.Array) {
|
||||
logutil.Trace("(*RotatingKVCache).concat", "keys_dim", keys.Dims(), "values_dim", values.Dims(), "offset", c.offset, "idx", c.idx, "max_size", c.maxSize)
|
||||
if c.keys == nil {
|
||||
c.keys, c.values = keys.Clone(), values.Clone()
|
||||
mlx.Pin(c.keys, c.values)
|
||||
} else {
|
||||
if c.idx < c.keys.Dim(2) {
|
||||
if c.offset <= c.maxSize {
|
||||
// Not yet wrapped: slots [c.idx, Dim) are grow padding
|
||||
// or stale post-rewind data, not live window content.
|
||||
c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice()))
|
||||
c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice()))
|
||||
} else {
|
||||
// Wrapped: logical order is slots[idx..Dim) then slots[0..idx).
|
||||
// Linearize so the trim + concat below operate on contiguous
|
||||
// positions and preserve the last (maxSize - 1) old tokens.
|
||||
tailK := c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(c.idx, c.keys.Dim(2)), mlx.Slice())
|
||||
tailV := c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(c.idx, c.values.Dim(2)), mlx.Slice())
|
||||
headK := c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice())
|
||||
headV := c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice())
|
||||
c.keys.Set(tailK.Concatenate(2, headK))
|
||||
c.values.Set(tailV.Concatenate(2, headV))
|
||||
c.idx = c.keys.Dim(2)
|
||||
}
|
||||
}
|
||||
|
||||
// Trim to max_size to maintain sliding window
|
||||
if trim := c.idx - c.maxSize + 1; trim > 0 {
|
||||
c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(trim, c.keys.Dim(2)), mlx.Slice()))
|
||||
c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(trim, c.values.Dim(2)), mlx.Slice()))
|
||||
}
|
||||
|
||||
c.keys.Set(c.keys.Concatenate(2, keys))
|
||||
c.values.Set(c.values.Concatenate(2, values))
|
||||
c.idx = c.keys.Dim(2)
|
||||
}
|
||||
|
||||
c.offset += keys.Dim(2)
|
||||
c.idx = c.keys.Dim(2)
|
||||
return c.keys, c.values
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
logutil.Trace("(*RotatingKVCache).update", "keys_dim", keys.Dims(), "values_dim", values.Dims(), "offset", c.offset, "idx", c.idx, "max_size", c.maxSize)
|
||||
B, H, L, Dk, Dv := keys.Dim(0), keys.Dim(1), keys.Dim(2), keys.Dim(3), values.Dim(3)
|
||||
|
||||
prev := c.offset
|
||||
|
||||
// Grow buffer if not yet at max
|
||||
if c.keys == nil || (prev >= c.keys.Dim(2) && c.keys.Dim(2) < c.maxSize) {
|
||||
newSize := min(c.step, c.maxSize-prev)
|
||||
newKeys := mlx.Zeros(keys.DType(), B, H, newSize, Dk)
|
||||
newValues := mlx.Zeros(values.DType(), B, H, newSize, Dv)
|
||||
if c.keys != nil {
|
||||
c.keys.Set(c.keys.Concatenate(2, newKeys))
|
||||
c.values.Set(c.values.Concatenate(2, newValues))
|
||||
} else {
|
||||
c.keys, c.values = newKeys, newValues
|
||||
mlx.Pin(c.keys, c.values)
|
||||
}
|
||||
c.idx = prev
|
||||
}
|
||||
|
||||
// Trim to max_size to maintain sliding window
|
||||
if trim := c.keys.Dim(2) - c.maxSize; trim > 0 {
|
||||
c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(trim, c.keys.Dim(2)), mlx.Slice()))
|
||||
c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(trim, c.values.Dim(2)), mlx.Slice()))
|
||||
c.idx = c.maxSize
|
||||
}
|
||||
|
||||
// Rotate when hitting max
|
||||
if c.idx >= c.maxSize {
|
||||
c.idx = 0
|
||||
}
|
||||
|
||||
c.keys.Set(c.keys.SliceUpdate(keys, mlx.Slice(), mlx.Slice(), mlx.Slice(c.idx, c.idx+L), mlx.Slice()))
|
||||
c.values.Set(c.values.SliceUpdate(values, mlx.Slice(), mlx.Slice(), mlx.Slice(c.idx, c.idx+L), mlx.Slice()))
|
||||
|
||||
c.offset += L
|
||||
c.idx += L
|
||||
|
||||
validLen := min(c.offset, c.maxSize)
|
||||
return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, validLen), mlx.Slice()),
|
||||
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, validLen), mlx.Slice())
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) State() []*mlx.Array {
|
||||
if c.keys == nil || c.values == nil {
|
||||
return nil
|
||||
}
|
||||
liveLen := min(c.offset, c.keys.Dim(2))
|
||||
return []*mlx.Array{
|
||||
c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, liveLen), mlx.Slice()),
|
||||
c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, liveLen), mlx.Slice()),
|
||||
}
|
||||
}
|
||||
|
||||
// rotatingSnapshot holds paged-out data for a RotatingKVCache.
|
||||
type rotatingSnapshot struct {
|
||||
kvSnapshot // embedded KV data
|
||||
idx int // buffer write position at snapshot time
|
||||
}
|
||||
|
||||
func (s *rotatingSnapshot) Size() int { return s.kvSnapshot.Size() }
|
||||
func (s *rotatingSnapshot) Close() { s.kvSnapshot.Close() }
|
||||
|
||||
func (c *RotatingKVCache) Snapshot(fromOffset int) Snapshot {
|
||||
if c.keys == nil || c.offset <= fromOffset {
|
||||
return nil
|
||||
}
|
||||
|
||||
state := c.State()
|
||||
k := state[0].Clone()
|
||||
v := state[1].Clone()
|
||||
mlx.Pin(k, v)
|
||||
|
||||
return &rotatingSnapshot{
|
||||
kvSnapshot: kvSnapshot{
|
||||
keys: k,
|
||||
values: v,
|
||||
fromOffset: fromOffset,
|
||||
toOffset: c.offset,
|
||||
},
|
||||
idx: c.idx,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) Restore(snapshot Snapshot, target int) bool {
|
||||
if target < 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
if snapshot == nil {
|
||||
if target >= c.offset {
|
||||
return target == c.offset
|
||||
}
|
||||
// Live rewind is only safe when the buffer hasn't filled yet
|
||||
// (offset <= maxSize). Once the window has shifted, rewinding
|
||||
// leaves fewer than maxSize trailing tokens to attend to —
|
||||
// a snapshot is required to restore the full window.
|
||||
if c.offset > c.maxSize {
|
||||
return false
|
||||
}
|
||||
c.offset = target
|
||||
c.idx = target
|
||||
return true
|
||||
}
|
||||
|
||||
snap := snapshot.(*rotatingSnapshot)
|
||||
|
||||
if target > snap.toOffset {
|
||||
return false
|
||||
}
|
||||
|
||||
// Reject if clamping would leave an incomplete window.
|
||||
if target < snap.toOffset && snap.toOffset > c.maxSize {
|
||||
return false
|
||||
}
|
||||
|
||||
// Restore from snapshot: rebuild buffer state.
|
||||
// Free existing state first.
|
||||
if c.keys != nil {
|
||||
mlx.Unpin(c.keys, c.values)
|
||||
}
|
||||
c.keys = snap.keys.Clone()
|
||||
c.values = snap.values.Clone()
|
||||
mlx.Pin(c.keys, c.values)
|
||||
c.offset = snap.toOffset
|
||||
c.idx = snap.idx
|
||||
|
||||
// Clamp to target if needed.
|
||||
if target < c.offset {
|
||||
c.offset = target
|
||||
c.idx = target
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) Merge(parent, child Snapshot) Snapshot {
|
||||
// For rotating caches, the child snapshot supersedes the parent
|
||||
// since it contains the full window state.
|
||||
if parent != nil {
|
||||
parent.Close()
|
||||
}
|
||||
return child
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) Split(snapshot Snapshot, at int) (Snapshot, Snapshot) {
|
||||
// Rotating cache snapshots contain the full window state.
|
||||
// Cannot cleanly split a ring buffer at an arbitrary point.
|
||||
return nil, snapshot
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) Free() {
|
||||
c.KVCache.Free()
|
||||
c.idx = 0
|
||||
}
|
||||
|
||||
// rotatingApplier composes the sliding-window storage restriction
|
||||
// onto the caller's logical mask.
|
||||
//
|
||||
// ringIdx is the cache's write cursor at Update time. At L=1 decode
|
||||
// the ring buffer is not position-ordered — logical col j lives at
|
||||
// storage slot (ringIdx+j) mod K — so tensor masks built in
|
||||
// logical space must be gathered into this layout before the kernel
|
||||
// sees them. At L>1 prefill the concat path has already linearised
|
||||
// storage, so the gather is identity and ringIdx is unused.
|
||||
type rotatingApplier struct {
|
||||
b *batch.Batch
|
||||
K int
|
||||
L int
|
||||
window int
|
||||
ringIdx int
|
||||
dtype mlx.DType
|
||||
}
|
||||
|
||||
func (r rotatingApplier) ApplyMask(logical nn.AttentionMask) nn.AttentionMask {
|
||||
if r.L == 1 {
|
||||
// Single-query decode: storage already enforces the window
|
||||
// (Update keeps the last maxSize tokens, all within
|
||||
// [absQ-window+1, absQ]), and every stored key's absolute
|
||||
// position <= absQ. For a zero or plain-causal logical mask
|
||||
// both constraints reduce to "no mask", so return the zero
|
||||
// mask and let SDPA dispatch to mode="".
|
||||
if logical.IsZero() || logical.IsCausal() {
|
||||
return nn.AttentionMask{}
|
||||
}
|
||||
|
||||
// Tensor-backed mask (user ArrayMask, causal+Relax, causal
|
||||
// with accumulated array): materialize in logical-position
|
||||
// order then gather K cols into ring-slot order so they
|
||||
// align with the cache output the kernel will index.
|
||||
arr := logical.AsArray(r.b, r.K, r.dtype)
|
||||
arr = gatherRingCols(arr, r.ringIdx, r.K)
|
||||
return nn.ArrayMask(arr)
|
||||
}
|
||||
|
||||
return logical.Intersect(nn.SlidingWindowMask(r.b, r.K, r.window, r.dtype))
|
||||
}
|
||||
|
||||
// gatherRingCols reorders a [B, 1, L, K] mask's K axis from
|
||||
// logical-position order (col 0 = oldest stored position) into the
|
||||
// cache's ring-slot order (col 0 = buffer slot 0). Logical col j
|
||||
// lives at slot (ringIdx+j) mod K, so storage slot s reads from
|
||||
// logical col (s-ringIdx+K) mod K. Returns arr unchanged when the
|
||||
// permutation is a no-op: ringIdx % K == 0 (layouts coincide), or
|
||||
// the K axis broadcasts (dim 3 == 1, i.e. Q-padding-shaped masks
|
||||
// where every key shares the same value).
|
||||
func gatherRingCols(arr *mlx.Array, ringIdx, K int) *mlx.Array {
|
||||
if w := arr.Dim(3); w != 1 && w != K {
|
||||
panic(fmt.Sprintf("gatherRingCols: K-axis width %d must be 1 or %d", w, K))
|
||||
}
|
||||
ringIdx %= K
|
||||
if ringIdx == 0 || arr.Dim(3) == 1 {
|
||||
return arr
|
||||
}
|
||||
shift := K - ringIdx
|
||||
tail := arr.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(), mlx.Slice(shift, K))
|
||||
head := arr.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(), mlx.Slice(0, shift))
|
||||
return tail.Concatenate(3, head)
|
||||
}
|
||||
283
x/mlxrunner/cache/cache_test.go
vendored
Normal file
283
x/mlxrunner/cache/cache_test.go
vendored
Normal file
@@ -0,0 +1,283 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/batch"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
func skipIfNoMLX(t *testing.T) {
|
||||
t.Helper()
|
||||
if err := mlx.CheckInit(); err != nil {
|
||||
t.Skipf("MLX not available: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// newKVBatch builds a B=1 batch at SeqOffsets=off with all-real
|
||||
// queries (SeqQueryLens=L) — the standard single-sequence cache
|
||||
// test shape.
|
||||
func newKVBatch(off, L int) *batch.Batch {
|
||||
return &batch.Batch{
|
||||
InputIDs: mlx.Zeros(mlx.DTypeInt32, 1, L),
|
||||
SeqOffsets: []int32{int32(off)},
|
||||
SeqQueryLens: []int32{int32(L)},
|
||||
}
|
||||
}
|
||||
|
||||
func TestKVCacheSnapshotRestoreNeedBase(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
c := NewKVCache()
|
||||
|
||||
for range 10 {
|
||||
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
c.Update(newKVBatch(c.Offset(), k.Dim(2)), k, v)
|
||||
}
|
||||
|
||||
// Snapshot [5, 10).
|
||||
snap := c.Snapshot(5)
|
||||
|
||||
// Free the cache completely — offset is now 0.
|
||||
c.Free()
|
||||
|
||||
// Restore should fail because cache doesn't have data up to fromOffset=5.
|
||||
if c.Restore(snap, 10) {
|
||||
t.Fatal("expected Restore to fail with no base data")
|
||||
}
|
||||
}
|
||||
|
||||
// TestKVCacheDataSurvivesSnapshotRestore verifies that actual array data
|
||||
// is preserved through a snapshot→free→restore cycle.
|
||||
func TestKVCacheDataSurvivesSnapshotRestore(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
c := NewKVCache()
|
||||
|
||||
for range 10 {
|
||||
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
c.Update(newKVBatch(c.Offset(), k.Dim(2)), k, v)
|
||||
}
|
||||
|
||||
snap := c.Snapshot(0)
|
||||
if snap == nil {
|
||||
t.Fatal("Snapshot returned nil")
|
||||
}
|
||||
|
||||
// Free and restore to a fresh cache.
|
||||
c2 := NewKVCache()
|
||||
if !c2.Restore(snap, 10) {
|
||||
t.Fatal("Restore failed")
|
||||
}
|
||||
if c2.Offset() != 10 {
|
||||
t.Fatalf("offset = %d, want 10", c2.Offset())
|
||||
}
|
||||
|
||||
// Verify State() returns arrays with correct sequence dimension.
|
||||
state := c2.State()
|
||||
if len(state) != 2 {
|
||||
t.Fatalf("State() returned %d arrays, want 2", len(state))
|
||||
}
|
||||
// keys shape: [B, H, seqLen, Dk]
|
||||
if state[0].Dim(2) != 10 {
|
||||
t.Fatalf("keys seq dim = %d, want 10", state[0].Dim(2))
|
||||
}
|
||||
if state[1].Dim(2) != 10 {
|
||||
t.Fatalf("values seq dim = %d, want 10", state[1].Dim(2))
|
||||
}
|
||||
}
|
||||
|
||||
// TestKVCacheSplitPreservesData verifies that split produces two snapshots
|
||||
// that can be sequentially restored to rebuild the original cache state.
|
||||
func TestKVCacheSplitPreservesData(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
c := NewKVCache()
|
||||
|
||||
for range 10 {
|
||||
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
c.Update(newKVBatch(c.Offset(), k.Dim(2)), k, v)
|
||||
}
|
||||
|
||||
snap := c.Snapshot(0)
|
||||
parent, child := c.Split(snap, 5)
|
||||
if parent == nil || child == nil {
|
||||
t.Fatal("Split returned nil")
|
||||
}
|
||||
|
||||
// Restore parent → offset=5, seq dim=5.
|
||||
c2 := NewKVCache()
|
||||
if !c2.Restore(parent, 5) {
|
||||
t.Fatal("Restore(parent) failed")
|
||||
}
|
||||
if c2.Offset() != 5 {
|
||||
t.Fatalf("offset after parent = %d, want 5", c2.Offset())
|
||||
}
|
||||
state := c2.State()
|
||||
if state[0].Dim(2) != 5 {
|
||||
t.Fatalf("keys seq dim after parent = %d, want 5", state[0].Dim(2))
|
||||
}
|
||||
|
||||
// Restore child on top → offset=10, seq dim=10.
|
||||
if !c2.Restore(child, 10) {
|
||||
t.Fatal("Restore(child) failed")
|
||||
}
|
||||
if c2.Offset() != 10 {
|
||||
t.Fatalf("offset after child = %d, want 10", c2.Offset())
|
||||
}
|
||||
state = c2.State()
|
||||
if state[0].Dim(2) != 10 {
|
||||
t.Fatalf("keys seq dim after child = %d, want 10", state[0].Dim(2))
|
||||
}
|
||||
}
|
||||
|
||||
// TestKVCacheSplitMergeRoundTripData verifies that splitting and merging back
|
||||
// produces a snapshot equivalent to the original.
|
||||
func TestKVCacheSplitMergeRoundTripData(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
c := NewKVCache()
|
||||
|
||||
for range 10 {
|
||||
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
c.Update(newKVBatch(c.Offset(), k.Dim(2)), k, v)
|
||||
}
|
||||
|
||||
snap := c.Snapshot(0)
|
||||
parent, child := c.Split(snap, 6)
|
||||
merged := c.Merge(parent, child)
|
||||
if merged == nil {
|
||||
t.Fatal("Merge returned nil")
|
||||
}
|
||||
|
||||
c2 := NewKVCache()
|
||||
if !c2.Restore(merged, 10) {
|
||||
t.Fatal("Restore(merged) failed")
|
||||
}
|
||||
if c2.Offset() != 10 {
|
||||
t.Fatalf("offset = %d, want 10", c2.Offset())
|
||||
}
|
||||
|
||||
state := c2.State()
|
||||
if state[0].Dim(2) != 10 {
|
||||
t.Fatalf("keys seq dim = %d, want 10", state[0].Dim(2))
|
||||
}
|
||||
if state[1].Dim(2) != 10 {
|
||||
t.Fatalf("values seq dim = %d, want 10", state[1].Dim(2))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRotatingKVCacheRestoreOutsideWindow(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
c := NewRotatingKVCache(4)
|
||||
|
||||
// Feed 10 tokens (window size 4, so positions 0-5 are evicted).
|
||||
for range 10 {
|
||||
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
c.Update(newKVBatch(c.Offset(), k.Dim(2)), k, v)
|
||||
}
|
||||
|
||||
// Offset 3 is outside the window.
|
||||
if c.Restore(nil, 3) {
|
||||
t.Fatal("Restore(nil, 3) should fail when outside window")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRotatingKVCacheSnapshotPreservesWindow verifies that after restoring
|
||||
// from a snapshot, the rotating cache has the correct window of data.
|
||||
func TestRotatingKVCacheSnapshotPreservesWindow(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
c := NewRotatingKVCache(4)
|
||||
|
||||
// Feed 10 tokens one at a time. Window size 4, so only last 4 are kept.
|
||||
for range 10 {
|
||||
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
c.Update(newKVBatch(c.Offset(), k.Dim(2)), k, v)
|
||||
}
|
||||
|
||||
snap := c.Snapshot(0)
|
||||
if snap == nil {
|
||||
t.Fatal("Snapshot returned nil")
|
||||
}
|
||||
|
||||
// Feed 5 more tokens.
|
||||
for range 5 {
|
||||
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
c.Update(newKVBatch(c.Offset(), k.Dim(2)), k, v)
|
||||
}
|
||||
|
||||
// Restore to offset 10.
|
||||
if !c.Restore(snap, 10) {
|
||||
t.Fatal("Restore failed")
|
||||
}
|
||||
if c.Offset() != 10 {
|
||||
t.Fatalf("offset = %d, want 10", c.Offset())
|
||||
}
|
||||
|
||||
state := c.State()
|
||||
if len(state) != 2 {
|
||||
t.Fatalf("State() returned %d arrays, want 2", len(state))
|
||||
}
|
||||
// Seq dim should be min(offset, maxSize) = min(10, 4) = 4.
|
||||
seqDim := state[0].Dim(2)
|
||||
if seqDim != 4 {
|
||||
t.Fatalf("keys seq dim = %d, want 4 (window size)", seqDim)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRotatingKVCacheRestoreFromSnapshot verifies that restoring from a
|
||||
// snapshot correctly preserves the write position (idx), so subsequent
|
||||
// single-token updates land in the right buffer slot.
|
||||
func TestRotatingKVCacheRestoreFromSnapshot(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
c := NewRotatingKVCache(4)
|
||||
|
||||
// Fill the window: 6 tokens into a size-4 window.
|
||||
// After this, idx has wrapped and the buffer has rotated.
|
||||
for range 6 {
|
||||
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
c.Update(newKVBatch(c.Offset(), k.Dim(2)), k, v)
|
||||
}
|
||||
if c.Offset() != 6 {
|
||||
t.Fatalf("offset = %d, want 6", c.Offset())
|
||||
}
|
||||
|
||||
snap := c.Snapshot(0)
|
||||
|
||||
// Mutate the cache further so live state diverges from snapshot.
|
||||
for range 3 {
|
||||
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
c.Update(newKVBatch(c.Offset(), k.Dim(2)), k, v)
|
||||
}
|
||||
|
||||
// Restore to snapshot state.
|
||||
if !c.Restore(snap, 6) {
|
||||
t.Fatal("Restore failed")
|
||||
}
|
||||
if c.Offset() != 6 {
|
||||
t.Fatalf("offset after restore = %d, want 6", c.Offset())
|
||||
}
|
||||
|
||||
// Feed one more token. If idx was restored correctly, this should
|
||||
// produce a valid window of size 4 at offset 7.
|
||||
k := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
v := mlx.Zeros(mlx.DTypeFloat16, 1, 4, 1, 8)
|
||||
c.Update(newKVBatch(c.Offset(), k.Dim(2)), k, v)
|
||||
|
||||
if c.Offset() != 7 {
|
||||
t.Fatalf("offset after post-restore update = %d, want 7", c.Offset())
|
||||
}
|
||||
state := c.State()
|
||||
if len(state) != 2 {
|
||||
t.Fatalf("State() returned %d arrays, want 2", len(state))
|
||||
}
|
||||
seqDim := state[0].Dim(2)
|
||||
if seqDim != 4 {
|
||||
t.Fatalf("keys seq dim = %d, want 4 (window size)", seqDim)
|
||||
}
|
||||
}
|
||||
285
x/mlxrunner/cache/recurrent.go
vendored
Normal file
285
x/mlxrunner/cache/recurrent.go
vendored
Normal file
@@ -0,0 +1,285 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/x/mlxrunner/batch"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/models/nn"
|
||||
)
|
||||
|
||||
// Recurrent is the contract for caches that back recurrent linear-attention layers.
|
||||
type Recurrent interface {
|
||||
Cache
|
||||
Get(b *batch.Batch, dtype mlx.DType) *nn.RecurrentHistory
|
||||
Put(b *batch.Batch, newConv, newDelta *mlx.Array)
|
||||
}
|
||||
|
||||
// RecurrentRecorder records the per-token scan inputs needed to commit an
|
||||
// accepted prefix after a speculative recurrent forward.
|
||||
type RecurrentRecorder interface {
|
||||
Record(qkv, q, k, v, gDecay, beta *mlx.Array)
|
||||
}
|
||||
|
||||
// RecurrentCache stores state for linear-recurrent layers.
|
||||
//
|
||||
// Conv state shape: [B, convTail, convDim]
|
||||
// Delta state shape: [B, numVHeads, headVDim, headKDim]
|
||||
type RecurrentCache struct {
|
||||
convState *mlx.Array
|
||||
deltaState *mlx.Array
|
||||
offset int
|
||||
|
||||
convTail int
|
||||
convDim int
|
||||
numVHeads int
|
||||
headVDim int
|
||||
headKDim int
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) setState(old, v *mlx.Array, contiguous bool) *mlx.Array {
|
||||
if v == nil || !v.Valid() {
|
||||
return old
|
||||
}
|
||||
|
||||
if contiguous {
|
||||
v = mlx.Contiguous(v, false)
|
||||
}
|
||||
v = v.Clone()
|
||||
|
||||
mlx.Pin(v)
|
||||
mlx.Unpin(old)
|
||||
|
||||
return v
|
||||
}
|
||||
|
||||
func NewRecurrentCache(convTail, convDim, numVHeads, headVDim, headKDim int32) *RecurrentCache {
|
||||
return &RecurrentCache{
|
||||
convTail: int(convTail),
|
||||
convDim: int(convDim),
|
||||
numVHeads: int(numVHeads),
|
||||
headVDim: int(headVDim),
|
||||
headKDim: int(headKDim),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) ensure(batch int, dtype mlx.DType) {
|
||||
if batch <= 0 {
|
||||
batch = 1
|
||||
}
|
||||
|
||||
// Keep the gated-delta recurrent state in float32 even when activations are
|
||||
// bf16/fp16. The convolution tail stays in the activation dtype.
|
||||
deltaDType := mlx.DTypeFloat32
|
||||
needConv := c.convState == nil || !c.convState.Valid() || c.convState.DType() != dtype ||
|
||||
c.convState.Dim(0) != batch || c.convState.Dim(1) != c.convTail || c.convState.Dim(2) != c.convDim
|
||||
needDelta := c.deltaState == nil || !c.deltaState.Valid() || c.deltaState.DType() != deltaDType ||
|
||||
c.deltaState.Dim(0) != batch || c.deltaState.Dim(1) != c.numVHeads || c.deltaState.Dim(2) != c.headVDim || c.deltaState.Dim(3) != c.headKDim
|
||||
if !needConv && !needDelta {
|
||||
return
|
||||
}
|
||||
|
||||
if needConv {
|
||||
c.convState = c.setState(c.convState, mlx.Zeros(dtype, batch, c.convTail, c.convDim), false)
|
||||
}
|
||||
if needDelta {
|
||||
c.deltaState = c.setState(c.deltaState, mlx.Zeros(deltaDType, batch, c.numVHeads, c.headVDim, c.headKDim), false)
|
||||
}
|
||||
}
|
||||
|
||||
// Get returns the current conv/delta state for the SSM layer's read
|
||||
// phase. Lazy-initializes zero-filled state tensors using b.InputIDs
|
||||
// for the batch size; reallocates if the existing state's batch size
|
||||
// or dtype no longer matches.
|
||||
func (c *RecurrentCache) Get(b *batch.Batch, dtype mlx.DType) *nn.RecurrentHistory {
|
||||
c.ensure(b.InputIDs.Dim(0), dtype)
|
||||
return nn.NewRecurrentHistory(c.convState, c.deltaState)
|
||||
}
|
||||
|
||||
// Put stores the post-computation conv/delta states for the SSM
|
||||
// layer's write phase and advances the cache offset by the current
|
||||
// forward's real token count.
|
||||
//
|
||||
// Assumes B = 1; heterogeneous batches are not supported.
|
||||
func (c *RecurrentCache) Put(b *batch.Batch, newConv, newDelta *mlx.Array) {
|
||||
c.convState = c.setState(c.convState, newConv, true)
|
||||
c.deltaState = c.setState(c.deltaState, newDelta, false)
|
||||
c.offset += int(b.SeqQueryLens[0])
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) State() []*mlx.Array {
|
||||
return []*mlx.Array{c.convState, c.deltaState}
|
||||
}
|
||||
|
||||
// recurrentSnapshot holds paged-out recurrent state. Self-contained —
|
||||
// does not depend on any parent state.
|
||||
type recurrentSnapshot struct {
|
||||
convState, deltaState *mlx.Array
|
||||
offset int
|
||||
}
|
||||
|
||||
func (s *recurrentSnapshot) Size() int { return s.convState.NumBytes() + s.deltaState.NumBytes() }
|
||||
func (s *recurrentSnapshot) Close() { mlx.Unpin(s.convState, s.deltaState) }
|
||||
|
||||
func (c *RecurrentCache) Snapshot(fromOffset int) Snapshot {
|
||||
// Recurrent state is not position-sliceable — always snapshot the full state.
|
||||
if c.convState == nil && c.deltaState == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
snap := &recurrentSnapshot{offset: c.offset}
|
||||
snap.convState = c.convState.Clone()
|
||||
snap.deltaState = c.deltaState.Clone()
|
||||
mlx.Pin(snap.convState, snap.deltaState)
|
||||
|
||||
return snap
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) Restore(snapshot Snapshot, target int) bool {
|
||||
if snapshot == nil {
|
||||
// Recurrent state is cumulative and can't rewind. Only succeed
|
||||
// if we're already at the target (no-op).
|
||||
return target == c.offset
|
||||
}
|
||||
|
||||
snap := snapshot.(*recurrentSnapshot)
|
||||
|
||||
// Recurrent snapshots encode cumulative state up to exactly
|
||||
// snap.offset. Target must match — rewinding would leave stale
|
||||
// state, and advancing isn't possible without feeding tokens.
|
||||
if target != snap.offset {
|
||||
return false
|
||||
}
|
||||
|
||||
c.convState = c.setState(c.convState, snap.convState, false)
|
||||
c.deltaState = c.setState(c.deltaState, snap.deltaState, false)
|
||||
c.offset = snap.offset
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) Merge(parent, child Snapshot) Snapshot {
|
||||
// Recurrent snapshots are self-contained — child supersedes parent.
|
||||
if parent != nil {
|
||||
parent.Close()
|
||||
}
|
||||
return child
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) Split(snapshot Snapshot, at int) (Snapshot, Snapshot) {
|
||||
// Recurrent state is cumulative and not position-sliceable.
|
||||
// Cannot recover intermediate state at the split point.
|
||||
return nil, snapshot
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) Free() {
|
||||
mlx.Unpin(c.convState, c.deltaState)
|
||||
c.convState, c.deltaState = nil, nil
|
||||
c.offset = 0
|
||||
}
|
||||
|
||||
func (c *RecurrentCache) Offset() int { return c.offset }
|
||||
|
||||
type speculativeRecurrentCache struct {
|
||||
speculativeBase
|
||||
target *RecurrentCache
|
||||
|
||||
start int
|
||||
|
||||
initialConv *mlx.Array
|
||||
initialDelta *mlx.Array
|
||||
|
||||
qkv, q, k, v, gDecay, beta *mlx.Array
|
||||
fullConv, fullDelta *mlx.Array
|
||||
length int
|
||||
}
|
||||
|
||||
func newSpeculativeRecurrentCache(target *RecurrentCache) *speculativeRecurrentCache {
|
||||
return &speculativeRecurrentCache{
|
||||
speculativeBase: speculativeBase{offset: target.Offset()},
|
||||
target: target,
|
||||
start: target.Offset(),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *speculativeRecurrentCache) Get(b *batch.Batch, dtype mlx.DType) *nn.RecurrentHistory {
|
||||
if c.fullConv != nil && c.fullDelta != nil {
|
||||
return nn.NewRecurrentHistory(c.fullConv, c.fullDelta)
|
||||
}
|
||||
|
||||
history := c.target.Get(b, dtype)
|
||||
if c.initialConv == nil {
|
||||
c.initialConv = history.ConvState()
|
||||
}
|
||||
if c.initialDelta == nil {
|
||||
c.initialDelta = history.DeltaState()
|
||||
}
|
||||
return history
|
||||
}
|
||||
|
||||
func (c *speculativeRecurrentCache) Record(qkv, q, k, v, gDecay, beta *mlx.Array) {
|
||||
c.qkv, c.q, c.k, c.v, c.gDecay, c.beta = qkv, q, k, v, gDecay, beta
|
||||
if qkv != nil {
|
||||
c.length = qkv.Dim(1)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *speculativeRecurrentCache) Put(b *batch.Batch, newConv, newDelta *mlx.Array) {
|
||||
c.fullConv, c.fullDelta = newConv, newDelta
|
||||
c.offset += int(b.SeqQueryLens[0])
|
||||
}
|
||||
|
||||
func (c *speculativeRecurrentCache) State() []*mlx.Array {
|
||||
if c.fullConv != nil && c.fullDelta != nil {
|
||||
return []*mlx.Array{c.fullConv, c.fullDelta}
|
||||
}
|
||||
return c.target.State()
|
||||
}
|
||||
|
||||
func (c *speculativeRecurrentCache) commit(n int) {
|
||||
if n <= 0 {
|
||||
return
|
||||
}
|
||||
if c.length > 0 && n > c.length {
|
||||
n = c.length
|
||||
}
|
||||
|
||||
if c.length > 0 && n == c.length && c.fullConv != nil && c.fullDelta != nil {
|
||||
c.target.convState = c.target.setState(c.target.convState, c.fullConv, true)
|
||||
c.target.deltaState = c.target.setState(c.target.deltaState, c.fullDelta, false)
|
||||
c.target.offset = c.start + n
|
||||
return
|
||||
}
|
||||
|
||||
if c.initialConv == nil || c.initialDelta == nil || c.qkv == nil || c.q == nil || c.k == nil || c.v == nil || c.gDecay == nil || c.beta == nil {
|
||||
return
|
||||
}
|
||||
|
||||
qkv := sliceSeq(c.qkv, n)
|
||||
convConcat := mlx.Concatenate([]*mlx.Array{c.initialConv, qkv}, 1)
|
||||
total := convConcat.Dim(1)
|
||||
nextConv := convConcat.Slice(mlx.Slice(), mlx.Slice(total-c.target.convTail, total), mlx.Slice())
|
||||
|
||||
_, delta := mlx.FastGatedDelta(
|
||||
sliceSeq(c.q, n),
|
||||
sliceSeq(c.k, n),
|
||||
sliceSeq(c.v, n),
|
||||
sliceSeq(c.gDecay, n),
|
||||
sliceSeq(c.beta, n),
|
||||
c.initialDelta,
|
||||
nil,
|
||||
)
|
||||
|
||||
c.target.convState = c.target.setState(c.target.convState, nextConv, true)
|
||||
c.target.deltaState = c.target.setState(c.target.deltaState, delta, false)
|
||||
c.target.offset = c.start + n
|
||||
}
|
||||
|
||||
func sliceSeq(a *mlx.Array, n int) *mlx.Array {
|
||||
switch a.NumDims() {
|
||||
case 3:
|
||||
return a.Slice(mlx.Slice(), mlx.Slice(0, n), mlx.Slice())
|
||||
case 4:
|
||||
return a.Slice(mlx.Slice(), mlx.Slice(0, n), mlx.Slice(), mlx.Slice())
|
||||
default:
|
||||
panic("recurrent speculative sequence tensor must be rank 3 or 4")
|
||||
}
|
||||
}
|
||||
284
x/mlxrunner/cache/recurrent_test.go
vendored
Normal file
284
x/mlxrunner/cache/recurrent_test.go
vendored
Normal file
@@ -0,0 +1,284 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/batch"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/models/nn"
|
||||
)
|
||||
|
||||
// TestRecurrentCacheRestoreExactOffset verifies that RecurrentCache restore
|
||||
// only succeeds when target exactly matches the snapshot's offset. Recurrent
|
||||
// state is cumulative, so it can't be rewound or fast-forwarded.
|
||||
func TestRecurrentCacheRestoreExactOffset(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
c := NewRecurrentCache(3, 12, 4, 8, 8)
|
||||
b1 := &batch.Batch{InputIDs: mlx.Zeros(mlx.DTypeInt32, 1, 1)}
|
||||
c.Get(b1, mlx.DTypeFloat16) // lazy-init
|
||||
|
||||
b10 := &batch.Batch{InputIDs: mlx.Zeros(mlx.DTypeInt32, 1, 10), SeqQueryLens: []int32{10}}
|
||||
c.Put(b10, nil, nil) // advance to 10
|
||||
|
||||
snap := c.Snapshot(0) // snap.offset == 10
|
||||
|
||||
b5 := &batch.Batch{InputIDs: mlx.Zeros(mlx.DTypeInt32, 1, 5), SeqQueryLens: []int32{5}}
|
||||
c.Put(b5, nil, nil) // cache now at 15
|
||||
|
||||
// target < snap.offset: fails (can't rewind past snapshot)
|
||||
if c.Restore(snap, 5) {
|
||||
t.Fatal("Restore(snap, 5) should fail — target != snap.offset")
|
||||
}
|
||||
|
||||
// target > snap.offset: fails (can't advance without feeding tokens)
|
||||
if c.Restore(snap, 15) {
|
||||
t.Fatal("Restore(snap, 15) should fail — target != snap.offset")
|
||||
}
|
||||
|
||||
// target == snap.offset: succeeds
|
||||
if !c.Restore(snap, 10) {
|
||||
t.Fatal("Restore(snap, 10) should succeed — target == snap.offset")
|
||||
}
|
||||
if c.Offset() != 10 {
|
||||
t.Fatalf("offset = %d, want 10", c.Offset())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecurrentCacheGetLazyInit(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
c := NewRecurrentCache(3, 4, 2, 4, 4)
|
||||
b := &batch.Batch{
|
||||
InputIDs: mlx.Zeros(mlx.DTypeInt32, 1, 1),
|
||||
SeqOffsets: []int32{0},
|
||||
SeqQueryLens: []int32{1},
|
||||
}
|
||||
h := c.Get(b, mlx.DTypeBFloat16)
|
||||
if c.Offset() != 0 {
|
||||
t.Fatalf("Get should not advance; got offset %d", c.Offset())
|
||||
}
|
||||
if h.ConvState() == nil || h.DeltaState() == nil {
|
||||
t.Fatal("history should expose conv/delta tensors")
|
||||
}
|
||||
if got := h.ConvState().DType(); got != mlx.DTypeBFloat16 {
|
||||
t.Fatalf("conv state dtype = %v, want %v", got, mlx.DTypeBFloat16)
|
||||
}
|
||||
if got := h.DeltaState().DType(); got != mlx.DTypeFloat32 {
|
||||
t.Fatalf("delta state dtype = %v, want %v", got, mlx.DTypeFloat32)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpeculativeRecurrentCacheUsesStagedState(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
target := NewRecurrentCache(2, 3, 1, 2, 3)
|
||||
caches, ok := BeginIsolatedSpeculation([]Cache{target})
|
||||
if !ok {
|
||||
t.Fatal("BeginIsolatedSpeculation failed")
|
||||
}
|
||||
c := caches[0].(*speculativeRecurrentCache)
|
||||
b := &batch.Batch{
|
||||
InputIDs: mlx.Zeros(mlx.DTypeInt32, 1, 1),
|
||||
SeqOffsets: []int32{0},
|
||||
SeqQueryLens: []int32{1},
|
||||
}
|
||||
|
||||
c.Get(b, mlx.DTypeFloat32)
|
||||
|
||||
convVals := []float32{1, 2, 3, 4, 5, 6}
|
||||
deltaVals := []float32{7, 8, 9, 10, 11, 12}
|
||||
nextConv := mlx.FromValues(convVals, 1, 2, 3)
|
||||
nextDelta := mlx.FromValues(deltaVals, 1, 1, 2, 3)
|
||||
c.Put(b, nextConv, nextDelta)
|
||||
|
||||
h := c.Get(b, mlx.DTypeFloat32)
|
||||
state := c.State()
|
||||
if len(state) != 2 {
|
||||
t.Fatalf("State() returned %d arrays, want 2", len(state))
|
||||
}
|
||||
|
||||
assertArray := func(name string, got, want *mlx.Array) {
|
||||
t.Helper()
|
||||
if got != want {
|
||||
t.Fatalf("%s = %p, want %p", name, got, want)
|
||||
}
|
||||
}
|
||||
assertArray("history conv", h.ConvState(), nextConv)
|
||||
assertArray("history delta", h.DeltaState(), nextDelta)
|
||||
assertArray("state conv", state[0], nextConv)
|
||||
assertArray("state delta", state[1], nextDelta)
|
||||
|
||||
if got := c.Offset(); got != 1 {
|
||||
t.Fatalf("speculative offset = %d, want 1", got)
|
||||
}
|
||||
if got := target.Offset(); got != 0 {
|
||||
t.Fatalf("target offset = %d, want 0", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRecurrentCachePaddedRoundTrip runs Get → CausalConv1D →
|
||||
// GatedDelta → Put on a B=1 batch with qLen<L, then again on a
|
||||
// fresh cache with an unpadded length-qLen batch using the same
|
||||
// real prefix. After the call, Offset() must equal qLen (not L),
|
||||
// and the resulting cache state must match the unpadded equivalent.
|
||||
// Pins the recurrent contract: a forward with padding produces the
|
||||
// same end-state as a forward with the real-prefix-only input.
|
||||
func TestRecurrentCachePaddedRoundTrip(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
const convTail, convDim = 2, 6
|
||||
const numVHeads, headVDim, headKDim = 1, 4, 6
|
||||
const L = 4
|
||||
const qLen = 2
|
||||
|
||||
// Use distinct values for the real prefix and the padded tail so
|
||||
// we can detect any leak from padded positions into the result.
|
||||
makeQKV := func(seed float32, T int) (q, k, v *mlx.Array) {
|
||||
mkLast := func(off float32, T, n, d int) *mlx.Array {
|
||||
vals := make([]float32, 1*T*n*d)
|
||||
for i := range vals {
|
||||
vals[i] = off + 0.05*float32(i)
|
||||
}
|
||||
return mlx.FromValues(vals, 1, T, n, d)
|
||||
}
|
||||
q = mkLast(seed, T, 1, headKDim)
|
||||
k = mkLast(seed+0.1, T, 1, headKDim)
|
||||
v = mkLast(seed+0.2, T, numVHeads, headVDim)
|
||||
return
|
||||
}
|
||||
makeGB := func(seed float32, T int) (g, beta *mlx.Array) {
|
||||
gVals := make([]float32, 1*T*numVHeads)
|
||||
bVals := make([]float32, 1*T*numVHeads)
|
||||
for i := range gVals {
|
||||
gVals[i] = seed + 0.01*float32(i)
|
||||
bVals[i] = seed - 0.02*float32(i)
|
||||
}
|
||||
g = mlx.FromValues(gVals, 1, T, numVHeads)
|
||||
beta = mlx.FromValues(bVals, 1, T, numVHeads)
|
||||
return
|
||||
}
|
||||
makeQKVPadded := func() (q, k, v *mlx.Array) {
|
||||
qReal, kReal, vReal := makeQKV(0.3, qLen)
|
||||
// Distinct, large junk values in the padded tail to surface
|
||||
// any leak (real outputs are O(1)).
|
||||
qPad := mlx.AddScalar(mlx.Zeros(mlx.DTypeFloat32, 1, L-qLen, 1, headKDim), 99)
|
||||
kPad := mlx.AddScalar(mlx.Zeros(mlx.DTypeFloat32, 1, L-qLen, 1, headKDim), 99)
|
||||
vPad := mlx.AddScalar(mlx.Zeros(mlx.DTypeFloat32, 1, L-qLen, numVHeads, headVDim), 99)
|
||||
q = mlx.Concatenate([]*mlx.Array{qReal, qPad}, 1)
|
||||
k = mlx.Concatenate([]*mlx.Array{kReal, kPad}, 1)
|
||||
v = mlx.Concatenate([]*mlx.Array{vReal, vPad}, 1)
|
||||
return
|
||||
}
|
||||
makeGBPadded := func() (g, beta *mlx.Array) {
|
||||
gReal, betaReal := makeGB(0.1, qLen)
|
||||
gPad := mlx.AddScalar(mlx.Zeros(mlx.DTypeFloat32, 1, L-qLen, numVHeads), 99)
|
||||
betaPad := mlx.AddScalar(mlx.Zeros(mlx.DTypeFloat32, 1, L-qLen, numVHeads), 99)
|
||||
g = mlx.Concatenate([]*mlx.Array{gReal, gPad}, 1)
|
||||
beta = mlx.Concatenate([]*mlx.Array{betaReal, betaPad}, 1)
|
||||
return
|
||||
}
|
||||
|
||||
// The conv input dimension must match the cache's convDim.
|
||||
mkConvInput := func(seed float32, T int) *mlx.Array {
|
||||
vals := make([]float32, 1*T*convDim)
|
||||
for i := range vals {
|
||||
vals[i] = seed + 0.05*float32(i)
|
||||
}
|
||||
return mlx.FromValues(vals, 1, T, convDim)
|
||||
}
|
||||
mkWeight := func(seed float32) *mlx.Array {
|
||||
vals := make([]float32, convDim*(convTail+1))
|
||||
for i := range vals {
|
||||
vals[i] = seed + 0.1*float32(i)
|
||||
}
|
||||
return mlx.FromValues(vals, convDim, convTail+1)
|
||||
}
|
||||
weight := mkWeight(0.2)
|
||||
|
||||
runForward := func(c *RecurrentCache, b *batch.Batch, T int) (*mlx.Array, *mlx.Array) {
|
||||
var convInput *mlx.Array
|
||||
if T == L {
|
||||
realPart := mkConvInput(0.4, qLen)
|
||||
padPart := mlx.AddScalar(mlx.Zeros(mlx.DTypeFloat32, 1, T-qLen, convDim), 99)
|
||||
convInput = mlx.Concatenate([]*mlx.Array{realPart, padPart}, 1)
|
||||
} else {
|
||||
convInput = mkConvInput(0.4, T)
|
||||
}
|
||||
|
||||
history := c.Get(b, mlx.DTypeFloat32)
|
||||
_, nextConv := nn.CausalConv1D(b, convInput, nil, weight, convTail,
|
||||
nn.WithRecurrentHistory(history))
|
||||
|
||||
var q, k, v, g, beta *mlx.Array
|
||||
if T == L {
|
||||
q, k, v = makeQKVPadded()
|
||||
g, beta = makeGBPadded()
|
||||
} else {
|
||||
q, k, v = makeQKV(0.3, T)
|
||||
g, beta = makeGB(0.1, T)
|
||||
}
|
||||
_, newDelta := nn.GatedDelta(b, q, k, v, g, beta,
|
||||
nn.WithRecurrentHistory(history))
|
||||
|
||||
c.Put(b, nextConv, newDelta)
|
||||
return nextConv, newDelta
|
||||
}
|
||||
|
||||
// Padded forward.
|
||||
cPad := NewRecurrentCache(convTail, convDim, numVHeads, headVDim, headKDim)
|
||||
bPad := &batch.Batch{
|
||||
InputIDs: mlx.Zeros(mlx.DTypeInt32, 1, L),
|
||||
SeqOffsets: []int32{0},
|
||||
SeqQueryLens: []int32{int32(qLen)},
|
||||
}
|
||||
nextConvPad, deltaPad := runForward(cPad, bPad, L)
|
||||
mlx.Eval(nextConvPad, deltaPad)
|
||||
if got := cPad.Offset(); got != qLen {
|
||||
t.Fatalf("padded forward: Offset() = %d, want %d (must advance by SeqQueryLens, not L)", got, qLen)
|
||||
}
|
||||
|
||||
// Unpadded reference.
|
||||
cRef := NewRecurrentCache(convTail, convDim, numVHeads, headVDim, headKDim)
|
||||
bRef := &batch.Batch{
|
||||
InputIDs: mlx.Zeros(mlx.DTypeInt32, 1, qLen),
|
||||
SeqOffsets: []int32{0},
|
||||
SeqQueryLens: []int32{int32(qLen)},
|
||||
}
|
||||
nextConvRef, deltaRef := runForward(cRef, bRef, qLen)
|
||||
mlx.Eval(nextConvRef, deltaRef)
|
||||
if got := cRef.Offset(); got != qLen {
|
||||
t.Fatalf("unpadded forward: Offset() = %d, want %d", got, qLen)
|
||||
}
|
||||
|
||||
gp := nextConvPad.Floats()
|
||||
gr := nextConvRef.Floats()
|
||||
if len(gp) != len(gr) {
|
||||
t.Fatalf("nextConv shape mismatch: padded %d vs unpadded %d", len(gp), len(gr))
|
||||
}
|
||||
for i := range gp {
|
||||
if math.Abs(float64(gp[i]-gr[i])) > 1e-4 {
|
||||
t.Fatalf("nextConv[%d]: padded=%v unpadded=%v (padding leaked into conv state)", i, gp[i], gr[i])
|
||||
}
|
||||
}
|
||||
|
||||
dp := deltaPad.Floats()
|
||||
dr := deltaRef.Floats()
|
||||
if len(dp) != len(dr) {
|
||||
t.Fatalf("delta state shape mismatch: padded %d vs unpadded %d", len(dp), len(dr))
|
||||
}
|
||||
for i := range dp {
|
||||
if math.Abs(float64(dp[i]-dr[i])) > 1e-3 {
|
||||
t.Fatalf("delta state[%d]: padded=%v unpadded=%v (padding leaked into recurrent state)", i, dp[i], dr[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecurrentCachePutAdvances(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
c := NewRecurrentCache(3, 4, 2, 4, 4)
|
||||
b := &batch.Batch{InputIDs: mlx.Zeros(mlx.DTypeInt32, 1, 2), SeqQueryLens: []int32{2}}
|
||||
newConv := mlx.Zeros(mlx.DTypeFloat16, 1, 3, 4)
|
||||
newDelta := mlx.Zeros(mlx.DTypeFloat16, 1, 2, 4, 4)
|
||||
c.Put(b, newConv, newDelta)
|
||||
if c.Offset() != 2 {
|
||||
t.Fatalf("cache offset not advanced: %d", c.Offset())
|
||||
}
|
||||
}
|
||||
292
x/mlxrunner/cache/rotating_attention_test.go
vendored
Normal file
292
x/mlxrunner/cache/rotating_attention_test.go
vendored
Normal file
@@ -0,0 +1,292 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/batch"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/models/nn"
|
||||
)
|
||||
|
||||
// TestRotatingKVCacheDecodeParity drives a rotating cache past its
|
||||
// wrap point with single-token writes, runs an L=1 decode through
|
||||
// SDPA, and compares against a reference computed from the same per-
|
||||
// position K/V in logical-position order with the same caller mask.
|
||||
//
|
||||
// Attention is permutation-invariant when K, V, and the mask are
|
||||
// permuted together, so the cache's storage-order output (with the
|
||||
// applier's gather composing the caller's logical mask back into
|
||||
// storage order) must equal the logical-order reference.
|
||||
func TestRotatingKVCacheDecodeParity(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
const H, D = 1, 4
|
||||
const window = 4
|
||||
const totalWrites = 7 // past wrap (window=4); last write is the L=1 decode
|
||||
const scale = 1.0
|
||||
|
||||
// Per-position k, v values. Use distinct seeds so the per-position
|
||||
// values are clearly distinguishable.
|
||||
perPosKV := func(pos int) (k, v *mlx.Array) {
|
||||
kVals := make([]float32, H*D)
|
||||
vVals := make([]float32, H*D)
|
||||
for i := range kVals {
|
||||
kVals[i] = 0.1*float32(pos+1) + 0.01*float32(i)
|
||||
vVals[i] = -0.1*float32(pos+1) + 0.01*float32(i)
|
||||
}
|
||||
k = mlx.FromValues(kVals, 1, H, 1, D)
|
||||
v = mlx.FromValues(vVals, 1, H, 1, D)
|
||||
return
|
||||
}
|
||||
|
||||
q := mlx.FromValues([]float32{0.7, -0.4, 0.2, 0.9}, 1, H, 1, D)
|
||||
mlx.Eval(q)
|
||||
|
||||
// Drive the cache: write positions 0..totalWrites-2 as a "history",
|
||||
// then position totalWrites-1 is the actual L=1 decode under test.
|
||||
c := NewRotatingKVCache(window)
|
||||
for pos := range totalWrites - 1 {
|
||||
k, v := perPosKV(pos)
|
||||
c.Update(newKVBatch(c.Offset(), k.Dim(2)), k, v)
|
||||
}
|
||||
|
||||
finalPos := totalWrites - 1
|
||||
kFinal, vFinal := perPosKV(finalPos)
|
||||
b := &batch.Batch{
|
||||
InputIDs: mlx.Zeros(mlx.DTypeInt32, 1, 1),
|
||||
SeqOffsets: []int32{int32(finalPos)},
|
||||
SeqQueryLens: []int32{1},
|
||||
}
|
||||
history := c.Update(b, kFinal, vFinal)
|
||||
|
||||
// Reference: the in-window logical-position-ordered K and V are
|
||||
// the last `window` per-position values (positions
|
||||
// [finalPos-window+1, finalPos]). Build them in that order.
|
||||
startPos := max(finalPos-window+1, 0)
|
||||
logicalKs := make([]*mlx.Array, 0, window)
|
||||
logicalVs := make([]*mlx.Array, 0, window)
|
||||
for pos := startPos; pos <= finalPos; pos++ {
|
||||
kp, vp := perPosKV(pos)
|
||||
logicalKs = append(logicalKs, kp)
|
||||
logicalVs = append(logicalVs, vp)
|
||||
}
|
||||
kLogical := mlx.Concatenate(logicalKs, 2)
|
||||
vLogical := mlx.Concatenate(logicalVs, 2)
|
||||
|
||||
// A logical-order ArrayMask with distinct, non-trivial values per
|
||||
// key column. Picked so each column's contribution to softmax is
|
||||
// distinct — the test fails if the cache's gather permutes the
|
||||
// columns wrong before the kernel sees them.
|
||||
maskVals := []float32{0.1, -0.3, 0.7, -0.2}
|
||||
logicalMask := mlx.FromValues(maskVals, 1, 1, 1, window)
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
model nn.AttentionMask
|
||||
// reference mask uses the same coordinates the model mask
|
||||
// represents; for ArrayMask it's the same tensor (since the
|
||||
// reference K/V is in logical order).
|
||||
refMode string
|
||||
refMask *mlx.Array
|
||||
}{
|
||||
{"zero", nn.AttentionMask{}, "", nil},
|
||||
{"causal-at-L1", nn.CausalMask(), "", nil},
|
||||
{"array", nn.ArrayMask(logicalMask), "array", logicalMask},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := nn.ScaledDotProductAttention(b, q, scale,
|
||||
nn.WithKVHistory(history),
|
||||
nn.WithMask(tc.model))
|
||||
|
||||
want := mlx.FastScaledDotProductAttention(q, kLogical, vLogical, scale,
|
||||
tc.refMode, tc.refMask)
|
||||
|
||||
mlx.Eval(got, want)
|
||||
gs, ws := got.Floats(), want.Floats()
|
||||
for i := range ws {
|
||||
if math.Abs(float64(gs[i]-ws[i])) > 1e-5 {
|
||||
t.Fatalf("index %d: got %v, want %v", i, gs[i], ws[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAssistantSharedHistoryL1MasksMatchNoMask(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
if !mlx.MetalIsAvailable() {
|
||||
t.Skip("MLX Metal not available")
|
||||
}
|
||||
const H, D = 1, 4
|
||||
const window = 4
|
||||
const total = 7
|
||||
const scale = 1.0
|
||||
|
||||
q := mlx.FromValues([]float32{0.7, -0.4, 0.2, 0.9}, 1, H, 1, D)
|
||||
mlx.Eval(q)
|
||||
|
||||
full := NewKVCache()
|
||||
sliding := NewRotatingKVCache(window)
|
||||
for pos := range total {
|
||||
kVals := make([]float32, H*D)
|
||||
vVals := make([]float32, H*D)
|
||||
for i := range kVals {
|
||||
kVals[i] = 0.1*float32(pos+1) + 0.01*float32(i)
|
||||
vVals[i] = -0.1*float32(pos+1) + 0.01*float32(i)
|
||||
}
|
||||
k := mlx.FromValues(kVals, 1, H, 1, D)
|
||||
v := mlx.FromValues(vVals, 1, H, 1, D)
|
||||
full.Update(newKVBatch(full.Offset(), 1), k, v)
|
||||
sliding.Update(newKVBatch(sliding.Offset(), 1), k, v)
|
||||
}
|
||||
|
||||
b := newKVBatch(total-1, 1)
|
||||
slidingHistory := sliding.View(b)
|
||||
cases := []struct {
|
||||
name string
|
||||
h *nn.KVHistory
|
||||
mask nn.AttentionMask
|
||||
}{
|
||||
{name: "full", h: full.View(b), mask: nn.CausalMask()},
|
||||
{name: "sliding", h: slidingHistory, mask: nn.CausalMask().Intersect(nn.SlidingWindowMask(b, slidingHistory.K().Dim(2), window, q.DType()))},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := nn.ScaledDotProductAttention(b, q, scale, nn.WithKVHistory(tc.h), nn.WithMask(tc.mask))
|
||||
want := mlx.FastScaledDotProductAttention(q, tc.h.K(), tc.h.V(), scale, "", nil)
|
||||
|
||||
mlx.Eval(got, want)
|
||||
gs, ws := got.Floats(), want.Floats()
|
||||
for i := range ws {
|
||||
if math.Abs(float64(gs[i]-ws[i])) > 1e-5 {
|
||||
t.Fatalf("index %d: got %v, want %v", i, gs[i], ws[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRotatingKVCachePrefillParity drives an L>1 prefill into a
|
||||
// rotating cache and verifies SDPA output through WithKVHistory
|
||||
// matches a reference computed from the same K/V with the model mask
|
||||
// and window restriction composed manually.
|
||||
func TestRotatingKVCachePrefillParity(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
const H, L, D = 1, 6, 4
|
||||
const window = 4
|
||||
const scale = 1.0
|
||||
|
||||
qVals := make([]float32, 1*H*L*D)
|
||||
kVals := make([]float32, 1*H*L*D)
|
||||
vVals := make([]float32, 1*H*L*D)
|
||||
for i := range qVals {
|
||||
qVals[i] = 0.5 + 0.05*float32(i)
|
||||
kVals[i] = -0.3 + 0.07*float32(i)
|
||||
vVals[i] = 0.3 + 0.03*float32(i)
|
||||
}
|
||||
q := mlx.FromValues(qVals, 1, H, L, D)
|
||||
k := mlx.FromValues(kVals, 1, H, L, D)
|
||||
v := mlx.FromValues(vVals, 1, H, L, D)
|
||||
b := newKVBatch(0, L)
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
mask nn.AttentionMask
|
||||
// rect arguments matching nn.AttentionMask.Relax (qLo, qHi, kLo, kHi)
|
||||
relax [][4]int
|
||||
causal bool
|
||||
}{
|
||||
{"zero", nn.AttentionMask{}, nil, false},
|
||||
{"causal", nn.CausalMask(), nil, true},
|
||||
{"causal+relax", nn.CausalMask().Relax(0, 1, 4, 2, 5), [][4]int{{1, 4, 2, 5}}, true},
|
||||
}
|
||||
|
||||
negInf := float32(math.Inf(-1))
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
c := NewRotatingKVCache(window)
|
||||
history := c.Update(b, k, v)
|
||||
|
||||
got := nn.ScaledDotProductAttention(b, q, scale,
|
||||
nn.WithKVHistory(history),
|
||||
nn.WithMask(tc.mask))
|
||||
|
||||
// Reference mask: causal blocks k > absQ; relax rectangles
|
||||
// release causal-blocked cells; window blocks k < absQ - window + 1.
|
||||
refVals := make([]float32, L*L)
|
||||
for qi := range L {
|
||||
absQ := qi
|
||||
for ki := range L {
|
||||
blocked := false
|
||||
if tc.causal && ki > absQ {
|
||||
blocked = true
|
||||
}
|
||||
for _, r := range tc.relax {
|
||||
qLo, qHi, kLo, kHi := r[0], r[1], r[2], r[3]
|
||||
if absQ >= qLo && absQ < qHi && ki >= kLo && ki < kHi {
|
||||
blocked = false
|
||||
}
|
||||
}
|
||||
if window > 0 && ki < absQ-window+1 {
|
||||
blocked = true
|
||||
}
|
||||
if blocked {
|
||||
refVals[qi*L+ki] = negInf
|
||||
}
|
||||
}
|
||||
}
|
||||
refMask := mlx.FromValues(refVals, 1, 1, L, L)
|
||||
want := mlx.FastScaledDotProductAttention(q, k, v, scale, "array", refMask)
|
||||
|
||||
mlx.Eval(got, want)
|
||||
gs, ws := got.Floats(), want.Floats()
|
||||
for i := range ws {
|
||||
if math.Abs(float64(gs[i]-ws[i])) > 1e-4 {
|
||||
t.Fatalf("index %d: got %v, want %v", i, gs[i], ws[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRotatingKVCacheMLAParity drives a rotating cache with the MLA
|
||||
// shape — K = [kvLatent, kPE] concatenated, V = zero-width — then
|
||||
// uses WithMLAHistory to slice V from K and compares output against
|
||||
// a manual reference. Pins the cache+MLA integration that
|
||||
// glm4_moe_lite uses in production.
|
||||
func TestRotatingKVCacheMLAParity(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
const H, L, D, valueDim = 1, 3, 6, 4
|
||||
const scale = 1.0
|
||||
const window = 8 // window >= L so no window restriction
|
||||
|
||||
kVals := make([]float32, 1*H*L*D)
|
||||
for i := range kVals {
|
||||
kVals[i] = 0.1 * float32(i+1)
|
||||
}
|
||||
k := mlx.FromValues(kVals, 1, H, L, D)
|
||||
v := mlx.Zeros(mlx.DTypeFloat32, 1, H, L, 0)
|
||||
|
||||
q := mlx.Zeros(mlx.DTypeFloat32, 1, H, L, D)
|
||||
b := newKVBatch(0, L)
|
||||
|
||||
c := NewRotatingKVCache(window)
|
||||
history := c.Update(b, k, v)
|
||||
got := nn.ScaledDotProductAttention(b, q, scale,
|
||||
nn.WithMLAHistory(history, valueDim),
|
||||
nn.WithMask(nn.CausalMask()))
|
||||
|
||||
vRef := k.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(), mlx.Slice(0, valueDim))
|
||||
want := mlx.FastScaledDotProductAttention(q, k, vRef, scale, "causal", nil)
|
||||
|
||||
mlx.Eval(got, want)
|
||||
gs, ws := got.Floats(), want.Floats()
|
||||
for i := range ws {
|
||||
if math.Abs(float64(gs[i]-ws[i])) > 1e-5 {
|
||||
t.Fatalf("index %d: got %v, want %v", i, gs[i], ws[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
338
x/mlxrunner/cache/rotating_multiturn_test.go
vendored
Normal file
338
x/mlxrunner/cache/rotating_multiturn_test.go
vendored
Normal file
@@ -0,0 +1,338 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
// singleTokenKV and multiTokenKV fabricate [B=1, H=1, L, D=2] key/value
|
||||
// tensors whose channel value is the token id, so stateIDs can recover
|
||||
// which ids survived in the cache.
|
||||
func singleTokenKV(id float32) (*mlx.Array, *mlx.Array) {
|
||||
k := mlx.FromValues([]float32{id, id}, 1, 1, 1, 2)
|
||||
v := mlx.FromValues([]float32{id, id}, 1, 1, 1, 2)
|
||||
return k, v
|
||||
}
|
||||
|
||||
func multiTokenKV(ids []float32) (*mlx.Array, *mlx.Array) {
|
||||
data := make([]float32, 0, 2*len(ids))
|
||||
for _, id := range ids {
|
||||
data = append(data, id, id)
|
||||
}
|
||||
k := mlx.FromValues(data, 1, 1, len(ids), 2)
|
||||
v := mlx.FromValues(data, 1, 1, len(ids), 2)
|
||||
return k, v
|
||||
}
|
||||
|
||||
// stateIDs returns the ids currently in the cache in slot order (logical
|
||||
// after a concat, physical/rotated after a single-token update).
|
||||
func stateIDs(t *testing.T, c *RotatingKVCache) []float32 {
|
||||
t.Helper()
|
||||
state := c.State()
|
||||
if state == nil {
|
||||
return nil
|
||||
}
|
||||
mlx.Eval(state[0])
|
||||
flat := state[0].Floats()
|
||||
n := state[0].Dim(2)
|
||||
out := make([]float32, n)
|
||||
for i := range n {
|
||||
out[i] = flat[i*2]
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func equalSlice(a, b []float32) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if a[i] != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func feedMulti(c *RotatingKVCache, startID float32, n int) float32 {
|
||||
ids := make([]float32, n)
|
||||
for i := range ids {
|
||||
ids[i] = startID + float32(i)
|
||||
}
|
||||
k, v := multiTokenKV(ids)
|
||||
c.Update(newKVBatch(c.Offset(), k.Dim(2)), k, v)
|
||||
return startID + float32(n)
|
||||
}
|
||||
|
||||
func feedSingle(c *RotatingKVCache, id float32) {
|
||||
k, v := singleTokenKV(id)
|
||||
c.Update(newKVBatch(c.Offset(), k.Dim(2)), k, v)
|
||||
}
|
||||
|
||||
// TestRotatingKVCacheConcatMidRotationPreservesContext: after the buffer
|
||||
// has wrapped, a multi-token concat must keep the (maxSize-1) most recent
|
||||
// pre-existing tokens in logical order so the first Q of the new batch
|
||||
// has a full sliding window.
|
||||
func TestRotatingKVCacheConcatMidRotationPreservesContext(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
const window = 4
|
||||
c := NewRotatingKVCache(window)
|
||||
|
||||
nextID := feedMulti(c, 1, 3)
|
||||
for range 6 {
|
||||
feedSingle(c, nextID)
|
||||
nextID++
|
||||
}
|
||||
if c.Offset() != 9 {
|
||||
t.Fatalf("setup: offset=%d want 9", c.Offset())
|
||||
}
|
||||
if c.idx >= c.maxSize {
|
||||
t.Fatalf("setup: expected mid-rotation idx (<%d), got %d", c.maxSize, c.idx)
|
||||
}
|
||||
|
||||
feedMulti(c, 10, 2)
|
||||
got := stateIDs(t, c)
|
||||
want := []float32{7, 8, 9, 10, 11}
|
||||
if !equalSlice(got, want) {
|
||||
t.Fatalf("post-concat window=%v want %v", got, want)
|
||||
}
|
||||
if c.Offset() != 11 {
|
||||
t.Fatalf("offset=%d want 11", c.Offset())
|
||||
}
|
||||
}
|
||||
|
||||
// TestRotatingKVCacheConcatAlignedInvariant: with an aligned buffer
|
||||
// (c.idx == Dim), an L>1 concat keeps the last (maxSize-1) pre-existing
|
||||
// tokens plus the full new batch. This is the chunked-prefill contract
|
||||
// x/mlxrunner/pipeline.go relies on.
|
||||
func TestRotatingKVCacheConcatAlignedInvariant(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
const window = 4
|
||||
c := NewRotatingKVCache(window)
|
||||
|
||||
// Chunk 1 fills past maxSize, leaving Dim == maxSize aligned.
|
||||
feedMulti(c, 1, 6)
|
||||
// Chunk 2: the buffer is intentionally oversized to (maxSize-1) + L
|
||||
// so the first new Q has its full window in scope for this forward.
|
||||
feedMulti(c, 7, 3)
|
||||
got := stateIDs(t, c)
|
||||
want := []float32{4, 5, 6, 7, 8, 9}
|
||||
if !equalSlice(got, want) {
|
||||
t.Fatalf("post-chunk-2 buffer=%v want %v", got, want)
|
||||
}
|
||||
|
||||
// The next decode trims oversize back to maxSize; order may be
|
||||
// physical (rotated), so check as a set.
|
||||
feedSingle(c, 10)
|
||||
got = stateIDs(t, c)
|
||||
if len(got) != window {
|
||||
t.Fatalf("post-decode Dim=%d want %d", len(got), window)
|
||||
}
|
||||
seen := map[float32]bool{}
|
||||
for _, v := range got {
|
||||
seen[v] = true
|
||||
}
|
||||
for _, w := range []float32{7, 8, 9, 10} {
|
||||
if !seen[w] {
|
||||
t.Fatalf("post-decode window missing %v (got %v)", w, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestRotatingKVCacheConcatAfterDecodeGrowsBuffer: update() grows the
|
||||
// underlying buffer by `step` slots via mlx.Zeros before writing, so
|
||||
// after one decode on a short prefill c.idx < Dim even though the cache
|
||||
// has not wrapped. Those trailing slots are zero padding and must not
|
||||
// be pulled back into the live window on the next concat.
|
||||
func TestRotatingKVCacheConcatAfterDecodeGrowsBuffer(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
const window = 512
|
||||
c := NewRotatingKVCache(window)
|
||||
|
||||
feedMulti(c, 1, 3)
|
||||
feedSingle(c, 4)
|
||||
feedMulti(c, 5, 3)
|
||||
|
||||
got := stateIDs(t, c)
|
||||
want := []float32{1, 2, 3, 4, 5, 6, 7}
|
||||
if !equalSlice(got, want) {
|
||||
t.Fatalf("growing-buffer concat=%v want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRotatingKVCacheConcatAfterLiveRewind: x/mlxrunner/cache.go calls
|
||||
// Restore(nil, target) between conversation turns to rewind the cache to
|
||||
// the matched prefix. Restore moves c.offset/c.idx without trimming the
|
||||
// underlying buffer, so slots [c.idx, Dim) still hold stale pre-rewind
|
||||
// tokens. A subsequent concat must drop those, not treat them as wrapped
|
||||
// window content.
|
||||
func TestRotatingKVCacheConcatAfterLiveRewind(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
const window = 8
|
||||
c := NewRotatingKVCache(window)
|
||||
|
||||
// Grow the buffer to exactly maxSize without wrapping.
|
||||
feedMulti(c, 1, 2)
|
||||
for id := float32(3); id <= 8; id++ {
|
||||
feedSingle(c, id)
|
||||
}
|
||||
if c.Offset() != window {
|
||||
t.Fatalf("setup: offset=%d want %d", c.Offset(), window)
|
||||
}
|
||||
|
||||
if !c.Restore(nil, 2) {
|
||||
t.Fatalf("live rewind to 2 failed")
|
||||
}
|
||||
if c.Offset() != 2 {
|
||||
t.Fatalf("post-rewind offset=%d want 2", c.Offset())
|
||||
}
|
||||
|
||||
feedMulti(c, 9, 3)
|
||||
got := stateIDs(t, c)
|
||||
want := []float32{1, 2, 9, 10, 11}
|
||||
if !equalSlice(got, want) {
|
||||
t.Fatalf("post-rewind concat=%v want %v", got, want)
|
||||
}
|
||||
if c.Offset() != 5 {
|
||||
t.Fatalf("offset=%d want 5", c.Offset())
|
||||
}
|
||||
}
|
||||
|
||||
// TestRotatingKVCacheConcatGrowingBuffer: when oldLen < maxSize the trim
|
||||
// formula drops to non-positive and all pre-existing tokens are kept.
|
||||
func TestRotatingKVCacheConcatGrowingBuffer(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
const window = 4
|
||||
c := NewRotatingKVCache(window)
|
||||
|
||||
feedMulti(c, 1, 2)
|
||||
feedMulti(c, 3, 2)
|
||||
got := stateIDs(t, c)
|
||||
want := []float32{1, 2, 3, 4}
|
||||
if !equalSlice(got, want) {
|
||||
t.Fatalf("growing buffer=%v want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRotatingKVCacheRunnerChunkedPrefill mirrors the
|
||||
// x/mlxrunner/pipeline.go prefill loop: a long prompt fed through
|
||||
// repeated L>1 Update() calls on a single cache. Scaled-down proxy for
|
||||
// the Gemma 4 26B case (sliding_window=1024, prefillChunkSize=2048).
|
||||
func TestRotatingKVCacheRunnerChunkedPrefill(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
const window = 4
|
||||
c := NewRotatingKVCache(window)
|
||||
|
||||
feedMulti(c, 1, 8)
|
||||
if c.Offset() != 8 {
|
||||
t.Fatalf("chunk 1: offset=%d want 8", c.Offset())
|
||||
}
|
||||
|
||||
feedMulti(c, 9, 8)
|
||||
got := stateIDs(t, c)
|
||||
want := []float32{6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
|
||||
if !equalSlice(got, want) {
|
||||
t.Fatalf("chunk 2: buffer=%v want %v", got, want)
|
||||
}
|
||||
|
||||
feedMulti(c, 17, 4)
|
||||
got = stateIDs(t, c)
|
||||
want = []float32{14, 15, 16, 17, 18, 19, 20}
|
||||
if !equalSlice(got, want) {
|
||||
t.Fatalf("chunk 3: buffer=%v want %v", got, want)
|
||||
}
|
||||
|
||||
// Decode trims oversize back to maxSize; order may be physical.
|
||||
feedSingle(c, 21)
|
||||
got = stateIDs(t, c)
|
||||
if len(got) != window {
|
||||
t.Fatalf("post-decode Dim=%d want %d", len(got), window)
|
||||
}
|
||||
seen := map[float32]bool{}
|
||||
for _, v := range got {
|
||||
seen[v] = true
|
||||
}
|
||||
for _, w := range []float32{18, 19, 20, 21} {
|
||||
if !seen[w] {
|
||||
t.Fatalf("post-decode window missing %v (got %v)", w, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestRotatingKVCacheMultiTurnChatSimulation walks a prefill → decode →
|
||||
// prefill sequence and checks that each new prefill retains the last
|
||||
// (maxSize-1) pre-existing tokens in logical order.
|
||||
func TestRotatingKVCacheMultiTurnChatSimulation(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
const window = 4
|
||||
c := NewRotatingKVCache(window)
|
||||
|
||||
nextID := feedMulti(c, 1, 2)
|
||||
for range 5 {
|
||||
feedSingle(c, nextID)
|
||||
nextID++
|
||||
}
|
||||
if c.Offset() != 7 {
|
||||
t.Fatalf("turn 1: offset=%d want 7", c.Offset())
|
||||
}
|
||||
|
||||
feedMulti(c, nextID, 3)
|
||||
nextID += 3
|
||||
got := stateIDs(t, c)
|
||||
want := []float32{5, 6, 7, 8, 9, 10}
|
||||
if !equalSlice(got, want) {
|
||||
t.Fatalf("turn 2 prefill buffer=%v want %v", got, want)
|
||||
}
|
||||
|
||||
for range 4 {
|
||||
feedSingle(c, nextID)
|
||||
nextID++
|
||||
}
|
||||
if c.Offset() != 14 {
|
||||
t.Fatalf("turn 2 decode: offset=%d want 14", c.Offset())
|
||||
}
|
||||
|
||||
feedMulti(c, nextID, 2)
|
||||
got = stateIDs(t, c)
|
||||
want = []float32{12, 13, 14, 15, 16}
|
||||
if !equalSlice(got, want) {
|
||||
t.Fatalf("turn 3 prefill buffer=%v want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRotatingKVCacheOffsetTracking: Offset() is the monotonic logical
|
||||
// token count through any mix of Update() calls — Gemma 4 uses
|
||||
// donorEntry.Offset - L for the consumer's RoPE offset.
|
||||
func TestRotatingKVCacheOffsetTracking(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
c := NewRotatingKVCache(4)
|
||||
nextID := feedMulti(c, 1, 3)
|
||||
if c.Offset() != 3 {
|
||||
t.Fatalf("after prefill 3: offset=%d want 3", c.Offset())
|
||||
}
|
||||
for i := range 5 {
|
||||
feedSingle(c, nextID)
|
||||
nextID++
|
||||
if c.Offset() != 3+i+1 {
|
||||
t.Fatalf("after decode %d: offset=%d want %d", i, c.Offset(), 3+i+1)
|
||||
}
|
||||
}
|
||||
nextID = feedMulti(c, nextID, 2)
|
||||
if c.Offset() != 10 {
|
||||
t.Fatalf("after turn-2 prefill: offset=%d want 10", c.Offset())
|
||||
}
|
||||
// L > maxSize concat.
|
||||
feedMulti(c, nextID, 7)
|
||||
if c.Offset() != 17 {
|
||||
t.Fatalf("after large prefill: offset=%d want 17", c.Offset())
|
||||
}
|
||||
}
|
||||
982
x/mlxrunner/cache_test.go
Normal file
982
x/mlxrunner/cache_test.go
Normal file
@@ -0,0 +1,982 @@
|
||||
package mlxrunner
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
// snapshotTracker records every fakeSnapshot created and every Close() call
|
||||
// so tests can detect leaked (created but never closed) or double-closed snapshots.
|
||||
type snapshotTracker struct {
|
||||
all []*fakeSnapshot
|
||||
}
|
||||
|
||||
func (tr *snapshotTracker) track(s *fakeSnapshot) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.tracker = tr
|
||||
tr.all = append(tr.all, s)
|
||||
}
|
||||
|
||||
// Fake caches that store actual token sequences so tests can verify the right
|
||||
// data was restored, not just the right offset.
|
||||
|
||||
// fakeSnapshot stores a copy of the token sub-sequence it covers.
|
||||
type fakeSnapshot struct {
|
||||
tokens []int32
|
||||
from, to int
|
||||
byteSize int // configurable for eviction tests
|
||||
|
||||
tracker *snapshotTracker
|
||||
closeCount int
|
||||
}
|
||||
|
||||
func (s *fakeSnapshot) Size() int { return s.byteSize }
|
||||
func (s *fakeSnapshot) Close() {
|
||||
s.closeCount++
|
||||
}
|
||||
|
||||
// fakeRewindableCache tracks the full token sequence and supports
|
||||
// arbitrary rewind via Restore(nil, target).
|
||||
type fakeRewindableCache struct {
|
||||
tokens []int32
|
||||
tracker *snapshotTracker
|
||||
}
|
||||
|
||||
func (c *fakeRewindableCache) feed(tokens []int32) {
|
||||
c.tokens = append(c.tokens, tokens...)
|
||||
}
|
||||
|
||||
func (c *fakeRewindableCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
return nil, nil
|
||||
}
|
||||
func (c *fakeRewindableCache) State() []*mlx.Array { return nil }
|
||||
func (c *fakeRewindableCache) Offset() int { return len(c.tokens) }
|
||||
|
||||
func (c *fakeRewindableCache) Free() {
|
||||
c.tokens = nil
|
||||
}
|
||||
|
||||
func (c *fakeRewindableCache) Snapshot(fromOffset int) cache.Snapshot {
|
||||
if fromOffset >= len(c.tokens) {
|
||||
return nil
|
||||
}
|
||||
from := fromOffset
|
||||
if from < 0 {
|
||||
from = 0
|
||||
}
|
||||
s := &fakeSnapshot{
|
||||
tokens: slices.Clone(c.tokens[from:]),
|
||||
from: from,
|
||||
to: len(c.tokens),
|
||||
}
|
||||
c.tracker.track(s)
|
||||
return s
|
||||
}
|
||||
|
||||
func (c *fakeRewindableCache) Restore(snapshot cache.Snapshot, target int) bool {
|
||||
if target < 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
if snapshot == nil {
|
||||
if target > len(c.tokens) {
|
||||
return false
|
||||
}
|
||||
c.tokens = c.tokens[:target]
|
||||
return true
|
||||
}
|
||||
s := snapshot.(*fakeSnapshot)
|
||||
if target > s.to || len(c.tokens) < s.from {
|
||||
return false
|
||||
}
|
||||
c.tokens = append(c.tokens[:s.from], s.tokens...)
|
||||
if target < len(c.tokens) {
|
||||
c.tokens = c.tokens[:target]
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *fakeRewindableCache) Merge(parent, child cache.Snapshot) cache.Snapshot {
|
||||
if parent == nil || child == nil {
|
||||
if parent != nil {
|
||||
parent.Close()
|
||||
}
|
||||
if child != nil {
|
||||
child.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
p := parent.(*fakeSnapshot)
|
||||
ch := child.(*fakeSnapshot)
|
||||
merged := make([]int32, len(p.tokens)+len(ch.tokens))
|
||||
copy(merged, p.tokens)
|
||||
copy(merged[len(p.tokens):], ch.tokens)
|
||||
s := &fakeSnapshot{
|
||||
tokens: merged,
|
||||
from: p.from,
|
||||
to: ch.to,
|
||||
byteSize: p.byteSize + ch.byteSize,
|
||||
}
|
||||
c.tracker.track(s)
|
||||
p.Close()
|
||||
ch.Close()
|
||||
return s
|
||||
}
|
||||
|
||||
func (c *fakeRewindableCache) Split(snapshot cache.Snapshot, at int) (cache.Snapshot, cache.Snapshot) {
|
||||
if snapshot == nil {
|
||||
return nil, nil
|
||||
}
|
||||
s := snapshot.(*fakeSnapshot)
|
||||
relAt := at - s.from
|
||||
if relAt <= 0 {
|
||||
return nil, snapshot
|
||||
}
|
||||
if relAt >= len(s.tokens) {
|
||||
return snapshot, nil
|
||||
}
|
||||
p := &fakeSnapshot{
|
||||
tokens: slices.Clone(s.tokens[:relAt]),
|
||||
from: s.from,
|
||||
to: at,
|
||||
byteSize: s.byteSize,
|
||||
}
|
||||
ch := &fakeSnapshot{
|
||||
tokens: slices.Clone(s.tokens[relAt:]),
|
||||
from: at,
|
||||
to: s.to,
|
||||
byteSize: s.byteSize,
|
||||
}
|
||||
c.tracker.track(p)
|
||||
c.tracker.track(ch)
|
||||
s.Close()
|
||||
return p, ch
|
||||
}
|
||||
|
||||
func TestKVCacheBeginWithFactoryLimitCapsPrefix(t *testing.T) {
|
||||
inputs := []int32{1, 2, 3, 4, 5}
|
||||
tracker := &snapshotTracker{}
|
||||
var kc kvCache
|
||||
|
||||
factory := func() []cache.Cache {
|
||||
return []cache.Cache{&fakeRewindableCache{tracker: tracker}}
|
||||
}
|
||||
|
||||
session := kc.beginWithFactoryLimit(inputs, factory, "test", -1, false)
|
||||
session.caches[0].(*fakeRewindableCache).feed(inputs)
|
||||
session.close()
|
||||
|
||||
session = kc.beginWithFactoryLimit(inputs, factory, "test", 3, false)
|
||||
if got, want := session.caches[0].Offset(), 3; got != want {
|
||||
t.Fatalf("cache offset = %d, want %d", got, want)
|
||||
}
|
||||
if got, want := session.remaining, inputs[3:]; !slices.Equal(got, want) {
|
||||
t.Fatalf("remaining = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestKVCacheBeginWithFactoryLimitDoesNotKeepSeedToken(t *testing.T) {
|
||||
inputs := []int32{1, 2, 3, 4, 5}
|
||||
tracker := &snapshotTracker{}
|
||||
var kc kvCache
|
||||
|
||||
factory := func() []cache.Cache {
|
||||
return []cache.Cache{&fakeRewindableCache{tracker: tracker}}
|
||||
}
|
||||
|
||||
session := kc.beginWithFactoryLimit(inputs, factory, "test", -1, false)
|
||||
session.caches[0].(*fakeRewindableCache).feed(inputs)
|
||||
session.close()
|
||||
|
||||
session = kc.beginWithFactoryLimit(inputs, factory, "test", len(inputs), false)
|
||||
if got, want := session.caches[0].Offset(), len(inputs); got != want {
|
||||
t.Fatalf("cache offset = %d, want %d", got, want)
|
||||
}
|
||||
if len(session.remaining) != 0 {
|
||||
t.Fatalf("remaining = %v, want empty", session.remaining)
|
||||
}
|
||||
}
|
||||
|
||||
// fakeSlidingWindowCache models RotatingKVCache semantics: stores the full
|
||||
// token sequence but only the trailing maxSize tokens are "live" in the window.
|
||||
// Once the window fills, live rewind is impossible without a snapshot.
|
||||
type fakeSlidingWindowCache struct {
|
||||
tokens []int32
|
||||
maxSize int
|
||||
tracker *snapshotTracker
|
||||
}
|
||||
|
||||
func (c *fakeSlidingWindowCache) feed(tokens []int32) {
|
||||
c.tokens = append(c.tokens, tokens...)
|
||||
}
|
||||
|
||||
func (c *fakeSlidingWindowCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
return nil, nil
|
||||
}
|
||||
func (c *fakeSlidingWindowCache) State() []*mlx.Array { return nil }
|
||||
func (c *fakeSlidingWindowCache) Offset() int { return len(c.tokens) }
|
||||
|
||||
func (c *fakeSlidingWindowCache) Free() {
|
||||
c.tokens = nil
|
||||
}
|
||||
|
||||
func (c *fakeSlidingWindowCache) Snapshot(fromOffset int) cache.Snapshot {
|
||||
if len(c.tokens) == 0 || len(c.tokens) <= fromOffset {
|
||||
return nil
|
||||
}
|
||||
// Snapshot captures the full window state (like RotatingKVCache.Snapshot).
|
||||
s := &fakeSnapshot{
|
||||
tokens: slices.Clone(c.tokens),
|
||||
from: 0,
|
||||
to: len(c.tokens),
|
||||
}
|
||||
c.tracker.track(s)
|
||||
return s
|
||||
}
|
||||
|
||||
func (c *fakeSlidingWindowCache) Restore(snapshot cache.Snapshot, target int) bool {
|
||||
if target < 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
if snapshot == nil {
|
||||
if target >= len(c.tokens) {
|
||||
return target == len(c.tokens)
|
||||
}
|
||||
// Live rewind only works when buffer hasn't filled (offset <= maxSize).
|
||||
if len(c.tokens) > c.maxSize {
|
||||
return false
|
||||
}
|
||||
c.tokens = c.tokens[:target]
|
||||
return true
|
||||
}
|
||||
s := snapshot.(*fakeSnapshot)
|
||||
if target > s.to {
|
||||
return false
|
||||
}
|
||||
// Reject if clamping would leave an incomplete window
|
||||
// (matches RotatingKVCache behavior).
|
||||
if target < s.to && s.to > c.maxSize {
|
||||
return false
|
||||
}
|
||||
c.tokens = slices.Clone(s.tokens)
|
||||
if target < len(c.tokens) {
|
||||
c.tokens = c.tokens[:target]
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *fakeSlidingWindowCache) Merge(parent, child cache.Snapshot) cache.Snapshot {
|
||||
// Child supersedes parent for sliding window (full window state).
|
||||
if parent != nil {
|
||||
parent.Close()
|
||||
}
|
||||
return child
|
||||
}
|
||||
|
||||
func (c *fakeSlidingWindowCache) Split(snapshot cache.Snapshot, at int) (cache.Snapshot, cache.Snapshot) {
|
||||
// Can't split a ring buffer at an arbitrary point.
|
||||
return nil, snapshot
|
||||
}
|
||||
|
||||
// fakeRecurrentCache models RecurrentCache semantics: stores tokens
|
||||
// but cannot rewind without a snapshot.
|
||||
type fakeRecurrentCache struct {
|
||||
tokens []int32
|
||||
tracker *snapshotTracker
|
||||
}
|
||||
|
||||
func (c *fakeRecurrentCache) feed(tokens []int32) {
|
||||
c.tokens = append(c.tokens, tokens...)
|
||||
}
|
||||
|
||||
func (c *fakeRecurrentCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
return nil, nil
|
||||
}
|
||||
func (c *fakeRecurrentCache) State() []*mlx.Array { return nil }
|
||||
func (c *fakeRecurrentCache) Offset() int { return len(c.tokens) }
|
||||
|
||||
func (c *fakeRecurrentCache) Free() {
|
||||
c.tokens = nil
|
||||
}
|
||||
|
||||
func (c *fakeRecurrentCache) Snapshot(fromOffset int) cache.Snapshot {
|
||||
// Recurrent state is cumulative; snapshot captures the full state.
|
||||
if len(c.tokens) == 0 {
|
||||
return nil
|
||||
}
|
||||
s := &fakeSnapshot{
|
||||
tokens: slices.Clone(c.tokens),
|
||||
from: 0,
|
||||
to: len(c.tokens),
|
||||
}
|
||||
c.tracker.track(s)
|
||||
return s
|
||||
}
|
||||
|
||||
func (c *fakeRecurrentCache) Restore(snapshot cache.Snapshot, target int) bool {
|
||||
if snapshot == nil {
|
||||
return target == len(c.tokens) // can only no-op
|
||||
}
|
||||
s := snapshot.(*fakeSnapshot)
|
||||
if target != s.to {
|
||||
return false // cumulative state requires exact match
|
||||
}
|
||||
c.tokens = slices.Clone(s.tokens)
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *fakeRecurrentCache) Merge(parent, child cache.Snapshot) cache.Snapshot {
|
||||
// Child supersedes parent for cumulative state.
|
||||
if parent != nil {
|
||||
parent.Close()
|
||||
}
|
||||
return child
|
||||
}
|
||||
|
||||
func (c *fakeRecurrentCache) Split(snapshot cache.Snapshot, at int) (cache.Snapshot, cache.Snapshot) {
|
||||
return nil, snapshot // can't split cumulative state
|
||||
}
|
||||
|
||||
type feedableCache interface {
|
||||
cache.Cache
|
||||
feed(tokens []int32)
|
||||
}
|
||||
|
||||
// testEnv encapsulates a kvCache and its fake caches for a test scenario.
|
||||
type testEnv struct {
|
||||
kvc *kvCache
|
||||
caches []cache.Cache // typed references for assertions
|
||||
tracker *snapshotTracker
|
||||
rewindable bool // true when all caches support arbitrary Restore(nil, target)
|
||||
}
|
||||
|
||||
// newTransformerEnv creates a test environment with a single rewindable cache
|
||||
// (pure transformer model).
|
||||
func newTransformerEnv() *testEnv {
|
||||
tracker := &snapshotTracker{}
|
||||
caches := []cache.Cache{&fakeRewindableCache{tracker: tracker}}
|
||||
return &testEnv{
|
||||
kvc: &kvCache{caches: caches},
|
||||
caches: caches,
|
||||
tracker: tracker,
|
||||
rewindable: true,
|
||||
}
|
||||
}
|
||||
|
||||
// newSlidingWindowEnv creates a test environment with one rewindable cache and
|
||||
// one sliding window cache (Mistral-style architecture). The sliding window
|
||||
// maxSize is set small enough that test sequences fill it, making
|
||||
// Restore(nil, target) fail — the same behavior as production models where
|
||||
// the window fills after a few turns.
|
||||
func newSlidingWindowEnv() *testEnv {
|
||||
tr := &snapshotTracker{}
|
||||
rc := &fakeRewindableCache{tracker: tr}
|
||||
sw := &fakeSlidingWindowCache{maxSize: 4, tracker: tr}
|
||||
caches := []cache.Cache{rc, sw}
|
||||
return &testEnv{
|
||||
kvc: &kvCache{caches: caches},
|
||||
caches: caches,
|
||||
tracker: tr,
|
||||
rewindable: false,
|
||||
}
|
||||
}
|
||||
|
||||
// newRecurrentEnv creates a test environment with one rewindable cache and one
|
||||
// non-rewindable cache (Jamba-style architecture).
|
||||
func newRecurrentEnv() *testEnv {
|
||||
tr := &snapshotTracker{}
|
||||
rc := &fakeRewindableCache{tracker: tr}
|
||||
nrc := &fakeRecurrentCache{tracker: tr}
|
||||
caches := []cache.Cache{rc, nrc}
|
||||
return &testEnv{
|
||||
kvc: &kvCache{caches: caches},
|
||||
caches: caches,
|
||||
tracker: tr,
|
||||
rewindable: false,
|
||||
}
|
||||
}
|
||||
|
||||
// assertAllTokens checks that every cache in the environment contains exactly
|
||||
// the expected token sequence.
|
||||
func (e *testEnv) assertAllTokens(t *testing.T, label string, expected []int32) {
|
||||
t.Helper()
|
||||
for i, c := range e.caches {
|
||||
assertTokens(t, label, c, expected)
|
||||
// Verify all caches report the same offset.
|
||||
if i > 0 && c.Offset() != e.caches[0].Offset() {
|
||||
t.Errorf("%s: cache %d offset=%d != cache 0 offset=%d",
|
||||
label, i, c.Offset(), e.caches[0].Offset())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// simulateRequest mirrors the production pipeline lifecycle:
|
||||
// begin -> prefill with snapshot(false) at branch points -> generate -> close
|
||||
|
||||
type requestResult struct {
|
||||
remaining []int32
|
||||
pendingSnapshots int
|
||||
}
|
||||
|
||||
// simulateRequest runs a request through the harness. If userSnapshotAt > 0,
|
||||
// a user snapshot is requested at that offset during prefill.
|
||||
func simulateRequest(t *testing.T, kvc *kvCache, inputs, generated []int32, userSnapshotAt ...int) requestResult {
|
||||
t.Helper()
|
||||
|
||||
session := kvc.begin(nil, inputs)
|
||||
for _, at := range userSnapshotAt {
|
||||
if at > 0 {
|
||||
session.requestSnapshot(at)
|
||||
}
|
||||
}
|
||||
|
||||
result := requestResult{
|
||||
remaining: slices.Clone(session.remaining),
|
||||
pendingSnapshots: len(session.pendingSnapshots),
|
||||
}
|
||||
|
||||
assertCacheOffsetAlignment(t, kvc, "after begin")
|
||||
|
||||
baseOffset := kvc.minCacheOffset()
|
||||
remaining := inputs[baseOffset:]
|
||||
|
||||
// Prefill: feed tokens, pausing at each pending snapshot.
|
||||
for len(session.pendingSnapshots) > 0 {
|
||||
sp := session.pendingSnapshots[0]
|
||||
count := sp.offset - baseOffset
|
||||
if count > len(remaining) {
|
||||
break
|
||||
}
|
||||
if count > 0 {
|
||||
feedAll(kvc.caches, remaining[:count])
|
||||
remaining = remaining[count:]
|
||||
baseOffset = sp.offset
|
||||
}
|
||||
assertCacheOffsetAlignment(t, kvc, "at snapshot point")
|
||||
session.snapshot()
|
||||
}
|
||||
|
||||
// Feed rest of input tokens.
|
||||
if len(remaining) > 0 {
|
||||
feedAll(kvc.caches, remaining)
|
||||
}
|
||||
|
||||
assertCacheOffsetAlignment(t, kvc, "after prefill")
|
||||
|
||||
// Generate tokens.
|
||||
if len(generated) > 0 {
|
||||
session.outputs = generated
|
||||
feedAll(kvc.caches, generated)
|
||||
}
|
||||
|
||||
assertCacheOffsetAlignment(t, kvc, "before close")
|
||||
session.close()
|
||||
return result
|
||||
}
|
||||
|
||||
func feedAll(caches []cache.Cache, tokens []int32) {
|
||||
for _, c := range caches {
|
||||
if fc, ok := c.(feedableCache); ok {
|
||||
fc.feed(tokens)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// assertCacheOffsetAlignment verifies all caches report the same offset.
|
||||
func assertCacheOffsetAlignment(t *testing.T, kvc *kvCache, label string) {
|
||||
t.Helper()
|
||||
if len(kvc.caches) < 2 {
|
||||
return
|
||||
}
|
||||
expected := kvc.caches[0].Offset()
|
||||
for i := 1; i < len(kvc.caches); i++ {
|
||||
if got := kvc.caches[i].Offset(); got != expected {
|
||||
t.Errorf("%s: cache %d offset=%d != cache 0 offset=%d", label, i, got, expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// assertTokens checks that a feedable cache contains the expected token sequence.
|
||||
// For sliding window caches, only the trailing maxSize tokens are checked.
|
||||
func assertTokens(t *testing.T, label string, c cache.Cache, expected []int32) {
|
||||
t.Helper()
|
||||
switch fc := c.(type) {
|
||||
case *fakeRewindableCache:
|
||||
if !slices.Equal(fc.tokens, expected) {
|
||||
t.Errorf("%s: rewindable tokens = %v, want %v", label, fc.tokens, expected)
|
||||
}
|
||||
case *fakeSlidingWindowCache:
|
||||
// Sliding window stores full history but only trailing maxSize are live.
|
||||
// Verify the full token sequence matches (the window semantics are
|
||||
// enforced by Snapshot/Restore, not by the token log).
|
||||
if !slices.Equal(fc.tokens, expected) {
|
||||
t.Errorf("%s: sliding window tokens = %v, want %v", label, fc.tokens, expected)
|
||||
}
|
||||
case *fakeRecurrentCache:
|
||||
if !slices.Equal(fc.tokens, expected) {
|
||||
t.Errorf("%s: non-rewindable tokens = %v, want %v", label, fc.tokens, expected)
|
||||
}
|
||||
default:
|
||||
t.Fatalf("%s: unknown cache type %T", label, c)
|
||||
}
|
||||
}
|
||||
|
||||
// checkTrieInvariants walks the trie and checks structural invariants.
|
||||
func checkTrieInvariants(t *testing.T, root *trieNode) {
|
||||
t.Helper()
|
||||
walkNodes(root, func(n *trieNode) bool {
|
||||
if n.parent != nil {
|
||||
if n.startOffset() != n.parent.endOffset {
|
||||
t.Errorf("node [%d,%d): startOffset %d != parent endOffset %d",
|
||||
n.startOffset(), n.endOffset, n.startOffset(), n.parent.endOffset)
|
||||
}
|
||||
}
|
||||
if len(n.tokens) != n.endOffset-n.startOffset() {
|
||||
t.Errorf("node [%d,%d): token count %d != offset span %d",
|
||||
n.startOffset(), n.endOffset, len(n.tokens), n.endOffset-n.startOffset())
|
||||
}
|
||||
for _, c := range n.children {
|
||||
if c.parent != n {
|
||||
t.Errorf("child [%d,%d) parent mismatch", c.startOffset(), c.endOffset)
|
||||
}
|
||||
}
|
||||
// No two siblings should start with the same token.
|
||||
seen := make(map[int32]bool)
|
||||
for _, c := range n.children {
|
||||
if len(c.tokens) > 0 {
|
||||
first := c.tokens[0]
|
||||
if seen[first] {
|
||||
t.Errorf("node [%d,%d): duplicate sibling first token %d",
|
||||
n.startOffset(), n.endOffset, first)
|
||||
}
|
||||
seen[first] = true
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// checkSnapshotLeaks verifies that every tracked snapshot is either still live
|
||||
// in the trie (closeCount == 0) or has been closed exactly once. It reports
|
||||
// leaked snapshots (not in trie, never closed) and double-closes.
|
||||
func checkSnapshotLeaks(t *testing.T, tracker *snapshotTracker, root *trieNode) {
|
||||
t.Helper()
|
||||
if tracker == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Collect all live snapshots still referenced by trie nodes.
|
||||
live := make(map[*fakeSnapshot]bool)
|
||||
walkNodes(root, func(n *trieNode) bool {
|
||||
for _, s := range n.snapshots {
|
||||
if s != nil {
|
||||
if fs, ok := s.(*fakeSnapshot); ok {
|
||||
live[fs] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
for i, s := range tracker.all {
|
||||
if live[s] {
|
||||
if s.closeCount != 0 {
|
||||
t.Errorf("snapshot #%d [%d,%d) is still in trie but was closed %d time(s)",
|
||||
i, s.from, s.to, s.closeCount)
|
||||
}
|
||||
} else {
|
||||
if s.closeCount == 0 {
|
||||
t.Errorf("snapshot #%d [%d,%d) leaked: created but never closed and not in trie",
|
||||
i, s.from, s.to)
|
||||
} else if s.closeCount > 1 {
|
||||
t.Errorf("snapshot #%d [%d,%d) double-closed: closed %d times",
|
||||
i, s.from, s.to, s.closeCount)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// forEachEnv runs fn as subtests for three realistic model configurations:
|
||||
// pure transformer, transformer + sliding window (Mistral-style), and
|
||||
// transformer + recurrent (Jamba-style). Leak checking runs automatically
|
||||
// at the end of each subtest.
|
||||
func forEachEnv(t *testing.T, fn func(t *testing.T, env *testEnv)) {
|
||||
t.Helper()
|
||||
run := func(t *testing.T, env *testEnv) {
|
||||
t.Cleanup(func() {
|
||||
checkSnapshotLeaks(t, env.tracker, env.kvc.root)
|
||||
})
|
||||
fn(t, env)
|
||||
}
|
||||
t.Run("Transformer", func(t *testing.T) { run(t, newTransformerEnv()) })
|
||||
t.Run("SlidingWindow", func(t *testing.T) { run(t, newSlidingWindowEnv()) })
|
||||
t.Run("Recurrent", func(t *testing.T) { run(t, newRecurrentEnv()) })
|
||||
}
|
||||
|
||||
// TestBranchCreationAndReuse exercises the core multi-conversation lifecycle:
|
||||
// two conversations share a prefix and diverge, creating a branch point.
|
||||
// A third conversation extends the first. Verifies trie structure, cache
|
||||
// hit lengths, and that semantic caches contain the correct token sequences.
|
||||
func TestBranchCreationAndReuse(t *testing.T) {
|
||||
forEachEnv(t, func(t *testing.T, env *testEnv) {
|
||||
kvc := env.kvc
|
||||
|
||||
// Request A: [1,2,3,4,5,6,7,8] + generate [20,21] — full miss.
|
||||
resA := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 6, 7, 8}, []int32{20, 21})
|
||||
if len(resA.remaining) != 8 {
|
||||
t.Fatalf("A: remaining = %d, want 8 (full miss)", len(resA.remaining))
|
||||
}
|
||||
env.assertAllTokens(t, "after A", []int32{1, 2, 3, 4, 5, 6, 7, 8, 20, 21})
|
||||
|
||||
// Verify trie was populated by close().
|
||||
_, mA := findBestMatch(kvc.root, []int32{1, 2, 3, 4, 5, 6, 7, 8, 20, 21})
|
||||
if mA != 10 {
|
||||
t.Fatalf("A findable: expected 10 matched, got %d", mA)
|
||||
}
|
||||
|
||||
// Request B: [1,2,3,4,5,10,11,12] — shares 5-token prefix with A.
|
||||
// For rewindable caches, switchToPath rewinds to the match point
|
||||
// so only the non-matching suffix needs evaluation. For non-rewindable
|
||||
// caches (RecurrentCache), the rewind fails and freeAll fires.
|
||||
resB := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 12}, []int32{30, 31})
|
||||
if env.rewindable {
|
||||
if resB.pendingSnapshots != 0 {
|
||||
t.Fatalf("B: pendingSnapshots = %d, want 0 (rewind succeeded)", resB.pendingSnapshots)
|
||||
}
|
||||
if len(resB.remaining) != 3 {
|
||||
t.Fatalf("B: remaining = %d, want 3 (rewind to match point)", len(resB.remaining))
|
||||
}
|
||||
} else {
|
||||
if resB.pendingSnapshots != 1 {
|
||||
t.Fatalf("B: pendingSnapshots = %d, want 1", resB.pendingSnapshots)
|
||||
}
|
||||
if len(resB.remaining) != 8 {
|
||||
t.Fatalf("B: remaining = %d, want 8 (freeAll fallback)", len(resB.remaining))
|
||||
}
|
||||
}
|
||||
env.assertAllTokens(t, "after B", []int32{1, 2, 3, 4, 5, 10, 11, 12, 30, 31})
|
||||
|
||||
// Both A and B should be findable in the trie.
|
||||
_, mA2 := findBestMatch(kvc.root, []int32{1, 2, 3, 4, 5, 6, 7, 8, 20, 21})
|
||||
if mA2 < 5 {
|
||||
t.Fatalf("A still findable: expected >= 5 matched, got %d", mA2)
|
||||
}
|
||||
_, mB := findBestMatch(kvc.root, []int32{1, 2, 3, 4, 5, 10, 11, 12, 30, 31})
|
||||
if mB < 5 {
|
||||
t.Fatalf("B findable: expected >= 5 matched, got %d", mB)
|
||||
}
|
||||
|
||||
// Request C: [1,2,3,4,5,6,7,8,40,41] — extends A's prefix.
|
||||
// Should get a cache hit for the shared prefix.
|
||||
resC := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 6, 7, 8, 40, 41}, nil)
|
||||
if len(resC.remaining) >= 10 {
|
||||
t.Fatalf("C: remaining = %d, want < 10 (should get cache hit)", len(resC.remaining))
|
||||
}
|
||||
env.assertAllTokens(t, "after C", []int32{1, 2, 3, 4, 5, 6, 7, 8, 40, 41})
|
||||
|
||||
checkTrieInvariants(t, kvc.root)
|
||||
})
|
||||
}
|
||||
|
||||
// TestExactMatchSeedBehavior verifies the holdback mechanism: when the exact
|
||||
// same prompt is requested twice, the cache does not overclaim cached work.
|
||||
// The last token must be re-evaluated to seed generation.
|
||||
func TestExactMatchSeedBehavior(t *testing.T) {
|
||||
forEachEnv(t, func(t *testing.T, env *testEnv) {
|
||||
kvc := env.kvc
|
||||
|
||||
// Request A: first time.
|
||||
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{10, 11})
|
||||
|
||||
// Request B: identical prompt. Holdback means matched=4, partial in
|
||||
// the 5-token edge. For rewindable caches, switchToPath rewinds to
|
||||
// offset 4, so only the held-back token needs re-evaluation. For
|
||||
// non-rewindable caches, the rewind fails and freeAll fires.
|
||||
resB := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{20, 21})
|
||||
if env.rewindable {
|
||||
if len(resB.remaining) != 1 {
|
||||
t.Fatalf("B: remaining = %d, want 1 (rewind to holdback point)", len(resB.remaining))
|
||||
}
|
||||
if resB.pendingSnapshots != 0 {
|
||||
t.Fatalf("B: pendingSnapshots = %d, want 0 (rewind succeeded)", resB.pendingSnapshots)
|
||||
}
|
||||
} else {
|
||||
if len(resB.remaining) != 5 {
|
||||
t.Fatalf("B: remaining = %d, want 5 (freeAll fallback)", len(resB.remaining))
|
||||
}
|
||||
if resB.pendingSnapshots != 1 {
|
||||
t.Fatalf("B: pendingSnapshots = %d, want 1", resB.pendingSnapshots)
|
||||
}
|
||||
}
|
||||
env.assertAllTokens(t, "after B", []int32{1, 2, 3, 4, 5, 20, 21})
|
||||
|
||||
checkTrieInvariants(t, kvc.root)
|
||||
})
|
||||
}
|
||||
|
||||
// TestConversationResumption tests the most common pattern: user sends a message,
|
||||
// gets a response, then sends a follow-up. The follow-up should reuse the cached
|
||||
// prefix (system prompt + first turn + assistant response).
|
||||
func TestConversationResumption(t *testing.T) {
|
||||
forEachEnv(t, func(t *testing.T, env *testEnv) {
|
||||
kvc := env.kvc
|
||||
|
||||
// Turn 1: system prompt + user message, assistant generates response.
|
||||
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{10, 11, 12})
|
||||
env.assertAllTokens(t, "turn 1", []int32{1, 2, 3, 4, 5, 10, 11, 12})
|
||||
|
||||
// Turn 2: full history + new user message. Should get a cache hit on
|
||||
// the prefix [1,2,3,4,5,10,11,12] and only need to evaluate [20,21].
|
||||
resB := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 12, 20, 21}, []int32{30})
|
||||
if len(resB.remaining) > 5 {
|
||||
t.Fatalf("turn 2: remaining = %d, want <= 5 (should reuse most of history)", len(resB.remaining))
|
||||
}
|
||||
env.assertAllTokens(t, "turn 2", []int32{1, 2, 3, 4, 5, 10, 11, 12, 20, 21, 30})
|
||||
|
||||
// Turn 3: even longer history.
|
||||
resC := simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 12, 20, 21, 30, 40, 41}, nil)
|
||||
if len(resC.remaining) > 5 {
|
||||
t.Fatalf("turn 3: remaining = %d, want <= 5", len(resC.remaining))
|
||||
}
|
||||
env.assertAllTokens(t, "turn 3", []int32{1, 2, 3, 4, 5, 10, 11, 12, 20, 21, 30, 40, 41})
|
||||
|
||||
checkTrieInvariants(t, kvc.root)
|
||||
})
|
||||
}
|
||||
|
||||
// TestEvictionPreservesActiveConversations creates multiple conversations sharing
|
||||
// a system prompt, triggers eviction via large snapshot sizes, and verifies the
|
||||
// active path and shared prefix survive while memory stays bounded.
|
||||
func TestEvictionPreservesActiveConversations(t *testing.T) {
|
||||
forEachEnv(t, func(t *testing.T, env *testEnv) {
|
||||
kvc := env.kvc
|
||||
systemPrompt := []int32{1, 2, 3, 4, 5}
|
||||
|
||||
// Create 5 conversations with unique suffixes.
|
||||
for i := range 5 {
|
||||
suffix := []int32{int32(100 + i*10), int32(101 + i*10), int32(102 + i*10)}
|
||||
inputs := append(slices.Clone(systemPrompt), suffix...)
|
||||
simulateRequest(t, kvc, inputs, []int32{int32(200 + i)})
|
||||
}
|
||||
|
||||
// Inflate snapshot sizes to trigger eviction.
|
||||
walkNodes(kvc.root, func(n *trieNode) bool {
|
||||
if !n.hasSnapshots() {
|
||||
return true
|
||||
}
|
||||
snaps := make([]cache.Snapshot, len(n.snapshots))
|
||||
for i, s := range n.snapshots {
|
||||
if s != nil {
|
||||
snaps[i] = &fakeSnapshot{byteSize: 2 * 1024 * 1024 * 1024} // 2 GiB per snapshot
|
||||
}
|
||||
}
|
||||
n.setSnapshots(snaps, &kvc.pagedOutBytes)
|
||||
return true
|
||||
})
|
||||
|
||||
// Run eviction.
|
||||
kvc.enforceEvictionPolicy()
|
||||
|
||||
// Memory should be within limits.
|
||||
if kvc.pagedOutBytes > maxPagedOutBytes {
|
||||
t.Fatalf("pagedOutBytes = %d, want <= %d", kvc.pagedOutBytes, maxPagedOutBytes)
|
||||
}
|
||||
|
||||
// Active path should be untouched.
|
||||
if len(kvc.activePath) < 2 {
|
||||
t.Fatalf("activePath should have >= 2 nodes, got %d", len(kvc.activePath))
|
||||
}
|
||||
|
||||
// System prompt prefix should still be findable (multi-child
|
||||
// branch points are protected from eviction entirely).
|
||||
_, matched := findBestMatch(kvc.root, systemPrompt)
|
||||
if matched < len(systemPrompt) {
|
||||
t.Fatalf("system prompt match = %d, want %d", matched, len(systemPrompt))
|
||||
}
|
||||
|
||||
checkTrieInvariants(t, kvc.root)
|
||||
})
|
||||
}
|
||||
|
||||
// TestUserSnapshotPreservesRestorePoint verifies that user-created snapshots
|
||||
// (snapshot(true)) resist structural changes that would destroy them:
|
||||
// - A user node forces new tokens into a child instead of extending in-place
|
||||
// - The snapshot remains restorable after other branches are added
|
||||
func TestUserSnapshotPreservesRestorePoint(t *testing.T) {
|
||||
forEachEnv(t, func(t *testing.T, env *testEnv) {
|
||||
kvc := env.kvc
|
||||
|
||||
// Request A: user snapshot at offset 5, then generate.
|
||||
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{10, 11}, 5)
|
||||
|
||||
assertUserNodeExists(t, kvc, "after A")
|
||||
|
||||
// Request B: extends A's prefix. The user node at offset 5 should
|
||||
// force tokens into a child rather than extending in-place.
|
||||
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 20, 21}, nil)
|
||||
env.assertAllTokens(t, "after B", []int32{1, 2, 3, 4, 5, 10, 11, 20, 21})
|
||||
assertUserNodeExists(t, kvc, "after B")
|
||||
|
||||
// Request C: diverge from the user node.
|
||||
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 30, 31}, []int32{40})
|
||||
|
||||
// Request D: switch back to A's branch — user snapshot still restorable.
|
||||
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 20, 21, 50}, nil)
|
||||
env.assertAllTokens(t, "back to A", []int32{1, 2, 3, 4, 5, 10, 11, 20, 21, 50})
|
||||
|
||||
checkTrieInvariants(t, kvc.root)
|
||||
})
|
||||
}
|
||||
|
||||
// TestUserSnapshotResistsAutoMerge verifies that when a sibling leaf is evicted,
|
||||
// a user-marked parent node is not auto-merged with its remaining single child.
|
||||
func TestUserSnapshotResistsAutoMerge(t *testing.T) {
|
||||
forEachEnv(t, func(t *testing.T, env *testEnv) {
|
||||
kvc := env.kvc
|
||||
|
||||
// Request A: user snapshot at offset 3, then continue to offset 5.
|
||||
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{10}, 3)
|
||||
|
||||
// Request B: diverges at the user node, creating a second child.
|
||||
simulateRequest(t, kvc, []int32{1, 2, 3, 6, 7}, []int32{20})
|
||||
|
||||
userNode := findUserNode(t, kvc)
|
||||
if len(userNode.children) != 2 {
|
||||
t.Fatalf("user node children = %d, want 2", len(userNode.children))
|
||||
}
|
||||
|
||||
// Inflate snapshot sizes and evict. The non-active branch should be
|
||||
// evicted, leaving the user node with one child.
|
||||
walkNodes(kvc.root, func(n *trieNode) bool {
|
||||
if !n.hasSnapshots() {
|
||||
return true
|
||||
}
|
||||
snaps := make([]cache.Snapshot, len(n.snapshots))
|
||||
for i, s := range n.snapshots {
|
||||
if s != nil {
|
||||
snaps[i] = &fakeSnapshot{byteSize: 5 * 1024 * 1024 * 1024}
|
||||
}
|
||||
}
|
||||
n.setSnapshots(snaps, &kvc.pagedOutBytes)
|
||||
return true
|
||||
})
|
||||
kvc.enforceEvictionPolicy()
|
||||
|
||||
// The user node should still exist (not auto-merged) even with one child.
|
||||
assertUserNodeExists(t, kvc, "after eviction")
|
||||
|
||||
checkTrieInvariants(t, kvc.root)
|
||||
})
|
||||
}
|
||||
|
||||
func findUserNode(t *testing.T, kvc *kvCache) *trieNode {
|
||||
t.Helper()
|
||||
var found *trieNode
|
||||
walkNodes(kvc.root, func(n *trieNode) bool {
|
||||
if n.user {
|
||||
found = n
|
||||
}
|
||||
return true
|
||||
})
|
||||
if found == nil {
|
||||
t.Fatal("no user-marked node found")
|
||||
}
|
||||
return found
|
||||
}
|
||||
|
||||
func assertUserNodeExists(t *testing.T, kvc *kvCache, label string) {
|
||||
t.Helper()
|
||||
var exists bool
|
||||
walkNodes(kvc.root, func(n *trieNode) bool {
|
||||
if n.user {
|
||||
exists = true
|
||||
}
|
||||
return true
|
||||
})
|
||||
if !exists {
|
||||
t.Fatalf("%s: no user-marked node found", label)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBranchSwitchRestoresCorrectState exercises switching back to an older
|
||||
// branch after working on a different one, verifying that the restored cache
|
||||
// state contains the correct token sequence for both rewindable and
|
||||
// non-rewindable caches.
|
||||
func TestBranchSwitchRestoresCorrectState(t *testing.T) {
|
||||
forEachEnv(t, func(t *testing.T, env *testEnv) {
|
||||
kvc := env.kvc
|
||||
|
||||
// Request A: [1,2,3,4,5] + generate [10,11]
|
||||
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{10, 11})
|
||||
env.assertAllTokens(t, "after A", []int32{1, 2, 3, 4, 5, 10, 11})
|
||||
|
||||
// Request B: [1,2,3,6,7] — diverges at token 4
|
||||
simulateRequest(t, kvc, []int32{1, 2, 3, 6, 7}, []int32{12, 13})
|
||||
env.assertAllTokens(t, "after B", []int32{1, 2, 3, 6, 7, 12, 13})
|
||||
|
||||
// Request C: switch back to A's branch [1,2,3,4,5,10,11,20]
|
||||
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5, 10, 11, 20}, nil)
|
||||
env.assertAllTokens(t, "after C (back to A)", []int32{1, 2, 3, 4, 5, 10, 11, 20})
|
||||
|
||||
checkTrieInvariants(t, kvc.root)
|
||||
})
|
||||
}
|
||||
|
||||
// TestLRUOnlyUpdatesUsedNodes verifies that intermediate nodes on the active
|
||||
// path whose snapshots were not actually restored don't get their lastUsed
|
||||
// refreshed, allowing them to age out and collapse.
|
||||
func TestLRUOnlyUpdatesUsedNodes(t *testing.T) {
|
||||
forEachEnv(t, func(t *testing.T, env *testEnv) {
|
||||
kvc := env.kvc
|
||||
|
||||
// Request A: creates path [1,2,3,4,5] + generate [10,11]
|
||||
simulateRequest(t, kvc, []int32{1, 2, 3, 4, 5}, []int32{10, 11})
|
||||
|
||||
// Request B: diverges at token 4, creating a branch point at offset 3
|
||||
// with a split snapshot.
|
||||
simulateRequest(t, kvc, []int32{1, 2, 3, 6, 7}, []int32{20, 21})
|
||||
|
||||
// Set all lastUsed to a known old time.
|
||||
oldTime := time.Now().Add(-1 * time.Hour)
|
||||
walkNodes(kvc.root, func(n *trieNode) bool {
|
||||
n.lastUsed = oldTime
|
||||
return true
|
||||
})
|
||||
|
||||
// Request C: continue on B's branch. This will match B's path
|
||||
// and extend it. The branch point's snapshot may be paged in
|
||||
// for some cache types but not others.
|
||||
beforeRequest := time.Now()
|
||||
simulateRequest(t, kvc, []int32{1, 2, 3, 6, 7, 20, 21, 30}, nil)
|
||||
|
||||
// The path must have enough depth to exercise intermediate nodes.
|
||||
if len(kvc.activePath) < 3 {
|
||||
t.Fatalf("activePath too short to test intermediate nodes: got %d nodes", len(kvc.activePath))
|
||||
}
|
||||
|
||||
// The frontier (deepest node on the active path) must be updated.
|
||||
frontier := kvc.activePath[len(kvc.activePath)-1]
|
||||
if frontier.lastUsed.Before(beforeRequest) {
|
||||
t.Errorf("frontier lastUsed was not updated: got %v, want >= %v",
|
||||
frontier.lastUsed, beforeRequest)
|
||||
}
|
||||
|
||||
// Every non-frontier node on the active path (including root)
|
||||
// should retain its old lastUsed — only the frontier gets refreshed.
|
||||
for i, node := range kvc.activePath[:len(kvc.activePath)-1] {
|
||||
if !node.lastUsed.Before(beforeRequest) {
|
||||
t.Errorf("activePath[%d] (endOffset=%d) lastUsed was refreshed: got %v, want < %v",
|
||||
i, node.endOffset, node.lastUsed, beforeRequest)
|
||||
}
|
||||
}
|
||||
|
||||
checkTrieInvariants(t, kvc.root)
|
||||
})
|
||||
}
|
||||
296
x/mlxrunner/cache_trie.go
Normal file
296
x/mlxrunner/cache_trie.go
Normal file
@@ -0,0 +1,296 @@
|
||||
package mlxrunner
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
)
|
||||
|
||||
// trieNode represents a node in the compressed prefix trie for KV cache branching.
|
||||
// Each node stores a compressed edge (multiple tokens) and optional paged-out
|
||||
// snapshot data per cache layer.
|
||||
type trieNode struct {
|
||||
tokens []int32 // compressed edge — multiple tokens per node
|
||||
endOffset int // cumulative tokens from root to end of this node
|
||||
parent *trieNode
|
||||
children []*trieNode
|
||||
lastUsed time.Time // for LRU eviction
|
||||
snapshots []cache.Snapshot // per-layer paged-out snapshot data (nil if not paged out)
|
||||
user bool // true = explicit restore point (resist auto-merge)
|
||||
}
|
||||
|
||||
// startOffset returns the cumulative token offset at the start of this node's edge.
|
||||
func (n *trieNode) startOffset() int {
|
||||
return n.endOffset - len(n.tokens)
|
||||
}
|
||||
|
||||
// snapshotBytes returns the total bytes of paged-out snapshots on this node.
|
||||
func (n *trieNode) snapshotBytes() int64 {
|
||||
var total int64
|
||||
for _, s := range n.snapshots {
|
||||
if s != nil {
|
||||
total += int64(s.Size())
|
||||
}
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
// setSnapshots replaces this node's snapshots with snaps and closes the old ones.
|
||||
// If counter is non-nil, the net byte delta is applied to it.
|
||||
func (n *trieNode) setSnapshots(snaps []cache.Snapshot, counter *int64) {
|
||||
old := n.swapSnapshots(snaps, counter)
|
||||
for _, s := range old {
|
||||
if s != nil {
|
||||
s.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// swapSnapshots is like setSnapshots but returns the previous snapshots
|
||||
// without closing them. Use this when the old snapshots will be consumed
|
||||
// (e.g. by Split/Merge).
|
||||
func (n *trieNode) swapSnapshots(snaps []cache.Snapshot, counter *int64) []cache.Snapshot {
|
||||
old := n.snapshots
|
||||
if counter != nil {
|
||||
*counter -= n.snapshotBytes()
|
||||
}
|
||||
n.snapshots = snaps
|
||||
if counter != nil {
|
||||
*counter += n.snapshotBytes()
|
||||
}
|
||||
return old
|
||||
}
|
||||
|
||||
// hasSnapshots returns true if any layer has snapshot data.
|
||||
func (n *trieNode) hasSnapshots() bool {
|
||||
return slices.ContainsFunc(n.snapshots, func(s cache.Snapshot) bool { return s != nil })
|
||||
}
|
||||
|
||||
// hasAllSnapshots returns true if every layer has snapshot data.
|
||||
func (n *trieNode) hasAllSnapshots() bool {
|
||||
return len(n.snapshots) > 0 && !slices.Contains(n.snapshots, nil)
|
||||
}
|
||||
|
||||
// findBestMatch walks the trie matching input tokens, returning the path of
|
||||
// nodes traversed and the total number of tokens matched.
|
||||
func findBestMatch(root *trieNode, tokens []int32) (path []*trieNode, matched int) {
|
||||
if root == nil {
|
||||
return nil, 0
|
||||
}
|
||||
|
||||
path = []*trieNode{root}
|
||||
pos := 0
|
||||
|
||||
node := root
|
||||
for pos < len(tokens) {
|
||||
// When multiple children share the same first token (e.g. after
|
||||
// a split), prefer the child whose full edge matches over one
|
||||
// that only partially matches. This is just being defensive - it
|
||||
// shouldn't actually happen.
|
||||
var best *trieNode
|
||||
bestMatched := 0
|
||||
bestFull := false
|
||||
for _, child := range node.children {
|
||||
edge := child.tokens
|
||||
if len(edge) == 0 {
|
||||
continue
|
||||
}
|
||||
if edge[0] != tokens[pos] {
|
||||
continue
|
||||
}
|
||||
// Count matching tokens in this child's edge.
|
||||
j := 0
|
||||
for j < len(edge) && pos+j < len(tokens) && edge[j] == tokens[pos+j] {
|
||||
j++
|
||||
}
|
||||
full := j == len(edge)
|
||||
// Prefer full edge matches; among same type, prefer longer.
|
||||
if best == nil || (full && !bestFull) || (full == bestFull && j > bestMatched) {
|
||||
best = child
|
||||
bestMatched = j
|
||||
bestFull = full
|
||||
}
|
||||
}
|
||||
if best == nil {
|
||||
break
|
||||
}
|
||||
|
||||
pos += bestMatched
|
||||
path = append(path, best)
|
||||
|
||||
if !bestFull {
|
||||
// Partial match within this edge
|
||||
break
|
||||
}
|
||||
node = best
|
||||
}
|
||||
|
||||
return path, pos
|
||||
}
|
||||
|
||||
// appendTokens either creates a new child node or extends the leaf in place,
|
||||
// returning the node that now holds the tokens.
|
||||
func (n *trieNode) appendTokens(root *trieNode, tokens []int32, endOffset int) *trieNode {
|
||||
if n == root || len(n.children) > 0 || n.hasSnapshots() {
|
||||
child := &trieNode{
|
||||
tokens: make([]int32, len(tokens)),
|
||||
endOffset: endOffset,
|
||||
parent: n,
|
||||
lastUsed: n.lastUsed,
|
||||
}
|
||||
copy(child.tokens, tokens)
|
||||
n.children = append(n.children, child)
|
||||
return child
|
||||
}
|
||||
n.tokens = append(n.tokens, tokens...)
|
||||
n.endOffset = endOffset
|
||||
return n
|
||||
}
|
||||
|
||||
// removeNode removes a leaf node from the trie.
|
||||
func removeNode(node *trieNode, counter *int64) {
|
||||
if node.parent == nil {
|
||||
panic("removeNode called on root")
|
||||
}
|
||||
if len(node.children) != 0 {
|
||||
panic("removeNode called on non-leaf node")
|
||||
}
|
||||
p := node.parent
|
||||
for i, child := range p.children {
|
||||
if child == node {
|
||||
p.children = append(p.children[:i], p.children[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
node.parent = nil
|
||||
node.setSnapshots(nil, counter)
|
||||
}
|
||||
|
||||
// splitNode splits a node at the given token offset within its edge,
|
||||
// creating a new parent node. Returns the new parent.
|
||||
// `at` is relative to the node's edge (0-based index into node.tokens).
|
||||
// If caches are provided, snapshots are split between parent and child
|
||||
// using Cache.Split; otherwise snapshots are invalidated.
|
||||
func splitNode(node *trieNode, at int, caches []cache.Cache, counter *int64) *trieNode {
|
||||
if at <= 0 || at >= len(node.tokens) {
|
||||
panic(fmt.Sprintf("splitNode: invalid split offset %d for node with %d tokens", at, len(node.tokens)))
|
||||
}
|
||||
|
||||
// Create new parent with the prefix of the edge.
|
||||
newParent := &trieNode{
|
||||
tokens: make([]int32, at),
|
||||
endOffset: node.startOffset() + at,
|
||||
parent: node.parent,
|
||||
children: []*trieNode{node},
|
||||
lastUsed: node.lastUsed,
|
||||
}
|
||||
copy(newParent.tokens, node.tokens[:at])
|
||||
|
||||
// Update the original node to have only the suffix.
|
||||
node.tokens = node.tokens[at:]
|
||||
// endOffset stays the same for the original node.
|
||||
|
||||
// Split snapshots between parent and child using Cache.Split.
|
||||
// Split consumes the old snapshots, so we remove them first (adjusting
|
||||
// the counter), then assign the split halves (adjusting it back).
|
||||
if node.hasSnapshots() {
|
||||
oldSnaps := node.swapSnapshots(nil, counter)
|
||||
parentSnaps := make([]cache.Snapshot, len(oldSnaps))
|
||||
childSnaps := make([]cache.Snapshot, len(oldSnaps))
|
||||
for i, snap := range oldSnaps {
|
||||
if snap != nil {
|
||||
parentSnaps[i], childSnaps[i] = caches[i].Split(snap, newParent.endOffset)
|
||||
}
|
||||
}
|
||||
newParent.setSnapshots(parentSnaps, counter)
|
||||
node.setSnapshots(childSnaps, counter)
|
||||
}
|
||||
|
||||
// Reparent: replace node with newParent in the old parent's children.
|
||||
if node.parent != nil {
|
||||
for i, child := range node.parent.children {
|
||||
if child == node {
|
||||
node.parent.children[i] = newParent
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
node.parent = newParent
|
||||
|
||||
return newParent
|
||||
}
|
||||
|
||||
// mergeWithChild merges a node with its single child: concatenates tokens,
|
||||
// merges snapshot data via Cache.Merge, and removes the child.
|
||||
func mergeWithChild(node *trieNode, caches []cache.Cache, counter *int64) {
|
||||
if len(node.children) != 1 {
|
||||
panic(fmt.Sprintf("mergeWithChild called on node with %d children", len(node.children)))
|
||||
}
|
||||
|
||||
child := node.children[0]
|
||||
|
||||
// Concatenate tokens.
|
||||
node.tokens = append(node.tokens, child.tokens...)
|
||||
node.endOffset = child.endOffset
|
||||
|
||||
// Merge snapshots per layer. Merge consumes the old snapshots, so we
|
||||
// remove them first (adjusting the counter), then assign the merged
|
||||
// result (adjusting it back).
|
||||
if len(node.snapshots) > 0 || len(child.snapshots) > 0 {
|
||||
nodeSnaps := node.swapSnapshots(nil, counter)
|
||||
childSnaps := child.swapSnapshots(nil, counter)
|
||||
merged := make([]cache.Snapshot, len(caches))
|
||||
for i := range caches {
|
||||
var ps, cs cache.Snapshot
|
||||
if nodeSnaps != nil {
|
||||
ps = nodeSnaps[i]
|
||||
}
|
||||
if childSnaps != nil {
|
||||
cs = childSnaps[i]
|
||||
}
|
||||
|
||||
merged[i] = caches[i].Merge(ps, cs)
|
||||
}
|
||||
node.setSnapshots(merged, counter)
|
||||
}
|
||||
|
||||
// Adopt grandchildren.
|
||||
node.children = child.children
|
||||
for _, gc := range node.children {
|
||||
gc.parent = node
|
||||
}
|
||||
|
||||
// Inherit user flag from child if child was a user-created snapshot node.
|
||||
node.user = child.user
|
||||
|
||||
// Update lastUsed to the more recent of the two.
|
||||
if child.lastUsed.After(node.lastUsed) {
|
||||
node.lastUsed = child.lastUsed
|
||||
}
|
||||
|
||||
child.parent = nil
|
||||
child.children = nil
|
||||
}
|
||||
|
||||
// walkNodes calls fn for every node in the trie (depth-first).
|
||||
// If fn returns false, the walk stops.
|
||||
func walkNodes(root *trieNode, fn func(*trieNode) bool) {
|
||||
if root == nil {
|
||||
return
|
||||
}
|
||||
var walk func(*trieNode) bool
|
||||
walk = func(n *trieNode) bool {
|
||||
if !fn(n) {
|
||||
return false
|
||||
}
|
||||
for _, child := range n.children {
|
||||
if !walk(child) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
walk(root)
|
||||
}
|
||||
455
x/mlxrunner/cache_trie_test.go
Normal file
455
x/mlxrunner/cache_trie_test.go
Normal file
@@ -0,0 +1,455 @@
|
||||
package mlxrunner
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
)
|
||||
|
||||
func newTestTrie(tokens []int32) *trieNode {
|
||||
root := &trieNode{lastUsed: time.Now()}
|
||||
if len(tokens) > 0 {
|
||||
child := &trieNode{
|
||||
tokens: slices.Clone(tokens),
|
||||
endOffset: len(tokens),
|
||||
parent: root,
|
||||
lastUsed: time.Now(),
|
||||
}
|
||||
root.children = []*trieNode{child}
|
||||
}
|
||||
return root
|
||||
}
|
||||
|
||||
func TestFindBestMatchMultipleBranches(t *testing.T) {
|
||||
root := &trieNode{lastUsed: time.Now()}
|
||||
|
||||
branch1 := &trieNode{
|
||||
tokens: []int32{1, 2, 3},
|
||||
endOffset: 3,
|
||||
parent: root,
|
||||
lastUsed: time.Now(),
|
||||
}
|
||||
branch2 := &trieNode{
|
||||
tokens: []int32{4, 5, 6},
|
||||
endOffset: 3,
|
||||
parent: root,
|
||||
lastUsed: time.Now(),
|
||||
}
|
||||
root.children = []*trieNode{branch1, branch2}
|
||||
|
||||
// Match branch 1.
|
||||
path, matched := findBestMatch(root, []int32{1, 2, 3, 7})
|
||||
if matched != 3 {
|
||||
t.Fatalf("expected 3 matched, got %d", matched)
|
||||
}
|
||||
if len(path) != 2 || path[1] != branch1 {
|
||||
t.Fatal("expected to match branch1")
|
||||
}
|
||||
|
||||
// Match branch 2.
|
||||
path, matched = findBestMatch(root, []int32{4, 5, 6, 8})
|
||||
if matched != 3 {
|
||||
t.Fatalf("expected 3 matched, got %d", matched)
|
||||
}
|
||||
if len(path) != 2 || path[1] != branch2 {
|
||||
t.Fatal("expected to match branch2")
|
||||
}
|
||||
|
||||
// Match neither.
|
||||
_, matched = findBestMatch(root, []int32{7, 8, 9})
|
||||
if matched != 0 {
|
||||
t.Fatalf("expected 0 matched, got %d", matched)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindBestMatchPrefersFullEdge(t *testing.T) {
|
||||
root := &trieNode{lastUsed: time.Now()}
|
||||
|
||||
shared := &trieNode{
|
||||
tokens: []int32{1, 2, 3},
|
||||
endOffset: 3,
|
||||
parent: root,
|
||||
lastUsed: time.Now(),
|
||||
}
|
||||
root.children = []*trieNode{shared}
|
||||
|
||||
longer := &trieNode{
|
||||
tokens: []int32{10, 11, 12, 13, 14},
|
||||
endOffset: 8,
|
||||
parent: shared,
|
||||
lastUsed: time.Now(),
|
||||
}
|
||||
shorter := &trieNode{
|
||||
tokens: []int32{10, 11, 12},
|
||||
endOffset: 6,
|
||||
parent: shared,
|
||||
lastUsed: time.Now(),
|
||||
}
|
||||
// Put longer first so naive first-match would pick it.
|
||||
shared.children = []*trieNode{longer, shorter}
|
||||
|
||||
input := []int32{1, 2, 3, 10, 11, 12, 99, 100}
|
||||
path, matched := findBestMatch(root, input)
|
||||
|
||||
if matched != 6 {
|
||||
t.Fatalf("expected 6 matched, got %d", matched)
|
||||
}
|
||||
if len(path) != 3 {
|
||||
t.Fatalf("expected 3 nodes in path, got %d", len(path))
|
||||
}
|
||||
if path[2] != shorter {
|
||||
t.Fatal("expected findBestMatch to pick shorter (full edge match), not longer (partial)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindBestMatchPrefersLongerPartial(t *testing.T) {
|
||||
root := &trieNode{lastUsed: time.Now()}
|
||||
|
||||
child1 := &trieNode{
|
||||
tokens: []int32{1, 2, 3, 4, 5},
|
||||
endOffset: 5,
|
||||
parent: root,
|
||||
lastUsed: time.Now(),
|
||||
}
|
||||
child2 := &trieNode{
|
||||
tokens: []int32{1, 2, 9},
|
||||
endOffset: 3,
|
||||
parent: root,
|
||||
lastUsed: time.Now(),
|
||||
}
|
||||
root.children = []*trieNode{child2, child1}
|
||||
|
||||
input := []int32{1, 2, 3, 7, 8}
|
||||
path, matched := findBestMatch(root, input)
|
||||
|
||||
if matched != 3 {
|
||||
t.Fatalf("expected 3 matched, got %d", matched)
|
||||
}
|
||||
if path[1] != child1 {
|
||||
t.Fatal("expected findBestMatch to pick child1 (longer partial match)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitNodeWithSnapshots(t *testing.T) {
|
||||
root := newTestTrie([]int32{1, 2, 3, 4, 5})
|
||||
child := root.children[0]
|
||||
|
||||
rc := &fakeRewindableCache{tracker: &snapshotTracker{}, tokens: []int32{1, 2, 3, 4, 5}}
|
||||
child.snapshots = []cache.Snapshot{rc.Snapshot(0)}
|
||||
child.user = true
|
||||
|
||||
caches := []cache.Cache{rc}
|
||||
|
||||
newParent := splitNode(child, 3, caches, nil)
|
||||
|
||||
if !newParent.hasSnapshots() {
|
||||
t.Fatal("newParent should have snapshots after split")
|
||||
}
|
||||
if newParent.user {
|
||||
t.Fatal("newParent should not be a user snapshot after splitNode")
|
||||
}
|
||||
if !child.hasSnapshots() {
|
||||
t.Fatal("child should have snapshots after split")
|
||||
}
|
||||
if !child.user {
|
||||
t.Fatal("child should remain a user snapshot")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindSplitAppendSequence(t *testing.T) {
|
||||
root := newTestTrie([]int32{1, 2, 3, 4, 5})
|
||||
|
||||
path, matched := findBestMatch(root, []int32{1, 2, 3, 6, 7})
|
||||
if matched != 3 {
|
||||
t.Fatalf("expected 3 matched, got %d", matched)
|
||||
}
|
||||
|
||||
lastNode := path[len(path)-1]
|
||||
matchedInEdge := matched - lastNode.startOffset()
|
||||
split := splitNode(lastNode, matchedInEdge, nil, nil)
|
||||
|
||||
split.appendTokens(root, []int32{6, 7}, 5)
|
||||
|
||||
if len(root.children) != 1 {
|
||||
t.Fatalf("root should have 1 child, got %d", len(root.children))
|
||||
}
|
||||
shared := root.children[0]
|
||||
if !slices.Equal(shared.tokens, []int32{1, 2, 3}) {
|
||||
t.Fatalf("shared tokens = %v, want [1,2,3]", shared.tokens)
|
||||
}
|
||||
if len(shared.children) != 2 {
|
||||
t.Fatalf("shared should have 2 children, got %d", len(shared.children))
|
||||
}
|
||||
|
||||
_, m1 := findBestMatch(root, []int32{1, 2, 3, 4, 5})
|
||||
if m1 != 5 {
|
||||
t.Fatalf("original branch: expected 5 matched, got %d", m1)
|
||||
}
|
||||
_, m2 := findBestMatch(root, []int32{1, 2, 3, 6, 7})
|
||||
if m2 != 5 {
|
||||
t.Fatalf("new branch: expected 5 matched, got %d", m2)
|
||||
}
|
||||
_, m3 := findBestMatch(root, []int32{1, 2, 3, 9, 9})
|
||||
if m3 != 3 {
|
||||
t.Fatalf("unrelated input: expected 3 matched, got %d", m3)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepeatedBranching(t *testing.T) {
|
||||
root := &trieNode{lastUsed: time.Now()}
|
||||
|
||||
root.appendTokens(root, []int32{1, 2, 3, 4, 5}, 5)
|
||||
|
||||
_, matchedB := findBestMatch(root, []int32{1, 2, 3, 6, 7})
|
||||
if matchedB != 3 {
|
||||
t.Fatalf("B: expected 3 matched, got %d", matchedB)
|
||||
}
|
||||
nodeA := root.children[0]
|
||||
split1 := splitNode(nodeA, 3, nil, nil)
|
||||
split1.appendTokens(root, []int32{6, 7}, 5)
|
||||
|
||||
_, matchedC := findBestMatch(root, []int32{1, 2, 8, 9})
|
||||
if matchedC != 2 {
|
||||
t.Fatalf("C: expected 2 matched, got %d", matchedC)
|
||||
}
|
||||
split2 := splitNode(split1, 2, nil, nil)
|
||||
split2.appendTokens(root, []int32{8, 9}, 4)
|
||||
|
||||
_, mA := findBestMatch(root, []int32{1, 2, 3, 4, 5})
|
||||
if mA != 5 {
|
||||
t.Fatalf("A: expected 5 matched, got %d", mA)
|
||||
}
|
||||
_, mB := findBestMatch(root, []int32{1, 2, 3, 6, 7})
|
||||
if mB != 5 {
|
||||
t.Fatalf("B: expected 5 matched, got %d", mB)
|
||||
}
|
||||
_, mC := findBestMatch(root, []int32{1, 2, 8, 9})
|
||||
if mC != 4 {
|
||||
t.Fatalf("C: expected 4 matched, got %d", mC)
|
||||
}
|
||||
|
||||
checkTrieInvariants(t, root)
|
||||
}
|
||||
|
||||
func TestMergeWithChild(t *testing.T) {
|
||||
t.Run("Basic", func(t *testing.T) {
|
||||
// root -> A[1,2,3] -> B[4,5] -> {C[6], D[7]}
|
||||
now := time.Now()
|
||||
root := &trieNode{lastUsed: now}
|
||||
a := &trieNode{
|
||||
tokens: []int32{1, 2, 3},
|
||||
endOffset: 3,
|
||||
parent: root,
|
||||
lastUsed: now,
|
||||
snapshots: []cache.Snapshot{&fakeSnapshot{tokens: []int32{1, 2, 3}, from: 0, to: 3}},
|
||||
}
|
||||
b := &trieNode{
|
||||
tokens: []int32{4, 5},
|
||||
endOffset: 5,
|
||||
parent: a,
|
||||
lastUsed: now,
|
||||
snapshots: []cache.Snapshot{&fakeSnapshot{tokens: []int32{4, 5}, from: 3, to: 5}},
|
||||
}
|
||||
c := &trieNode{tokens: []int32{6}, endOffset: 6, parent: b, lastUsed: now}
|
||||
d := &trieNode{tokens: []int32{7}, endOffset: 6, parent: b, lastUsed: now}
|
||||
root.children = []*trieNode{a}
|
||||
a.children = []*trieNode{b}
|
||||
b.children = []*trieNode{c, d}
|
||||
|
||||
mc := &fakeRewindableCache{tracker: &snapshotTracker{}, tokens: []int32{1, 2, 3, 4, 5}}
|
||||
mergeWithChild(a, []cache.Cache{mc}, nil)
|
||||
|
||||
// Tokens concatenated.
|
||||
if !slices.Equal(a.tokens, []int32{1, 2, 3, 4, 5}) {
|
||||
t.Fatalf("merged tokens = %v, want [1,2,3,4,5]", a.tokens)
|
||||
}
|
||||
if a.endOffset != 5 {
|
||||
t.Fatalf("merged endOffset = %d, want 5", a.endOffset)
|
||||
}
|
||||
// Grandchildren reparented.
|
||||
if len(a.children) != 2 {
|
||||
t.Fatalf("merged children count = %d, want 2", len(a.children))
|
||||
}
|
||||
if c.parent != a || d.parent != a {
|
||||
t.Fatal("grandchildren should be reparented to merged node")
|
||||
}
|
||||
// B detached.
|
||||
if b.parent != nil || b.children != nil || b.snapshots != nil {
|
||||
t.Fatal("child B should be fully detached after merge")
|
||||
}
|
||||
// Merged snapshot should cover [0,5).
|
||||
if !a.hasSnapshots() {
|
||||
t.Fatal("merged node should have snapshots")
|
||||
}
|
||||
ms := a.snapshots[0].(*fakeSnapshot)
|
||||
if ms.from != 0 || ms.to != 5 {
|
||||
t.Fatalf("merged snapshot = [%d,%d), want [0,5)", ms.from, ms.to)
|
||||
}
|
||||
|
||||
checkTrieInvariants(t, root)
|
||||
})
|
||||
|
||||
t.Run("UserFlag", func(t *testing.T) {
|
||||
root := &trieNode{lastUsed: time.Now()}
|
||||
parent := &trieNode{
|
||||
tokens: []int32{1, 2}, endOffset: 2, parent: root,
|
||||
lastUsed: time.Now(), user: false,
|
||||
}
|
||||
child := &trieNode{
|
||||
tokens: []int32{3, 4}, endOffset: 4, parent: parent,
|
||||
lastUsed: time.Now(), user: true,
|
||||
}
|
||||
root.children = []*trieNode{parent}
|
||||
parent.children = []*trieNode{child}
|
||||
|
||||
mergeWithChild(parent, nil, nil)
|
||||
|
||||
if !parent.user {
|
||||
t.Fatal("merged node should inherit user=true from child")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("LastUsed", func(t *testing.T) {
|
||||
now := time.Now()
|
||||
root := &trieNode{lastUsed: now}
|
||||
parent := &trieNode{
|
||||
tokens: []int32{1}, endOffset: 1, parent: root,
|
||||
lastUsed: now.Add(-1 * time.Hour),
|
||||
}
|
||||
child := &trieNode{
|
||||
tokens: []int32{2}, endOffset: 2, parent: parent,
|
||||
lastUsed: now.Add(1 * time.Hour),
|
||||
}
|
||||
root.children = []*trieNode{parent}
|
||||
parent.children = []*trieNode{child}
|
||||
|
||||
mergeWithChild(parent, nil, nil)
|
||||
|
||||
if !parent.lastUsed.Equal(now.Add(1 * time.Hour)) {
|
||||
t.Fatal("merged node should pick the more recent lastUsed")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("PanicOnMultipleChildren", func(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Fatal("expected panic on node with 2 children")
|
||||
}
|
||||
}()
|
||||
root := &trieNode{lastUsed: time.Now()}
|
||||
node := &trieNode{
|
||||
tokens: []int32{1}, endOffset: 1, parent: root, lastUsed: time.Now(),
|
||||
children: []*trieNode{
|
||||
{tokens: []int32{2}, endOffset: 2, lastUsed: time.Now()},
|
||||
{tokens: []int32{3}, endOffset: 2, lastUsed: time.Now()},
|
||||
},
|
||||
}
|
||||
root.children = []*trieNode{node}
|
||||
mergeWithChild(node, nil, nil)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSplitMergeRoundTrip(t *testing.T) {
|
||||
root := &trieNode{lastUsed: time.Now()}
|
||||
leaf := &trieNode{
|
||||
tokens: []int32{1, 2, 3, 4, 5},
|
||||
endOffset: 5,
|
||||
parent: root,
|
||||
lastUsed: time.Now(),
|
||||
snapshots: []cache.Snapshot{&fakeSnapshot{tokens: []int32{1, 2, 3, 4, 5}, from: 0, to: 5}},
|
||||
}
|
||||
root.children = []*trieNode{leaf}
|
||||
|
||||
mc := &fakeRewindableCache{tracker: &snapshotTracker{}, tokens: []int32{1, 2, 3, 4, 5}}
|
||||
caches := []cache.Cache{mc}
|
||||
|
||||
// Split at 3: [1,2,3] -> [4,5]
|
||||
newParent := splitNode(leaf, 3, caches, nil)
|
||||
if !slices.Equal(newParent.tokens, []int32{1, 2, 3}) {
|
||||
t.Fatalf("after split: parent tokens = %v, want [1,2,3]", newParent.tokens)
|
||||
}
|
||||
if !slices.Equal(leaf.tokens, []int32{4, 5}) {
|
||||
t.Fatalf("after split: child tokens = %v, want [4,5]", leaf.tokens)
|
||||
}
|
||||
checkTrieInvariants(t, root)
|
||||
|
||||
// Merge back: should restore [1,2,3,4,5]
|
||||
mergeWithChild(newParent, caches, nil)
|
||||
if !slices.Equal(newParent.tokens, []int32{1, 2, 3, 4, 5}) {
|
||||
t.Fatalf("after merge: tokens = %v, want [1,2,3,4,5]", newParent.tokens)
|
||||
}
|
||||
if newParent.endOffset != 5 {
|
||||
t.Fatalf("after merge: endOffset = %d, want 5", newParent.endOffset)
|
||||
}
|
||||
if len(newParent.children) != 0 {
|
||||
t.Fatalf("after merge: children count = %d, want 0", len(newParent.children))
|
||||
}
|
||||
// Merged snapshot should cover [0,5).
|
||||
if !newParent.hasSnapshots() {
|
||||
t.Fatal("after merge: should have snapshots")
|
||||
}
|
||||
ms := newParent.snapshots[0].(*fakeSnapshot)
|
||||
if ms.from != 0 || ms.to != 5 {
|
||||
t.Fatalf("after merge: snapshot = [%d,%d), want [0,5)", ms.from, ms.to)
|
||||
}
|
||||
|
||||
checkTrieInvariants(t, root)
|
||||
}
|
||||
|
||||
func TestRemoveNode(t *testing.T) {
|
||||
t.Run("Leaf", func(t *testing.T) {
|
||||
root := &trieNode{lastUsed: time.Now()}
|
||||
shared := &trieNode{
|
||||
tokens: []int32{1, 2, 3}, endOffset: 3, parent: root, lastUsed: time.Now(),
|
||||
}
|
||||
leafA := &trieNode{
|
||||
tokens: []int32{4, 5}, endOffset: 5, parent: shared, lastUsed: time.Now(),
|
||||
snapshots: []cache.Snapshot{&fakeSnapshot{from: 3, to: 5}},
|
||||
}
|
||||
leafB := &trieNode{
|
||||
tokens: []int32{6, 7}, endOffset: 5, parent: shared, lastUsed: time.Now(),
|
||||
snapshots: []cache.Snapshot{&fakeSnapshot{from: 3, to: 5}},
|
||||
}
|
||||
root.children = []*trieNode{shared}
|
||||
shared.children = []*trieNode{leafA, leafB}
|
||||
|
||||
removeNode(leafA, nil)
|
||||
|
||||
if len(shared.children) != 1 {
|
||||
t.Fatalf("parent should have 1 child, got %d", len(shared.children))
|
||||
}
|
||||
if shared.children[0] != leafB {
|
||||
t.Fatal("remaining child should be leafB")
|
||||
}
|
||||
if leafA.parent != nil {
|
||||
t.Fatal("removed node parent should be nil")
|
||||
}
|
||||
if leafA.snapshots != nil {
|
||||
t.Fatal("removed node snapshots should be nil")
|
||||
}
|
||||
|
||||
checkTrieInvariants(t, root)
|
||||
})
|
||||
|
||||
t.Run("PanicOnRoot", func(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Fatal("expected panic when removing root")
|
||||
}
|
||||
}()
|
||||
removeNode(&trieNode{}, nil)
|
||||
})
|
||||
|
||||
t.Run("PanicOnNonLeaf", func(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Fatal("expected panic when removing non-leaf")
|
||||
}
|
||||
}()
|
||||
parent := &trieNode{parent: &trieNode{}}
|
||||
parent.children = []*trieNode{{}}
|
||||
removeNode(parent, nil)
|
||||
})
|
||||
}
|
||||
476
x/mlxrunner/client.go
Normal file
476
x/mlxrunner/client.go
Normal file
@@ -0,0 +1,476 @@
|
||||
package mlxrunner
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/x/imagegen"
|
||||
"github.com/ollama/ollama/x/imagegen/manifest"
|
||||
)
|
||||
|
||||
// Client wraps an MLX runner subprocess to implement llm.LlamaServer for LLM models.
|
||||
type Client struct {
|
||||
port int
|
||||
modelName string
|
||||
contextLength atomic.Int64
|
||||
memory atomic.Uint64
|
||||
done chan struct{}
|
||||
doneErr error // valid after done is closed
|
||||
client *http.Client
|
||||
status *llm.StatusWriter
|
||||
mu sync.Mutex
|
||||
cmd *exec.Cmd
|
||||
}
|
||||
|
||||
// NewClient prepares a new MLX runner client for LLM models.
|
||||
// The subprocess is not started until Load() is called.
|
||||
func NewClient(modelName string) (*Client, error) {
|
||||
if err := imagegen.CheckPlatformSupport(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c := &Client{
|
||||
modelName: modelName,
|
||||
done: make(chan struct{}),
|
||||
client: http.DefaultClient,
|
||||
}
|
||||
|
||||
modelManifest, err := manifest.LoadManifest(modelName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.memory.Store(uint64(modelManifest.TotalTensorSize()))
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// WaitUntilRunning waits for the subprocess to be ready.
|
||||
func (c *Client) WaitUntilRunning(ctx context.Context) error {
|
||||
timeout := time.After(2 * time.Minute)
|
||||
ticker := time.NewTicker(100 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-c.done:
|
||||
if msg := c.status.LastError(); msg != "" {
|
||||
return fmt.Errorf("mlx runner failed: %s (exit: %v)", msg, c.doneErr)
|
||||
}
|
||||
return fmt.Errorf("mlx runner exited unexpectedly: %w", c.doneErr)
|
||||
case <-timeout:
|
||||
if msg := c.status.LastError(); msg != "" {
|
||||
return fmt.Errorf("timeout waiting for mlx runner: %s", msg)
|
||||
}
|
||||
return errors.New("timeout waiting for mlx runner to start")
|
||||
case <-ticker.C:
|
||||
if err := c.Ping(ctx); err == nil {
|
||||
slog.Info("mlx runner is ready", "port", c.port)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type CompletionRequest struct {
|
||||
Prompt string
|
||||
Options api.Options
|
||||
Logprobs bool
|
||||
TopLogprobs int
|
||||
}
|
||||
|
||||
type CompletionResponse struct {
|
||||
Content string
|
||||
Done bool
|
||||
DoneReason int
|
||||
|
||||
PromptEvalCount int
|
||||
PromptEvalDuration time.Duration
|
||||
EvalCount int
|
||||
EvalDuration time.Duration
|
||||
|
||||
Logprobs []llm.Logprob
|
||||
|
||||
Error *api.StatusError
|
||||
}
|
||||
|
||||
// Close terminates the subprocess.
|
||||
func (c *Client) Close() error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.cmd != nil && c.cmd.Process != nil {
|
||||
slog.Info("stopping mlx runner subprocess", "pid", c.cmd.Process.Pid)
|
||||
c.cmd.Process.Signal(os.Interrupt)
|
||||
|
||||
select {
|
||||
case <-c.done:
|
||||
case <-time.After(5 * time.Second):
|
||||
c.cmd.Process.Kill()
|
||||
}
|
||||
c.cmd = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Completion implements llm.LlamaServer.
|
||||
func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
|
||||
creq := CompletionRequest{
|
||||
Prompt: req.Prompt,
|
||||
Logprobs: req.Logprobs,
|
||||
TopLogprobs: req.TopLogprobs,
|
||||
}
|
||||
if req.Options != nil {
|
||||
creq.Options = *req.Options
|
||||
}
|
||||
|
||||
body, err := json.Marshal(creq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
httpURL := fmt.Sprintf("http://127.0.0.1:%d/completion", c.port)
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", httpURL, strings.NewReader(string(body)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.client.Do(httpReq)
|
||||
if err != nil {
|
||||
if errMsg := c.status.LastError(); errMsg != "" {
|
||||
return fmt.Errorf("mlx runner failed: %s", errMsg)
|
||||
}
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return api.StatusError{StatusCode: resp.StatusCode, ErrorMessage: strings.TrimSpace(string(respBody))}
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
for scanner.Scan() {
|
||||
var raw CompletionResponse
|
||||
if err := json.Unmarshal(scanner.Bytes(), &raw); err != nil {
|
||||
slog.Debug("mlx response parse error", "error", err, "line", string(scanner.Bytes()))
|
||||
continue
|
||||
}
|
||||
|
||||
if raw.Error != nil {
|
||||
return *raw.Error
|
||||
}
|
||||
|
||||
cresp := llm.CompletionResponse{
|
||||
Content: raw.Content,
|
||||
Done: raw.Done,
|
||||
DoneReason: llm.DoneReason(raw.DoneReason),
|
||||
PromptEvalCount: raw.PromptEvalCount,
|
||||
PromptEvalDuration: raw.PromptEvalDuration,
|
||||
EvalCount: raw.EvalCount,
|
||||
EvalDuration: raw.EvalDuration,
|
||||
Logprobs: raw.Logprobs,
|
||||
}
|
||||
|
||||
fn(cresp)
|
||||
if cresp.Done {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
if errMsg := c.status.LastError(); errMsg != "" {
|
||||
return fmt.Errorf("mlx runner failed: %s", errMsg)
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) ContextLength() int {
|
||||
return int(c.contextLength.Load())
|
||||
}
|
||||
|
||||
// Detokenize implements llm.LlamaServer.
|
||||
func (c *Client) Detokenize(ctx context.Context, tokens []int) (string, error) {
|
||||
return "", errors.New("not supported")
|
||||
}
|
||||
|
||||
// Embedding implements llm.LlamaServer.
|
||||
func (c *Client) Embedding(ctx context.Context, input string) ([]float32, int, error) {
|
||||
return nil, 0, errors.New("not supported")
|
||||
}
|
||||
|
||||
// GetDeviceInfos implements llm.LlamaServer.
|
||||
func (c *Client) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPort implements llm.LlamaServer.
|
||||
func (c *Client) GetPort() int {
|
||||
return c.port
|
||||
}
|
||||
|
||||
// HasExited implements llm.LlamaServer.
|
||||
func (c *Client) HasExited() bool {
|
||||
select {
|
||||
case <-c.done:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Load checks whether the model fits in GPU memory and starts the subprocess.
|
||||
func (c *Client) Load(ctx context.Context, _ ml.SystemInfo, gpus []ml.DeviceInfo, requireFull bool) ([]ml.DeviceID, error) {
|
||||
if len(gpus) > 0 {
|
||||
modelSize := c.memory.Load()
|
||||
// We currently only use the first GPU with MLX
|
||||
available := gpus[0].FreeMemory
|
||||
overhead := gpus[0].MinimumMemory() + envconfig.GpuOverhead()
|
||||
if available > overhead {
|
||||
available -= overhead
|
||||
} else {
|
||||
available = 0
|
||||
}
|
||||
|
||||
if modelSize > available {
|
||||
if requireFull {
|
||||
return nil, llm.ErrLoadRequiredFull
|
||||
}
|
||||
return nil, fmt.Errorf("model requires %s but only %s are available (after %s overhead)", format.HumanBytes2(modelSize), format.HumanBytes2(available), format.HumanBytes2(overhead))
|
||||
}
|
||||
}
|
||||
|
||||
// Find a free port
|
||||
port := 0
|
||||
if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
|
||||
if l, err := net.ListenTCP("tcp", a); err == nil {
|
||||
port = l.Addr().(*net.TCPAddr).Port
|
||||
l.Close()
|
||||
}
|
||||
}
|
||||
if port == 0 {
|
||||
port = rand.Intn(65535-49152) + 49152
|
||||
}
|
||||
c.port = port
|
||||
|
||||
// Get the current executable path
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to lookup executable path: %w", err)
|
||||
}
|
||||
if eval, err := filepath.EvalSymlinks(exe); err == nil {
|
||||
exe = eval
|
||||
}
|
||||
|
||||
// Spawn subprocess: ollama runner --mlx-engine --model <name> --port <port>
|
||||
cmd := exec.Command(exe, "runner", "--mlx-engine", "--model", c.modelName, "--port", strconv.Itoa(port))
|
||||
cmd.Env = os.Environ()
|
||||
|
||||
// Set library path environment variable for MLX libraries
|
||||
// Linux: LD_LIBRARY_PATH, Windows: PATH
|
||||
var libPathEnvVar string
|
||||
switch runtime.GOOS {
|
||||
case "linux":
|
||||
libPathEnvVar = "LD_LIBRARY_PATH"
|
||||
case "windows":
|
||||
libPathEnvVar = "PATH"
|
||||
}
|
||||
|
||||
if libPathEnvVar != "" {
|
||||
libraryPaths := []string{ml.LibOllamaPath}
|
||||
if mlxDirs, err := filepath.Glob(filepath.Join(ml.LibOllamaPath, "mlx_*")); err == nil {
|
||||
libraryPaths = append(libraryPaths, mlxDirs...)
|
||||
}
|
||||
|
||||
if existingPath, ok := os.LookupEnv(libPathEnvVar); ok {
|
||||
libraryPaths = append(libraryPaths, filepath.SplitList(existingPath)...)
|
||||
}
|
||||
|
||||
pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
|
||||
|
||||
found := false
|
||||
for i := range cmd.Env {
|
||||
envName := cmd.Env[i]
|
||||
if runtime.GOOS == "windows" {
|
||||
envName = strings.ToUpper(envName)
|
||||
}
|
||||
if strings.HasPrefix(envName, libPathEnvVar+"=") {
|
||||
cmd.Env[i] = libPathEnvVar + "=" + pathEnvVal
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
cmd.Env = append(cmd.Env, libPathEnvVar+"="+pathEnvVal)
|
||||
}
|
||||
slog.Debug("mlx subprocess library path", libPathEnvVar, pathEnvVal)
|
||||
}
|
||||
|
||||
// Point MLX's JIT compiler at our bundled CUDA runtime headers.
|
||||
// MLX resolves headers via $CUDA_PATH/include/*.h (and checks CUDA_HOME first).
|
||||
// Always use bundled headers to avoid version mismatches with any
|
||||
// system-installed CUDA toolkit.
|
||||
if mlxDirs, err := filepath.Glob(filepath.Join(ml.LibOllamaPath, "mlx_cuda_*")); err == nil {
|
||||
for _, d := range mlxDirs {
|
||||
if _, err := os.Stat(filepath.Join(d, "include")); err == nil {
|
||||
setEnv(cmd, "CUDA_PATH", d)
|
||||
setEnv(cmd, "CUDA_HOME", d)
|
||||
slog.Debug("mlx subprocess CUDA headers", "CUDA_PATH", d)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.cmd = cmd
|
||||
|
||||
status := llm.NewStatusWriter(os.Stderr)
|
||||
c.status = status
|
||||
// os/exec serializes Write calls when shared, which keeps the status writer
|
||||
// from seeing concurrent stdout/stderr fragments.
|
||||
cmd.Stdout = status
|
||||
cmd.Stderr = status
|
||||
|
||||
slog.Info("starting mlx runner subprocess", "model", c.modelName, "port", c.port)
|
||||
if err := cmd.Start(); err != nil {
|
||||
return nil, fmt.Errorf("failed to start mlx runner: %w", err)
|
||||
}
|
||||
|
||||
// Reap subprocess when it exits
|
||||
go func() {
|
||||
c.doneErr = cmd.Wait()
|
||||
close(c.done)
|
||||
}()
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// ModelPath implements llm.LlamaServer.
|
||||
func (c *Client) ModelPath() string {
|
||||
return c.modelName
|
||||
}
|
||||
|
||||
// Pid implements llm.LlamaServer.
|
||||
func (c *Client) Pid() int {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.cmd != nil && c.cmd.Process != nil {
|
||||
return c.cmd.Process.Pid
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
type statusResponse struct {
|
||||
Status int
|
||||
Progress int
|
||||
ContextLength int
|
||||
Memory uint64
|
||||
}
|
||||
|
||||
// Ping implements llm.LlamaServer.
|
||||
func (c *Client) Ping(ctx context.Context) error {
|
||||
reqURL := fmt.Sprintf("http://127.0.0.1:%d/v1/status", c.port)
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("health check failed: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var status statusResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&status); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.contextLength.Store(int64(status.ContextLength))
|
||||
c.memory.Store(status.Memory)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Tokenize implements llm.LlamaServer.
|
||||
func (c *Client) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||
reqURL := fmt.Sprintf("http://127.0.0.1:%d/v1/tokenize", c.port)
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", reqURL, strings.NewReader(content))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "text/plain")
|
||||
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var tokens []int
|
||||
if err := json.NewDecoder(resp.Body).Decode(&tokens); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
func (c *Client) currentMemory() uint64 {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
c.Ping(ctx) //nolint:errcheck
|
||||
return c.memory.Load()
|
||||
}
|
||||
|
||||
// MemorySize implements llm.LlamaServer.
|
||||
func (c *Client) MemorySize() (total, vram uint64) {
|
||||
mem := c.currentMemory()
|
||||
return mem, mem
|
||||
}
|
||||
|
||||
// VRAMByGPU implements llm.LlamaServer.
|
||||
func (c *Client) VRAMByGPU(id ml.DeviceID) uint64 {
|
||||
return c.currentMemory()
|
||||
}
|
||||
|
||||
var _ llm.LlamaServer = (*Client)(nil)
|
||||
|
||||
// setEnv sets or replaces an environment variable in cmd.Env.
|
||||
func setEnv(cmd *exec.Cmd, key, value string) {
|
||||
entry := key + "=" + value
|
||||
prefix := strings.ToUpper(key + "=")
|
||||
for i, e := range cmd.Env {
|
||||
if strings.HasPrefix(strings.ToUpper(e), prefix) {
|
||||
cmd.Env[i] = entry
|
||||
return
|
||||
}
|
||||
}
|
||||
cmd.Env = append(cmd.Env, entry)
|
||||
}
|
||||
715
x/mlxrunner/dflash.go
Normal file
715
x/mlxrunner/dflash.go
Normal file
@@ -0,0 +1,715 @@
|
||||
package mlxrunner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/batch"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||
sampler "github.com/ollama/ollama/x/mlxrunner/sample"
|
||||
)
|
||||
|
||||
type dflashStats struct {
|
||||
iterations int
|
||||
drafted int
|
||||
accepted int
|
||||
mismatches int
|
||||
allAccepted int
|
||||
batched int
|
||||
serial int
|
||||
targetDuration time.Duration
|
||||
draftDuration time.Duration
|
||||
validateDuration time.Duration
|
||||
}
|
||||
|
||||
type dflashDecodeMode string
|
||||
|
||||
const (
|
||||
dflashDecodeDisabled dflashDecodeMode = ""
|
||||
dflashDecodeGreedy dflashDecodeMode = "greedy"
|
||||
dflashDecodeSample dflashDecodeMode = "sample"
|
||||
)
|
||||
|
||||
func (m dflashDecodeMode) enabled() bool {
|
||||
return m != dflashDecodeDisabled
|
||||
}
|
||||
|
||||
func newDFlashTargetCaches(m base.Model) []cache.Cache {
|
||||
if cacheFactory, ok := m.(interface{ NewCaches() []cache.Cache }); ok {
|
||||
return cacheFactory.NewCaches()
|
||||
}
|
||||
caches := make([]cache.Cache, m.NumLayers())
|
||||
for i := range caches {
|
||||
caches[i] = cache.NewKVCache()
|
||||
}
|
||||
return caches
|
||||
}
|
||||
|
||||
func freeCacheSet(caches []cache.Cache) {
|
||||
for _, c := range caches {
|
||||
if c != nil {
|
||||
c.Free()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Runner) dflashGate(opts sampler.Options) (dflashDecodeMode, string) {
|
||||
if r.Draft == nil {
|
||||
return dflashDecodeDisabled, "no_draft"
|
||||
}
|
||||
if _, ok := r.Draft.(base.DFlashDraftModel); !ok {
|
||||
return dflashDecodeDisabled, "draft_not_dflash"
|
||||
}
|
||||
if _, ok := r.Model.(base.DFlashTargetModel); !ok {
|
||||
return dflashDecodeDisabled, "target_not_dflash"
|
||||
}
|
||||
if _, ok := r.Model.(base.MTPEmbeddingModel); !ok {
|
||||
return dflashDecodeDisabled, "target_embeddings_missing"
|
||||
}
|
||||
if opts.Logprobs || opts.TopLogprobs > 0 {
|
||||
return dflashDecodeDisabled, "logprobs_requested"
|
||||
}
|
||||
|
||||
if opts.Temperature > 0 || dflashUsesSamplerHistory(opts) {
|
||||
return dflashDecodeSample, ""
|
||||
}
|
||||
|
||||
return dflashDecodeGreedy, ""
|
||||
}
|
||||
|
||||
func dflashUsesSamplerHistory(opts sampler.Options) bool {
|
||||
if opts.RepeatLastN == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
repeatPenalty := opts.RepeatPenalty
|
||||
if repeatPenalty <= 0 {
|
||||
repeatPenalty = 1
|
||||
}
|
||||
return repeatPenalty != 1 || opts.PresencePenalty != 0 || opts.FrequencyPenalty != 0
|
||||
}
|
||||
|
||||
func (r *Runner) runGreedyDFlashDecode(ctx context.Context, request Request, session *cacheSession, targetCaches []cache.Cache, draftCaches []cache.Cache, seed []int32, position *int, started time.Time) error {
|
||||
target := r.Model.(base.DFlashTargetModel)
|
||||
draft := r.Draft.(base.DFlashDraftModel)
|
||||
stats := dflashStats{}
|
||||
slog.Info("DFlash greedy decode enabled", "block_size", draft.BlockSize(), "target_layers", draft.TargetLayerIDs())
|
||||
|
||||
targetForward := func(token *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
hidden, targetHidden := target.ForwardDFlash(&batch.Batch{
|
||||
InputIDs: token,
|
||||
SeqOffsets: []int32{int32(*position)},
|
||||
SeqQueryLens: []int32{int32(token.Dim(1))},
|
||||
}, targetCaches, draft.TargetLayerIDs())
|
||||
*position += token.Dim(1)
|
||||
return hidden, targetHidden
|
||||
}
|
||||
|
||||
t0 := time.Now()
|
||||
hidden, targetHidden := targetForward(mlx.FromValues(seed, 1, len(seed)))
|
||||
draft.AppendContext(targetHidden, draftCaches)
|
||||
current := sampler.Result{Token: greedyTokenFromLogits(r.lastLogits(hidden))}
|
||||
mlx.Pin(current.Arrays()...)
|
||||
mlx.Sweep()
|
||||
mlx.AsyncEval(current.Arrays()...)
|
||||
stats.targetDuration += time.Since(t0)
|
||||
defer func() {
|
||||
mlx.Unpin(current.Arrays()...)
|
||||
}()
|
||||
|
||||
dec := decoder{tokenizer: r.Tokenizer}
|
||||
final := CompletionResponse{Done: true, PromptEvalCount: len(request.Tokens), DoneReason: 1}
|
||||
now := started
|
||||
generated := 0
|
||||
|
||||
for generated < request.Options.NumPredict {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if generated == 0 {
|
||||
mlx.Eval(current.Arrays()...)
|
||||
final.PromptEvalDuration = time.Since(now)
|
||||
now = time.Now()
|
||||
}
|
||||
|
||||
done, err := r.emitMTPToken(ctx, request, session, &dec, current, &final)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !done {
|
||||
generated++
|
||||
}
|
||||
if done || generated >= request.Options.NumPredict {
|
||||
break
|
||||
}
|
||||
|
||||
draftCount := min(draft.BlockSize()-1, request.Options.NumPredict-generated)
|
||||
if draftCount <= 0 {
|
||||
t0 = time.Now()
|
||||
hidden, targetHidden := targetForward(mtpTokenInput(current.Token))
|
||||
draft.AppendContext(targetHidden, draftCaches)
|
||||
stats.targetDuration += time.Since(t0)
|
||||
next := sampler.Result{Token: greedyTokenFromLogits(r.lastLogits(hidden))}
|
||||
mlx.Pin(next.Arrays()...)
|
||||
old := current
|
||||
current = next
|
||||
mlx.Unpin(old.Arrays()...)
|
||||
mlx.Sweep()
|
||||
mlx.AsyncEval(current.Arrays()...)
|
||||
continue
|
||||
}
|
||||
|
||||
stats.iterations++
|
||||
t0 = time.Now()
|
||||
draftTokens := r.generateDFlashDrafts(draft, current.Token, draftCaches, draftCount)
|
||||
mlx.Pin(draftTokens)
|
||||
mlx.Eval(draftTokens)
|
||||
stats.draftDuration += time.Since(t0)
|
||||
stats.drafted += draftCount
|
||||
|
||||
t0 = time.Now()
|
||||
next, accepted, done, err := r.acceptDFlashDrafts(ctx, request, session, &dec, target, draft, targetCaches, draftCaches, position, current, draftTokens, &final, &generated, &stats)
|
||||
stats.validateDuration += time.Since(t0)
|
||||
mlx.Unpin(draftTokens)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
stats.accepted += accepted
|
||||
if accepted == draftCount {
|
||||
stats.allAccepted++
|
||||
} else {
|
||||
stats.mismatches++
|
||||
}
|
||||
if done || generated >= request.Options.NumPredict {
|
||||
break
|
||||
}
|
||||
|
||||
mlx.Pin(next.Arrays()...)
|
||||
old := current
|
||||
current = next
|
||||
mlx.Unpin(old.Arrays()...)
|
||||
mlx.Sweep()
|
||||
mlx.AsyncEval(current.Arrays()...)
|
||||
|
||||
if generated%256 == 0 {
|
||||
mlx.ClearCache()
|
||||
}
|
||||
}
|
||||
|
||||
final.EvalCount = generated
|
||||
final.EvalDuration = time.Since(now)
|
||||
acceptance := 0.0
|
||||
if stats.drafted > 0 {
|
||||
acceptance = float64(stats.accepted) / float64(stats.drafted)
|
||||
}
|
||||
avgDraft := 0.0
|
||||
avgAccepted := 0.0
|
||||
if stats.iterations > 0 {
|
||||
avgDraft = float64(stats.drafted) / float64(stats.iterations)
|
||||
avgAccepted = float64(stats.accepted) / float64(stats.iterations)
|
||||
}
|
||||
slog.Info("DFlash decode stats", "mode", "greedy", "generated", generated, "drafted", stats.drafted, "accepted", stats.accepted, "acceptance", acceptance, "iterations", stats.iterations, "avg_draft", avgDraft, "avg_accepted", avgAccepted, "batched", stats.batched, "serial", stats.serial, "mismatches", stats.mismatches, "all_accepted", stats.allAccepted, "max_draft", draft.BlockSize()-1, "block_size", draft.BlockSize(), "target_layers", draft.TargetLayerIDs(), "target_duration", stats.targetDuration, "draft_duration", stats.draftDuration, "validate_duration", stats.validateDuration)
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case request.Responses <- final:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Runner) runSampleDFlashDecode(ctx context.Context, request Request, session *cacheSession, targetCaches []cache.Cache, draftCaches []cache.Cache, seed []int32, position *int, started time.Time) error {
|
||||
target := r.Model.(base.DFlashTargetModel)
|
||||
draft := r.Draft.(base.DFlashDraftModel)
|
||||
stats := dflashStats{}
|
||||
slog.Info("DFlash sample decode enabled",
|
||||
"block_size", draft.BlockSize(),
|
||||
"target_layers", draft.TargetLayerIDs(),
|
||||
"temperature", request.SamplerOpts.Temperature,
|
||||
"top_p", request.SamplerOpts.TopP,
|
||||
"top_k", request.SamplerOpts.TopK,
|
||||
"min_p", request.SamplerOpts.MinP,
|
||||
"repeat_penalty", request.SamplerOpts.RepeatPenalty,
|
||||
"presence_penalty", request.SamplerOpts.PresencePenalty,
|
||||
"frequency_penalty", request.SamplerOpts.FrequencyPenalty,
|
||||
)
|
||||
|
||||
targetForward := func(token *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
hidden, targetHidden := target.ForwardDFlash(&batch.Batch{
|
||||
InputIDs: token,
|
||||
SeqOffsets: []int32{int32(*position)},
|
||||
SeqQueryLens: []int32{int32(token.Dim(1))},
|
||||
}, targetCaches, draft.TargetLayerIDs())
|
||||
*position += token.Dim(1)
|
||||
return hidden, targetHidden
|
||||
}
|
||||
|
||||
t0 := time.Now()
|
||||
hidden, targetHidden := targetForward(mlx.FromValues(seed, 1, len(seed)))
|
||||
draft.AppendContext(targetHidden, draftCaches)
|
||||
current := r.Sampler.Sample([]int{pipelineSlot}, r.lastLogits(hidden))
|
||||
mlx.Pin(current.Arrays()...)
|
||||
mlx.Sweep()
|
||||
mlx.AsyncEval(current.Arrays()...)
|
||||
stats.targetDuration += time.Since(t0)
|
||||
defer func() {
|
||||
mlx.Unpin(current.Arrays()...)
|
||||
}()
|
||||
|
||||
dec := decoder{tokenizer: r.Tokenizer}
|
||||
final := CompletionResponse{Done: true, PromptEvalCount: len(request.Tokens), DoneReason: 1}
|
||||
now := started
|
||||
generated := 0
|
||||
|
||||
for generated < request.Options.NumPredict {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if generated == 0 {
|
||||
mlx.Eval(current.Arrays()...)
|
||||
final.PromptEvalDuration = time.Since(now)
|
||||
now = time.Now()
|
||||
}
|
||||
|
||||
done, err := r.emitMTPToken(ctx, request, session, &dec, current, &final)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !done {
|
||||
generated++
|
||||
}
|
||||
if done || generated >= request.Options.NumPredict {
|
||||
break
|
||||
}
|
||||
|
||||
draftCount := min(draft.BlockSize()-1, request.Options.NumPredict-generated)
|
||||
if draftCount <= 0 {
|
||||
t0 = time.Now()
|
||||
hidden, targetHidden := targetForward(mtpTokenInput(current.Token))
|
||||
draft.AppendContext(targetHidden, draftCaches)
|
||||
stats.targetDuration += time.Since(t0)
|
||||
next := r.Sampler.Sample([]int{pipelineSlot}, r.lastLogits(hidden))
|
||||
mlx.Pin(next.Arrays()...)
|
||||
old := current
|
||||
current = next
|
||||
mlx.Unpin(old.Arrays()...)
|
||||
mlx.Sweep()
|
||||
mlx.AsyncEval(current.Arrays()...)
|
||||
continue
|
||||
}
|
||||
|
||||
stats.iterations++
|
||||
t0 = time.Now()
|
||||
candidates := r.generateDFlashDraftCandidates(draft, current.Token, draftCaches, draftCount)
|
||||
var candidateArrays []*mlx.Array
|
||||
if candidates != nil {
|
||||
draftCount = candidates.tokens.Dim(1)
|
||||
candidateArrays = candidates.Arrays()
|
||||
mlx.Pin(candidateArrays...)
|
||||
mlx.Sweep()
|
||||
}
|
||||
stats.draftDuration += time.Since(t0)
|
||||
stats.drafted += draftCount
|
||||
|
||||
var next sampler.Result
|
||||
if draftCount == 0 {
|
||||
t0 = time.Now()
|
||||
hidden, targetHidden := targetForward(mtpTokenInput(current.Token))
|
||||
draft.AppendContext(targetHidden, draftCaches)
|
||||
stats.targetDuration += time.Since(t0)
|
||||
next = r.Sampler.Sample([]int{pipelineSlot}, r.lastLogits(hidden))
|
||||
} else {
|
||||
var accepted int
|
||||
t0 = time.Now()
|
||||
next, accepted, done, err = r.acceptSampleDFlashDrafts(ctx, request, session, &dec, target, draft, targetCaches, draftCaches, position, current, candidates, &final, &generated, &stats)
|
||||
stats.validateDuration += time.Since(t0)
|
||||
mlx.Unpin(candidateArrays...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
stats.accepted += accepted
|
||||
if accepted == draftCount {
|
||||
stats.allAccepted++
|
||||
} else {
|
||||
stats.mismatches++
|
||||
}
|
||||
if next.Token == nil {
|
||||
mlx.Sweep()
|
||||
}
|
||||
if done || generated >= request.Options.NumPredict {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
mlx.Pin(next.Arrays()...)
|
||||
old := current
|
||||
current = next
|
||||
mlx.Unpin(old.Arrays()...)
|
||||
mlx.Sweep()
|
||||
mlx.AsyncEval(current.Arrays()...)
|
||||
|
||||
if generated%256 == 0 {
|
||||
mlx.ClearCache()
|
||||
}
|
||||
}
|
||||
|
||||
final.EvalCount = generated
|
||||
final.EvalDuration = time.Since(now)
|
||||
acceptance := 0.0
|
||||
if stats.drafted > 0 {
|
||||
acceptance = float64(stats.accepted) / float64(stats.drafted)
|
||||
}
|
||||
avgDraft := 0.0
|
||||
avgAccepted := 0.0
|
||||
if stats.iterations > 0 {
|
||||
avgDraft = float64(stats.drafted) / float64(stats.iterations)
|
||||
avgAccepted = float64(stats.accepted) / float64(stats.iterations)
|
||||
}
|
||||
slog.Info("DFlash decode stats", "mode", "sample", "generated", generated, "drafted", stats.drafted, "accepted", stats.accepted, "acceptance", acceptance, "iterations", stats.iterations, "avg_draft", avgDraft, "avg_accepted", avgAccepted, "batched", stats.batched, "serial", stats.serial, "mismatches", stats.mismatches, "all_accepted", stats.allAccepted, "max_draft", draft.BlockSize()-1, "block_size", draft.BlockSize(), "target_layers", draft.TargetLayerIDs(), "target_duration", stats.targetDuration, "draft_duration", stats.draftDuration, "validate_duration", stats.validateDuration)
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case request.Responses <- final:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Runner) dflashDraftLogits(draft base.DFlashDraftModel, current *mlx.Array, caches []cache.Cache, draftCount int) *mlx.Array {
|
||||
blockLen := draftCount + 1
|
||||
values := make([]int32, blockLen)
|
||||
values[0] = int32(tokenID(current))
|
||||
for i := 1; i < blockLen; i++ {
|
||||
values[i] = draft.MaskTokenID()
|
||||
}
|
||||
block := mlx.FromValues(values, 1, blockLen)
|
||||
logits := draft.Draft(block, caches)
|
||||
return logits.Slice(mlx.Slice(), mlx.Slice(1, blockLen), mlx.Slice())
|
||||
}
|
||||
|
||||
func (r *Runner) generateDFlashDrafts(draft base.DFlashDraftModel, current *mlx.Array, caches []cache.Cache, draftCount int) *mlx.Array {
|
||||
logits := r.dflashDraftLogits(draft, current, caches, draftCount)
|
||||
return logits.Argmax(-1, false).AsType(mlx.DTypeInt32)
|
||||
}
|
||||
|
||||
type dflashDraftCandidates struct {
|
||||
tokens *mlx.Array
|
||||
dist sampler.Distribution
|
||||
}
|
||||
|
||||
func (c *dflashDraftCandidates) Arrays() []*mlx.Array {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
return append([]*mlx.Array{c.tokens}, c.dist.Arrays()...)
|
||||
}
|
||||
|
||||
func (r *Runner) generateDFlashDraftCandidates(draft base.DFlashDraftModel, current *mlx.Array, caches []cache.Cache, draftCount int) *dflashDraftCandidates {
|
||||
if draftCount <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
logits := r.dflashDraftLogits(draft, current, caches, draftCount)
|
||||
draftTokens := make([]*mlx.Array, 0, draftCount)
|
||||
draftDists := make([]sampler.Distribution, 0, draftCount)
|
||||
var prefix *mlx.Array
|
||||
|
||||
for i := range draftCount {
|
||||
rows := logits.Slice(mlx.Slice(), mlx.Slice(0, i+1), mlx.Slice())
|
||||
dist := r.Sampler.Distribution(pipelineSlot, rows, prefix).SliceRows(i, i+1)
|
||||
nextToken := mtpTokenVector(r.Sampler.SampleDistribution(pipelineSlot, dist))
|
||||
nextInput := mtpTokenInput(nextToken)
|
||||
|
||||
draftTokens = append(draftTokens, nextInput)
|
||||
draftDists = append(draftDists, dist)
|
||||
if prefix == nil {
|
||||
prefix = nextInput
|
||||
} else {
|
||||
prefix = prefix.Concatenate(1, nextInput)
|
||||
}
|
||||
}
|
||||
if len(draftTokens) == 0 {
|
||||
return nil
|
||||
}
|
||||
return &dflashDraftCandidates{
|
||||
tokens: mlx.Concatenate(draftTokens, 1),
|
||||
dist: sampler.ConcatenateDistributions(draftDists),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Runner) acceptDFlashDrafts(ctx context.Context, request Request, session *cacheSession, dec *decoder, target base.DFlashTargetModel, draft base.DFlashDraftModel, targetCaches []cache.Cache, draftCaches []cache.Cache, position *int, current sampler.Result, draftTokens *mlx.Array, final *CompletionResponse, generated *int, stats *dflashStats) (sampler.Result, int, bool, error) {
|
||||
specCaches, spec, ok := cache.BeginSpeculation(targetCaches)
|
||||
if !ok {
|
||||
stats.serial++
|
||||
return r.acceptDFlashDraftsSerial(ctx, request, session, dec, target, draft, targetCaches, draftCaches, position, current, draftTokens, final, generated)
|
||||
}
|
||||
stats.batched++
|
||||
return r.acceptDFlashDraftsBatched(ctx, request, session, dec, target, draft, specCaches, spec, draftCaches, position, current, draftTokens, final, generated)
|
||||
}
|
||||
|
||||
func (r *Runner) acceptDFlashDraftsBatched(ctx context.Context, request Request, session *cacheSession, dec *decoder, target base.DFlashTargetModel, draft base.DFlashDraftModel, specCaches []cache.Cache, spec *cache.Speculation, draftCaches []cache.Cache, position *int, current sampler.Result, draftTokens *mlx.Array, final *CompletionResponse, generated *int) (sampler.Result, int, bool, error) {
|
||||
before := *position
|
||||
draftCount := draftTokens.Dim(1)
|
||||
verifyInput := mtpTokenInput(current.Token).Concatenate(1, draftTokens)
|
||||
hiddenSeq, targetHiddenSeq := target.ForwardDFlash(&batch.Batch{
|
||||
InputIDs: verifyInput,
|
||||
SeqOffsets: []int32{int32(before)},
|
||||
SeqQueryLens: []int32{int32(verifyInput.Dim(1))},
|
||||
}, specCaches, draft.TargetLayerIDs())
|
||||
|
||||
selectedTokens := r.Model.Unembed(hiddenSeq).Argmax(-1, false).AsType(mlx.DTypeInt32)
|
||||
mlx.Eval(draftTokens, selectedTokens)
|
||||
|
||||
draftIDs := draftTokens.Ints()
|
||||
selectedIDs := selectedTokens.Ints()
|
||||
if len(selectedIDs) < draftCount+1 {
|
||||
spec.Commit(0)
|
||||
return sampler.Result{}, 0, false, fmt.Errorf("dflash validation produced %d tokens for %d draft tokens", len(selectedIDs), draftCount)
|
||||
}
|
||||
|
||||
accepted := 0
|
||||
for i, id := range draftIDs {
|
||||
if selectedIDs[i] != id {
|
||||
break
|
||||
}
|
||||
accepted++
|
||||
if r.Tokenizer.IsEOS(int32(id)) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
commitN := accepted + 1
|
||||
spec.Commit(0)
|
||||
|
||||
done := false
|
||||
for _, id := range draftIDs[:accepted] {
|
||||
if *generated >= request.Options.NumPredict {
|
||||
done = true
|
||||
break
|
||||
}
|
||||
res := sampler.Result{Token: mlx.FromValues([]int32{int32(id)}, 1)}
|
||||
var err error
|
||||
done, err = r.emitMTPToken(ctx, request, session, dec, res, final)
|
||||
if err != nil {
|
||||
return sampler.Result{}, accepted, done, err
|
||||
}
|
||||
if !done {
|
||||
(*generated)++
|
||||
}
|
||||
if done {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
spec.Commit(commitN)
|
||||
*position = before + commitN
|
||||
draft.AppendContext(targetHiddenSeq.Slice(mlx.Slice(), mlx.Slice(0, commitN), mlx.Slice()), draftCaches)
|
||||
|
||||
if done || *generated >= request.Options.NumPredict {
|
||||
return sampler.Result{}, accepted, true, nil
|
||||
}
|
||||
|
||||
nextIndex := accepted
|
||||
if nextIndex >= len(selectedIDs) {
|
||||
nextIndex = len(selectedIDs) - 1
|
||||
}
|
||||
return sampler.Result{Token: mlx.FromValues([]int32{int32(selectedIDs[nextIndex])}, 1)}, accepted, false, nil
|
||||
}
|
||||
|
||||
func (r *Runner) acceptDFlashDraftsSerial(ctx context.Context, request Request, session *cacheSession, dec *decoder, target base.DFlashTargetModel, draft base.DFlashDraftModel, targetCaches []cache.Cache, draftCaches []cache.Cache, position *int, current sampler.Result, draftTokens *mlx.Array, final *CompletionResponse, generated *int) (sampler.Result, int, bool, error) {
|
||||
targetForward := func(token *mlx.Array) *mlx.Array {
|
||||
hidden, targetHidden := target.ForwardDFlash(&batch.Batch{
|
||||
InputIDs: token,
|
||||
SeqOffsets: []int32{int32(*position)},
|
||||
SeqQueryLens: []int32{int32(token.Dim(1))},
|
||||
}, targetCaches, draft.TargetLayerIDs())
|
||||
*position += token.Dim(1)
|
||||
draft.AppendContext(targetHidden, draftCaches)
|
||||
return r.lastLogits(hidden)
|
||||
}
|
||||
|
||||
logits := targetForward(mtpTokenInput(current.Token))
|
||||
accepted := 0
|
||||
for _, id := range draftTokens.Ints() {
|
||||
selected := greedyTokenFromLogits(logits)
|
||||
mlx.Eval(selected)
|
||||
selectedID := tokenID(selected)
|
||||
if selectedID != id {
|
||||
return sampler.Result{Token: mlx.FromValues([]int32{int32(selectedID)}, 1)}, accepted, false, nil
|
||||
}
|
||||
|
||||
res := sampler.Result{Token: mlx.FromValues([]int32{int32(id)}, 1)}
|
||||
done, err := r.emitMTPToken(ctx, request, session, dec, res, final)
|
||||
if err != nil {
|
||||
return sampler.Result{}, accepted, done, err
|
||||
}
|
||||
accepted++
|
||||
if !done {
|
||||
(*generated)++
|
||||
}
|
||||
if done || *generated >= request.Options.NumPredict {
|
||||
return sampler.Result{}, accepted, true, nil
|
||||
}
|
||||
|
||||
logits = targetForward(mtpTokenInput(res.Token))
|
||||
}
|
||||
|
||||
return sampler.Result{Token: greedyTokenFromLogits(logits)}, accepted, false, nil
|
||||
}
|
||||
|
||||
func (r *Runner) acceptSampleDFlashDrafts(ctx context.Context, request Request, session *cacheSession, dec *decoder, target base.DFlashTargetModel, draft base.DFlashDraftModel, targetCaches []cache.Cache, draftCaches []cache.Cache, position *int, current sampler.Result, candidates *dflashDraftCandidates, final *CompletionResponse, generated *int, stats *dflashStats) (sampler.Result, int, bool, error) {
|
||||
specCaches, spec, ok := cache.BeginSpeculation(targetCaches)
|
||||
if !ok {
|
||||
stats.serial++
|
||||
return r.acceptSampleDFlashDraftsSerial(ctx, request, session, dec, target, draft, targetCaches, draftCaches, position, current, candidates, final, generated)
|
||||
}
|
||||
stats.batched++
|
||||
return r.acceptSampleDFlashDraftsBatched(ctx, request, session, dec, target, draft, specCaches, spec, draftCaches, position, current, candidates, final, generated)
|
||||
}
|
||||
|
||||
func (r *Runner) acceptSampleDFlashDraftsBatched(ctx context.Context, request Request, session *cacheSession, dec *decoder, target base.DFlashTargetModel, draft base.DFlashDraftModel, specCaches []cache.Cache, spec *cache.Speculation, draftCaches []cache.Cache, position *int, current sampler.Result, candidates *dflashDraftCandidates, final *CompletionResponse, generated *int) (sampler.Result, int, bool, error) {
|
||||
before := *position
|
||||
draftCount := candidates.tokens.Dim(1)
|
||||
verifyInput := mtpTokenInput(current.Token).Concatenate(1, candidates.tokens)
|
||||
hiddenSeq, targetHiddenSeq := target.ForwardDFlash(&batch.Batch{
|
||||
InputIDs: verifyInput,
|
||||
SeqOffsets: []int32{int32(before)},
|
||||
SeqQueryLens: []int32{int32(verifyInput.Dim(1))},
|
||||
}, specCaches, draft.TargetLayerIDs())
|
||||
|
||||
targetDist := r.Sampler.Distribution(pipelineSlot, r.Model.Unembed(hiddenSeq), candidates.tokens)
|
||||
draftDist := candidates.dist
|
||||
acceptedMask := r.mtpSampleAcceptedMask(targetDist.SliceRows(0, draftCount), draftDist, candidates.tokens)
|
||||
mlx.Eval(candidates.tokens, acceptedMask)
|
||||
|
||||
draftIDs := candidates.tokens.Ints()
|
||||
acceptedFlags := acceptedMask.Ints()
|
||||
accepted := 0
|
||||
for _, ok := range acceptedFlags {
|
||||
if ok == 0 {
|
||||
break
|
||||
}
|
||||
accepted++
|
||||
}
|
||||
if accepted > draftCount {
|
||||
spec.Commit(0)
|
||||
return sampler.Result{}, 0, false, fmt.Errorf("dflash sample validation accepted %d tokens for %d draft tokens", accepted, draftCount)
|
||||
}
|
||||
|
||||
commitIDs := make([]int32, 0, accepted+1)
|
||||
done := false
|
||||
for i, id := range draftIDs[:accepted] {
|
||||
commitIDs = append(commitIDs, int32(id))
|
||||
if r.Tokenizer.IsEOS(int32(id)) {
|
||||
done = true
|
||||
accepted = i + 1
|
||||
commitIDs = commitIDs[:accepted]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
commitN := accepted + 1
|
||||
spec.Commit(0)
|
||||
|
||||
for _, id := range draftIDs[:accepted] {
|
||||
if *generated >= request.Options.NumPredict {
|
||||
done = true
|
||||
break
|
||||
}
|
||||
res := sampler.Result{Token: mlx.FromValues([]int32{int32(id)}, 1)}
|
||||
var err error
|
||||
done, err = r.emitMTPToken(ctx, request, session, dec, res, final)
|
||||
if err != nil {
|
||||
return sampler.Result{}, accepted, done, err
|
||||
}
|
||||
if !done {
|
||||
(*generated)++
|
||||
}
|
||||
if done {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
spec.Commit(commitN)
|
||||
*position = before + commitN
|
||||
draft.AppendContext(targetHiddenSeq.Slice(mlx.Slice(), mlx.Slice(0, commitN), mlx.Slice()), draftCaches)
|
||||
|
||||
if done || *generated >= request.Options.NumPredict {
|
||||
r.Sampler.Commit(pipelineSlot, commitIDs)
|
||||
return sampler.Result{}, accepted, true, nil
|
||||
}
|
||||
|
||||
var nextToken *mlx.Array
|
||||
if accepted == draftCount {
|
||||
nextToken = r.mtpSampleTokenAt(targetDist, draftCount)
|
||||
} else {
|
||||
nextToken = r.mtpSampleResidualToken(targetDist, draftDist, accepted)
|
||||
}
|
||||
mlx.Eval(nextToken)
|
||||
nextID := int32(tokenID(nextToken))
|
||||
commitIDs = append(commitIDs, nextID)
|
||||
r.Sampler.Commit(pipelineSlot, commitIDs)
|
||||
|
||||
return sampler.Result{Token: nextToken}, accepted, false, nil
|
||||
}
|
||||
|
||||
func (r *Runner) acceptSampleDFlashDraftsSerial(ctx context.Context, request Request, session *cacheSession, dec *decoder, target base.DFlashTargetModel, draft base.DFlashDraftModel, targetCaches []cache.Cache, draftCaches []cache.Cache, position *int, current sampler.Result, candidates *dflashDraftCandidates, final *CompletionResponse, generated *int) (sampler.Result, int, bool, error) {
|
||||
targetForward := func(token *mlx.Array) *mlx.Array {
|
||||
hidden, targetHidden := target.ForwardDFlash(&batch.Batch{
|
||||
InputIDs: token,
|
||||
SeqOffsets: []int32{int32(*position)},
|
||||
SeqQueryLens: []int32{int32(token.Dim(1))},
|
||||
}, targetCaches, draft.TargetLayerIDs())
|
||||
*position += token.Dim(1)
|
||||
draft.AppendContext(targetHidden, draftCaches)
|
||||
return r.lastLogits(hidden)
|
||||
}
|
||||
|
||||
mlx.Eval(candidates.tokens)
|
||||
draftIDs := candidates.tokens.Ints()
|
||||
logits := targetForward(mtpTokenInput(current.Token))
|
||||
accepted := 0
|
||||
|
||||
for i, id := range draftIDs {
|
||||
targetDist := r.Sampler.Distribution(pipelineSlot, logits, nil)
|
||||
draftDist := candidates.dist.SliceRows(i, i+1)
|
||||
draftToken := mlx.FromValues([]int32{int32(id)}, 1)
|
||||
acceptedMask := r.mtpSampleAcceptedMask(targetDist, draftDist, draftToken)
|
||||
mlx.Eval(acceptedMask)
|
||||
|
||||
if acceptedMask.Ints()[0] == 0 {
|
||||
nextToken := mtpTokenVector(r.Sampler.SampleDistribution(pipelineSlot, targetDist.ResidualAgainst(draftDist)))
|
||||
mlx.Eval(nextToken)
|
||||
r.Sampler.Commit(pipelineSlot, []int32{int32(tokenID(nextToken))})
|
||||
return sampler.Result{Token: nextToken}, accepted, false, nil
|
||||
}
|
||||
|
||||
accepted++
|
||||
r.Sampler.Commit(pipelineSlot, []int32{int32(id)})
|
||||
res := sampler.Result{Token: mlx.FromValues([]int32{int32(id)}, 1)}
|
||||
done, err := r.emitMTPToken(ctx, request, session, dec, res, final)
|
||||
if err != nil {
|
||||
return sampler.Result{}, accepted, done, err
|
||||
}
|
||||
if !done {
|
||||
(*generated)++
|
||||
}
|
||||
if done || *generated >= request.Options.NumPredict {
|
||||
return sampler.Result{}, accepted, true, nil
|
||||
}
|
||||
|
||||
logits = targetForward(mtpTokenInput(res.Token))
|
||||
}
|
||||
|
||||
targetDist := r.Sampler.Distribution(pipelineSlot, logits, nil)
|
||||
nextToken := mtpTokenVector(r.Sampler.SampleDistribution(pipelineSlot, targetDist))
|
||||
mlx.Eval(nextToken)
|
||||
r.Sampler.Commit(pipelineSlot, []int32{int32(tokenID(nextToken))})
|
||||
return sampler.Result{Token: nextToken}, accepted, false, nil
|
||||
}
|
||||
13
x/mlxrunner/imports.go
Normal file
13
x/mlxrunner/imports.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package mlxrunner
|
||||
|
||||
import (
|
||||
_ "github.com/ollama/ollama/x/models/dflash"
|
||||
_ "github.com/ollama/ollama/x/models/gemma3"
|
||||
_ "github.com/ollama/ollama/x/models/gemma4"
|
||||
_ "github.com/ollama/ollama/x/models/glm4_moe_lite"
|
||||
_ "github.com/ollama/ollama/x/models/laguna"
|
||||
_ "github.com/ollama/ollama/x/models/llama"
|
||||
_ "github.com/ollama/ollama/x/models/qwen3"
|
||||
_ "github.com/ollama/ollama/x/models/qwen3_5"
|
||||
_ "github.com/ollama/ollama/x/models/qwen3_5_moe"
|
||||
)
|
||||
3
x/mlxrunner/mlx/.gitignore
vendored
Normal file
3
x/mlxrunner/mlx/.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
_deps
|
||||
build
|
||||
dist
|
||||
32
x/mlxrunner/mlx/CMakeLists.txt
Normal file
32
x/mlxrunner/mlx/CMakeLists.txt
Normal file
@@ -0,0 +1,32 @@
|
||||
cmake_minimum_required(VERSION 3.5)
|
||||
|
||||
project(mlx)
|
||||
|
||||
if(CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT)
|
||||
set(CMAKE_INSTALL_PREFIX "${CMAKE_CURRENT_SOURCE_DIR}/dist" CACHE PATH "" FORCE)
|
||||
endif()
|
||||
|
||||
set(MLX_BUILD_GGUF OFF CACHE BOOL "" FORCE)
|
||||
set(MLX_BUILD_SAFETENSORS ON CACHE BOOL "" FORCE)
|
||||
set(MLX_C_BUILD_EXAMPLES OFF CACHE BOOL "" FORCE)
|
||||
set(BUILD_SHARED_LIBS ON CACHE BOOL "" FORCE)
|
||||
|
||||
set(CMAKE_INSTALL_RPATH "@loader_path")
|
||||
|
||||
include(FetchContent)
|
||||
|
||||
# Read MLX-C version from top-level file (shared with imagegen CMakeLists)
|
||||
file(READ "${CMAKE_SOURCE_DIR}/MLX_C_VERSION" MLX_C_GIT_TAG)
|
||||
string(STRIP "${MLX_C_GIT_TAG}" MLX_C_GIT_TAG)
|
||||
|
||||
FetchContent_Declare(
|
||||
mlx-c
|
||||
GIT_REPOSITORY "https://github.com/ml-explore/mlx-c.git"
|
||||
GIT_TAG ${MLX_C_GIT_TAG}
|
||||
)
|
||||
|
||||
FetchContent_MakeAvailable(mlx-c)
|
||||
|
||||
# Sync vendored headers with fetched version
|
||||
file(GLOB _mlx_c_hdrs "${mlx-c_SOURCE_DIR}/mlx/c/*.h")
|
||||
file(COPY ${_mlx_c_hdrs} DESTINATION "${CMAKE_CURRENT_SOURCE_DIR}/include/mlx/c/")
|
||||
99
x/mlxrunner/mlx/act.go
Normal file
99
x/mlxrunner/mlx/act.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package mlx
|
||||
|
||||
import "math"
|
||||
|
||||
var geluCoeff = float32(math.Sqrt(2 / math.Pi))
|
||||
|
||||
// GELUApprox returns 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
|
||||
// as a fused kernel.
|
||||
var GELUApprox = Compile1(
|
||||
"GELUApprox",
|
||||
func(x *Array) *Array {
|
||||
// Dtype-matched scalars avoid implicit upcasts on bf16 inputs.
|
||||
dt := x.DType()
|
||||
half := FromValue[float32](0.5).AsType(dt)
|
||||
coeff := FromValue(geluCoeff).AsType(dt)
|
||||
c := FromValue[float32](0.044715).AsType(dt)
|
||||
one := FromValue[float32](1.0).AsType(dt)
|
||||
|
||||
// x^3 via x*x*x (avoids general Power which is slower).
|
||||
x3 := x.Multiply(x).Multiply(x)
|
||||
inner := x.Add(c.Multiply(x3))
|
||||
tanh := coeff.Multiply(inner).Tanh()
|
||||
return half.Multiply(x).Multiply(one.Add(tanh))
|
||||
},
|
||||
Shapeless(),
|
||||
)
|
||||
|
||||
// SiLU returns a * sigmoid(a) as a fused kernel.
|
||||
var SiLU = Compile1(
|
||||
"SiLU",
|
||||
func(a *Array) *Array {
|
||||
return a.Multiply(a.Sigmoid())
|
||||
},
|
||||
Shapeless(),
|
||||
)
|
||||
|
||||
// SoftplusF32 returns softplus(x) computed in float32 precision and cast back
|
||||
// to x's original dtype, as a fused kernel. Matches the laguna attention
|
||||
// output-gate formula: softplus(cast_f32(x)).cast(orig_dtype).
|
||||
var SoftplusF32 = Compile1(
|
||||
"SoftplusF32",
|
||||
func(x *Array) *Array {
|
||||
dt := x.DType()
|
||||
zero := FromValue[float32](0)
|
||||
return Logaddexp(x.AsType(DTypeFloat32), zero).AsType(dt)
|
||||
},
|
||||
Shapeless(),
|
||||
)
|
||||
|
||||
// SwiGLU returns silu(gate) * up as a fused kernel.
|
||||
var SwiGLU = Compile2(
|
||||
"SwiGLU",
|
||||
func(gate, up *Array) *Array {
|
||||
return SiLU(gate).Multiply(up)
|
||||
},
|
||||
Shapeless(),
|
||||
)
|
||||
|
||||
// GeGLU returns gelu_approx(gate) * up as a fused kernel. Matches mlx_lm's
|
||||
// geglu, used by Gemma-family MLP and MoE paths.
|
||||
var GeGLU = Compile2(
|
||||
"GeGLU",
|
||||
func(gate, up *Array) *Array {
|
||||
return GELUApprox(gate).Multiply(up)
|
||||
},
|
||||
Shapeless(),
|
||||
)
|
||||
|
||||
// LogitSoftcap returns tanh(x / cap) * cap as a fused kernel. Matches
|
||||
// mlx_lm's logit_softcap. cap must have the same dtype as x.
|
||||
var LogitSoftcap = Compile2(
|
||||
"LogitSoftcap",
|
||||
func(x, cap *Array) *Array {
|
||||
return x.Divide(cap).Tanh().Multiply(cap)
|
||||
},
|
||||
Shapeless(),
|
||||
)
|
||||
|
||||
// sigmoidRouterFused traces the DeepSeek-V2 / GLM-MoE aux-loss-free router
|
||||
// head. Two outputs are returned so the pre-bias sigmoid (used to gather
|
||||
// per-expert scores after top-k) and the post-bias negation (used as the
|
||||
// argpartition key for top-k) share a single kernel.
|
||||
var sigmoidRouterFused = Compile(
|
||||
"SigmoidRouter",
|
||||
func(in ...*Array) []*Array {
|
||||
gates, bias := in[0], in[1]
|
||||
orig := gates.Sigmoid()
|
||||
neg := orig.Add(bias).Negative()
|
||||
return []*Array{orig, neg}
|
||||
},
|
||||
Shapeless(),
|
||||
)
|
||||
|
||||
// SigmoidRouter returns (sigmoid(gates), -(sigmoid(gates)+bias)) as a fused
|
||||
// kernel — the DeepSeek-V2 / GLM-MoE aux-loss-free router head.
|
||||
func SigmoidRouter(gates, bias *Array) (origScores, negScores *Array) {
|
||||
out := sigmoidRouterFused(gates, bias)
|
||||
return out[0], out[1]
|
||||
}
|
||||
295
x/mlxrunner/mlx/array.go
Normal file
295
x/mlxrunner/mlx/array.go
Normal file
@@ -0,0 +1,295 @@
|
||||
package mlx
|
||||
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"unsafe"
|
||||
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
type Array struct {
|
||||
ctx C.mlx_array
|
||||
name string
|
||||
pinned atomic.Int32
|
||||
}
|
||||
|
||||
var (
|
||||
arrays []*Array
|
||||
arraysMu sync.Mutex
|
||||
)
|
||||
|
||||
// constructor utilities
|
||||
|
||||
func New(name string) *Array {
|
||||
t := &Array{name: name}
|
||||
|
||||
if tracing {
|
||||
traceScratch = append(traceScratch, t)
|
||||
} else {
|
||||
arraysMu.Lock()
|
||||
defer arraysMu.Unlock()
|
||||
|
||||
arrays = append(arrays, t)
|
||||
}
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
type scalarTypes interface {
|
||||
~bool | ~int | ~float32 | ~float64 | ~complex64
|
||||
}
|
||||
|
||||
func FromValue[T scalarTypes](t T) *Array {
|
||||
tt := New("")
|
||||
switch v := any(t).(type) {
|
||||
case bool:
|
||||
tt.ctx = C.mlx_array_new_bool(C.bool(v))
|
||||
case int:
|
||||
tt.ctx = C.mlx_array_new_int(C.int(v))
|
||||
case float32:
|
||||
tt.ctx = C.mlx_array_new_float32(C.float(v))
|
||||
case float64:
|
||||
tt.ctx = C.mlx_array_new_float64(C.double(v))
|
||||
case complex64:
|
||||
tt.ctx = C.mlx_array_new_complex(C.float(real(v)), C.float(imag(v)))
|
||||
default:
|
||||
panic("unsupported type")
|
||||
}
|
||||
return tt
|
||||
}
|
||||
|
||||
type arrayTypes interface {
|
||||
~bool | ~uint8 | ~uint16 | ~uint32 | ~uint64 |
|
||||
~int8 | ~int16 | ~int32 | ~int64 |
|
||||
~float32 | ~float64 |
|
||||
~complex64
|
||||
}
|
||||
|
||||
func FromValues[S ~[]E, E arrayTypes](s S, shape ...int) *Array {
|
||||
if len(shape) == 0 {
|
||||
panic("shape must be provided for non-scalar tensors")
|
||||
}
|
||||
|
||||
cShape := make([]C.int, len(shape))
|
||||
for i := range shape {
|
||||
cShape[i] = C.int(shape[i])
|
||||
}
|
||||
|
||||
var dtype DType
|
||||
switch reflect.TypeOf(s).Elem().Kind() {
|
||||
case reflect.Bool:
|
||||
dtype = DTypeBool
|
||||
case reflect.Uint8:
|
||||
dtype = DTypeUint8
|
||||
case reflect.Uint16:
|
||||
dtype = DTypeUint16
|
||||
case reflect.Uint32:
|
||||
dtype = DTypeUint32
|
||||
case reflect.Uint64:
|
||||
dtype = DTypeUint64
|
||||
case reflect.Int8:
|
||||
dtype = DTypeInt8
|
||||
case reflect.Int16:
|
||||
dtype = DTypeInt16
|
||||
case reflect.Int32:
|
||||
dtype = DTypeInt32
|
||||
case reflect.Int64:
|
||||
dtype = DTypeInt64
|
||||
case reflect.Float32:
|
||||
dtype = DTypeFloat32
|
||||
case reflect.Float64:
|
||||
dtype = DTypeFloat64
|
||||
case reflect.Complex64:
|
||||
dtype = DTypeComplex64
|
||||
default:
|
||||
panic("unsupported type")
|
||||
}
|
||||
|
||||
bts := make([]byte, binary.Size(s))
|
||||
if _, err := binary.Encode(bts, binary.LittleEndian, s); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
tt := New("")
|
||||
tt.ctx = C.mlx_array_new_data(unsafe.Pointer(&bts[0]), unsafe.SliceData(cShape), C.int(len(cShape)), C.mlx_dtype(dtype))
|
||||
return tt
|
||||
}
|
||||
|
||||
func (t *Array) Set(other *Array) {
|
||||
C.mlx_array_set(&t.ctx, other.ctx)
|
||||
}
|
||||
|
||||
func (t *Array) Clone() *Array {
|
||||
tt := New(t.name)
|
||||
C.mlx_array_set(&tt.ctx, t.ctx)
|
||||
return tt
|
||||
}
|
||||
|
||||
// lifecycle utilities
|
||||
|
||||
// Pin marks arrays as in-use so they are retained during Sweep.
|
||||
func Pin(s ...*Array) {
|
||||
for _, t := range s {
|
||||
if t != nil {
|
||||
t.pinned.Add(1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Unpin marks arrays as no longer in-use, allowing Sweep to free them.
|
||||
func Unpin(s ...*Array) {
|
||||
for _, t := range s {
|
||||
if t != nil {
|
||||
if t.pinned.Add(-1) < 0 {
|
||||
panic(fmt.Sprintf("mlx.Unpin: negative pin count on array %q", t.name))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Sweep releases all unpinned arrays, primarily intermediate tensors. MLX will truly
|
||||
// free them when there are no other references, including dependencies in the graph.
|
||||
func Sweep() {
|
||||
arraysMu.Lock()
|
||||
defer arraysMu.Unlock()
|
||||
n := 0
|
||||
for _, t := range arrays {
|
||||
if t.pinned.Load() > 0 && t.Valid() {
|
||||
arrays[n] = t
|
||||
n++
|
||||
} else if t.Valid() {
|
||||
C.mlx_array_free(t.ctx)
|
||||
t.ctx.ctx = nil
|
||||
}
|
||||
}
|
||||
arrays = arrays[:n]
|
||||
}
|
||||
|
||||
// misc. utilities
|
||||
|
||||
func (t *Array) Valid() bool {
|
||||
return t.ctx.ctx != nil
|
||||
}
|
||||
|
||||
func (t *Array) String() string {
|
||||
str := C.mlx_string_new()
|
||||
defer C.mlx_string_free(str)
|
||||
C.mlx_array_tostring(&str, t.ctx)
|
||||
return strings.TrimSpace(C.GoString(C.mlx_string_data(str)))
|
||||
}
|
||||
|
||||
func (t *Array) LogValue() slog.Value {
|
||||
attrs := []slog.Attr{
|
||||
slog.String("name", t.name),
|
||||
slog.Int("pinned", int(t.pinned.Load())),
|
||||
}
|
||||
if t.Valid() {
|
||||
attrs = append(attrs,
|
||||
slog.Any("dtype", t.DType()),
|
||||
slog.Any("shape", t.Dims()),
|
||||
slog.Int("num_bytes", t.NumBytes()),
|
||||
)
|
||||
}
|
||||
return slog.GroupValue(attrs...)
|
||||
}
|
||||
|
||||
// shape utilities
|
||||
|
||||
func (t *Array) Size() int {
|
||||
return int(C.mlx_array_size(t.ctx))
|
||||
}
|
||||
|
||||
func (t *Array) NumBytes() int {
|
||||
return int(C.mlx_array_nbytes(t.ctx))
|
||||
}
|
||||
|
||||
func (t *Array) NumDims() int {
|
||||
return int(C.mlx_array_ndim(t.ctx))
|
||||
}
|
||||
|
||||
func (t *Array) Dims() []int {
|
||||
dims := make([]int, t.NumDims())
|
||||
for i := range dims {
|
||||
dims[i] = t.Dim(i)
|
||||
}
|
||||
|
||||
return dims
|
||||
}
|
||||
|
||||
func (t *Array) Dim(dim int) int {
|
||||
return int(C.mlx_array_dim(t.ctx, C.int(dim)))
|
||||
}
|
||||
|
||||
func (t *Array) DType() DType {
|
||||
return DType(C.mlx_array_dtype(t.ctx))
|
||||
}
|
||||
|
||||
// data utilities
|
||||
|
||||
func (t *Array) Int() int {
|
||||
var item C.int64_t
|
||||
C.mlx_array_item_int64(&item, t.ctx)
|
||||
return int(item)
|
||||
}
|
||||
|
||||
func (t *Array) Float() float64 {
|
||||
var item C.double
|
||||
C.mlx_array_item_float64(&item, t.ctx)
|
||||
return float64(item)
|
||||
}
|
||||
|
||||
func (t *Array) Ints() []int {
|
||||
if dt := t.DType(); dt != DTypeInt32 {
|
||||
panic(fmt.Sprintf("mlx: Ints requires DTypeInt32, got %v", dt))
|
||||
}
|
||||
ints := make([]int, t.Size())
|
||||
for i, f := range unsafe.Slice(C.mlx_array_data_int32(t.ctx), len(ints)) {
|
||||
ints[i] = int(f)
|
||||
}
|
||||
return ints
|
||||
}
|
||||
|
||||
func (t *Array) Floats() []float32 {
|
||||
if dt := t.DType(); dt != DTypeFloat32 {
|
||||
panic(fmt.Sprintf("mlx: Floats requires DTypeFloat32, got %v", dt))
|
||||
}
|
||||
floats := make([]float32, t.Size())
|
||||
for i, f := range unsafe.Slice(C.mlx_array_data_float32(t.ctx), len(floats)) {
|
||||
floats[i] = float32(f)
|
||||
}
|
||||
return floats
|
||||
}
|
||||
|
||||
func (t *Array) Save(name string) error {
|
||||
cName := C.CString(name)
|
||||
defer C.free(unsafe.Pointer(cName))
|
||||
C.mlx_save(cName, t.ctx)
|
||||
return nil
|
||||
}
|
||||
|
||||
// LogArrays logs all live arrays, sorted by size
|
||||
func LogArrays() {
|
||||
arraysMu.Lock()
|
||||
defer arraysMu.Unlock()
|
||||
sort.Slice(arrays, func(i, j int) bool {
|
||||
return arrays[i].NumBytes() > arrays[j].NumBytes()
|
||||
})
|
||||
|
||||
var total int
|
||||
for _, t := range arrays {
|
||||
nb := t.NumBytes()
|
||||
total += nb
|
||||
logutil.Trace(fmt.Sprintf("tensor %-60s %5s %5s pinned=%d %v", t.name, t.DType(), PrettyBytes(nb), t.pinned.Load(), t.Dims()))
|
||||
}
|
||||
logutil.Trace(fmt.Sprintf("tensors total: %d, size: %s, active: %s", len(arrays), PrettyBytes(total), PrettyBytes(ActiveMemory())))
|
||||
}
|
||||
76
x/mlxrunner/mlx/array_test.go
Normal file
76
x/mlxrunner/mlx/array_test.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package mlx
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestFromValue(t *testing.T) {
|
||||
withMLXThread(t, func() {
|
||||
for got, want := range map[*Array]DType{
|
||||
FromValue(true): DTypeBool,
|
||||
FromValue(false): DTypeBool,
|
||||
FromValue(int(7)): DTypeInt32,
|
||||
FromValue(float32(3.14)): DTypeFloat32,
|
||||
FromValue(float64(2.71)): DTypeFloat64,
|
||||
FromValue(complex64(1 + 2i)): DTypeComplex64,
|
||||
} {
|
||||
if got.DType() != want {
|
||||
t.Errorf("%s: want %v, got %v", want, want, got)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestFromValues(t *testing.T) {
|
||||
withMLXThread(t, func() {
|
||||
for got, want := range map[*Array]DType{
|
||||
FromValues([]bool{true, false, true}, 3): DTypeBool,
|
||||
FromValues([]uint8{1, 2, 3}, 3): DTypeUint8,
|
||||
FromValues([]uint16{1, 2, 3}, 3): DTypeUint16,
|
||||
FromValues([]uint32{1, 2, 3}, 3): DTypeUint32,
|
||||
FromValues([]uint64{1, 2, 3}, 3): DTypeUint64,
|
||||
FromValues([]int8{-1, -2, -3}, 3): DTypeInt8,
|
||||
FromValues([]int16{-1, -2, -3}, 3): DTypeInt16,
|
||||
FromValues([]int32{-1, -2, -3}, 3): DTypeInt32,
|
||||
FromValues([]int64{-1, -2, -3}, 3): DTypeInt64,
|
||||
FromValues([]float32{3.14, 2.71, 1.61}, 3): DTypeFloat32,
|
||||
FromValues([]float64{3.14, 2.71, 1.61}, 3): DTypeFloat64,
|
||||
FromValues([]complex64{1 + 2i, 3 + 4i, 5 + 6i}, 3): DTypeComplex64,
|
||||
} {
|
||||
if got.DType() != want {
|
||||
t.Errorf("%s: want %v, got %v", want, want, got)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestComparisonOpsAndBernoulli(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
a := FromValues([]float32{1, 2, 3}, 3)
|
||||
b := FromValues([]float32{1, 1, 4}, 3)
|
||||
eq := a.Equal(b).AsType(DTypeInt32)
|
||||
gt := a.Greater(b).AsType(DTypeInt32)
|
||||
le := a.LessEqual(b).AsType(DTypeInt32)
|
||||
bern := Bernoulli(FromValues([]float32{1, 0}, 2)).AsType(DTypeInt32)
|
||||
Eval(eq, gt, le, bern)
|
||||
|
||||
for name, tc := range map[string]struct {
|
||||
got []int
|
||||
want []int
|
||||
}{
|
||||
"equal": {eq.Ints(), []int{1, 0, 0}},
|
||||
"greater": {gt.Ints(), []int{0, 1, 0}},
|
||||
"lessEqual": {le.Ints(), []int{1, 0, 1}},
|
||||
"bernoulli": {bern.Ints(), []int{1, 0}},
|
||||
} {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
if len(tc.got) != len(tc.want) {
|
||||
t.Fatalf("got %v, want %v", tc.got, tc.want)
|
||||
}
|
||||
for i := range tc.want {
|
||||
if tc.got[i] != tc.want[i] {
|
||||
t.Fatalf("got %v, want %v", tc.got, tc.want)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
192
x/mlxrunner/mlx/compile.go
Normal file
192
x/mlxrunner/mlx/compile.go
Normal file
@@ -0,0 +1,192 @@
|
||||
package mlx
|
||||
|
||||
// #include <stdlib.h>
|
||||
// #include "generated.h"
|
||||
//
|
||||
// extern int closureCallback(mlx_vector_array* res, mlx_vector_array input, void* payload);
|
||||
// extern void closureDestructor(void* payload);
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"runtime/cgo"
|
||||
"sync"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// CompileFunc is the signature of a function that can be compiled.
|
||||
type CompileFunc func(inputs ...*Array) []*Array
|
||||
|
||||
// CompileOption configures Compile behavior.
|
||||
type CompileOption func(*compileConfig)
|
||||
|
||||
type compileConfig struct {
|
||||
shapeless bool
|
||||
}
|
||||
|
||||
// Shapeless traces the function once against symbolic shapes so the compiled
|
||||
// graph accepts any input shape afterwards. Without this option, MLX re-traces
|
||||
// on each new (shape, dtype) combination and caches each specialization.
|
||||
func Shapeless() CompileOption {
|
||||
return func(c *compileConfig) { c.shapeless = true }
|
||||
}
|
||||
|
||||
// Compile returns a compiled version of fn. When called during another
|
||||
// compile's trace, fn is inlined directly so outer compiles can fuse through
|
||||
// inner ones.
|
||||
//
|
||||
// Compiled functions must not have side effects outside of the function. Do
|
||||
// not access data other than the arguments passed in (either Go data or MLX
|
||||
// arrays) unless it is a constant.
|
||||
func Compile(name string, fn CompileFunc, opts ...CompileOption) CompileFunc {
|
||||
var cfg compileConfig
|
||||
for _, o := range opts {
|
||||
o(&cfg)
|
||||
}
|
||||
|
||||
var closure C.mlx_closure
|
||||
var once sync.Once
|
||||
|
||||
return func(inputs ...*Array) []*Array {
|
||||
if tracing {
|
||||
return fn(inputs...)
|
||||
}
|
||||
|
||||
once.Do(func() {
|
||||
payload := (*cgo.Handle)(C.malloc(C.size_t(unsafe.Sizeof(cgo.Handle(0)))))
|
||||
*payload = cgo.NewHandle(fn)
|
||||
src := C.mlx_closure_new_func_payload(
|
||||
(*[0]byte)(C.closureCallback),
|
||||
unsafe.Pointer(payload),
|
||||
(*[0]byte)(C.closureDestructor),
|
||||
)
|
||||
defer C.mlx_closure_free(src)
|
||||
|
||||
closure = C.mlx_closure_new()
|
||||
mlxCheck(name+": compile failed", func() C.int {
|
||||
return C.mlx_compile(&closure, src, C.bool(cfg.shapeless))
|
||||
})
|
||||
})
|
||||
|
||||
inVec := C.mlx_vector_array_new()
|
||||
defer C.mlx_vector_array_free(inVec)
|
||||
for _, in := range inputs {
|
||||
C.mlx_vector_array_append_value(inVec, in.ctx)
|
||||
}
|
||||
|
||||
outVec := C.mlx_vector_array_new()
|
||||
defer C.mlx_vector_array_free(outVec)
|
||||
mlxCheck(name+": closure apply failed", func() C.int {
|
||||
return C.mlx_closure_apply(&outVec, closure, inVec)
|
||||
})
|
||||
|
||||
n := int(C.mlx_vector_array_size(outVec))
|
||||
outputs := make([]*Array, n)
|
||||
for i := range n {
|
||||
outputs[i] = New(name)
|
||||
C.mlx_vector_array_get(&outputs[i].ctx, outVec, C.size_t(i))
|
||||
}
|
||||
return outputs
|
||||
}
|
||||
}
|
||||
|
||||
// Compile1 compiles a unary function. See Compile.
|
||||
func Compile1(name string, fn func(*Array) *Array, opts ...CompileOption) func(*Array) *Array {
|
||||
cf := Compile(name, func(in ...*Array) []*Array {
|
||||
return []*Array{fn(in[0])}
|
||||
}, opts...)
|
||||
return func(a *Array) *Array {
|
||||
return cf(a)[0]
|
||||
}
|
||||
}
|
||||
|
||||
// Compile2 compiles a binary function. See Compile.
|
||||
func Compile2(name string, fn func(*Array, *Array) *Array, opts ...CompileOption) func(*Array, *Array) *Array {
|
||||
cf := Compile(name, func(in ...*Array) []*Array {
|
||||
return []*Array{fn(in[0], in[1])}
|
||||
}, opts...)
|
||||
return func(a, b *Array) *Array {
|
||||
return cf(a, b)[0]
|
||||
}
|
||||
}
|
||||
|
||||
// Compile3 compiles a ternary function. See Compile.
|
||||
func Compile3(name string, fn func(*Array, *Array, *Array) *Array, opts ...CompileOption) func(*Array, *Array, *Array) *Array {
|
||||
cf := Compile(name, func(in ...*Array) []*Array {
|
||||
return []*Array{fn(in[0], in[1], in[2])}
|
||||
}, opts...)
|
||||
return func(a, b, c *Array) *Array {
|
||||
return cf(a, b, c)[0]
|
||||
}
|
||||
}
|
||||
|
||||
// tracing is true while a compile callback is running. Since MLX is
|
||||
// single-threaded at this level a plain Go bool suffices.
|
||||
var tracing bool
|
||||
|
||||
// traceScratch collects arrays created during a compile trace so they can be
|
||||
// freed as a group when the callback returns.
|
||||
var traceScratch []*Array
|
||||
|
||||
//export closureCallback
|
||||
func closureCallback(res *C.mlx_vector_array, input C.mlx_vector_array, payload unsafe.Pointer) (rc C.int) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
slog.Error("mlx closure callback panicked", "panic", r)
|
||||
rc = 1
|
||||
}
|
||||
}()
|
||||
|
||||
handle := *(*cgo.Handle)(payload)
|
||||
fn := handle.Value().(CompileFunc)
|
||||
|
||||
// When tracing, we track all of the intermediates that are created and free them separately at the end of
|
||||
// the process. This will give the effect of a single op - inputs are owned by the original caller (via
|
||||
// the MLX layer) and outputs are transferred back to MLX to create a new Go side tensor.
|
||||
if tracing {
|
||||
panic("mlx: nested compile trace")
|
||||
}
|
||||
tracing = true
|
||||
traceScratch = nil
|
||||
defer func() {
|
||||
for _, a := range traceScratch {
|
||||
if a.pinned.Load() > 0 {
|
||||
panic("mlx: traced array was pinned during compilation")
|
||||
}
|
||||
if a.Valid() {
|
||||
C.mlx_array_free(a.ctx)
|
||||
a.ctx.ctx = nil
|
||||
}
|
||||
}
|
||||
tracing = false
|
||||
traceScratch = nil
|
||||
}()
|
||||
|
||||
n := int(C.mlx_vector_array_size(input))
|
||||
inputs := make([]*Array, n)
|
||||
for i := range n {
|
||||
a := New("")
|
||||
C.mlx_vector_array_get(&a.ctx, input, C.size_t(i))
|
||||
inputs[i] = a
|
||||
}
|
||||
|
||||
outputs := fn(inputs...)
|
||||
|
||||
var arrPtr *C.mlx_array
|
||||
if len(outputs) > 0 {
|
||||
handles := make([]C.mlx_array, len(outputs))
|
||||
for i, out := range outputs {
|
||||
handles[i] = out.ctx
|
||||
}
|
||||
arrPtr = &handles[0]
|
||||
}
|
||||
C.mlx_vector_array_set_data(res, arrPtr, C.size_t(len(outputs)))
|
||||
return 0
|
||||
}
|
||||
|
||||
//export closureDestructor
|
||||
func closureDestructor(payload unsafe.Pointer) {
|
||||
handle := *(*cgo.Handle)(payload)
|
||||
handle.Delete()
|
||||
C.free(payload)
|
||||
}
|
||||
147
x/mlxrunner/mlx/compile_test.go
Normal file
147
x/mlxrunner/mlx/compile_test.go
Normal file
@@ -0,0 +1,147 @@
|
||||
package mlx
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCompileFusion(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
// Compile fuses the ops inside a function body into a single kernel,
|
||||
// eliminating intermediate buffers. Use a diamond-shaped graph where
|
||||
// two branches must be materialized simultaneously without fusion,
|
||||
// then compare peak memory against the compiled version which fuses
|
||||
// everything into one kernel with no intermediates.
|
||||
const n = 1024 * 1024 // 4MB per float32 array
|
||||
data := make([]float32, n)
|
||||
for i := range data {
|
||||
data[i] = float32(i + 1)
|
||||
}
|
||||
|
||||
// Diamond: both a*b and a+b must be live for the final multiply.
|
||||
// Without fusion: peak includes both intermediates (~8MB extra).
|
||||
// With fusion: single kernel, no intermediates.
|
||||
body := func(a, b *Array) *Array {
|
||||
return a.Multiply(b).Multiply(a.Add(b))
|
||||
}
|
||||
|
||||
a := FromValues(data, n)
|
||||
b := FromValues(data, n)
|
||||
Pin(a, b)
|
||||
defer Unpin(a, b)
|
||||
|
||||
// Compiled: ops fused into a single kernel.
|
||||
EnableCompile()
|
||||
fn := Compile2("diamond", body, Shapeless())
|
||||
warm := fn(a, b)
|
||||
Eval(warm)
|
||||
Sweep()
|
||||
ClearCache()
|
||||
ResetPeakMemory()
|
||||
y := fn(a, b)
|
||||
Eval(y)
|
||||
compiledPeak := PeakMemory()
|
||||
Sweep()
|
||||
|
||||
// Uncompiled: ops evaluated individually, intermediates materialized.
|
||||
ClearCache()
|
||||
ResetPeakMemory()
|
||||
z := body(a, b)
|
||||
Eval(z)
|
||||
uncompiledPeak := PeakMemory()
|
||||
Sweep()
|
||||
|
||||
if compiledPeak == 0 && uncompiledPeak == 0 {
|
||||
t.Skip("peak memory tracking not available")
|
||||
}
|
||||
|
||||
t.Logf("peak memory: compiled=%d uncompiled=%d", compiledPeak, uncompiledPeak)
|
||||
|
||||
if compiledPeak >= uncompiledPeak {
|
||||
t.Fatalf("compilation did not reduce peak memory: compiled=%d uncompiled=%d", compiledPeak, uncompiledPeak)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompileNested(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
// A compiled function that calls another compiled function should
|
||||
// produce correct results. The inner function inlines via isTracing()
|
||||
// during the outer's trace.
|
||||
inner := Compile1("silu", func(a *Array) *Array {
|
||||
return a.Multiply(a.Sigmoid())
|
||||
}, Shapeless())
|
||||
|
||||
outer := Compile2("swiglu", func(gate, up *Array) *Array {
|
||||
return inner(gate).Multiply(up)
|
||||
}, Shapeless())
|
||||
|
||||
gate := FromValues([]float32{0, 1, 2}, 3)
|
||||
up := FromValues([]float32{1, 1, 1}, 3)
|
||||
Pin(gate, up)
|
||||
defer Unpin(gate, up)
|
||||
|
||||
y := outer(gate, up)
|
||||
Eval(y)
|
||||
|
||||
// silu(x) = x * sigmoid(x); for x=0 → 0, x=1 → ~0.7311, x=2 → ~1.7616
|
||||
got := y.Floats()
|
||||
want := []float32{0, 0.7310586, 1.7615942}
|
||||
for i, v := range got {
|
||||
if v-want[i] > 1e-4 || want[i]-v > 1e-4 {
|
||||
t.Fatalf("got[%d]=%v want %v", i, v, want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompileCallbackPanicRecovers(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
boom := Compile1("boom", func(a *Array) *Array {
|
||||
panic("intentional test panic")
|
||||
})
|
||||
|
||||
x := FromValues([]float32{1}, 1)
|
||||
Pin(x)
|
||||
defer Unpin(x)
|
||||
|
||||
defer func() {
|
||||
r := recover()
|
||||
if r == nil {
|
||||
t.Fatal("expected panic from Call, got none")
|
||||
}
|
||||
if _, ok := r.(string); !ok {
|
||||
t.Fatalf("expected string panic, got %T: %v", r, r)
|
||||
}
|
||||
}()
|
||||
boom(x)
|
||||
}
|
||||
|
||||
func TestCompileNoTrackingGrowth(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
// Repeated invocations of a compiled kernel should not grow the
|
||||
// tracked-arrays list — the callback's traceScratch collects
|
||||
// intermediates during tracing and frees them when the callback returns.
|
||||
fn := Compile2("mul_add", func(a, b *Array) *Array {
|
||||
return a.Multiply(b).Add(b)
|
||||
})
|
||||
|
||||
a := FromValues([]float32{1, 2}, 2)
|
||||
b := FromValues([]float32{3, 4}, 2)
|
||||
Pin(a, b)
|
||||
defer Unpin(a, b)
|
||||
|
||||
Sweep()
|
||||
before := len(arrays)
|
||||
|
||||
for range 100 {
|
||||
_ = fn(a, b)
|
||||
Sweep()
|
||||
}
|
||||
|
||||
after := len(arrays)
|
||||
if after > before+2 {
|
||||
t.Fatalf("tracked arrays grew from %d to %d across 100 calls (includes initial trace)", before, after)
|
||||
}
|
||||
}
|
||||
94
x/mlxrunner/mlx/dtype.go
Normal file
94
x/mlxrunner/mlx/dtype.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package mlx
|
||||
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
|
||||
type DType int
|
||||
|
||||
func (t DType) String() string {
|
||||
switch t {
|
||||
case DTypeBool:
|
||||
return "BOOL"
|
||||
case DTypeUint8:
|
||||
return "U8"
|
||||
case DTypeUint16:
|
||||
return "U16"
|
||||
case DTypeUint32:
|
||||
return "U32"
|
||||
case DTypeUint64:
|
||||
return "U64"
|
||||
case DTypeInt8:
|
||||
return "I8"
|
||||
case DTypeInt16:
|
||||
return "I16"
|
||||
case DTypeInt32:
|
||||
return "I32"
|
||||
case DTypeInt64:
|
||||
return "I64"
|
||||
case DTypeFloat16:
|
||||
return "F16"
|
||||
case DTypeFloat32:
|
||||
return "F32"
|
||||
case DTypeFloat64:
|
||||
return "F64"
|
||||
case DTypeBFloat16:
|
||||
return "BF16"
|
||||
case DTypeComplex64:
|
||||
return "C64"
|
||||
default:
|
||||
return "Unknown"
|
||||
}
|
||||
}
|
||||
|
||||
func (t *DType) UnmarshalJSON(b []byte) error {
|
||||
switch string(b) {
|
||||
case `"BOOL"`:
|
||||
*t = DTypeBool
|
||||
case `"U8"`:
|
||||
*t = DTypeUint8
|
||||
case `"U16"`:
|
||||
*t = DTypeUint16
|
||||
case `"U32"`:
|
||||
*t = DTypeUint32
|
||||
case `"U64"`:
|
||||
*t = DTypeUint64
|
||||
case `"I8"`:
|
||||
*t = DTypeInt8
|
||||
case `"I16"`:
|
||||
*t = DTypeInt16
|
||||
case `"I32"`:
|
||||
*t = DTypeInt32
|
||||
case `"I64"`:
|
||||
*t = DTypeInt64
|
||||
case `"F16"`:
|
||||
*t = DTypeFloat16
|
||||
case `"F64"`:
|
||||
*t = DTypeFloat64
|
||||
case `"F32"`:
|
||||
*t = DTypeFloat32
|
||||
case `"BF16"`:
|
||||
*t = DTypeBFloat16
|
||||
case `"C64"`:
|
||||
*t = DTypeComplex64
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
const (
|
||||
DTypeBool DType = C.MLX_BOOL
|
||||
DTypeUint8 DType = C.MLX_UINT8
|
||||
DTypeUint16 DType = C.MLX_UINT16
|
||||
DTypeUint32 DType = C.MLX_UINT32
|
||||
DTypeUint64 DType = C.MLX_UINT64
|
||||
DTypeInt8 DType = C.MLX_INT8
|
||||
DTypeInt16 DType = C.MLX_INT16
|
||||
DTypeInt32 DType = C.MLX_INT32
|
||||
DTypeInt64 DType = C.MLX_INT64
|
||||
DTypeFloat16 DType = C.MLX_FLOAT16
|
||||
DTypeFloat32 DType = C.MLX_FLOAT32
|
||||
DTypeFloat64 DType = C.MLX_FLOAT64
|
||||
DTypeBFloat16 DType = C.MLX_BFLOAT16
|
||||
DTypeComplex64 DType = C.MLX_COMPLEX64
|
||||
)
|
||||
36
x/mlxrunner/mlx/dynamic.c
Normal file
36
x/mlxrunner/mlx/dynamic.c
Normal file
@@ -0,0 +1,36 @@
|
||||
#include "dynamic.h"
|
||||
|
||||
#include <stdio.h>
|
||||
|
||||
#ifdef _WIN32
|
||||
#include <windows.h>
|
||||
#define DLOPEN(path) LoadLibraryA(path)
|
||||
#define DLCLOSE(handle) FreeLibrary((HMODULE)(handle))
|
||||
#else
|
||||
#ifdef __APPLE__
|
||||
#include <mach-o/dyld.h>
|
||||
#include <libgen.h>
|
||||
#endif
|
||||
#include <dlfcn.h>
|
||||
#define DLOPEN(path) dlopen(path, RTLD_LAZY | RTLD_GLOBAL)
|
||||
#define DLCLOSE(handle) dlclose(handle)
|
||||
#endif
|
||||
|
||||
static int mlx_dynamic_open(mlx_dynamic_handle* handle, const char* path) {
|
||||
handle->ctx = (void*) DLOPEN(path);
|
||||
if (handle->ctx == NULL) {
|
||||
return 1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
int mlx_dynamic_load(mlx_dynamic_handle* handle, const char *path) {
|
||||
return mlx_dynamic_open(handle, path);
|
||||
}
|
||||
|
||||
void mlx_dynamic_unload(mlx_dynamic_handle* handle) {
|
||||
if (handle->ctx) {
|
||||
DLCLOSE(handle->ctx);
|
||||
handle->ctx = NULL;
|
||||
}
|
||||
}
|
||||
253
x/mlxrunner/mlx/dynamic.go
Normal file
253
x/mlxrunner/mlx/dynamic.go
Normal file
@@ -0,0 +1,253 @@
|
||||
package mlx
|
||||
|
||||
// #include "dynamic.h"
|
||||
// #include "generated.h"
|
||||
// #include <stdlib.h>
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
var initError error
|
||||
var initLoadError string
|
||||
var initLoadedPath string
|
||||
|
||||
// CheckInit returns any error that occurred during MLX dynamic library initialization.
|
||||
func CheckInit() error {
|
||||
if initLoadedPath != "" {
|
||||
slog.Debug("MLX dynamic library loaded", "path", initLoadedPath)
|
||||
}
|
||||
if initError != nil && initLoadError != "" {
|
||||
slog.Error(initLoadError)
|
||||
}
|
||||
return initError
|
||||
}
|
||||
|
||||
// tryLoadFromDir searches a directory for the mlxc shared library and loads it.
|
||||
func tryLoadFromDir(dir string) bool {
|
||||
// On Windows, MSVC produces mlxc.dll (no lib prefix)
|
||||
// On Unix, it's libmlxc.so or libmlxc.dylib
|
||||
pattern := "libmlxc.*"
|
||||
if runtime.GOOS == "windows" {
|
||||
pattern = "mlxc.*"
|
||||
}
|
||||
matches, err := fs.Glob(os.DirFS(dir), pattern)
|
||||
if err != nil || len(matches) == 0 {
|
||||
return false
|
||||
}
|
||||
for _, match := range matches {
|
||||
path := filepath.Join(dir, match)
|
||||
cPath := C.CString(path)
|
||||
defer C.free(unsafe.Pointer(cPath))
|
||||
|
||||
var handle C.mlx_dynamic_handle
|
||||
if C.mlx_dynamic_load(&handle, cPath) != 0 {
|
||||
initLoadError = fmt.Sprintf("failed to load MLX dynamic library: path=%s", path)
|
||||
continue
|
||||
}
|
||||
if C.mlx_dynamic_load_symbols(handle) != 0 {
|
||||
initLoadError = fmt.Sprintf("failed to load MLX dynamic library symbols: path=%s", path)
|
||||
C.mlx_dynamic_unload(&handle)
|
||||
continue
|
||||
}
|
||||
|
||||
initLoadedPath = path
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// libOllamaRoots returns candidate directories for MLX dynamic libraries.
|
||||
// Production: exe_dir/lib/ollama (dist tarball) and exe_dir (app bundle).
|
||||
// Development: build/lib/ollama and build/*/lib/ollama.
|
||||
func libOllamaRoots() []string {
|
||||
var roots []string
|
||||
|
||||
// Production paths relative to executable
|
||||
if exe, err := os.Executable(); err == nil {
|
||||
if eval, err := filepath.EvalSymlinks(exe); err == nil {
|
||||
exe = eval
|
||||
}
|
||||
exeDir := filepath.Dir(exe)
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
roots = append(roots, filepath.Join(exeDir, "lib", "ollama"))
|
||||
roots = append(roots, exeDir) // app bundle: Contents/Resources/
|
||||
case "linux":
|
||||
roots = append(roots, filepath.Join(exeDir, "..", "lib", "ollama"))
|
||||
case "windows":
|
||||
roots = append(roots, filepath.Join(exeDir, "lib", "ollama"))
|
||||
}
|
||||
}
|
||||
|
||||
// Development paths: build/lib/ollama and build/*/lib/ollama.
|
||||
// Reverse-sort and filter the glob results so higher-versioned Metal
|
||||
// builds (e.g., metal-v4) are tried before lower ones (metal-v3),
|
||||
// and incompatible variants are skipped. Without this, alphabetical
|
||||
// order would always pick v3 over v4 in dev builds.
|
||||
for _, base := range repoBuildDirs() {
|
||||
roots = append(roots, filepath.Join(base, "lib", "ollama"))
|
||||
if matches, err := filepath.Glob(filepath.Join(base, "*", "lib", "ollama")); err == nil {
|
||||
sort.Sort(sort.Reverse(sort.StringSlice(matches)))
|
||||
for _, m := range matches {
|
||||
// Extract the build dir name (e.g., "metal-v4" from "build/metal-v4/lib/ollama")
|
||||
rel, _ := filepath.Rel(base, m)
|
||||
variant := strings.SplitN(rel, string(filepath.Separator), 2)[0]
|
||||
if isCompatibleMLXVariant(variant) {
|
||||
roots = append(roots, m)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return roots
|
||||
}
|
||||
|
||||
// repoBuildDirs returns candidate build/ directories relative to cwd and repo root.
|
||||
func repoBuildDirs() []string {
|
||||
var dirs []string
|
||||
if cwd, err := os.Getwd(); err == nil {
|
||||
dirs = append(dirs, filepath.Join(cwd, "build"))
|
||||
for dir := cwd; ; {
|
||||
if _, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil {
|
||||
if dir != cwd {
|
||||
dirs = append(dirs, filepath.Join(dir, "build"))
|
||||
}
|
||||
break
|
||||
}
|
||||
parent := filepath.Dir(dir)
|
||||
if parent == dir {
|
||||
break
|
||||
}
|
||||
dir = parent
|
||||
}
|
||||
}
|
||||
return dirs
|
||||
}
|
||||
|
||||
// prependLibraryPath prepends dir to the platform's dynamic library search
|
||||
// path so the linker finds colocated libmlx before any stale copies.
|
||||
// Called once after successful library load.
|
||||
func prependLibraryPath(dir string) {
|
||||
var envVar string
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
envVar = "DYLD_LIBRARY_PATH"
|
||||
case "linux":
|
||||
envVar = "LD_LIBRARY_PATH"
|
||||
default:
|
||||
return
|
||||
}
|
||||
if existing := os.Getenv(envVar); existing != "" {
|
||||
os.Setenv(envVar, dir+string(filepath.ListSeparator)+existing)
|
||||
} else {
|
||||
os.Setenv(envVar, dir)
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
switch runtime.GOOS {
|
||||
case "darwin", "linux", "windows":
|
||||
default:
|
||||
return
|
||||
}
|
||||
|
||||
// OLLAMA_LLM_LIBRARY overrides variant selection (e.g., "mlx_metal_v3").
|
||||
// When set to an mlx_* value, only that specific subdir is tried.
|
||||
// The GGML runner ignores mlx_* values (see discover/runner.go).
|
||||
forcedVariant, _ := os.LookupEnv("OLLAMA_LLM_LIBRARY")
|
||||
if forcedVariant != "" && !strings.HasPrefix(forcedVariant, "mlx_") {
|
||||
forcedVariant = "" // not an MLX variant, ignore
|
||||
}
|
||||
|
||||
found := findMLXLibrary(forcedVariant)
|
||||
if !found {
|
||||
initError = fmt.Errorf("failed to load MLX dynamic library (searched: %v)", libOllamaRoots())
|
||||
return
|
||||
}
|
||||
|
||||
prependLibraryPath(filepath.Dir(initLoadedPath))
|
||||
}
|
||||
|
||||
func findMLXLibrary(forcedVariant string) bool {
|
||||
for _, root := range libOllamaRoots() {
|
||||
if forcedVariant != "" {
|
||||
if tryLoadFromDir(filepath.Join(root, forcedVariant)) {
|
||||
return true
|
||||
}
|
||||
} else {
|
||||
if tryLoadFromMLXSubdirs(root) {
|
||||
return true
|
||||
}
|
||||
if tryLoadFromDir(root) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// tryLoadFromMLXSubdirs globs for mlx_* subdirs within dir, filters out
|
||||
// incompatible variants, tries the remainder in reverse sorted order (so
|
||||
// higher-versioned variants are preferred), and returns true on first
|
||||
// successful load.
|
||||
func tryLoadFromMLXSubdirs(dir string) bool {
|
||||
mlxDirs, err := filepath.Glob(filepath.Join(dir, "mlx_*"))
|
||||
if err != nil || len(mlxDirs) == 0 {
|
||||
return false
|
||||
}
|
||||
// Reverse sort: mlx_metal_v4 before mlx_metal_v3, mlx_cuda_v13 before v12
|
||||
sort.Sort(sort.Reverse(sort.StringSlice(mlxDirs)))
|
||||
for _, mlxDir := range mlxDirs {
|
||||
if !isCompatibleMLXVariant(filepath.Base(mlxDir)) {
|
||||
slog.Debug("skipping incompatible MLX variant", "dir", mlxDir)
|
||||
continue
|
||||
}
|
||||
if tryLoadFromDir(mlxDir) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isCompatibleMLXVariant checks whether an MLX variant directory is
|
||||
// compatible with the current OS. On macOS, dlopen does NOT enforce
|
||||
// the deployment target for dynamically loaded libraries, so we must
|
||||
// check compatibility ourselves to avoid loading Metal 4.x shaders
|
||||
// on a Metal 3.x driver.
|
||||
func isCompatibleMLXVariant(name string) bool {
|
||||
if runtime.GOOS != "darwin" {
|
||||
return true // non-macOS variants use dlopen failure for filtering
|
||||
}
|
||||
// Metal variant naming:
|
||||
// Production: mlx_metal_v3, mlx_metal_v4
|
||||
// Dev build: metal-v3, metal-v4
|
||||
var verStr string
|
||||
switch {
|
||||
case strings.HasPrefix(name, "mlx_metal_v"):
|
||||
verStr = strings.TrimPrefix(name, "mlx_metal_v")
|
||||
case strings.HasPrefix(name, "metal-v"):
|
||||
verStr = strings.TrimPrefix(name, "metal-v")
|
||||
}
|
||||
if verStr != "" {
|
||||
metalVer, err := strconv.Atoi(verStr)
|
||||
if err != nil {
|
||||
return true // unknown format, try it
|
||||
}
|
||||
// Metal 4.x requires macOS 26+
|
||||
if metalVer >= 4 && macOSMajorVersion() < 26 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
47
x/mlxrunner/mlx/dynamic.h
Normal file
47
x/mlxrunner/mlx/dynamic.h
Normal file
@@ -0,0 +1,47 @@
|
||||
#ifndef MLX_DYNAMIC_H
|
||||
#define MLX_DYNAMIC_H
|
||||
|
||||
#ifdef _WIN32
|
||||
#include <windows.h>
|
||||
#define DLSYM(handle, symbol) (void*)GetProcAddress((HMODULE)(handle.ctx), symbol)
|
||||
#else
|
||||
#include <dlfcn.h>
|
||||
#define DLSYM(handle, symbol) dlsym(handle.ctx, symbol)
|
||||
#endif
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
// Provide fallback typedefs for float16_t and bfloat16_t on non-ARM64
|
||||
// platforms where arm_fp16.h and arm_bf16.h are not available. These are
|
||||
// only used as function pointer signature placeholders since MLX requires
|
||||
// Apple Silicon at runtime.
|
||||
#if !defined(__aarch64__) && !defined(__ARM_FEATURE_FP16_SCALAR_ARITHMETIC)
|
||||
typedef uint16_t float16_t;
|
||||
#endif
|
||||
|
||||
#if !defined(__aarch64__) && !defined(__ARM_FEATURE_BF16)
|
||||
typedef uint16_t bfloat16_t;
|
||||
#endif
|
||||
|
||||
// Undef ERROR to avoid conflict with wingdi.h on Windows
|
||||
#ifdef ERROR
|
||||
#undef ERROR
|
||||
#endif
|
||||
#define MLX_ERROR(fmt, ...) fprintf(stderr, "%s %s - ERROR - %s:%d - " fmt "\n", __DATE__, __TIME__, __FILE__, __LINE__, ##__VA_ARGS__); return 1
|
||||
#define CHECK(x) if (!(x)) { MLX_ERROR("CHECK failed: " #x); }
|
||||
#define CHECK_LOAD(handle, x) *(void**)(&x##_) = DLSYM(handle, #x); CHECK(x##_)
|
||||
// OPTIONAL_LOAD: load symbol if available, leave function pointer NULL otherwise
|
||||
#define OPTIONAL_LOAD(handle, x) *(void**)(&x##_) = DLSYM(handle, #x)
|
||||
|
||||
typedef struct {
|
||||
void* ctx;
|
||||
} mlx_dynamic_handle;
|
||||
|
||||
int mlx_dynamic_load(
|
||||
mlx_dynamic_handle* handle,
|
||||
const char *path);
|
||||
|
||||
void mlx_dynamic_unload(
|
||||
mlx_dynamic_handle* handle);
|
||||
|
||||
#endif // MLX_DYNAMIC_H
|
||||
17
x/mlxrunner/mlx/dynamic_darwin.go
Normal file
17
x/mlxrunner/mlx/dynamic_darwin.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package mlx
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func macOSMajorVersion() int {
|
||||
ver, err := syscall.Sysctl("kern.osproductversion")
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
parts := strings.SplitN(ver, ".", 2)
|
||||
major, _ := strconv.Atoi(parts[0])
|
||||
return major
|
||||
}
|
||||
5
x/mlxrunner/mlx/dynamic_other.go
Normal file
5
x/mlxrunner/mlx/dynamic_other.go
Normal file
@@ -0,0 +1,5 @@
|
||||
//go:build !darwin
|
||||
|
||||
package mlx
|
||||
|
||||
func macOSMajorVersion() int { return 0 }
|
||||
47
x/mlxrunner/mlx/fast.go
Normal file
47
x/mlxrunner/mlx/fast.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package mlx
|
||||
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
func FastScaledDotProductAttention(q, k, v *Array, scale float32, mode string, mask *Array) *Array {
|
||||
sinks := New("")
|
||||
cMode := C.CString(mode)
|
||||
defer C.free(unsafe.Pointer(cMode))
|
||||
|
||||
var maskCtx C.mlx_array
|
||||
if mask != nil {
|
||||
maskCtx = mask.ctx
|
||||
} else {
|
||||
empty := New("")
|
||||
maskCtx = empty.ctx
|
||||
}
|
||||
|
||||
out := New("FAST_SDPA")
|
||||
C.mlx_fast_scaled_dot_product_attention(&out.ctx, q.ctx, k.ctx, v.ctx, C.float(scale), cMode, maskCtx, sinks.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
type LayerNorm struct {
|
||||
Weight *Array `weight:"weight"`
|
||||
Bias *Array `weight:"bias"`
|
||||
}
|
||||
|
||||
func (r *LayerNorm) Forward(x *Array, eps float32) *Array {
|
||||
out := New("FAST_LAYERNORM")
|
||||
C.mlx_fast_layer_norm(&out.ctx, x.ctx, r.Weight.ctx, r.Bias.ctx, C.float(eps), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
type RMSNorm struct {
|
||||
Weight *Array `weight:"weight"`
|
||||
}
|
||||
|
||||
func (r *RMSNorm) Forward(x *Array, eps float32) *Array {
|
||||
out := New("FAST_RMSNORM")
|
||||
C.mlx_fast_rms_norm(&out.ctx, x.ctx, r.Weight.ctx, C.float(eps), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
663
x/mlxrunner/mlx/gated_delta.go
Normal file
663
x/mlxrunner/mlx/gated_delta.go
Normal file
@@ -0,0 +1,663 @@
|
||||
package mlx
|
||||
|
||||
// #include <stdlib.h>
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
var (
|
||||
gatedDeltaMetalKernelOnce sync.Once
|
||||
gatedDeltaMetalKernel C.mlx_fast_metal_kernel
|
||||
gatedDeltaMetalDisabled bool
|
||||
|
||||
gatedDeltaCUDAKernelOnce sync.Once
|
||||
gatedDeltaCUDAKernel C.mlx_fast_cuda_kernel
|
||||
gatedDeltaCUDADisabled bool
|
||||
)
|
||||
|
||||
const gatedDeltaMetalKernelSource = `
|
||||
auto n = thread_position_in_grid.z;
|
||||
auto b_idx = n / Hv;
|
||||
auto hv_idx = n % Hv;
|
||||
auto hk_idx = hv_idx / (Hv / Hk);
|
||||
constexpr int n_per_t = Dk / 32;
|
||||
|
||||
// q, k: [B, T, Hk, Dk]
|
||||
auto q_ = q + b_idx * T * Hk * Dk + hk_idx * Dk;
|
||||
auto k_ = k + b_idx * T * Hk * Dk + hk_idx * Dk;
|
||||
|
||||
// v, y: [B, T, Hv, Dv]
|
||||
auto v_ = v + b_idx * T * Hv * Dv + hv_idx * Dv;
|
||||
y += b_idx * T * Hv * Dv + hv_idx * Dv;
|
||||
|
||||
auto dk_idx = thread_position_in_threadgroup.x;
|
||||
auto dv_idx = thread_position_in_grid.y;
|
||||
|
||||
// state_in, state_out: [B, Hv, Dv, Dk]
|
||||
auto i_state = state_in + (n * Dv + dv_idx) * Dk;
|
||||
auto o_state = state_out + (n * Dv + dv_idx) * Dk;
|
||||
|
||||
float state[n_per_t];
|
||||
for (int i = 0; i < n_per_t; ++i) {
|
||||
auto s_idx = n_per_t * dk_idx + i;
|
||||
state[i] = static_cast<float>(i_state[s_idx]);
|
||||
}
|
||||
|
||||
// g: [B, T, Hv]
|
||||
auto g_ = g + b_idx * T * Hv;
|
||||
auto beta_ = beta + b_idx * T * Hv;
|
||||
|
||||
for (int t = 0; t < T; ++t) {
|
||||
float kv_mem = 0.0f;
|
||||
for (int i = 0; i < n_per_t; ++i) {
|
||||
auto s_idx = n_per_t * dk_idx + i;
|
||||
state[i] = state[i] * g_[hv_idx];
|
||||
kv_mem += state[i] * k_[s_idx];
|
||||
}
|
||||
kv_mem = simd_sum(kv_mem);
|
||||
|
||||
auto delta = (v_[dv_idx] - kv_mem) * beta_[hv_idx];
|
||||
|
||||
float out = 0.0f;
|
||||
for (int i = 0; i < n_per_t; ++i) {
|
||||
auto s_idx = n_per_t * dk_idx + i;
|
||||
state[i] = state[i] + k_[s_idx] * delta;
|
||||
out += state[i] * q_[s_idx];
|
||||
}
|
||||
out = simd_sum(out);
|
||||
if (thread_index_in_simdgroup == 0) {
|
||||
y[dv_idx] = static_cast<InT>(out);
|
||||
}
|
||||
|
||||
q_ += Hk * Dk;
|
||||
k_ += Hk * Dk;
|
||||
v_ += Hv * Dv;
|
||||
y += Hv * Dv;
|
||||
g_ += Hv;
|
||||
beta_ += Hv;
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_per_t; ++i) {
|
||||
auto s_idx = n_per_t * dk_idx + i;
|
||||
o_state[s_idx] = static_cast<StT>(state[i]);
|
||||
}
|
||||
`
|
||||
|
||||
const gatedDeltaCUDAKernelSource = `
|
||||
auto tid_x = threadIdx.x;
|
||||
auto tid_y = threadIdx.y;
|
||||
auto grid_y = blockIdx.y * blockDim.y + tid_y;
|
||||
auto grid_z = blockIdx.z;
|
||||
|
||||
int T_val = static_cast<int>(*T);
|
||||
|
||||
auto n = grid_z;
|
||||
auto b_idx = n / Hv;
|
||||
auto hv_idx = n % Hv;
|
||||
auto hk_idx = hv_idx / (Hv / Hk);
|
||||
constexpr int n_per_t = Dk / 32;
|
||||
|
||||
// q, k: [B, T, Hk, Dk]
|
||||
auto q_ = q + b_idx * T_val * Hk * Dk + hk_idx * Dk;
|
||||
auto k_ = k + b_idx * T_val * Hk * Dk + hk_idx * Dk;
|
||||
|
||||
// v, y: [B, T, Hv, Dv]
|
||||
auto dv_idx = grid_y;
|
||||
auto v_ = v + b_idx * T_val * Hv * Dv + hv_idx * Dv;
|
||||
y += b_idx * T_val * Hv * Dv + hv_idx * Dv;
|
||||
|
||||
auto dk_idx = tid_x;
|
||||
|
||||
// state_in, state_out: [B, Hv, Dv, Dk]
|
||||
auto i_state = state_in + (n * Dv + dv_idx) * Dk;
|
||||
auto o_state = state_out + (n * Dv + dv_idx) * Dk;
|
||||
|
||||
float state[n_per_t];
|
||||
for (int i = 0; i < n_per_t; ++i) {
|
||||
auto s_idx = n_per_t * dk_idx + i;
|
||||
state[i] = static_cast<float>(i_state[s_idx]);
|
||||
}
|
||||
|
||||
// g: [B, T, Hv]
|
||||
auto g_ = g + b_idx * T_val * Hv;
|
||||
auto beta_ = beta + b_idx * T_val * Hv;
|
||||
|
||||
for (int t = 0; t < T_val; ++t) {
|
||||
float kv_mem = 0.0f;
|
||||
for (int i = 0; i < n_per_t; ++i) {
|
||||
auto s_idx = n_per_t * dk_idx + i;
|
||||
state[i] = state[i] * static_cast<float>(g_[hv_idx]);
|
||||
kv_mem += state[i] * static_cast<float>(k_[s_idx]);
|
||||
}
|
||||
// Warp reduction (full warp, 32 threads in x)
|
||||
for (int offset = 16; offset > 0; offset >>= 1)
|
||||
kv_mem += __shfl_down_sync(0xffffffff, kv_mem, offset);
|
||||
kv_mem = __shfl_sync(0xffffffff, kv_mem, 0);
|
||||
|
||||
auto delta = (static_cast<float>(v_[dv_idx]) - kv_mem) * static_cast<float>(beta_[hv_idx]);
|
||||
|
||||
float out = 0.0f;
|
||||
for (int i = 0; i < n_per_t; ++i) {
|
||||
auto s_idx = n_per_t * dk_idx + i;
|
||||
state[i] = state[i] + static_cast<float>(k_[s_idx]) * delta;
|
||||
out += state[i] * static_cast<float>(q_[s_idx]);
|
||||
}
|
||||
// Warp reduction
|
||||
for (int offset = 16; offset > 0; offset >>= 1)
|
||||
out += __shfl_down_sync(0xffffffff, out, offset);
|
||||
if (tid_x == 0) {
|
||||
y[dv_idx] = static_cast<InT>(out);
|
||||
}
|
||||
|
||||
q_ += Hk * Dk;
|
||||
k_ += Hk * Dk;
|
||||
v_ += Hv * Dv;
|
||||
y += Hv * Dv;
|
||||
g_ += Hv;
|
||||
beta_ += Hv;
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_per_t; ++i) {
|
||||
auto s_idx = n_per_t * dk_idx + i;
|
||||
o_state[s_idx] = static_cast<StT>(state[i]);
|
||||
}
|
||||
`
|
||||
|
||||
func cStringVector(values []string) (C.mlx_vector_string, func(), bool) {
|
||||
vec := C.mlx_vector_string_new()
|
||||
ok := true
|
||||
for _, s := range values {
|
||||
cs := C.CString(s)
|
||||
if C.mlx_vector_string_append_value(vec, cs) != 0 {
|
||||
ok = false
|
||||
}
|
||||
C.free(unsafe.Pointer(cs))
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
}
|
||||
cleanup := func() {
|
||||
C.mlx_vector_string_free(vec)
|
||||
}
|
||||
return vec, cleanup, ok
|
||||
}
|
||||
|
||||
func initGatedDeltaMetalKernel() {
|
||||
inputs, freeInputs, ok := cStringVector([]string{"q", "k", "v", "g", "beta", "state_in", "T"})
|
||||
if !ok {
|
||||
gatedDeltaMetalDisabled = true
|
||||
freeInputs()
|
||||
return
|
||||
}
|
||||
defer freeInputs()
|
||||
|
||||
outputs, freeOutputs, ok := cStringVector([]string{"y", "state_out"})
|
||||
if !ok {
|
||||
gatedDeltaMetalDisabled = true
|
||||
freeOutputs()
|
||||
return
|
||||
}
|
||||
defer freeOutputs()
|
||||
|
||||
cName := C.CString("gated_delta_step")
|
||||
defer C.free(unsafe.Pointer(cName))
|
||||
cSource := C.CString(gatedDeltaMetalKernelSource)
|
||||
defer C.free(unsafe.Pointer(cSource))
|
||||
cHeader := C.CString("")
|
||||
defer C.free(unsafe.Pointer(cHeader))
|
||||
|
||||
gatedDeltaMetalKernel = C.mlx_fast_metal_kernel_new(
|
||||
cName,
|
||||
inputs,
|
||||
outputs,
|
||||
cSource,
|
||||
cHeader,
|
||||
C.bool(true),
|
||||
C.bool(false),
|
||||
)
|
||||
}
|
||||
|
||||
// gatedDeltaKernel runs a fused Metal kernel for the qwen3.5 recurrent update.
|
||||
// It returns ok=false on unsupported shapes/devices or kernel setup/apply failure.
|
||||
func gatedDeltaKernel(q, k, v, g, beta, state *Array) (y, nextState *Array, ok bool) {
|
||||
if gatedDeltaMetalDisabled {
|
||||
return nil, nil, false
|
||||
}
|
||||
if q == nil || k == nil || v == nil || g == nil || beta == nil || state == nil {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
qd := q.Dims()
|
||||
kd := k.Dims()
|
||||
vd := v.Dims()
|
||||
gd := g.Dims()
|
||||
bd := beta.Dims()
|
||||
sd := state.Dims()
|
||||
if len(qd) != 4 || len(kd) != 4 || len(vd) != 4 || len(gd) != 3 || len(bd) != 3 || len(sd) != 4 {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
B, T, Hk, Dk := qd[0], qd[1], qd[2], qd[3]
|
||||
if T <= 0 || Hk <= 0 || Dk <= 0 || Dk%32 != 0 {
|
||||
return nil, nil, false
|
||||
}
|
||||
if kd[0] != B || kd[1] != T || kd[2] != Hk || kd[3] != Dk {
|
||||
return nil, nil, false
|
||||
}
|
||||
Hv, Dv := vd[2], vd[3]
|
||||
if vd[0] != B || vd[1] != T || Hv <= 0 || Dv <= 0 || Hv%Hk != 0 {
|
||||
return nil, nil, false
|
||||
}
|
||||
if gd[0] != B || gd[1] != T || gd[2] != Hv {
|
||||
return nil, nil, false
|
||||
}
|
||||
if bd[0] != B || bd[1] != T || bd[2] != Hv {
|
||||
return nil, nil, false
|
||||
}
|
||||
if sd[0] != B || sd[1] != Hv || sd[2] != Dv || sd[3] != Dk {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
inputDType := q.DType()
|
||||
stateDType := state.DType()
|
||||
if k.DType() != inputDType || v.DType() != inputDType || g.DType() != inputDType || beta.DType() != inputDType {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
gatedDeltaMetalKernelOnce.Do(initGatedDeltaMetalKernel)
|
||||
if gatedDeltaMetalDisabled {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
cfg := C.mlx_fast_metal_kernel_config_new()
|
||||
defer C.mlx_fast_metal_kernel_config_free(cfg)
|
||||
|
||||
cInT := C.CString("InT")
|
||||
defer C.free(unsafe.Pointer(cInT))
|
||||
if C.mlx_fast_metal_kernel_config_add_template_arg_dtype(cfg, cInT, C.mlx_dtype(inputDType)) != 0 {
|
||||
gatedDeltaMetalDisabled = true
|
||||
return nil, nil, false
|
||||
}
|
||||
cStT := C.CString("StT")
|
||||
defer C.free(unsafe.Pointer(cStT))
|
||||
if C.mlx_fast_metal_kernel_config_add_template_arg_dtype(cfg, cStT, C.mlx_dtype(stateDType)) != 0 {
|
||||
gatedDeltaMetalDisabled = true
|
||||
return nil, nil, false
|
||||
}
|
||||
for _, tpl := range []struct {
|
||||
name string
|
||||
value int
|
||||
}{
|
||||
{name: "Dk", value: Dk},
|
||||
{name: "Dv", value: Dv},
|
||||
{name: "Hk", value: Hk},
|
||||
{name: "Hv", value: Hv},
|
||||
} {
|
||||
cn := C.CString(tpl.name)
|
||||
rc := C.mlx_fast_metal_kernel_config_add_template_arg_int(cfg, cn, C.int(tpl.value))
|
||||
C.free(unsafe.Pointer(cn))
|
||||
if rc != 0 {
|
||||
gatedDeltaMetalDisabled = true
|
||||
return nil, nil, false
|
||||
}
|
||||
}
|
||||
|
||||
yShape := []C.int{C.int(B), C.int(T), C.int(Hv), C.int(Dv)}
|
||||
stateShape := []C.int{C.int(B), C.int(Hv), C.int(Dv), C.int(Dk)}
|
||||
if C.mlx_fast_metal_kernel_config_add_output_arg(cfg, unsafe.SliceData(yShape), C.size_t(len(yShape)), C.mlx_dtype(inputDType)) != 0 {
|
||||
gatedDeltaMetalDisabled = true
|
||||
return nil, nil, false
|
||||
}
|
||||
if C.mlx_fast_metal_kernel_config_add_output_arg(cfg, unsafe.SliceData(stateShape), C.size_t(len(stateShape)), C.mlx_dtype(stateDType)) != 0 {
|
||||
gatedDeltaMetalDisabled = true
|
||||
return nil, nil, false
|
||||
}
|
||||
if C.mlx_fast_metal_kernel_config_set_grid(cfg, 32, C.int(Dv), C.int(B*Hv)) != 0 {
|
||||
gatedDeltaMetalDisabled = true
|
||||
return nil, nil, false
|
||||
}
|
||||
threadY := Dv
|
||||
if threadY > 4 {
|
||||
threadY = 4
|
||||
}
|
||||
if C.mlx_fast_metal_kernel_config_set_thread_group(cfg, 32, C.int(threadY), 1) != 0 {
|
||||
gatedDeltaMetalDisabled = true
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
tScalar := FromValue(T)
|
||||
inputs := []C.mlx_array{
|
||||
q.ctx,
|
||||
k.ctx,
|
||||
v.ctx,
|
||||
g.ctx,
|
||||
beta.ctx,
|
||||
state.ctx,
|
||||
tScalar.ctx,
|
||||
}
|
||||
inVec := C.mlx_vector_array_new_data(unsafe.SliceData(inputs), C.size_t(len(inputs)))
|
||||
defer C.mlx_vector_array_free(inVec)
|
||||
|
||||
outVec := C.mlx_vector_array_new()
|
||||
defer C.mlx_vector_array_free(outVec)
|
||||
if C.mlx_fast_metal_kernel_apply(&outVec, gatedDeltaMetalKernel, inVec, cfg, DefaultStream().ctx) != 0 {
|
||||
gatedDeltaMetalDisabled = true
|
||||
return nil, nil, false
|
||||
}
|
||||
if int(C.mlx_vector_array_size(outVec)) < 2 {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
y = New("GATED_DELTA_METAL_Y")
|
||||
nextState = New("GATED_DELTA_METAL_STATE")
|
||||
C.mlx_vector_array_get(&y.ctx, outVec, 0)
|
||||
C.mlx_vector_array_get(&nextState.ctx, outVec, 1)
|
||||
return y, nextState, true
|
||||
}
|
||||
|
||||
func repeatHeadsForGatedDelta(x *Array, repeatFactor int) *Array {
|
||||
if repeatFactor <= 1 {
|
||||
return x
|
||||
}
|
||||
shape := x.Dims()
|
||||
x = ExpandDims(x, 3)
|
||||
x = Tile(x, []int32{1, 1, 1, int32(repeatFactor), 1})
|
||||
return Reshape(x, int32(shape[0]), int32(shape[1]), int32(shape[2]*repeatFactor), int32(shape[3]))
|
||||
}
|
||||
|
||||
func gatedDeltaFallback(q, k, v, g, beta, state *Array) (y, nextState *Array) {
|
||||
if q == nil || k == nil || v == nil || g == nil || beta == nil || state == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
qd := q.Dims()
|
||||
kd := k.Dims()
|
||||
vd := v.Dims()
|
||||
gd := g.Dims()
|
||||
bd := beta.Dims()
|
||||
sd := state.Dims()
|
||||
if len(qd) != 4 || len(kd) != 4 || len(vd) != 4 || len(gd) != 3 || len(bd) != 3 || len(sd) != 4 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
B, T, Hk, Dk := int32(qd[0]), int32(qd[1]), int32(qd[2]), int32(qd[3])
|
||||
Hv, Dv := int32(vd[2]), int32(vd[3])
|
||||
if T <= 0 || Hk <= 0 || Dk <= 0 || Hv <= 0 || Dv <= 0 || Hv%Hk != 0 {
|
||||
return nil, nil
|
||||
}
|
||||
if kd[0] != int(B) || kd[1] != int(T) || kd[2] != int(Hk) || kd[3] != int(Dk) {
|
||||
return nil, nil
|
||||
}
|
||||
if vd[0] != int(B) || vd[1] != int(T) {
|
||||
return nil, nil
|
||||
}
|
||||
if gd[0] != int(B) || gd[1] != int(T) || gd[2] != int(Hv) {
|
||||
return nil, nil
|
||||
}
|
||||
if bd[0] != int(B) || bd[1] != int(T) || bd[2] != int(Hv) {
|
||||
return nil, nil
|
||||
}
|
||||
if sd[0] != int(B) || sd[1] != int(Hv) || sd[2] != int(Dv) || sd[3] != int(Dk) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
repeatFactor := int(Hv / Hk)
|
||||
q = repeatHeadsForGatedDelta(q, repeatFactor)
|
||||
k = repeatHeadsForGatedDelta(k, repeatFactor)
|
||||
|
||||
nextState = state
|
||||
if T == 1 {
|
||||
qt := Squeeze(q, 1)
|
||||
kt := Squeeze(k, 1)
|
||||
vt := Squeeze(v, 1)
|
||||
gt := Squeeze(g, 1)
|
||||
bt := Squeeze(beta, 1)
|
||||
|
||||
nextState = Mul(nextState, ExpandDims(ExpandDims(gt, -1), -1))
|
||||
kvMem := Sum(Mul(nextState, ExpandDims(kt, 2)), -1, false)
|
||||
delta := Mul(Sub(vt, kvMem), ExpandDims(bt, -1))
|
||||
nextState = Add(nextState, Mul(ExpandDims(kt, 2), ExpandDims(delta, -1)))
|
||||
yt := Sum(Mul(nextState, ExpandDims(qt, 2)), -1, false)
|
||||
return ExpandDims(yt, 1), nextState
|
||||
}
|
||||
|
||||
outs := make([]*Array, 0, T)
|
||||
for t := int32(0); t < T; t++ {
|
||||
qt := Squeeze(SliceStartStop(q, []int32{0, t, 0, 0}, []int32{B, t + 1, Hv, Dk}), 1)
|
||||
kt := Squeeze(SliceStartStop(k, []int32{0, t, 0, 0}, []int32{B, t + 1, Hv, Dk}), 1)
|
||||
vt := Squeeze(SliceStartStop(v, []int32{0, t, 0, 0}, []int32{B, t + 1, Hv, Dv}), 1)
|
||||
gt := Squeeze(SliceStartStop(g, []int32{0, t, 0}, []int32{B, t + 1, Hv}), 1)
|
||||
bt := Squeeze(SliceStartStop(beta, []int32{0, t, 0}, []int32{B, t + 1, Hv}), 1)
|
||||
|
||||
nextState = Mul(nextState, ExpandDims(ExpandDims(gt, -1), -1))
|
||||
kvMem := Sum(Mul(nextState, ExpandDims(kt, 2)), -1, false)
|
||||
delta := Mul(Sub(vt, kvMem), ExpandDims(bt, -1))
|
||||
nextState = Add(nextState, Mul(ExpandDims(kt, 2), ExpandDims(delta, -1)))
|
||||
yt := Sum(Mul(nextState, ExpandDims(qt, 2)), -1, false)
|
||||
outs = append(outs, ExpandDims(yt, 1))
|
||||
}
|
||||
return Concatenate(outs, 1), nextState
|
||||
}
|
||||
|
||||
func initGatedDeltaCUDAKernel() {
|
||||
var cudaAvail C.bool
|
||||
if C.mlx_cuda_is_available(&cudaAvail) != 0 || !bool(cudaAvail) {
|
||||
gatedDeltaCUDADisabled = true
|
||||
return
|
||||
}
|
||||
|
||||
inputs, freeInputs, ok := cStringVector([]string{"q", "k", "v", "g", "beta", "state_in", "T"})
|
||||
if !ok {
|
||||
gatedDeltaCUDADisabled = true
|
||||
freeInputs()
|
||||
return
|
||||
}
|
||||
defer freeInputs()
|
||||
|
||||
outputs, freeOutputs, ok := cStringVector([]string{"y", "state_out"})
|
||||
if !ok {
|
||||
gatedDeltaCUDADisabled = true
|
||||
freeOutputs()
|
||||
return
|
||||
}
|
||||
defer freeOutputs()
|
||||
|
||||
cName := C.CString("gated_delta_step")
|
||||
defer C.free(unsafe.Pointer(cName))
|
||||
cSource := C.CString(gatedDeltaCUDAKernelSource)
|
||||
defer C.free(unsafe.Pointer(cSource))
|
||||
cHeader := C.CString("")
|
||||
defer C.free(unsafe.Pointer(cHeader))
|
||||
|
||||
gatedDeltaCUDAKernel = C.mlx_fast_cuda_kernel_new(
|
||||
cName,
|
||||
inputs,
|
||||
outputs,
|
||||
cSource,
|
||||
cHeader,
|
||||
C.bool(true),
|
||||
C.int(0),
|
||||
)
|
||||
}
|
||||
|
||||
func gatedDeltaCUDAKernelApply(q, k, v, g, beta, state *Array) (y, nextState *Array, ok bool) {
|
||||
if gatedDeltaCUDADisabled {
|
||||
return nil, nil, false
|
||||
}
|
||||
if q == nil || k == nil || v == nil || g == nil || beta == nil || state == nil {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
qd := q.Dims()
|
||||
kd := k.Dims()
|
||||
vd := v.Dims()
|
||||
gd := g.Dims()
|
||||
bd := beta.Dims()
|
||||
sd := state.Dims()
|
||||
if len(qd) != 4 || len(kd) != 4 || len(vd) != 4 || len(gd) != 3 || len(bd) != 3 || len(sd) != 4 {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
B, T, Hk, Dk := qd[0], qd[1], qd[2], qd[3]
|
||||
if T <= 0 || Hk <= 0 || Dk <= 0 || Dk%32 != 0 {
|
||||
return nil, nil, false
|
||||
}
|
||||
if kd[0] != B || kd[1] != T || kd[2] != Hk || kd[3] != Dk {
|
||||
return nil, nil, false
|
||||
}
|
||||
Hv, Dv := vd[2], vd[3]
|
||||
if vd[0] != B || vd[1] != T || Hv <= 0 || Dv <= 0 || Hv%Hk != 0 {
|
||||
return nil, nil, false
|
||||
}
|
||||
if gd[0] != B || gd[1] != T || gd[2] != Hv {
|
||||
return nil, nil, false
|
||||
}
|
||||
if bd[0] != B || bd[1] != T || bd[2] != Hv {
|
||||
return nil, nil, false
|
||||
}
|
||||
if sd[0] != B || sd[1] != Hv || sd[2] != Dv || sd[3] != Dk {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
inputDType := q.DType()
|
||||
stateDType := state.DType()
|
||||
if k.DType() != inputDType || v.DType() != inputDType || g.DType() != inputDType || beta.DType() != inputDType {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
gatedDeltaCUDAKernelOnce.Do(initGatedDeltaCUDAKernel)
|
||||
if gatedDeltaCUDADisabled {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
cfg := C.mlx_fast_cuda_kernel_config_new()
|
||||
defer C.mlx_fast_cuda_kernel_config_free(cfg)
|
||||
|
||||
cInT := C.CString("InT")
|
||||
defer C.free(unsafe.Pointer(cInT))
|
||||
if C.mlx_fast_cuda_kernel_config_add_template_arg_dtype(cfg, cInT, C.mlx_dtype(inputDType)) != 0 {
|
||||
gatedDeltaCUDADisabled = true
|
||||
return nil, nil, false
|
||||
}
|
||||
cStT := C.CString("StT")
|
||||
defer C.free(unsafe.Pointer(cStT))
|
||||
if C.mlx_fast_cuda_kernel_config_add_template_arg_dtype(cfg, cStT, C.mlx_dtype(stateDType)) != 0 {
|
||||
gatedDeltaCUDADisabled = true
|
||||
return nil, nil, false
|
||||
}
|
||||
for _, tpl := range []struct {
|
||||
name string
|
||||
value int
|
||||
}{
|
||||
{name: "Dk", value: Dk},
|
||||
{name: "Dv", value: Dv},
|
||||
{name: "Hk", value: Hk},
|
||||
{name: "Hv", value: Hv},
|
||||
} {
|
||||
cn := C.CString(tpl.name)
|
||||
rc := C.mlx_fast_cuda_kernel_config_add_template_arg_int(cfg, cn, C.int(tpl.value))
|
||||
C.free(unsafe.Pointer(cn))
|
||||
if rc != 0 {
|
||||
gatedDeltaCUDADisabled = true
|
||||
return nil, nil, false
|
||||
}
|
||||
}
|
||||
|
||||
yShape := []C.int{C.int(B), C.int(T), C.int(Hv), C.int(Dv)}
|
||||
stateShape := []C.int{C.int(B), C.int(Hv), C.int(Dv), C.int(Dk)}
|
||||
if C.mlx_fast_cuda_kernel_config_add_output_arg(cfg, unsafe.SliceData(yShape), C.size_t(len(yShape)), C.mlx_dtype(inputDType)) != 0 {
|
||||
gatedDeltaCUDADisabled = true
|
||||
return nil, nil, false
|
||||
}
|
||||
if C.mlx_fast_cuda_kernel_config_add_output_arg(cfg, unsafe.SliceData(stateShape), C.size_t(len(stateShape)), C.mlx_dtype(stateDType)) != 0 {
|
||||
gatedDeltaCUDADisabled = true
|
||||
return nil, nil, false
|
||||
}
|
||||
if C.mlx_fast_cuda_kernel_config_set_grid(cfg, 32, C.int(Dv), C.int(B*Hv)) != 0 {
|
||||
gatedDeltaCUDADisabled = true
|
||||
return nil, nil, false
|
||||
}
|
||||
threadY := Dv
|
||||
if threadY > 4 {
|
||||
threadY = 4
|
||||
}
|
||||
if C.mlx_fast_cuda_kernel_config_set_thread_group(cfg, 32, C.int(threadY), 1) != 0 {
|
||||
gatedDeltaCUDADisabled = true
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
tScalar := FromValue(T)
|
||||
inputs := []C.mlx_array{
|
||||
q.ctx,
|
||||
k.ctx,
|
||||
v.ctx,
|
||||
g.ctx,
|
||||
beta.ctx,
|
||||
state.ctx,
|
||||
tScalar.ctx,
|
||||
}
|
||||
inVec := C.mlx_vector_array_new_data(unsafe.SliceData(inputs), C.size_t(len(inputs)))
|
||||
defer C.mlx_vector_array_free(inVec)
|
||||
|
||||
outVec := C.mlx_vector_array_new()
|
||||
defer C.mlx_vector_array_free(outVec)
|
||||
if C.mlx_fast_cuda_kernel_apply(&outVec, gatedDeltaCUDAKernel, inVec, cfg, DefaultStream().ctx) != 0 {
|
||||
gatedDeltaCUDADisabled = true
|
||||
return nil, nil, false
|
||||
}
|
||||
if int(C.mlx_vector_array_size(outVec)) < 2 {
|
||||
return nil, nil, false
|
||||
}
|
||||
|
||||
y = New("GATED_DELTA_CUDA_Y")
|
||||
nextState = New("GATED_DELTA_CUDA_STATE")
|
||||
C.mlx_vector_array_get(&y.ctx, outVec, 0)
|
||||
C.mlx_vector_array_get(&nextState.ctx, outVec, 1)
|
||||
return y, nextState, true
|
||||
}
|
||||
|
||||
// FastGatedDelta runs the recurrent update operation.
|
||||
//
|
||||
// When mask is non-nil, it must be a [B, T] bool tensor identifying real
|
||||
// (true) vs. padded (false) positions in q/k/v/g/beta. Padded positions
|
||||
// are substituted with neutral values (q=k=v=beta=0, g=1) so each padded
|
||||
// kernel iteration is a no-op — state passes through unchanged and the
|
||||
// final state equals the state after the last real token of each row.
|
||||
//
|
||||
// It tries the fused CUDA kernel first, then Metal, then falls back to a
|
||||
// backend-agnostic MLX implementation with identical inputs/outputs.
|
||||
func FastGatedDelta(q, k, v, g, beta, state, mask *Array) (y, nextState *Array) {
|
||||
// TODO: handle this more efficiently with a masked kernel (MLX-LM has one).
|
||||
if mask != nil {
|
||||
B := int32(mask.Dim(0))
|
||||
T := int32(mask.Dim(1))
|
||||
m4 := Reshape(mask, B, T, 1, 1)
|
||||
m3 := Reshape(mask, B, T, 1)
|
||||
zeroQ := FromValue(float32(0)).AsType(q.DType())
|
||||
zeroK := FromValue(float32(0)).AsType(k.DType())
|
||||
zeroV := FromValue(float32(0)).AsType(v.DType())
|
||||
zeroBeta := FromValue(float32(0)).AsType(beta.DType())
|
||||
oneG := FromValue(float32(1)).AsType(g.DType())
|
||||
q = Where(m4, q, zeroQ)
|
||||
k = Where(m4, k, zeroK)
|
||||
v = Where(m4, v, zeroV)
|
||||
beta = Where(m3, beta, zeroBeta)
|
||||
g = Where(m3, g, oneG)
|
||||
}
|
||||
|
||||
if y, nextState, ok := gatedDeltaCUDAKernelApply(q, k, v, g, beta, state); ok {
|
||||
return y, nextState
|
||||
}
|
||||
if y, nextState, ok := gatedDeltaKernel(q, k, v, g, beta, state); ok {
|
||||
return y, nextState
|
||||
}
|
||||
y, nextState = gatedDeltaFallback(q, k, v, g, beta, state)
|
||||
if y == nil || nextState == nil {
|
||||
panic("mlx.FastGatedDelta: fallback failed (invalid inputs or unsupported shapes)")
|
||||
}
|
||||
return y, nextState
|
||||
}
|
||||
3012
x/mlxrunner/mlx/generated.c
Normal file
3012
x/mlxrunner/mlx/generated.c
Normal file
File diff suppressed because it is too large
Load Diff
7256
x/mlxrunner/mlx/generated.h
Normal file
7256
x/mlxrunner/mlx/generated.h
Normal file
File diff suppressed because it is too large
Load Diff
17
x/mlxrunner/mlx/generator/generated.c.gotmpl
Normal file
17
x/mlxrunner/mlx/generator/generated.c.gotmpl
Normal file
@@ -0,0 +1,17 @@
|
||||
// This code is auto-generated; DO NOT EDIT.
|
||||
|
||||
#include "generated.h"
|
||||
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
{{ range .Functions }}
|
||||
{{ .Type }} (*{{ .Name }}_){{ .Parameters }} = NULL;
|
||||
{{- end }}
|
||||
|
||||
int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
|
||||
{{- range .Functions }}
|
||||
{{ if .Optional }}OPTIONAL_LOAD{{ else }}CHECK_LOAD{{ end }}(handle, {{ .Name }});
|
||||
{{- end }}
|
||||
return 0;
|
||||
}
|
||||
26
x/mlxrunner/mlx/generator/generated.h.gotmpl
Normal file
26
x/mlxrunner/mlx/generator/generated.h.gotmpl
Normal file
@@ -0,0 +1,26 @@
|
||||
// This code is auto-generated; DO NOT EDIT.
|
||||
|
||||
#ifndef MLX_GENERATED_H
|
||||
#define MLX_GENERATED_H
|
||||
|
||||
#include "dynamic.h"
|
||||
{{ range .Functions }}
|
||||
#define {{ .Name }} {{ .Name }}_mlx_gen_orig_
|
||||
{{- end }}
|
||||
|
||||
#include "mlx/c/mlx.h"
|
||||
{{ range .Functions }}
|
||||
#undef {{ .Name }}
|
||||
{{- end }}
|
||||
{{ range .Functions }}
|
||||
extern {{ .Type }} (*{{ .Name }}_){{ .Parameters }};
|
||||
{{- end }}
|
||||
|
||||
int mlx_dynamic_load_symbols(mlx_dynamic_handle handle);
|
||||
{{ range .Functions }}
|
||||
static inline {{ .Type }} {{ .Name }}{{ .Parameters }} {{ "{" }}
|
||||
return {{ .Name }}_({{ .Args }});
|
||||
{{ "}" }}
|
||||
{{- end }}
|
||||
|
||||
#endif // MLX_GENERATED_H
|
||||
157
x/mlxrunner/mlx/generator/main.go
Normal file
157
x/mlxrunner/mlx/generator/main.go
Normal file
@@ -0,0 +1,157 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"embed"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
tree_sitter "github.com/tree-sitter/go-tree-sitter"
|
||||
tree_sitter_cpp "github.com/tree-sitter/tree-sitter-cpp/bindings/go"
|
||||
)
|
||||
|
||||
//go:embed *.gotmpl
|
||||
var fsys embed.FS
|
||||
|
||||
// optionalSymbols lists symbols that may not be present in all builds
|
||||
// (e.g., float16/bfloat16 are unavailable in CUDA builds of MLX).
|
||||
var optionalSymbols = map[string]bool{
|
||||
"mlx_array_item_float16": true,
|
||||
"mlx_array_item_bfloat16": true,
|
||||
"mlx_array_data_float16": true,
|
||||
"mlx_array_data_bfloat16": true,
|
||||
}
|
||||
|
||||
type Function struct {
|
||||
Type,
|
||||
Name,
|
||||
Parameters,
|
||||
Args string
|
||||
Optional bool
|
||||
}
|
||||
|
||||
func ParseFunction(node *tree_sitter.Node, tc *tree_sitter.TreeCursor, source []byte) Function {
|
||||
var fn Function
|
||||
fn.Name = node.ChildByFieldName("declarator").Utf8Text(source)
|
||||
if params := node.ChildByFieldName("parameters"); params != nil {
|
||||
fn.Parameters = params.Utf8Text(source)
|
||||
fn.Args = ParseParameters(params, tc, source)
|
||||
}
|
||||
|
||||
var types []string
|
||||
for node.Parent() != nil && node.Parent().Kind() != "declaration" {
|
||||
if node.Parent().Kind() == "pointer_declarator" {
|
||||
types = append(types, "*")
|
||||
}
|
||||
node = node.Parent()
|
||||
}
|
||||
|
||||
for sibling := node.PrevSibling(); sibling != nil; sibling = sibling.PrevSibling() {
|
||||
types = append(types, sibling.Utf8Text(source))
|
||||
}
|
||||
|
||||
slices.Reverse(types)
|
||||
fn.Type = strings.Join(types, " ")
|
||||
return fn
|
||||
}
|
||||
|
||||
func ParseParameters(node *tree_sitter.Node, tc *tree_sitter.TreeCursor, source []byte) string {
|
||||
var s []string
|
||||
for _, child := range node.Children(tc) {
|
||||
if child.IsNamed() {
|
||||
child := child.ChildByFieldName("declarator")
|
||||
for child != nil && child.Kind() != "identifier" {
|
||||
if child.Kind() == "parenthesized_declarator" {
|
||||
child = child.Child(1)
|
||||
} else {
|
||||
child = child.ChildByFieldName("declarator")
|
||||
}
|
||||
}
|
||||
|
||||
if child != nil {
|
||||
s = append(s, child.Utf8Text(source))
|
||||
}
|
||||
}
|
||||
}
|
||||
return strings.Join(s, ", ")
|
||||
}
|
||||
|
||||
func main() {
|
||||
var output string
|
||||
flag.StringVar(&output, "output", ".", "Output directory for generated files")
|
||||
flag.Parse()
|
||||
|
||||
parser := tree_sitter.NewParser()
|
||||
defer parser.Close()
|
||||
|
||||
language := tree_sitter.NewLanguage(tree_sitter_cpp.Language())
|
||||
parser.SetLanguage(language)
|
||||
|
||||
query, _ := tree_sitter.NewQuery(language, `(function_declarator declarator: (identifier)) @func`)
|
||||
defer query.Close()
|
||||
|
||||
qc := tree_sitter.NewQueryCursor()
|
||||
defer qc.Close()
|
||||
|
||||
var files []string
|
||||
for _, arg := range flag.Args() {
|
||||
matches, err := filepath.Glob(arg)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error expanding glob %s: %v\n", arg, err)
|
||||
continue
|
||||
}
|
||||
files = append(files, matches...)
|
||||
}
|
||||
|
||||
var funs []Function
|
||||
for _, arg := range files {
|
||||
bts, err := os.ReadFile(arg)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error reading file %s: %v\n", arg, err)
|
||||
continue
|
||||
}
|
||||
|
||||
tree := parser.Parse(bts, nil)
|
||||
defer tree.Close()
|
||||
|
||||
tc := tree.Walk()
|
||||
defer tc.Close()
|
||||
|
||||
matches := qc.Matches(query, tree.RootNode(), bts)
|
||||
for match := matches.Next(); match != nil; match = matches.Next() {
|
||||
for _, capture := range match.Captures {
|
||||
fn := ParseFunction(&capture.Node, tc, bts)
|
||||
fn.Optional = optionalSymbols[fn.Name]
|
||||
funs = append(funs, fn)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tmpl, err := template.New("").ParseFS(fsys, "*.gotmpl")
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error parsing template: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, tmpl := range tmpl.Templates() {
|
||||
name := filepath.Join(output, strings.TrimSuffix(tmpl.Name(), ".gotmpl"))
|
||||
|
||||
fmt.Println("Generating", name)
|
||||
f, err := os.Create(name)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error creating file %s: %v\n", name, err)
|
||||
continue
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if err := tmpl.Execute(f, map[string]any{
|
||||
"Functions": funs,
|
||||
}); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error executing template %s: %v\n", tmpl.Name(), err)
|
||||
}
|
||||
}
|
||||
}
|
||||
12
x/mlxrunner/mlx/include/mlx/c/README.md
Normal file
12
x/mlxrunner/mlx/include/mlx/c/README.md
Normal file
@@ -0,0 +1,12 @@
|
||||
# Vendored MLX-C Headers
|
||||
|
||||
These header files are vendored from [mlx-c](https://github.com/ml-explore/mlx-c).
|
||||
The pinned version is in `MLX_C_VERSION` at the repo root.
|
||||
|
||||
Headers are automatically refreshed when you run a CMake build:
|
||||
|
||||
```shell
|
||||
cmake --preset 'MLX CUDA 13'
|
||||
```
|
||||
|
||||
See the [MLX Engine](../../../../../../../docs/development.md#mlx-engine-optional) section of the development docs for full build instructions.
|
||||
420
x/mlxrunner/mlx/include/mlx/c/array.h
Normal file
420
x/mlxrunner/mlx/include/mlx/c/array.h
Normal file
@@ -0,0 +1,420 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
|
||||
#ifndef MLX_ARRAY_H
|
||||
#define MLX_ARRAY_H
|
||||
|
||||
#include "mlx/c/string.h"
|
||||
|
||||
#include <float.h>
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
// Complex number support
|
||||
#ifdef _MSC_VER
|
||||
#define _CRT_USE_C_COMPLEX_H
|
||||
#include <complex.h>
|
||||
typedef _Fcomplex mlx_complex64_t;
|
||||
#else
|
||||
#include <complex.h>
|
||||
typedef float _Complex mlx_complex64_t;
|
||||
#endif
|
||||
|
||||
#include "half.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup mlx_array Array
|
||||
* MLX N-dimensional array object.
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
/**
|
||||
* A N-dimensional array object.
|
||||
*/
|
||||
typedef struct mlx_array_ {
|
||||
void* ctx;
|
||||
} mlx_array;
|
||||
|
||||
static mlx_array mlx_array_empty;
|
||||
|
||||
/**
|
||||
* Array element type.
|
||||
*/
|
||||
typedef enum mlx_dtype_ {
|
||||
MLX_BOOL,
|
||||
MLX_UINT8,
|
||||
MLX_UINT16,
|
||||
MLX_UINT32,
|
||||
MLX_UINT64,
|
||||
MLX_INT8,
|
||||
MLX_INT16,
|
||||
MLX_INT32,
|
||||
MLX_INT64,
|
||||
MLX_FLOAT16,
|
||||
MLX_FLOAT32,
|
||||
MLX_FLOAT64,
|
||||
MLX_BFLOAT16,
|
||||
MLX_COMPLEX64,
|
||||
} mlx_dtype;
|
||||
|
||||
/**
|
||||
* Size of given mlx_dtype datatype in bytes.
|
||||
*/
|
||||
size_t mlx_dtype_size(mlx_dtype dtype);
|
||||
|
||||
/**
|
||||
* Get array description.
|
||||
*/
|
||||
int mlx_array_tostring(mlx_string* str, const mlx_array arr);
|
||||
|
||||
/**
|
||||
* New empty array.
|
||||
*/
|
||||
mlx_array mlx_array_new(void);
|
||||
|
||||
/**
|
||||
* Free an array.
|
||||
*/
|
||||
int mlx_array_free(mlx_array arr);
|
||||
|
||||
/**
|
||||
* New array from a bool scalar.
|
||||
*/
|
||||
mlx_array mlx_array_new_bool(bool val);
|
||||
/**
|
||||
* New array from a int scalar.
|
||||
*/
|
||||
mlx_array mlx_array_new_int(int val);
|
||||
/**
|
||||
* New array from a float32 scalar.
|
||||
*/
|
||||
mlx_array mlx_array_new_float32(float val);
|
||||
/**
|
||||
* New array from a float scalar.
|
||||
* Same as float32.
|
||||
*/
|
||||
mlx_array mlx_array_new_float(float val);
|
||||
/**
|
||||
* New array from a float64 scalar.
|
||||
*/
|
||||
mlx_array mlx_array_new_float64(double val);
|
||||
/**
|
||||
* New array from a double scalar.
|
||||
* Same as float64.
|
||||
*/
|
||||
mlx_array mlx_array_new_double(double val);
|
||||
/**
|
||||
* New array from a complex scalar.
|
||||
*/
|
||||
mlx_array mlx_array_new_complex(float real_val, float imag_val);
|
||||
/**
|
||||
* New array from existing buffer.
|
||||
* @param data A buffer which will be copied.
|
||||
* @param shape Shape of the array.
|
||||
* @param dim Number of dimensions (size of `shape`).
|
||||
* @param dtype Type of array elements.
|
||||
*/
|
||||
mlx_array mlx_array_new_data(
|
||||
const void* data,
|
||||
const int* shape,
|
||||
int dim,
|
||||
mlx_dtype dtype);
|
||||
/**
|
||||
* New array from existing buffer.
|
||||
* @param data A buffer which will be copied.
|
||||
* @param shape Shape of the array.
|
||||
* @param dim Number of dimensions (size of `shape`).
|
||||
* @param dtype Type of array elements.
|
||||
* @param dtor Callback for when the buffer is no longer needed.
|
||||
*/
|
||||
mlx_array mlx_array_new_data_managed(
|
||||
void* data,
|
||||
const int* shape,
|
||||
int dim,
|
||||
mlx_dtype dtype,
|
||||
void (*dtor)(void*));
|
||||
/**
|
||||
* New array from existing buffer.
|
||||
* @param data A buffer which will be copied.
|
||||
* @param shape Shape of the array.
|
||||
* @param dim Number of dimensions (size of `shape`).
|
||||
* @param dtype Type of array elements.
|
||||
* @param payload Payload pointer passed to the `dtor` callback instead of
|
||||
* `data`.
|
||||
* @param dtor Callback for when the buffer is no longer needed.
|
||||
*/
|
||||
mlx_array mlx_array_new_data_managed_payload(
|
||||
void* data,
|
||||
const int* shape,
|
||||
int dim,
|
||||
mlx_dtype dtype,
|
||||
void* payload,
|
||||
void (*dtor)(void*));
|
||||
/**
|
||||
* Set array to provided src array.
|
||||
*/
|
||||
int mlx_array_set(mlx_array* arr, const mlx_array src);
|
||||
/**
|
||||
* Set array to a bool scalar.
|
||||
*/
|
||||
int mlx_array_set_bool(mlx_array* arr, bool val);
|
||||
/**
|
||||
* Set array to a int scalar.
|
||||
*/
|
||||
int mlx_array_set_int(mlx_array* arr, int val);
|
||||
/**
|
||||
* Set array to a float32 scalar.
|
||||
*/
|
||||
int mlx_array_set_float32(mlx_array* arr, float val);
|
||||
/**
|
||||
* Set array to a float scalar.
|
||||
*/
|
||||
int mlx_array_set_float(mlx_array* arr, float val);
|
||||
/**
|
||||
* Set array to a float64 scalar.
|
||||
*/
|
||||
int mlx_array_set_float64(mlx_array* arr, double val);
|
||||
/**
|
||||
* Set array to a double scalar.
|
||||
*/
|
||||
int mlx_array_set_double(mlx_array* arr, double val);
|
||||
/**
|
||||
* Set array to a complex scalar.
|
||||
*/
|
||||
int mlx_array_set_complex(mlx_array* arr, float real_val, float imag_val);
|
||||
/**
|
||||
* Set array to specified data and shape.
|
||||
* @param arr Destination array.
|
||||
* @param data A buffer which will be copied.
|
||||
* @param shape Shape of the array.
|
||||
* @param dim Number of dimensions (size of `shape`).
|
||||
* @param dtype Type of array elements.
|
||||
*/
|
||||
int mlx_array_set_data(
|
||||
mlx_array* arr,
|
||||
const void* data,
|
||||
const int* shape,
|
||||
int dim,
|
||||
mlx_dtype dtype);
|
||||
|
||||
/**
|
||||
* The size of the array's datatype in bytes.
|
||||
*/
|
||||
size_t mlx_array_itemsize(const mlx_array arr);
|
||||
/**
|
||||
* Number of elements in the array.
|
||||
*/
|
||||
size_t mlx_array_size(const mlx_array arr);
|
||||
/**
|
||||
* The number of bytes in the array.
|
||||
*/
|
||||
size_t mlx_array_nbytes(const mlx_array arr);
|
||||
/**
|
||||
* The array's dimension.
|
||||
*/
|
||||
size_t mlx_array_ndim(const mlx_array arr);
|
||||
/**
|
||||
* The shape of the array.
|
||||
* Returns: a pointer to the sizes of each dimension.
|
||||
*/
|
||||
const int* mlx_array_shape(const mlx_array arr);
|
||||
/**
|
||||
* The strides of the array.
|
||||
* Returns: a pointer to the sizes of each dimension.
|
||||
*/
|
||||
const size_t* mlx_array_strides(const mlx_array arr);
|
||||
/**
|
||||
* The shape of the array in a particular dimension.
|
||||
*/
|
||||
int mlx_array_dim(const mlx_array arr, int dim);
|
||||
/**
|
||||
* The array element type.
|
||||
*/
|
||||
mlx_dtype mlx_array_dtype(const mlx_array arr);
|
||||
|
||||
/**
|
||||
* Evaluate the array.
|
||||
*/
|
||||
int mlx_array_eval(mlx_array arr);
|
||||
|
||||
/**
|
||||
* Access the value of a scalar array.
|
||||
*/
|
||||
int mlx_array_item_bool(bool* res, const mlx_array arr);
|
||||
/**
|
||||
* Access the value of a scalar array.
|
||||
*/
|
||||
int mlx_array_item_uint8(uint8_t* res, const mlx_array arr);
|
||||
/**
|
||||
* Access the value of a scalar array.
|
||||
*/
|
||||
int mlx_array_item_uint16(uint16_t* res, const mlx_array arr);
|
||||
/**
|
||||
* Access the value of a scalar array.
|
||||
*/
|
||||
int mlx_array_item_uint32(uint32_t* res, const mlx_array arr);
|
||||
/**
|
||||
* Access the value of a scalar array.
|
||||
*/
|
||||
int mlx_array_item_uint64(uint64_t* res, const mlx_array arr);
|
||||
/**
|
||||
* Access the value of a scalar array.
|
||||
*/
|
||||
int mlx_array_item_int8(int8_t* res, const mlx_array arr);
|
||||
/**
|
||||
* Access the value of a scalar array.
|
||||
*/
|
||||
int mlx_array_item_int16(int16_t* res, const mlx_array arr);
|
||||
/**
|
||||
* Access the value of a scalar array.
|
||||
*/
|
||||
int mlx_array_item_int32(int32_t* res, const mlx_array arr);
|
||||
/**
|
||||
* Access the value of a scalar array.
|
||||
*/
|
||||
int mlx_array_item_int64(int64_t* res, const mlx_array arr);
|
||||
/**
|
||||
* Access the value of a scalar array.
|
||||
*/
|
||||
int mlx_array_item_float32(float* res, const mlx_array arr);
|
||||
/**
|
||||
* Access the value of a scalar array.
|
||||
*/
|
||||
int mlx_array_item_float64(double* res, const mlx_array arr);
|
||||
/**
|
||||
* Access the value of a scalar array.
|
||||
*/
|
||||
int mlx_array_item_complex64(mlx_complex64_t* res, const mlx_array arr);
|
||||
|
||||
#ifdef HAS_FLOAT16
|
||||
/**
|
||||
* Access the value of a scalar array.
|
||||
*/
|
||||
int mlx_array_item_float16(float16_t* res, const mlx_array arr);
|
||||
#endif
|
||||
|
||||
#ifdef HAS_BFLOAT16
|
||||
/**
|
||||
* Access the value of a scalar array.
|
||||
*/
|
||||
int mlx_array_item_bfloat16(bfloat16_t* res, const mlx_array arr);
|
||||
#endif
|
||||
|
||||
/**
|
||||
* Returns a pointer to the array data, cast to `bool*`.
|
||||
* Array must be evaluated, otherwise returns NULL.
|
||||
*/
|
||||
const bool* mlx_array_data_bool(const mlx_array arr);
|
||||
/**
|
||||
* Returns a pointer to the array data, cast to `uint8_t*`.
|
||||
* Array must be evaluated, otherwise returns NULL.
|
||||
*/
|
||||
const uint8_t* mlx_array_data_uint8(const mlx_array arr);
|
||||
/**
|
||||
* Returns a pointer to the array data, cast to `uint16_t*`.
|
||||
* Array must be evaluated, otherwise returns NULL.
|
||||
*/
|
||||
const uint16_t* mlx_array_data_uint16(const mlx_array arr);
|
||||
/**
|
||||
* Returns a pointer to the array data, cast to `uint32_t*`.
|
||||
* Array must be evaluated, otherwise returns NULL.
|
||||
*/
|
||||
const uint32_t* mlx_array_data_uint32(const mlx_array arr);
|
||||
/**
|
||||
* Returns a pointer to the array data, cast to `uint64_t*`.
|
||||
* Array must be evaluated, otherwise returns NULL.
|
||||
*/
|
||||
const uint64_t* mlx_array_data_uint64(const mlx_array arr);
|
||||
/**
|
||||
* Returns a pointer to the array data, cast to `int8_t*`.
|
||||
* Array must be evaluated, otherwise returns NULL.
|
||||
*/
|
||||
const int8_t* mlx_array_data_int8(const mlx_array arr);
|
||||
/**
|
||||
* Returns a pointer to the array data, cast to `int16_t*`.
|
||||
* Array must be evaluated, otherwise returns NULL.
|
||||
*/
|
||||
const int16_t* mlx_array_data_int16(const mlx_array arr);
|
||||
/**
|
||||
* Returns a pointer to the array data, cast to `int32_t*`.
|
||||
* Array must be evaluated, otherwise returns NULL.
|
||||
*/
|
||||
const int32_t* mlx_array_data_int32(const mlx_array arr);
|
||||
/**
|
||||
* Returns a pointer to the array data, cast to `int64_t*`.
|
||||
* Array must be evaluated, otherwise returns NULL.
|
||||
*/
|
||||
const int64_t* mlx_array_data_int64(const mlx_array arr);
|
||||
/**
|
||||
* Returns a pointer to the array data, cast to `float32*`.
|
||||
* Array must be evaluated, otherwise returns NULL.
|
||||
*/
|
||||
const float* mlx_array_data_float32(const mlx_array arr);
|
||||
/**
|
||||
* Returns a pointer to the array data, cast to `float64*`.
|
||||
* Array must be evaluated, otherwise returns NULL.
|
||||
*/
|
||||
const double* mlx_array_data_float64(const mlx_array arr);
|
||||
/**
|
||||
* Returns a pointer to the array data, cast to `_Complex*`.
|
||||
* Array must be evaluated, otherwise returns NULL.
|
||||
*/
|
||||
const mlx_complex64_t* mlx_array_data_complex64(const mlx_array arr);
|
||||
|
||||
#ifdef HAS_FLOAT16
|
||||
/**
|
||||
* Returns a pointer to the array data, cast to `float16_t*`.
|
||||
* Array must be evaluated, otherwise returns NULL.
|
||||
*/
|
||||
const float16_t* mlx_array_data_float16(const mlx_array arr);
|
||||
#endif
|
||||
|
||||
#ifdef HAS_BFLOAT16
|
||||
/**
|
||||
* Returns a pointer to the array data, cast to `bfloat16_t*`.
|
||||
* Array must be evaluated, otherwise returns NULL.
|
||||
*/
|
||||
const bfloat16_t* mlx_array_data_bfloat16(const mlx_array arr);
|
||||
#endif
|
||||
|
||||
/**
|
||||
* Check if the array is available.
|
||||
* Internal function: use at your own risk.
|
||||
*/
|
||||
int _mlx_array_is_available(bool* res, const mlx_array arr);
|
||||
|
||||
/**
|
||||
* Wait on the array to be available. After this `_mlx_array_is_available`
|
||||
* returns `true`. Internal function: use at your own risk.
|
||||
*/
|
||||
int _mlx_array_wait(const mlx_array arr);
|
||||
|
||||
/**
|
||||
* Whether the array is contiguous in memory.
|
||||
* Internal function: use at your own risk.
|
||||
*/
|
||||
int _mlx_array_is_contiguous(bool* res, const mlx_array arr);
|
||||
|
||||
/**
|
||||
* Whether the array's rows are contiguous in memory.
|
||||
* Internal function: use at your own risk.
|
||||
*/
|
||||
int _mlx_array_is_row_contiguous(bool* res, const mlx_array arr);
|
||||
|
||||
/**
|
||||
* Whether the array's columns are contiguous in memory.
|
||||
* Internal function: use at your own risk.
|
||||
*/
|
||||
int _mlx_array_is_col_contiguous(bool* res, const mlx_array arr);
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
197
x/mlxrunner/mlx/include/mlx/c/closure.h
Normal file
197
x/mlxrunner/mlx/include/mlx/c/closure.h
Normal file
@@ -0,0 +1,197 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
/* */
|
||||
/* This file is auto-generated. Do not edit manually. */
|
||||
/* */
|
||||
|
||||
#ifndef MLX_CLOSURE_H
|
||||
#define MLX_CLOSURE_H
|
||||
|
||||
#include "mlx/c/array.h"
|
||||
#include "mlx/c/map.h"
|
||||
#include "mlx/c/optional.h"
|
||||
#include "mlx/c/stream.h"
|
||||
#include "mlx/c/vector.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup mlx_closure Closures
|
||||
* MLX closure objects.
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
typedef struct mlx_closure_ {
|
||||
void* ctx;
|
||||
} mlx_closure;
|
||||
mlx_closure mlx_closure_new(void);
|
||||
int mlx_closure_free(mlx_closure cls);
|
||||
mlx_closure mlx_closure_new_func(
|
||||
int (*fun)(mlx_vector_array*, const mlx_vector_array));
|
||||
mlx_closure mlx_closure_new_func_payload(
|
||||
int (*fun)(mlx_vector_array*, const mlx_vector_array, void*),
|
||||
void* payload,
|
||||
void (*dtor)(void*));
|
||||
int mlx_closure_set(mlx_closure* cls, const mlx_closure src);
|
||||
int mlx_closure_apply(
|
||||
mlx_vector_array* res,
|
||||
mlx_closure cls,
|
||||
const mlx_vector_array input);
|
||||
|
||||
mlx_closure mlx_closure_new_unary(int (*fun)(mlx_array*, const mlx_array));
|
||||
|
||||
typedef struct mlx_closure_kwargs_ {
|
||||
void* ctx;
|
||||
} mlx_closure_kwargs;
|
||||
mlx_closure_kwargs mlx_closure_kwargs_new(void);
|
||||
int mlx_closure_kwargs_free(mlx_closure_kwargs cls);
|
||||
mlx_closure_kwargs mlx_closure_kwargs_new_func(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
const mlx_vector_array,
|
||||
const mlx_map_string_to_array));
|
||||
mlx_closure_kwargs mlx_closure_kwargs_new_func_payload(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
const mlx_vector_array,
|
||||
const mlx_map_string_to_array,
|
||||
void*),
|
||||
void* payload,
|
||||
void (*dtor)(void*));
|
||||
int mlx_closure_kwargs_set(
|
||||
mlx_closure_kwargs* cls,
|
||||
const mlx_closure_kwargs src);
|
||||
int mlx_closure_kwargs_apply(
|
||||
mlx_vector_array* res,
|
||||
mlx_closure_kwargs cls,
|
||||
const mlx_vector_array input_0,
|
||||
const mlx_map_string_to_array input_1);
|
||||
|
||||
typedef struct mlx_closure_value_and_grad_ {
|
||||
void* ctx;
|
||||
} mlx_closure_value_and_grad;
|
||||
mlx_closure_value_and_grad mlx_closure_value_and_grad_new(void);
|
||||
int mlx_closure_value_and_grad_free(mlx_closure_value_and_grad cls);
|
||||
mlx_closure_value_and_grad mlx_closure_value_and_grad_new_func(
|
||||
int (*fun)(mlx_vector_array*, mlx_vector_array*, const mlx_vector_array));
|
||||
mlx_closure_value_and_grad mlx_closure_value_and_grad_new_func_payload(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
mlx_vector_array*,
|
||||
const mlx_vector_array,
|
||||
void*),
|
||||
void* payload,
|
||||
void (*dtor)(void*));
|
||||
int mlx_closure_value_and_grad_set(
|
||||
mlx_closure_value_and_grad* cls,
|
||||
const mlx_closure_value_and_grad src);
|
||||
int mlx_closure_value_and_grad_apply(
|
||||
mlx_vector_array* res_0,
|
||||
mlx_vector_array* res_1,
|
||||
mlx_closure_value_and_grad cls,
|
||||
const mlx_vector_array input);
|
||||
|
||||
typedef struct mlx_closure_custom_ {
|
||||
void* ctx;
|
||||
} mlx_closure_custom;
|
||||
mlx_closure_custom mlx_closure_custom_new(void);
|
||||
int mlx_closure_custom_free(mlx_closure_custom cls);
|
||||
mlx_closure_custom mlx_closure_custom_new_func(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
const mlx_vector_array,
|
||||
const mlx_vector_array,
|
||||
const mlx_vector_array));
|
||||
mlx_closure_custom mlx_closure_custom_new_func_payload(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
const mlx_vector_array,
|
||||
const mlx_vector_array,
|
||||
const mlx_vector_array,
|
||||
void*),
|
||||
void* payload,
|
||||
void (*dtor)(void*));
|
||||
int mlx_closure_custom_set(
|
||||
mlx_closure_custom* cls,
|
||||
const mlx_closure_custom src);
|
||||
int mlx_closure_custom_apply(
|
||||
mlx_vector_array* res,
|
||||
mlx_closure_custom cls,
|
||||
const mlx_vector_array input_0,
|
||||
const mlx_vector_array input_1,
|
||||
const mlx_vector_array input_2);
|
||||
|
||||
typedef struct mlx_closure_custom_jvp_ {
|
||||
void* ctx;
|
||||
} mlx_closure_custom_jvp;
|
||||
mlx_closure_custom_jvp mlx_closure_custom_jvp_new(void);
|
||||
int mlx_closure_custom_jvp_free(mlx_closure_custom_jvp cls);
|
||||
mlx_closure_custom_jvp mlx_closure_custom_jvp_new_func(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
const mlx_vector_array,
|
||||
const mlx_vector_array,
|
||||
const int*,
|
||||
size_t _num));
|
||||
mlx_closure_custom_jvp mlx_closure_custom_jvp_new_func_payload(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
const mlx_vector_array,
|
||||
const mlx_vector_array,
|
||||
const int*,
|
||||
size_t _num,
|
||||
void*),
|
||||
void* payload,
|
||||
void (*dtor)(void*));
|
||||
int mlx_closure_custom_jvp_set(
|
||||
mlx_closure_custom_jvp* cls,
|
||||
const mlx_closure_custom_jvp src);
|
||||
int mlx_closure_custom_jvp_apply(
|
||||
mlx_vector_array* res,
|
||||
mlx_closure_custom_jvp cls,
|
||||
const mlx_vector_array input_0,
|
||||
const mlx_vector_array input_1,
|
||||
const int* input_2,
|
||||
size_t input_2_num);
|
||||
|
||||
typedef struct mlx_closure_custom_vmap_ {
|
||||
void* ctx;
|
||||
} mlx_closure_custom_vmap;
|
||||
mlx_closure_custom_vmap mlx_closure_custom_vmap_new(void);
|
||||
int mlx_closure_custom_vmap_free(mlx_closure_custom_vmap cls);
|
||||
mlx_closure_custom_vmap mlx_closure_custom_vmap_new_func(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
mlx_vector_int*,
|
||||
const mlx_vector_array,
|
||||
const int*,
|
||||
size_t _num));
|
||||
mlx_closure_custom_vmap mlx_closure_custom_vmap_new_func_payload(
|
||||
int (*fun)(
|
||||
mlx_vector_array*,
|
||||
mlx_vector_int*,
|
||||
const mlx_vector_array,
|
||||
const int*,
|
||||
size_t _num,
|
||||
void*),
|
||||
void* payload,
|
||||
void (*dtor)(void*));
|
||||
int mlx_closure_custom_vmap_set(
|
||||
mlx_closure_custom_vmap* cls,
|
||||
const mlx_closure_custom_vmap src);
|
||||
int mlx_closure_custom_vmap_apply(
|
||||
mlx_vector_array* res_0,
|
||||
mlx_vector_int* res_1,
|
||||
mlx_closure_custom_vmap cls,
|
||||
const mlx_vector_array input_0,
|
||||
const int* input_1,
|
||||
size_t input_1_num);
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
58
x/mlxrunner/mlx/include/mlx/c/compile.h
Normal file
58
x/mlxrunner/mlx/include/mlx/c/compile.h
Normal file
@@ -0,0 +1,58 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
/* */
|
||||
/* This file is auto-generated. Do not edit manually. */
|
||||
/* */
|
||||
|
||||
#ifndef MLX_COMPILE_H
|
||||
#define MLX_COMPILE_H
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "mlx/c/array.h"
|
||||
#include "mlx/c/closure.h"
|
||||
#include "mlx/c/distributed_group.h"
|
||||
#include "mlx/c/io_types.h"
|
||||
#include "mlx/c/map.h"
|
||||
#include "mlx/c/stream.h"
|
||||
#include "mlx/c/string.h"
|
||||
#include "mlx/c/vector.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup compile Compilation operations
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
typedef enum mlx_compile_mode_ {
|
||||
MLX_COMPILE_MODE_DISABLED,
|
||||
MLX_COMPILE_MODE_NO_SIMPLIFY,
|
||||
MLX_COMPILE_MODE_NO_FUSE,
|
||||
MLX_COMPILE_MODE_ENABLED
|
||||
} mlx_compile_mode;
|
||||
|
||||
int mlx_compile(mlx_closure* res, const mlx_closure fun, bool shapeless);
|
||||
int mlx_detail_compile(
|
||||
mlx_closure* res,
|
||||
const mlx_closure fun,
|
||||
uintptr_t fun_id,
|
||||
bool shapeless,
|
||||
const uint64_t* constants,
|
||||
size_t constants_num);
|
||||
int mlx_detail_compile_clear_cache(void);
|
||||
int mlx_detail_compile_erase(uintptr_t fun_id);
|
||||
int mlx_disable_compile(void);
|
||||
int mlx_enable_compile(void);
|
||||
int mlx_set_compile_mode(mlx_compile_mode mode);
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
39
x/mlxrunner/mlx/include/mlx/c/cuda.h
Normal file
39
x/mlxrunner/mlx/include/mlx/c/cuda.h
Normal file
@@ -0,0 +1,39 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
/* */
|
||||
/* This file is auto-generated. Do not edit manually. */
|
||||
/* */
|
||||
|
||||
#ifndef MLX_CUDA_H
|
||||
#define MLX_CUDA_H
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "mlx/c/array.h"
|
||||
#include "mlx/c/closure.h"
|
||||
#include "mlx/c/distributed_group.h"
|
||||
#include "mlx/c/io_types.h"
|
||||
#include "mlx/c/map.h"
|
||||
#include "mlx/c/stream.h"
|
||||
#include "mlx/c/string.h"
|
||||
#include "mlx/c/vector.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup cuda Cuda specific operations
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
int mlx_cuda_is_available(bool* res);
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
154
x/mlxrunner/mlx/include/mlx/c/device.h
Normal file
154
x/mlxrunner/mlx/include/mlx/c/device.h
Normal file
@@ -0,0 +1,154 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
|
||||
#ifndef MLX_DEVICE_H
|
||||
#define MLX_DEVICE_H
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stddef.h>
|
||||
|
||||
#include "mlx/c/string.h"
|
||||
#include "mlx/c/vector.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup mlx_device Device
|
||||
* MLX device object.
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
/**
|
||||
* A MLX device object.
|
||||
*/
|
||||
typedef struct mlx_device_ {
|
||||
void* ctx;
|
||||
} mlx_device;
|
||||
|
||||
/**
|
||||
* Device type.
|
||||
*/
|
||||
typedef enum mlx_device_type_ { MLX_CPU, MLX_GPU } mlx_device_type;
|
||||
|
||||
/**
|
||||
* Returns a new empty device.
|
||||
*/
|
||||
mlx_device mlx_device_new(void);
|
||||
|
||||
/**
|
||||
* Returns a new device of specified `type`, with specified `index`.
|
||||
*/
|
||||
mlx_device mlx_device_new_type(mlx_device_type type, int index);
|
||||
/**
|
||||
* Free a device.
|
||||
*/
|
||||
int mlx_device_free(mlx_device dev);
|
||||
/**
|
||||
* Set device to provided src device.
|
||||
*/
|
||||
int mlx_device_set(mlx_device* dev, const mlx_device src);
|
||||
/**
|
||||
* Get device description.
|
||||
*/
|
||||
int mlx_device_tostring(mlx_string* str, mlx_device dev);
|
||||
/**
|
||||
* Check if devices are the same.
|
||||
*/
|
||||
bool mlx_device_equal(mlx_device lhs, mlx_device rhs);
|
||||
/**
|
||||
* Returns the index of the device.
|
||||
*/
|
||||
int mlx_device_get_index(int* index, mlx_device dev);
|
||||
/**
|
||||
* Returns the type of the device.
|
||||
*/
|
||||
int mlx_device_get_type(mlx_device_type* type, mlx_device dev);
|
||||
/**
|
||||
* Returns the default MLX device.
|
||||
*/
|
||||
int mlx_get_default_device(mlx_device* dev);
|
||||
/**
|
||||
* Set the default MLX device.
|
||||
*/
|
||||
int mlx_set_default_device(mlx_device dev);
|
||||
/**
|
||||
* Check if device is available.
|
||||
*/
|
||||
int mlx_device_is_available(bool* avail, mlx_device dev);
|
||||
/**
|
||||
* Get the number of available devices for a device type.
|
||||
*/
|
||||
int mlx_device_count(int* count, mlx_device_type type);
|
||||
|
||||
/**
|
||||
* A MLX device info object.
|
||||
* Contains key-value pairs with device properties.
|
||||
* Keys vary by backend but common keys include:
|
||||
* - device_name (string): Device name
|
||||
* - architecture (string): Architecture identifier
|
||||
* Additional keys may be present depending on the backend.
|
||||
*/
|
||||
typedef struct mlx_device_info_ {
|
||||
void* ctx;
|
||||
} mlx_device_info;
|
||||
|
||||
/**
|
||||
* Returns a new empty device info object.
|
||||
*/
|
||||
mlx_device_info mlx_device_info_new(void);
|
||||
/**
|
||||
* Get device information for a device.
|
||||
*/
|
||||
int mlx_device_info_get(mlx_device_info* info, mlx_device dev);
|
||||
/**
|
||||
* Free a device info object.
|
||||
*/
|
||||
int mlx_device_info_free(mlx_device_info info);
|
||||
/**
|
||||
* Check if a key exists in the device info.
|
||||
* Returns 0 on success, 1 on error.
|
||||
* Sets *exists to true if the key exists, false otherwise.
|
||||
*/
|
||||
int mlx_device_info_has_key(
|
||||
bool* exists,
|
||||
mlx_device_info info,
|
||||
const char* key);
|
||||
/**
|
||||
* Check if a value is a string type.
|
||||
* Returns 0 on success, 1 on error.
|
||||
* Sets *is_string to true if the value is a string, false if it's a size_t.
|
||||
*/
|
||||
int mlx_device_info_is_string(
|
||||
bool* is_string,
|
||||
mlx_device_info info,
|
||||
const char* key);
|
||||
/**
|
||||
* Get a string value from device info.
|
||||
* Returns 0 on success, 1 on error, 2 if key not found or wrong type.
|
||||
*/
|
||||
int mlx_device_info_get_string(
|
||||
const char** value,
|
||||
mlx_device_info info,
|
||||
const char* key);
|
||||
/**
|
||||
* Get a size_t value from device info.
|
||||
* Returns 0 on success, 1 on error, 2 if key not found or wrong type.
|
||||
*/
|
||||
int mlx_device_info_get_size(
|
||||
size_t* value,
|
||||
mlx_device_info info,
|
||||
const char* key);
|
||||
/**
|
||||
* Get all keys from device info.
|
||||
* Returns 0 on success, 1 on error.
|
||||
*/
|
||||
int mlx_device_info_get_keys(mlx_vector_string* keys, mlx_device_info info);
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
83
x/mlxrunner/mlx/include/mlx/c/distributed.h
Normal file
83
x/mlxrunner/mlx/include/mlx/c/distributed.h
Normal file
@@ -0,0 +1,83 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
/* */
|
||||
/* This file is auto-generated. Do not edit manually. */
|
||||
/* */
|
||||
|
||||
#ifndef MLX_DISTRIBUTED_H
|
||||
#define MLX_DISTRIBUTED_H
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "mlx/c/array.h"
|
||||
#include "mlx/c/closure.h"
|
||||
#include "mlx/c/distributed_group.h"
|
||||
#include "mlx/c/io_types.h"
|
||||
#include "mlx/c/map.h"
|
||||
#include "mlx/c/stream.h"
|
||||
#include "mlx/c/string.h"
|
||||
#include "mlx/c/vector.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup distributed Distributed collectives
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
int mlx_distributed_all_gather(
|
||||
mlx_array* res,
|
||||
const mlx_array x,
|
||||
const mlx_distributed_group group /* may be null */,
|
||||
const mlx_stream S);
|
||||
int mlx_distributed_all_max(
|
||||
mlx_array* res,
|
||||
const mlx_array x,
|
||||
const mlx_distributed_group group /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_distributed_all_min(
|
||||
mlx_array* res,
|
||||
const mlx_array x,
|
||||
const mlx_distributed_group group /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_distributed_all_sum(
|
||||
mlx_array* res,
|
||||
const mlx_array x,
|
||||
const mlx_distributed_group group /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_distributed_recv(
|
||||
mlx_array* res,
|
||||
const int* shape,
|
||||
size_t shape_num,
|
||||
mlx_dtype dtype,
|
||||
int src,
|
||||
const mlx_distributed_group group /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_distributed_recv_like(
|
||||
mlx_array* res,
|
||||
const mlx_array x,
|
||||
int src,
|
||||
const mlx_distributed_group group /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_distributed_send(
|
||||
mlx_array* res,
|
||||
const mlx_array x,
|
||||
int dst,
|
||||
const mlx_distributed_group group /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_distributed_sum_scatter(
|
||||
mlx_array* res,
|
||||
const mlx_array x,
|
||||
const mlx_distributed_group group /* may be null */,
|
||||
const mlx_stream s);
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
74
x/mlxrunner/mlx/include/mlx/c/distributed_group.h
Normal file
74
x/mlxrunner/mlx/include/mlx/c/distributed_group.h
Normal file
@@ -0,0 +1,74 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
|
||||
#ifndef MLX_DISTRIBUTED_GROUP_H
|
||||
#define MLX_DISTRIBUTED_GROUP_H
|
||||
|
||||
#include <stdbool.h>
|
||||
|
||||
#include "mlx/c/stream.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup mlx_distributed_group MLX distributed
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
/**
|
||||
* A MLX distributed group object.
|
||||
*/
|
||||
typedef struct mlx_distributed_group_ {
|
||||
void* ctx;
|
||||
} mlx_distributed_group;
|
||||
|
||||
/**
|
||||
* Create an empty group.
|
||||
*/
|
||||
mlx_distributed_group mlx_distributed_group_new(void);
|
||||
|
||||
/**
|
||||
* Free the group.
|
||||
*/
|
||||
int mlx_distributed_group_free(mlx_distributed_group group);
|
||||
|
||||
/**
|
||||
* Initialize distributed.
|
||||
*/
|
||||
int mlx_distributed_init(
|
||||
mlx_distributed_group* res,
|
||||
bool strict,
|
||||
const char* bk /* may be null */);
|
||||
|
||||
/**
|
||||
* Get the rank.
|
||||
*/
|
||||
int mlx_distributed_group_rank(mlx_distributed_group group);
|
||||
|
||||
/**
|
||||
* Get the group size.
|
||||
*/
|
||||
int mlx_distributed_group_size(mlx_distributed_group group);
|
||||
|
||||
/**
|
||||
* Split the group.
|
||||
*/
|
||||
int mlx_distributed_group_split(
|
||||
mlx_distributed_group* res,
|
||||
mlx_distributed_group group,
|
||||
int color,
|
||||
int key);
|
||||
|
||||
/**
|
||||
* Check if distributed is available.
|
||||
*/
|
||||
bool mlx_distributed_is_available(const char* bk /* may be null */);
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
41
x/mlxrunner/mlx/include/mlx/c/error.h
Normal file
41
x/mlxrunner/mlx/include/mlx/c/error.h
Normal file
@@ -0,0 +1,41 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
|
||||
#ifndef MLX_ERROR_H
|
||||
#define MLX_ERROR_H
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup mlx_error Error management
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
typedef void (*mlx_error_handler_func)(const char* msg, void* data);
|
||||
|
||||
/**
|
||||
* Set the error handler.
|
||||
*/
|
||||
void mlx_set_error_handler(
|
||||
mlx_error_handler_func handler,
|
||||
void* data,
|
||||
void (*dtor)(void*));
|
||||
|
||||
/**
|
||||
* Throw an error.
|
||||
*/
|
||||
void _mlx_error(const char* file, const int line, const char* fmt, ...);
|
||||
|
||||
/**
|
||||
* Throw an error. Macro which passes file name and line number to _mlx_error().
|
||||
*/
|
||||
#define mlx_error(...) _mlx_error(__FILE__, __LINE__, __VA_ARGS__)
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
75
x/mlxrunner/mlx/include/mlx/c/export.h
Normal file
75
x/mlxrunner/mlx/include/mlx/c/export.h
Normal file
@@ -0,0 +1,75 @@
|
||||
/* Copyright © 2023-2025 Apple Inc. */
|
||||
|
||||
#ifndef MLX_EXPORT_H
|
||||
#define MLX_EXPORT_H
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "mlx/c/array.h"
|
||||
#include "mlx/c/closure.h"
|
||||
#include "mlx/c/distributed_group.h"
|
||||
#include "mlx/c/io_types.h"
|
||||
#include "mlx/c/map.h"
|
||||
#include "mlx/c/stream.h"
|
||||
#include "mlx/c/string.h"
|
||||
#include "mlx/c/vector.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup export Function serialization
|
||||
*/
|
||||
/**@{*/
|
||||
int mlx_export_function(
|
||||
const char* file,
|
||||
const mlx_closure fun,
|
||||
const mlx_vector_array args,
|
||||
bool shapeless);
|
||||
int mlx_export_function_kwargs(
|
||||
const char* file,
|
||||
const mlx_closure_kwargs fun,
|
||||
const mlx_vector_array args,
|
||||
const mlx_map_string_to_array kwargs,
|
||||
bool shapeless);
|
||||
|
||||
typedef struct mlx_function_exporter_ {
|
||||
void* ctx;
|
||||
} mlx_function_exporter;
|
||||
mlx_function_exporter mlx_function_exporter_new(
|
||||
const char* file,
|
||||
const mlx_closure fun,
|
||||
bool shapeless);
|
||||
int mlx_function_exporter_free(mlx_function_exporter xfunc);
|
||||
int mlx_function_exporter_apply(
|
||||
const mlx_function_exporter xfunc,
|
||||
const mlx_vector_array args);
|
||||
int mlx_function_exporter_apply_kwargs(
|
||||
const mlx_function_exporter xfunc,
|
||||
const mlx_vector_array args,
|
||||
const mlx_map_string_to_array kwargs);
|
||||
|
||||
typedef struct mlx_imported_function_ {
|
||||
void* ctx;
|
||||
} mlx_imported_function;
|
||||
mlx_imported_function mlx_imported_function_new(const char* file);
|
||||
int mlx_imported_function_free(mlx_imported_function xfunc);
|
||||
int mlx_imported_function_apply(
|
||||
mlx_vector_array* res,
|
||||
const mlx_imported_function xfunc,
|
||||
const mlx_vector_array args);
|
||||
int mlx_imported_function_apply_kwargs(
|
||||
mlx_vector_array* res,
|
||||
const mlx_imported_function xfunc,
|
||||
const mlx_vector_array args,
|
||||
const mlx_map_string_to_array kwargs);
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
206
x/mlxrunner/mlx/include/mlx/c/fast.h
Normal file
206
x/mlxrunner/mlx/include/mlx/c/fast.h
Normal file
@@ -0,0 +1,206 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
/* */
|
||||
/* This file is auto-generated. Do not edit manually. */
|
||||
/* */
|
||||
|
||||
#ifndef MLX_FAST_H
|
||||
#define MLX_FAST_H
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "mlx/c/array.h"
|
||||
#include "mlx/c/closure.h"
|
||||
#include "mlx/c/distributed_group.h"
|
||||
#include "mlx/c/io_types.h"
|
||||
#include "mlx/c/map.h"
|
||||
#include "mlx/c/stream.h"
|
||||
#include "mlx/c/string.h"
|
||||
#include "mlx/c/vector.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup fast Fast custom operations
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
typedef struct mlx_fast_cuda_kernel_config_ {
|
||||
void* ctx;
|
||||
} mlx_fast_cuda_kernel_config;
|
||||
mlx_fast_cuda_kernel_config mlx_fast_cuda_kernel_config_new(void);
|
||||
void mlx_fast_cuda_kernel_config_free(mlx_fast_cuda_kernel_config cls);
|
||||
|
||||
int mlx_fast_cuda_kernel_config_add_output_arg(
|
||||
mlx_fast_cuda_kernel_config cls,
|
||||
const int* shape,
|
||||
size_t size,
|
||||
mlx_dtype dtype);
|
||||
int mlx_fast_cuda_kernel_config_set_grid(
|
||||
mlx_fast_cuda_kernel_config cls,
|
||||
int grid1,
|
||||
int grid2,
|
||||
int grid3);
|
||||
int mlx_fast_cuda_kernel_config_set_thread_group(
|
||||
mlx_fast_cuda_kernel_config cls,
|
||||
int thread1,
|
||||
int thread2,
|
||||
int thread3);
|
||||
int mlx_fast_cuda_kernel_config_set_init_value(
|
||||
mlx_fast_cuda_kernel_config cls,
|
||||
float value);
|
||||
int mlx_fast_cuda_kernel_config_set_verbose(
|
||||
mlx_fast_cuda_kernel_config cls,
|
||||
bool verbose);
|
||||
int mlx_fast_cuda_kernel_config_add_template_arg_dtype(
|
||||
mlx_fast_cuda_kernel_config cls,
|
||||
const char* name,
|
||||
mlx_dtype dtype);
|
||||
int mlx_fast_cuda_kernel_config_add_template_arg_int(
|
||||
mlx_fast_cuda_kernel_config cls,
|
||||
const char* name,
|
||||
int value);
|
||||
int mlx_fast_cuda_kernel_config_add_template_arg_bool(
|
||||
mlx_fast_cuda_kernel_config cls,
|
||||
const char* name,
|
||||
bool value);
|
||||
|
||||
typedef struct mlx_fast_cuda_kernel_ {
|
||||
void* ctx;
|
||||
} mlx_fast_cuda_kernel;
|
||||
|
||||
mlx_fast_cuda_kernel mlx_fast_cuda_kernel_new(
|
||||
const char* name,
|
||||
const mlx_vector_string input_names,
|
||||
const mlx_vector_string output_names,
|
||||
const char* source,
|
||||
const char* header,
|
||||
bool ensure_row_contiguous,
|
||||
int shared_memory);
|
||||
|
||||
void mlx_fast_cuda_kernel_free(mlx_fast_cuda_kernel cls);
|
||||
|
||||
int mlx_fast_cuda_kernel_apply(
|
||||
mlx_vector_array* outputs,
|
||||
mlx_fast_cuda_kernel cls,
|
||||
const mlx_vector_array inputs,
|
||||
const mlx_fast_cuda_kernel_config config,
|
||||
const mlx_stream stream);
|
||||
|
||||
int mlx_fast_layer_norm(
|
||||
mlx_array* res,
|
||||
const mlx_array x,
|
||||
const mlx_array weight /* may be null */,
|
||||
const mlx_array bias /* may be null */,
|
||||
float eps,
|
||||
const mlx_stream s);
|
||||
|
||||
typedef struct mlx_fast_metal_kernel_config_ {
|
||||
void* ctx;
|
||||
} mlx_fast_metal_kernel_config;
|
||||
mlx_fast_metal_kernel_config mlx_fast_metal_kernel_config_new(void);
|
||||
void mlx_fast_metal_kernel_config_free(mlx_fast_metal_kernel_config cls);
|
||||
|
||||
int mlx_fast_metal_kernel_config_add_output_arg(
|
||||
mlx_fast_metal_kernel_config cls,
|
||||
const int* shape,
|
||||
size_t size,
|
||||
mlx_dtype dtype);
|
||||
int mlx_fast_metal_kernel_config_set_grid(
|
||||
mlx_fast_metal_kernel_config cls,
|
||||
int grid1,
|
||||
int grid2,
|
||||
int grid3);
|
||||
int mlx_fast_metal_kernel_config_set_thread_group(
|
||||
mlx_fast_metal_kernel_config cls,
|
||||
int thread1,
|
||||
int thread2,
|
||||
int thread3);
|
||||
int mlx_fast_metal_kernel_config_set_init_value(
|
||||
mlx_fast_metal_kernel_config cls,
|
||||
float value);
|
||||
int mlx_fast_metal_kernel_config_set_verbose(
|
||||
mlx_fast_metal_kernel_config cls,
|
||||
bool verbose);
|
||||
int mlx_fast_metal_kernel_config_add_template_arg_dtype(
|
||||
mlx_fast_metal_kernel_config cls,
|
||||
const char* name,
|
||||
mlx_dtype dtype);
|
||||
int mlx_fast_metal_kernel_config_add_template_arg_int(
|
||||
mlx_fast_metal_kernel_config cls,
|
||||
const char* name,
|
||||
int value);
|
||||
int mlx_fast_metal_kernel_config_add_template_arg_bool(
|
||||
mlx_fast_metal_kernel_config cls,
|
||||
const char* name,
|
||||
bool value);
|
||||
|
||||
typedef struct mlx_fast_metal_kernel_ {
|
||||
void* ctx;
|
||||
} mlx_fast_metal_kernel;
|
||||
|
||||
mlx_fast_metal_kernel mlx_fast_metal_kernel_new(
|
||||
const char* name,
|
||||
const mlx_vector_string input_names,
|
||||
const mlx_vector_string output_names,
|
||||
const char* source,
|
||||
const char* header,
|
||||
bool ensure_row_contiguous,
|
||||
bool atomic_outputs);
|
||||
|
||||
void mlx_fast_metal_kernel_free(mlx_fast_metal_kernel cls);
|
||||
|
||||
int mlx_fast_metal_kernel_apply(
|
||||
mlx_vector_array* outputs,
|
||||
mlx_fast_metal_kernel cls,
|
||||
const mlx_vector_array inputs,
|
||||
const mlx_fast_metal_kernel_config config,
|
||||
const mlx_stream stream);
|
||||
|
||||
int mlx_fast_rms_norm(
|
||||
mlx_array* res,
|
||||
const mlx_array x,
|
||||
const mlx_array weight /* may be null */,
|
||||
float eps,
|
||||
const mlx_stream s);
|
||||
int mlx_fast_rope(
|
||||
mlx_array* res,
|
||||
const mlx_array x,
|
||||
int dims,
|
||||
bool traditional,
|
||||
mlx_optional_float base,
|
||||
float scale,
|
||||
int offset,
|
||||
const mlx_array freqs /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_fast_rope_dynamic(
|
||||
mlx_array* res,
|
||||
const mlx_array x,
|
||||
int dims,
|
||||
bool traditional,
|
||||
mlx_optional_float base,
|
||||
float scale,
|
||||
const mlx_array offset,
|
||||
const mlx_array freqs /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_fast_scaled_dot_product_attention(
|
||||
mlx_array* res,
|
||||
const mlx_array queries,
|
||||
const mlx_array keys,
|
||||
const mlx_array values,
|
||||
float scale,
|
||||
const char* mask_mode,
|
||||
const mlx_array mask_arr /* may be null */,
|
||||
const mlx_array sinks /* may be null */,
|
||||
const mlx_stream s);
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
158
x/mlxrunner/mlx/include/mlx/c/fft.h
Normal file
158
x/mlxrunner/mlx/include/mlx/c/fft.h
Normal file
@@ -0,0 +1,158 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
/* */
|
||||
/* This file is auto-generated. Do not edit manually. */
|
||||
/* */
|
||||
|
||||
#ifndef MLX_FFT_H
|
||||
#define MLX_FFT_H
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "mlx/c/array.h"
|
||||
#include "mlx/c/closure.h"
|
||||
#include "mlx/c/distributed_group.h"
|
||||
#include "mlx/c/io_types.h"
|
||||
#include "mlx/c/map.h"
|
||||
#include "mlx/c/stream.h"
|
||||
#include "mlx/c/string.h"
|
||||
#include "mlx/c/vector.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup fft FFT operations
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
typedef enum mlx_fft_norm_ {
|
||||
MLX_FFT_NORM_BACKWARD,
|
||||
MLX_FFT_NORM_ORTHO,
|
||||
MLX_FFT_NORM_FORWARD
|
||||
} mlx_fft_norm;
|
||||
|
||||
int mlx_fft_fft(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
int n,
|
||||
int axis,
|
||||
mlx_fft_norm norm,
|
||||
const mlx_stream s);
|
||||
int mlx_fft_fft2(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const int* n,
|
||||
size_t n_num,
|
||||
const int* axes,
|
||||
size_t axes_num,
|
||||
mlx_fft_norm norm,
|
||||
const mlx_stream s);
|
||||
int mlx_fft_fftfreq(mlx_array* res, int n, double d, const mlx_stream s);
|
||||
int mlx_fft_fftn(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const int* n,
|
||||
size_t n_num,
|
||||
const int* axes,
|
||||
size_t axes_num,
|
||||
mlx_fft_norm norm,
|
||||
const mlx_stream s);
|
||||
int mlx_fft_fftshift(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const int* axes,
|
||||
size_t axes_num,
|
||||
const mlx_stream s);
|
||||
int mlx_fft_ifft(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
int n,
|
||||
int axis,
|
||||
mlx_fft_norm norm,
|
||||
const mlx_stream s);
|
||||
int mlx_fft_ifft2(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const int* n,
|
||||
size_t n_num,
|
||||
const int* axes,
|
||||
size_t axes_num,
|
||||
mlx_fft_norm norm,
|
||||
const mlx_stream s);
|
||||
int mlx_fft_ifftn(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const int* n,
|
||||
size_t n_num,
|
||||
const int* axes,
|
||||
size_t axes_num,
|
||||
mlx_fft_norm norm,
|
||||
const mlx_stream s);
|
||||
int mlx_fft_ifftshift(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const int* axes,
|
||||
size_t axes_num,
|
||||
const mlx_stream s);
|
||||
int mlx_fft_irfft(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
int n,
|
||||
int axis,
|
||||
mlx_fft_norm norm,
|
||||
const mlx_stream s);
|
||||
int mlx_fft_irfft2(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const int* n,
|
||||
size_t n_num,
|
||||
const int* axes,
|
||||
size_t axes_num,
|
||||
mlx_fft_norm norm,
|
||||
const mlx_stream s);
|
||||
int mlx_fft_irfftn(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const int* n,
|
||||
size_t n_num,
|
||||
const int* axes,
|
||||
size_t axes_num,
|
||||
mlx_fft_norm norm,
|
||||
const mlx_stream s);
|
||||
int mlx_fft_rfft(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
int n,
|
||||
int axis,
|
||||
mlx_fft_norm norm,
|
||||
const mlx_stream s);
|
||||
int mlx_fft_rfft2(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const int* n,
|
||||
size_t n_num,
|
||||
const int* axes,
|
||||
size_t axes_num,
|
||||
mlx_fft_norm norm,
|
||||
const mlx_stream s);
|
||||
int mlx_fft_rfftfreq(mlx_array* res, int n, double d, const mlx_stream s);
|
||||
int mlx_fft_rfftn(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const int* n,
|
||||
size_t n_num,
|
||||
const int* axes,
|
||||
size_t axes_num,
|
||||
mlx_fft_norm norm,
|
||||
const mlx_stream s);
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
61
x/mlxrunner/mlx/include/mlx/c/graph_utils.h
Normal file
61
x/mlxrunner/mlx/include/mlx/c/graph_utils.h
Normal file
@@ -0,0 +1,61 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
/* */
|
||||
/* This file is auto-generated. Do not edit manually. */
|
||||
/* */
|
||||
|
||||
#ifndef MLX_GRAPH_UTILS_H
|
||||
#define MLX_GRAPH_UTILS_H
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "mlx/c/array.h"
|
||||
#include "mlx/c/closure.h"
|
||||
#include "mlx/c/distributed_group.h"
|
||||
#include "mlx/c/io_types.h"
|
||||
#include "mlx/c/map.h"
|
||||
#include "mlx/c/stream.h"
|
||||
#include "mlx/c/string.h"
|
||||
#include "mlx/c/vector.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup graph_utils Graph Utils
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
typedef struct mlx_node_namer_ {
|
||||
void* ctx;
|
||||
} mlx_node_namer;
|
||||
|
||||
mlx_node_namer mlx_node_namer_new();
|
||||
int mlx_node_namer_free(mlx_node_namer namer);
|
||||
int mlx_node_namer_set_name(
|
||||
mlx_node_namer namer,
|
||||
const mlx_array arr,
|
||||
const char* name);
|
||||
int mlx_node_namer_get_name(
|
||||
const char** name,
|
||||
mlx_node_namer namer,
|
||||
const mlx_array arr);
|
||||
|
||||
int mlx_export_to_dot(
|
||||
FILE* os,
|
||||
const mlx_node_namer namer,
|
||||
const mlx_vector_array outputs);
|
||||
int mlx_print_graph(
|
||||
FILE* os,
|
||||
const mlx_node_namer namer,
|
||||
const mlx_vector_array outputs);
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
26
x/mlxrunner/mlx/include/mlx/c/half.h
Normal file
26
x/mlxrunner/mlx/include/mlx/c/half.h
Normal file
@@ -0,0 +1,26 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
|
||||
#ifndef MLX_HALF_H
|
||||
#define MLX_HALF_H
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#if defined(__ARM_FEATURE_FP16_SCALAR_ARITHMETIC) || defined(__aarch64__)
|
||||
#define HAS_FLOAT16
|
||||
#include <arm_fp16.h>
|
||||
typedef __fp16 float16_t;
|
||||
#endif
|
||||
|
||||
#if defined(__ARM_FEATURE_BF16) || defined(__aarch64__)
|
||||
#define HAS_BFLOAT16
|
||||
#include <arm_bf16.h>
|
||||
typedef __bf16 bfloat16_t;
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
68
x/mlxrunner/mlx/include/mlx/c/io.h
Normal file
68
x/mlxrunner/mlx/include/mlx/c/io.h
Normal file
@@ -0,0 +1,68 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
/* */
|
||||
/* This file is auto-generated. Do not edit manually. */
|
||||
/* */
|
||||
|
||||
#ifndef MLX_IO_H
|
||||
#define MLX_IO_H
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "mlx/c/array.h"
|
||||
#include "mlx/c/closure.h"
|
||||
#include "mlx/c/distributed_group.h"
|
||||
#include "mlx/c/io_types.h"
|
||||
#include "mlx/c/map.h"
|
||||
#include "mlx/c/stream.h"
|
||||
#include "mlx/c/string.h"
|
||||
#include "mlx/c/vector.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup io IO operations
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
int mlx_load_reader(
|
||||
mlx_array* res,
|
||||
mlx_io_reader in_stream,
|
||||
const mlx_stream s);
|
||||
int mlx_load(mlx_array* res, const char* file, const mlx_stream s);
|
||||
|
||||
int mlx_load_gguf(mlx_io_gguf* gguf, const char* file, const mlx_stream s);
|
||||
|
||||
int mlx_load_safetensors_reader(
|
||||
mlx_map_string_to_array* res_0,
|
||||
mlx_map_string_to_string* res_1,
|
||||
mlx_io_reader in_stream,
|
||||
const mlx_stream s);
|
||||
int mlx_load_safetensors(
|
||||
mlx_map_string_to_array* res_0,
|
||||
mlx_map_string_to_string* res_1,
|
||||
const char* file,
|
||||
const mlx_stream s);
|
||||
int mlx_save_writer(mlx_io_writer out_stream, const mlx_array a);
|
||||
int mlx_save(const char* file, const mlx_array a);
|
||||
int mlx_save_gguf(const char* file, mlx_io_gguf gguf);
|
||||
|
||||
int mlx_save_safetensors_writer(
|
||||
mlx_io_writer in_stream,
|
||||
const mlx_map_string_to_array param,
|
||||
const mlx_map_string_to_string metadata);
|
||||
int mlx_save_safetensors(
|
||||
const char* file,
|
||||
const mlx_map_string_to_array param,
|
||||
const mlx_map_string_to_string metadata);
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
150
x/mlxrunner/mlx/include/mlx/c/io_types.h
Normal file
150
x/mlxrunner/mlx/include/mlx/c/io_types.h
Normal file
@@ -0,0 +1,150 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
|
||||
#ifndef MLX_IO_TYPES_H
|
||||
#define MLX_IO_TYPES_H
|
||||
|
||||
#include <stdbool.h>
|
||||
|
||||
#include "mlx/c/string.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup mlx_io_types IO Types
|
||||
* MLX IO type objects.
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
/**
|
||||
* A MLX IO reader object.
|
||||
*/
|
||||
typedef struct mlx_io_reader_ {
|
||||
void* ctx;
|
||||
} mlx_io_reader;
|
||||
/**
|
||||
* A MLX IO writer object.
|
||||
*/
|
||||
typedef struct mlx_io_writer_ {
|
||||
void* ctx;
|
||||
} mlx_io_writer;
|
||||
|
||||
/**
|
||||
* Virtual table for custom IO reader and writer objects.
|
||||
*/
|
||||
typedef struct mlx_io_vtable_ {
|
||||
bool (*is_open)(void*);
|
||||
bool (*good)(void*);
|
||||
size_t (*tell)(void*);
|
||||
void (*seek)(void*, int64_t off, int whence);
|
||||
void (*read)(void*, char* data, size_t n);
|
||||
void (*read_at_offset)(void*, char* data, size_t n, size_t off);
|
||||
void (*write)(void*, const char* data, size_t n);
|
||||
const char* (*label)(void*);
|
||||
void (*free)(void*);
|
||||
} mlx_io_vtable;
|
||||
|
||||
/**
|
||||
* Returns a new custom IO reader.
|
||||
* `vtable` operates on user descriptor `desc`.
|
||||
*/
|
||||
mlx_io_reader mlx_io_reader_new(void* desc, mlx_io_vtable vtable);
|
||||
|
||||
/**
|
||||
* Get IO reader user descriptor.
|
||||
*/
|
||||
int mlx_io_reader_descriptor(void** desc_, mlx_io_reader io);
|
||||
|
||||
/**
|
||||
* Get IO reader description.
|
||||
*/
|
||||
int mlx_io_reader_tostring(mlx_string* str_, mlx_io_reader io);
|
||||
|
||||
/**
|
||||
* Free IO reader.
|
||||
*
|
||||
* Note that MLX arrays are lazily evaluated, so the underlying object may
|
||||
* be not freed right away. The ``free()`` callback from ``mlx_io_vtable``
|
||||
* will be called when the underlying object is actually freed.
|
||||
*/
|
||||
int mlx_io_reader_free(mlx_io_reader io);
|
||||
|
||||
/**
|
||||
* Returns a new custom IO writer.
|
||||
* `vtable` operates on user descriptor `desc`.
|
||||
*/
|
||||
mlx_io_writer mlx_io_writer_new(void* desc, mlx_io_vtable vtable);
|
||||
|
||||
/**
|
||||
* Get IO writer user descriptor.
|
||||
*/
|
||||
int mlx_io_writer_descriptor(void** desc_, mlx_io_writer io);
|
||||
|
||||
/**
|
||||
* Get IO writer description.
|
||||
*/
|
||||
int mlx_io_writer_tostring(mlx_string* str_, mlx_io_writer io);
|
||||
|
||||
/**
|
||||
* Free IO writer.
|
||||
*
|
||||
* Note that MLX arrays are lazily evaluated, so the underlying object may
|
||||
* be not freed right away. The ``free()`` callback from ``mlx_io_vtable``
|
||||
* will be called when the underlying object is actually freed.
|
||||
*/
|
||||
int mlx_io_writer_free(mlx_io_writer io);
|
||||
|
||||
/**
|
||||
* A MLX GGUF object.
|
||||
*/
|
||||
typedef struct mlx_io_gguf_ {
|
||||
void* ctx;
|
||||
} mlx_io_gguf;
|
||||
|
||||
mlx_io_gguf mlx_io_gguf_new(void);
|
||||
int mlx_io_gguf_free(mlx_io_gguf io);
|
||||
int mlx_io_gguf_get_keys(mlx_vector_string* keys, mlx_io_gguf io);
|
||||
int mlx_io_gguf_get_array(mlx_array* arr, mlx_io_gguf io, const char* key);
|
||||
int mlx_io_gguf_get_metadata_array(
|
||||
mlx_array* arr,
|
||||
mlx_io_gguf io,
|
||||
const char* key);
|
||||
int mlx_io_gguf_get_metadata_string(
|
||||
mlx_string* str,
|
||||
mlx_io_gguf io,
|
||||
const char* key);
|
||||
int mlx_io_gguf_get_metadata_vector_string(
|
||||
mlx_vector_string* vstr,
|
||||
mlx_io_gguf io,
|
||||
const char* key);
|
||||
int mlx_io_gguf_has_metadata_array(bool* flag, mlx_io_gguf io, const char* key);
|
||||
int mlx_io_gguf_has_metadata_string(
|
||||
bool* flag,
|
||||
mlx_io_gguf io,
|
||||
const char* key);
|
||||
int mlx_io_gguf_has_metadata_vector_string(
|
||||
bool* flag,
|
||||
mlx_io_gguf io,
|
||||
const char* key);
|
||||
int mlx_io_gguf_set_array(mlx_io_gguf io, const char* key, const mlx_array arr);
|
||||
int mlx_io_gguf_set_metadata_array(
|
||||
mlx_io_gguf io,
|
||||
const char* key,
|
||||
const mlx_array marr);
|
||||
int mlx_io_gguf_set_metadata_string(
|
||||
mlx_io_gguf io,
|
||||
const char* key,
|
||||
const char* mstr);
|
||||
int mlx_io_gguf_set_metadata_vector_string(
|
||||
mlx_io_gguf io,
|
||||
const char* key,
|
||||
const mlx_vector_string mvstr);
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
128
x/mlxrunner/mlx/include/mlx/c/linalg.h
Normal file
128
x/mlxrunner/mlx/include/mlx/c/linalg.h
Normal file
@@ -0,0 +1,128 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
/* */
|
||||
/* This file is auto-generated. Do not edit manually. */
|
||||
/* */
|
||||
|
||||
#ifndef MLX_LINALG_H
|
||||
#define MLX_LINALG_H
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "mlx/c/array.h"
|
||||
#include "mlx/c/closure.h"
|
||||
#include "mlx/c/distributed_group.h"
|
||||
#include "mlx/c/io_types.h"
|
||||
#include "mlx/c/map.h"
|
||||
#include "mlx/c/stream.h"
|
||||
#include "mlx/c/string.h"
|
||||
#include "mlx/c/vector.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup linalg Linear algebra operations
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
int mlx_linalg_cholesky(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
bool upper,
|
||||
const mlx_stream s);
|
||||
int mlx_linalg_cholesky_inv(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
bool upper,
|
||||
const mlx_stream s);
|
||||
int mlx_linalg_cross(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const mlx_array b,
|
||||
int axis,
|
||||
const mlx_stream s);
|
||||
int mlx_linalg_eig(
|
||||
mlx_array* res_0,
|
||||
mlx_array* res_1,
|
||||
const mlx_array a,
|
||||
const mlx_stream s);
|
||||
int mlx_linalg_eigh(
|
||||
mlx_array* res_0,
|
||||
mlx_array* res_1,
|
||||
const mlx_array a,
|
||||
const char* UPLO,
|
||||
const mlx_stream s);
|
||||
int mlx_linalg_eigvals(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||
int mlx_linalg_eigvalsh(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const char* UPLO,
|
||||
const mlx_stream s);
|
||||
int mlx_linalg_inv(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||
int mlx_linalg_lu(mlx_vector_array* res, const mlx_array a, const mlx_stream s);
|
||||
int mlx_linalg_lu_factor(
|
||||
mlx_array* res_0,
|
||||
mlx_array* res_1,
|
||||
const mlx_array a,
|
||||
const mlx_stream s);
|
||||
int mlx_linalg_norm(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
double ord,
|
||||
const int* axis /* may be null */,
|
||||
size_t axis_num,
|
||||
bool keepdims,
|
||||
const mlx_stream s);
|
||||
int mlx_linalg_norm_matrix(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const char* ord,
|
||||
const int* axis /* may be null */,
|
||||
size_t axis_num,
|
||||
bool keepdims,
|
||||
const mlx_stream s);
|
||||
int mlx_linalg_norm_l2(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const int* axis /* may be null */,
|
||||
size_t axis_num,
|
||||
bool keepdims,
|
||||
const mlx_stream s);
|
||||
int mlx_linalg_pinv(mlx_array* res, const mlx_array a, const mlx_stream s);
|
||||
int mlx_linalg_qr(
|
||||
mlx_array* res_0,
|
||||
mlx_array* res_1,
|
||||
const mlx_array a,
|
||||
const mlx_stream s);
|
||||
int mlx_linalg_solve(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const mlx_array b,
|
||||
const mlx_stream s);
|
||||
int mlx_linalg_solve_triangular(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
const mlx_array b,
|
||||
bool upper,
|
||||
const mlx_stream s);
|
||||
int mlx_linalg_svd(
|
||||
mlx_vector_array* res,
|
||||
const mlx_array a,
|
||||
bool compute_uv,
|
||||
const mlx_stream s);
|
||||
int mlx_linalg_tri_inv(
|
||||
mlx_array* res,
|
||||
const mlx_array a,
|
||||
bool upper,
|
||||
const mlx_stream s);
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
149
x/mlxrunner/mlx/include/mlx/c/map.h
Normal file
149
x/mlxrunner/mlx/include/mlx/c/map.h
Normal file
@@ -0,0 +1,149 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
/* */
|
||||
/* This file is auto-generated. Do not edit manually. */
|
||||
/* */
|
||||
|
||||
#ifndef MLX_MAP_H
|
||||
#define MLX_MAP_H
|
||||
|
||||
#include "mlx/c/array.h"
|
||||
#include "mlx/c/string.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup mlx_map Maps
|
||||
* MLX map objects.
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
/**
|
||||
* A string-to-array map
|
||||
*/
|
||||
typedef struct mlx_map_string_to_array_ {
|
||||
void* ctx;
|
||||
} mlx_map_string_to_array;
|
||||
|
||||
/**
|
||||
* Returns a new empty string-to-array map.
|
||||
*/
|
||||
mlx_map_string_to_array mlx_map_string_to_array_new(void);
|
||||
/**
|
||||
* Set map to provided src map.
|
||||
*/
|
||||
int mlx_map_string_to_array_set(
|
||||
mlx_map_string_to_array* map,
|
||||
const mlx_map_string_to_array src);
|
||||
/**
|
||||
* Free a string-to-array map.
|
||||
*/
|
||||
int mlx_map_string_to_array_free(mlx_map_string_to_array map);
|
||||
/**
|
||||
* Insert a new `value` at the specified `key` in the map.
|
||||
*/
|
||||
int mlx_map_string_to_array_insert(
|
||||
mlx_map_string_to_array map,
|
||||
const char* key,
|
||||
const mlx_array value);
|
||||
/**
|
||||
* Returns the value indexed at the specified `key` in the map.
|
||||
*/
|
||||
int mlx_map_string_to_array_get(
|
||||
mlx_array* value,
|
||||
const mlx_map_string_to_array map,
|
||||
const char* key);
|
||||
|
||||
/**
|
||||
* An iterator over a string-to-array map.
|
||||
*/
|
||||
typedef struct mlx_map_string_to_array_iterator_ {
|
||||
void* ctx;
|
||||
void* map_ctx;
|
||||
} mlx_map_string_to_array_iterator;
|
||||
/**
|
||||
* Returns a new iterator over the given map.
|
||||
*/
|
||||
mlx_map_string_to_array_iterator mlx_map_string_to_array_iterator_new(
|
||||
mlx_map_string_to_array map);
|
||||
/**
|
||||
* Free iterator.
|
||||
*/
|
||||
int mlx_map_string_to_array_iterator_free(mlx_map_string_to_array_iterator it);
|
||||
/**
|
||||
* Increment iterator.
|
||||
*/
|
||||
int mlx_map_string_to_array_iterator_next(
|
||||
const char** key,
|
||||
mlx_array* value,
|
||||
mlx_map_string_to_array_iterator it);
|
||||
|
||||
/**
|
||||
* A string-to-string map
|
||||
*/
|
||||
typedef struct mlx_map_string_to_string_ {
|
||||
void* ctx;
|
||||
} mlx_map_string_to_string;
|
||||
|
||||
/**
|
||||
* Returns a new empty string-to-string map.
|
||||
*/
|
||||
mlx_map_string_to_string mlx_map_string_to_string_new(void);
|
||||
/**
|
||||
* Set map to provided src map.
|
||||
*/
|
||||
int mlx_map_string_to_string_set(
|
||||
mlx_map_string_to_string* map,
|
||||
const mlx_map_string_to_string src);
|
||||
/**
|
||||
* Free a string-to-string map.
|
||||
*/
|
||||
int mlx_map_string_to_string_free(mlx_map_string_to_string map);
|
||||
/**
|
||||
* Insert a new `value` at the specified `key` in the map.
|
||||
*/
|
||||
int mlx_map_string_to_string_insert(
|
||||
mlx_map_string_to_string map,
|
||||
const char* key,
|
||||
const char* value);
|
||||
/**
|
||||
* Returns the value indexed at the specified `key` in the map.
|
||||
*/
|
||||
int mlx_map_string_to_string_get(
|
||||
const char** value,
|
||||
const mlx_map_string_to_string map,
|
||||
const char* key);
|
||||
|
||||
/**
|
||||
* An iterator over a string-to-string map.
|
||||
*/
|
||||
typedef struct mlx_map_string_to_string_iterator_ {
|
||||
void* ctx;
|
||||
void* map_ctx;
|
||||
} mlx_map_string_to_string_iterator;
|
||||
/**
|
||||
* Returns a new iterator over the given map.
|
||||
*/
|
||||
mlx_map_string_to_string_iterator mlx_map_string_to_string_iterator_new(
|
||||
mlx_map_string_to_string map);
|
||||
/**
|
||||
* Free iterator.
|
||||
*/
|
||||
int mlx_map_string_to_string_iterator_free(
|
||||
mlx_map_string_to_string_iterator it);
|
||||
/**
|
||||
* Increment iterator.
|
||||
*/
|
||||
int mlx_map_string_to_string_iterator_next(
|
||||
const char** key,
|
||||
const char** value,
|
||||
mlx_map_string_to_string_iterator it);
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
47
x/mlxrunner/mlx/include/mlx/c/memory.h
Normal file
47
x/mlxrunner/mlx/include/mlx/c/memory.h
Normal file
@@ -0,0 +1,47 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
/* */
|
||||
/* This file is auto-generated. Do not edit manually. */
|
||||
/* */
|
||||
|
||||
#ifndef MLX_MEMORY_H
|
||||
#define MLX_MEMORY_H
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "mlx/c/array.h"
|
||||
#include "mlx/c/closure.h"
|
||||
#include "mlx/c/distributed_group.h"
|
||||
#include "mlx/c/io_types.h"
|
||||
#include "mlx/c/map.h"
|
||||
#include "mlx/c/stream.h"
|
||||
#include "mlx/c/string.h"
|
||||
#include "mlx/c/vector.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup memory Memory operations
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
int mlx_clear_cache(void);
|
||||
int mlx_get_active_memory(size_t* res);
|
||||
int mlx_get_cache_memory(size_t* res);
|
||||
int mlx_get_memory_limit(size_t* res);
|
||||
int mlx_get_peak_memory(size_t* res);
|
||||
int mlx_reset_peak_memory(void);
|
||||
int mlx_set_cache_limit(size_t* res, size_t limit);
|
||||
int mlx_set_memory_limit(size_t* res, size_t limit);
|
||||
int mlx_set_wired_limit(size_t* res, size_t limit);
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
41
x/mlxrunner/mlx/include/mlx/c/metal.h
Normal file
41
x/mlxrunner/mlx/include/mlx/c/metal.h
Normal file
@@ -0,0 +1,41 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
/* */
|
||||
/* This file is auto-generated. Do not edit manually. */
|
||||
/* */
|
||||
|
||||
#ifndef MLX_METAL_H
|
||||
#define MLX_METAL_H
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "mlx/c/array.h"
|
||||
#include "mlx/c/closure.h"
|
||||
#include "mlx/c/distributed_group.h"
|
||||
#include "mlx/c/io_types.h"
|
||||
#include "mlx/c/map.h"
|
||||
#include "mlx/c/stream.h"
|
||||
#include "mlx/c/string.h"
|
||||
#include "mlx/c/vector.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup metal Metal specific operations
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
int mlx_metal_is_available(bool* res);
|
||||
int mlx_metal_start_capture(const char* path);
|
||||
int mlx_metal_stop_capture(void);
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
35
x/mlxrunner/mlx/include/mlx/c/mlx.h
Normal file
35
x/mlxrunner/mlx/include/mlx/c/mlx.h
Normal file
@@ -0,0 +1,35 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
|
||||
#ifndef MLX_ALL_H
|
||||
#define MLX_ALL_H
|
||||
|
||||
#include "mlx/c/array.h"
|
||||
#include "mlx/c/closure.h"
|
||||
#include "mlx/c/compile.h"
|
||||
#include "mlx/c/cuda.h"
|
||||
#include "mlx/c/device.h"
|
||||
#include "mlx/c/distributed.h"
|
||||
#include "mlx/c/distributed_group.h"
|
||||
#include "mlx/c/error.h"
|
||||
#include "mlx/c/export.h"
|
||||
#include "mlx/c/fast.h"
|
||||
#include "mlx/c/fft.h"
|
||||
#include "mlx/c/graph_utils.h"
|
||||
#include "mlx/c/half.h"
|
||||
#include "mlx/c/io.h"
|
||||
#include "mlx/c/io_types.h"
|
||||
#include "mlx/c/linalg.h"
|
||||
#include "mlx/c/map.h"
|
||||
#include "mlx/c/memory.h"
|
||||
#include "mlx/c/metal.h"
|
||||
#include "mlx/c/ops.h"
|
||||
#include "mlx/c/optional.h"
|
||||
#include "mlx/c/random.h"
|
||||
#include "mlx/c/stream.h"
|
||||
#include "mlx/c/string.h"
|
||||
#include "mlx/c/transforms.h"
|
||||
#include "mlx/c/transforms_impl.h"
|
||||
#include "mlx/c/vector.h"
|
||||
#include "mlx/c/version.h"
|
||||
|
||||
#endif
|
||||
1287
x/mlxrunner/mlx/include/mlx/c/ops.h
Normal file
1287
x/mlxrunner/mlx/include/mlx/c/ops.h
Normal file
File diff suppressed because it is too large
Load Diff
51
x/mlxrunner/mlx/include/mlx/c/optional.h
Normal file
51
x/mlxrunner/mlx/include/mlx/c/optional.h
Normal file
@@ -0,0 +1,51 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
|
||||
#ifndef MLX_OPTIONAL_H
|
||||
#define MLX_OPTIONAL_H
|
||||
|
||||
#include <stdbool.h>
|
||||
|
||||
#include "mlx/c/array.h"
|
||||
#include "mlx/c/string.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup mlx_optional Optionals
|
||||
* MLX optional scalars.
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
/**
|
||||
* A int optional.
|
||||
*/
|
||||
typedef struct mlx_optional_int_ {
|
||||
int value;
|
||||
bool has_value;
|
||||
} mlx_optional_int;
|
||||
|
||||
/**
|
||||
* A float optional.
|
||||
*/
|
||||
typedef struct mlx_optional_float_ {
|
||||
float value;
|
||||
bool has_value;
|
||||
} mlx_optional_float;
|
||||
|
||||
/**
|
||||
* A dtype optional.
|
||||
*/
|
||||
typedef struct mlx_optional_dtype_ {
|
||||
mlx_dtype value;
|
||||
bool has_value;
|
||||
} mlx_optional_dtype;
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
166
x/mlxrunner/mlx/include/mlx/c/random.h
Normal file
166
x/mlxrunner/mlx/include/mlx/c/random.h
Normal file
@@ -0,0 +1,166 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
/* */
|
||||
/* This file is auto-generated. Do not edit manually. */
|
||||
/* */
|
||||
|
||||
#ifndef MLX_RANDOM_H
|
||||
#define MLX_RANDOM_H
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "mlx/c/array.h"
|
||||
#include "mlx/c/closure.h"
|
||||
#include "mlx/c/distributed_group.h"
|
||||
#include "mlx/c/io_types.h"
|
||||
#include "mlx/c/map.h"
|
||||
#include "mlx/c/stream.h"
|
||||
#include "mlx/c/string.h"
|
||||
#include "mlx/c/vector.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup random Random number operations
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
int mlx_random_bernoulli(
|
||||
mlx_array* res,
|
||||
const mlx_array p,
|
||||
const int* shape,
|
||||
size_t shape_num,
|
||||
const mlx_array key /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_random_bits(
|
||||
mlx_array* res,
|
||||
const int* shape,
|
||||
size_t shape_num,
|
||||
int width,
|
||||
const mlx_array key /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_random_categorical_shape(
|
||||
mlx_array* res,
|
||||
const mlx_array logits,
|
||||
int axis,
|
||||
const int* shape,
|
||||
size_t shape_num,
|
||||
const mlx_array key /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_random_categorical_num_samples(
|
||||
mlx_array* res,
|
||||
const mlx_array logits_,
|
||||
int axis,
|
||||
int num_samples,
|
||||
const mlx_array key /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_random_categorical(
|
||||
mlx_array* res,
|
||||
const mlx_array logits,
|
||||
int axis,
|
||||
const mlx_array key /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_random_gumbel(
|
||||
mlx_array* res,
|
||||
const int* shape,
|
||||
size_t shape_num,
|
||||
mlx_dtype dtype,
|
||||
const mlx_array key /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_random_key(mlx_array* res, uint64_t seed);
|
||||
int mlx_random_laplace(
|
||||
mlx_array* res,
|
||||
const int* shape,
|
||||
size_t shape_num,
|
||||
mlx_dtype dtype,
|
||||
float loc,
|
||||
float scale,
|
||||
const mlx_array key /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_random_multivariate_normal(
|
||||
mlx_array* res,
|
||||
const mlx_array mean,
|
||||
const mlx_array cov,
|
||||
const int* shape,
|
||||
size_t shape_num,
|
||||
mlx_dtype dtype,
|
||||
const mlx_array key /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_random_normal_broadcast(
|
||||
mlx_array* res,
|
||||
const int* shape,
|
||||
size_t shape_num,
|
||||
mlx_dtype dtype,
|
||||
const mlx_array loc /* may be null */,
|
||||
const mlx_array scale /* may be null */,
|
||||
const mlx_array key /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_random_normal(
|
||||
mlx_array* res,
|
||||
const int* shape,
|
||||
size_t shape_num,
|
||||
mlx_dtype dtype,
|
||||
float loc,
|
||||
float scale,
|
||||
const mlx_array key /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_random_permutation(
|
||||
mlx_array* res,
|
||||
const mlx_array x,
|
||||
int axis,
|
||||
const mlx_array key /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_random_permutation_arange(
|
||||
mlx_array* res,
|
||||
int x,
|
||||
const mlx_array key /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_random_randint(
|
||||
mlx_array* res,
|
||||
const mlx_array low,
|
||||
const mlx_array high,
|
||||
const int* shape,
|
||||
size_t shape_num,
|
||||
mlx_dtype dtype,
|
||||
const mlx_array key /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_random_seed(uint64_t seed);
|
||||
int mlx_random_split_num(
|
||||
mlx_array* res,
|
||||
const mlx_array key,
|
||||
int num,
|
||||
const mlx_stream s);
|
||||
int mlx_random_split(
|
||||
mlx_array* res_0,
|
||||
mlx_array* res_1,
|
||||
const mlx_array key,
|
||||
const mlx_stream s);
|
||||
int mlx_random_truncated_normal(
|
||||
mlx_array* res,
|
||||
const mlx_array lower,
|
||||
const mlx_array upper,
|
||||
const int* shape,
|
||||
size_t shape_num,
|
||||
mlx_dtype dtype,
|
||||
const mlx_array key /* may be null */,
|
||||
const mlx_stream s);
|
||||
int mlx_random_uniform(
|
||||
mlx_array* res,
|
||||
const mlx_array low,
|
||||
const mlx_array high,
|
||||
const int* shape,
|
||||
size_t shape_num,
|
||||
mlx_dtype dtype,
|
||||
const mlx_array key /* may be null */,
|
||||
const mlx_stream s);
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
88
x/mlxrunner/mlx/include/mlx/c/stream.h
Normal file
88
x/mlxrunner/mlx/include/mlx/c/stream.h
Normal file
@@ -0,0 +1,88 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
|
||||
#ifndef MLX_STREAM_H
|
||||
#define MLX_STREAM_H
|
||||
|
||||
#include <stdbool.h>
|
||||
|
||||
#include "mlx/c/device.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup mlx_stream Stream
|
||||
* MLX stream object.
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
/**
|
||||
* A MLX stream object.
|
||||
*/
|
||||
typedef struct mlx_stream_ {
|
||||
void* ctx;
|
||||
} mlx_stream;
|
||||
|
||||
/**
|
||||
* Returns a new empty stream.
|
||||
*/
|
||||
mlx_stream mlx_stream_new(void);
|
||||
|
||||
/**
|
||||
* Returns a new stream on a device.
|
||||
*/
|
||||
mlx_stream mlx_stream_new_device(mlx_device dev);
|
||||
/**
|
||||
* Set stream to provided src stream.
|
||||
*/
|
||||
int mlx_stream_set(mlx_stream* stream, const mlx_stream src);
|
||||
/**
|
||||
* Free a stream.
|
||||
*/
|
||||
int mlx_stream_free(mlx_stream stream);
|
||||
/**
|
||||
* Get stream description.
|
||||
*/
|
||||
int mlx_stream_tostring(mlx_string* str, mlx_stream stream);
|
||||
/**
|
||||
* Check if streams are the same.
|
||||
*/
|
||||
bool mlx_stream_equal(mlx_stream lhs, mlx_stream rhs);
|
||||
/**
|
||||
* Return the device of the stream.
|
||||
*/
|
||||
int mlx_stream_get_device(mlx_device* dev, mlx_stream stream);
|
||||
/**
|
||||
* Return the index of the stream.
|
||||
*/
|
||||
int mlx_stream_get_index(int* index, mlx_stream stream);
|
||||
/**
|
||||
* Synchronize with the provided stream.
|
||||
*/
|
||||
int mlx_synchronize(mlx_stream stream);
|
||||
/**
|
||||
* Returns the default stream on the given device.
|
||||
*/
|
||||
int mlx_get_default_stream(mlx_stream* stream, mlx_device dev);
|
||||
/**
|
||||
* Set default stream.
|
||||
*/
|
||||
int mlx_set_default_stream(mlx_stream stream);
|
||||
/**
|
||||
* Returns the current default CPU stream.
|
||||
*/
|
||||
mlx_stream mlx_default_cpu_stream_new(void);
|
||||
|
||||
/**
|
||||
* Returns the current default GPU stream.
|
||||
*/
|
||||
mlx_stream mlx_default_gpu_stream_new(void);
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
55
x/mlxrunner/mlx/include/mlx/c/string.h
Normal file
55
x/mlxrunner/mlx/include/mlx/c/string.h
Normal file
@@ -0,0 +1,55 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
|
||||
#ifndef MLX_STRING_H
|
||||
#define MLX_STRING_H
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup mlx_string String
|
||||
* MLX string object.
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
/**
|
||||
* A MLX string object.
|
||||
*/
|
||||
typedef struct mlx_string_ {
|
||||
void* ctx;
|
||||
} mlx_string;
|
||||
|
||||
/**
|
||||
* Returns a new empty string.
|
||||
*/
|
||||
mlx_string mlx_string_new(void);
|
||||
|
||||
/**
|
||||
* Returns a new string, copying contents from `str`, which must end with `\0`.
|
||||
*/
|
||||
mlx_string mlx_string_new_data(const char* str);
|
||||
|
||||
/**
|
||||
* Set string to src string.
|
||||
*/
|
||||
int mlx_string_set(mlx_string* str, const mlx_string src);
|
||||
|
||||
/**
|
||||
* Returns a pointer to the string contents.
|
||||
* The pointer is valid for the life duration of the string.
|
||||
*/
|
||||
const char* mlx_string_data(mlx_string str);
|
||||
|
||||
/**
|
||||
* Free string.
|
||||
*/
|
||||
int mlx_string_free(mlx_string str);
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
68
x/mlxrunner/mlx/include/mlx/c/transforms.h
Normal file
68
x/mlxrunner/mlx/include/mlx/c/transforms.h
Normal file
@@ -0,0 +1,68 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
/* */
|
||||
/* This file is auto-generated. Do not edit manually. */
|
||||
/* */
|
||||
|
||||
#ifndef MLX_TRANSFORMS_H
|
||||
#define MLX_TRANSFORMS_H
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "mlx/c/array.h"
|
||||
#include "mlx/c/closure.h"
|
||||
#include "mlx/c/distributed_group.h"
|
||||
#include "mlx/c/io_types.h"
|
||||
#include "mlx/c/map.h"
|
||||
#include "mlx/c/stream.h"
|
||||
#include "mlx/c/string.h"
|
||||
#include "mlx/c/vector.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup transforms Transform operations
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
int mlx_async_eval(const mlx_vector_array outputs);
|
||||
int mlx_checkpoint(mlx_closure* res, const mlx_closure fun);
|
||||
int mlx_custom_function(
|
||||
mlx_closure* res,
|
||||
const mlx_closure fun,
|
||||
const mlx_closure_custom fun_vjp /* may be null */,
|
||||
const mlx_closure_custom_jvp fun_jvp /* may be null */,
|
||||
const mlx_closure_custom_vmap fun_vmap /* may be null */);
|
||||
int mlx_custom_vjp(
|
||||
mlx_closure* res,
|
||||
const mlx_closure fun,
|
||||
const mlx_closure_custom fun_vjp);
|
||||
int mlx_eval(const mlx_vector_array outputs);
|
||||
int mlx_jvp(
|
||||
mlx_vector_array* res_0,
|
||||
mlx_vector_array* res_1,
|
||||
const mlx_closure fun,
|
||||
const mlx_vector_array primals,
|
||||
const mlx_vector_array tangents);
|
||||
int mlx_value_and_grad(
|
||||
mlx_closure_value_and_grad* res,
|
||||
const mlx_closure fun,
|
||||
const int* argnums,
|
||||
size_t argnums_num);
|
||||
int mlx_vjp(
|
||||
mlx_vector_array* res_0,
|
||||
mlx_vector_array* res_1,
|
||||
const mlx_closure fun,
|
||||
const mlx_vector_array primals,
|
||||
const mlx_vector_array cotangents);
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
54
x/mlxrunner/mlx/include/mlx/c/transforms_impl.h
Normal file
54
x/mlxrunner/mlx/include/mlx/c/transforms_impl.h
Normal file
@@ -0,0 +1,54 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
/* */
|
||||
/* This file is auto-generated. Do not edit manually. */
|
||||
/* */
|
||||
|
||||
#ifndef MLX_TRANSFORMS_IMPL_H
|
||||
#define MLX_TRANSFORMS_IMPL_H
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "mlx/c/array.h"
|
||||
#include "mlx/c/closure.h"
|
||||
#include "mlx/c/distributed_group.h"
|
||||
#include "mlx/c/io_types.h"
|
||||
#include "mlx/c/map.h"
|
||||
#include "mlx/c/stream.h"
|
||||
#include "mlx/c/string.h"
|
||||
#include "mlx/c/vector.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup transforms_impl Implementation detail operations
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
int mlx_detail_vmap_replace(
|
||||
mlx_vector_array* res,
|
||||
const mlx_vector_array inputs,
|
||||
const mlx_vector_array s_inputs,
|
||||
const mlx_vector_array s_outputs,
|
||||
const int* in_axes,
|
||||
size_t in_axes_num,
|
||||
const int* out_axes,
|
||||
size_t out_axes_num);
|
||||
int mlx_detail_vmap_trace(
|
||||
mlx_vector_array* res_0,
|
||||
mlx_vector_array* res_1,
|
||||
const mlx_closure fun,
|
||||
const mlx_vector_array inputs,
|
||||
const int* in_axes,
|
||||
size_t in_axes_num);
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
133
x/mlxrunner/mlx/include/mlx/c/vector.h
Normal file
133
x/mlxrunner/mlx/include/mlx/c/vector.h
Normal file
@@ -0,0 +1,133 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
/* */
|
||||
/* This file is auto-generated. Do not edit manually. */
|
||||
/* */
|
||||
|
||||
#ifndef MLX_VECTOR_H
|
||||
#define MLX_VECTOR_H
|
||||
|
||||
#include "mlx/c/array.h"
|
||||
#include "mlx/c/string.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \defgroup mlx_vector Vectors
|
||||
* MLX vector objects.
|
||||
*/
|
||||
/**@{*/
|
||||
|
||||
/**
|
||||
* A vector of array.
|
||||
*/
|
||||
typedef struct mlx_vector_array_ {
|
||||
void* ctx;
|
||||
} mlx_vector_array;
|
||||
mlx_vector_array mlx_vector_array_new(void);
|
||||
int mlx_vector_array_set(mlx_vector_array* vec, const mlx_vector_array src);
|
||||
int mlx_vector_array_free(mlx_vector_array vec);
|
||||
mlx_vector_array mlx_vector_array_new_data(const mlx_array* data, size_t size);
|
||||
mlx_vector_array mlx_vector_array_new_value(const mlx_array val);
|
||||
int mlx_vector_array_set_data(
|
||||
mlx_vector_array* vec,
|
||||
const mlx_array* data,
|
||||
size_t size);
|
||||
int mlx_vector_array_set_value(mlx_vector_array* vec, const mlx_array val);
|
||||
int mlx_vector_array_append_data(
|
||||
mlx_vector_array vec,
|
||||
const mlx_array* data,
|
||||
size_t size);
|
||||
int mlx_vector_array_append_value(mlx_vector_array vec, const mlx_array val);
|
||||
size_t mlx_vector_array_size(mlx_vector_array vec);
|
||||
int mlx_vector_array_get(
|
||||
mlx_array* res,
|
||||
const mlx_vector_array vec,
|
||||
size_t idx);
|
||||
|
||||
/**
|
||||
* A vector of vector_array.
|
||||
*/
|
||||
typedef struct mlx_vector_vector_array_ {
|
||||
void* ctx;
|
||||
} mlx_vector_vector_array;
|
||||
mlx_vector_vector_array mlx_vector_vector_array_new(void);
|
||||
int mlx_vector_vector_array_set(
|
||||
mlx_vector_vector_array* vec,
|
||||
const mlx_vector_vector_array src);
|
||||
int mlx_vector_vector_array_free(mlx_vector_vector_array vec);
|
||||
mlx_vector_vector_array mlx_vector_vector_array_new_data(
|
||||
const mlx_vector_array* data,
|
||||
size_t size);
|
||||
mlx_vector_vector_array mlx_vector_vector_array_new_value(
|
||||
const mlx_vector_array val);
|
||||
int mlx_vector_vector_array_set_data(
|
||||
mlx_vector_vector_array* vec,
|
||||
const mlx_vector_array* data,
|
||||
size_t size);
|
||||
int mlx_vector_vector_array_set_value(
|
||||
mlx_vector_vector_array* vec,
|
||||
const mlx_vector_array val);
|
||||
int mlx_vector_vector_array_append_data(
|
||||
mlx_vector_vector_array vec,
|
||||
const mlx_vector_array* data,
|
||||
size_t size);
|
||||
int mlx_vector_vector_array_append_value(
|
||||
mlx_vector_vector_array vec,
|
||||
const mlx_vector_array val);
|
||||
size_t mlx_vector_vector_array_size(mlx_vector_vector_array vec);
|
||||
int mlx_vector_vector_array_get(
|
||||
mlx_vector_array* res,
|
||||
const mlx_vector_vector_array vec,
|
||||
size_t idx);
|
||||
|
||||
/**
|
||||
* A vector of int.
|
||||
*/
|
||||
typedef struct mlx_vector_int_ {
|
||||
void* ctx;
|
||||
} mlx_vector_int;
|
||||
mlx_vector_int mlx_vector_int_new(void);
|
||||
int mlx_vector_int_set(mlx_vector_int* vec, const mlx_vector_int src);
|
||||
int mlx_vector_int_free(mlx_vector_int vec);
|
||||
mlx_vector_int mlx_vector_int_new_data(int* data, size_t size);
|
||||
mlx_vector_int mlx_vector_int_new_value(int val);
|
||||
int mlx_vector_int_set_data(mlx_vector_int* vec, int* data, size_t size);
|
||||
int mlx_vector_int_set_value(mlx_vector_int* vec, int val);
|
||||
int mlx_vector_int_append_data(mlx_vector_int vec, int* data, size_t size);
|
||||
int mlx_vector_int_append_value(mlx_vector_int vec, int val);
|
||||
size_t mlx_vector_int_size(mlx_vector_int vec);
|
||||
int mlx_vector_int_get(int* res, const mlx_vector_int vec, size_t idx);
|
||||
|
||||
/**
|
||||
* A vector of string.
|
||||
*/
|
||||
typedef struct mlx_vector_string_ {
|
||||
void* ctx;
|
||||
} mlx_vector_string;
|
||||
mlx_vector_string mlx_vector_string_new(void);
|
||||
int mlx_vector_string_set(mlx_vector_string* vec, const mlx_vector_string src);
|
||||
int mlx_vector_string_free(mlx_vector_string vec);
|
||||
mlx_vector_string mlx_vector_string_new_data(const char** data, size_t size);
|
||||
mlx_vector_string mlx_vector_string_new_value(const char* val);
|
||||
int mlx_vector_string_set_data(
|
||||
mlx_vector_string* vec,
|
||||
const char** data,
|
||||
size_t size);
|
||||
int mlx_vector_string_set_value(mlx_vector_string* vec, const char* val);
|
||||
int mlx_vector_string_append_data(
|
||||
mlx_vector_string vec,
|
||||
const char** data,
|
||||
size_t size);
|
||||
int mlx_vector_string_append_value(mlx_vector_string vec, const char* val);
|
||||
size_t mlx_vector_string_size(mlx_vector_string vec);
|
||||
int mlx_vector_string_get(char** res, const mlx_vector_string vec, size_t idx);
|
||||
|
||||
/**@}*/
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
18
x/mlxrunner/mlx/include/mlx/c/version.h
Normal file
18
x/mlxrunner/mlx/include/mlx/c/version.h
Normal file
@@ -0,0 +1,18 @@
|
||||
/* Copyright © 2023-2024 Apple Inc. */
|
||||
|
||||
#ifndef MLX_VERSION_H
|
||||
#define MLX_VERSION_H
|
||||
|
||||
#include "mlx/c/string.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
int mlx_version(mlx_string* str_);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
164
x/mlxrunner/mlx/io.go
Normal file
164
x/mlxrunner/mlx/io.go
Normal file
@@ -0,0 +1,164 @@
|
||||
package mlx
|
||||
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"iter"
|
||||
"runtime"
|
||||
"sort"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// SafetensorsFile represents a loaded safetensors file.
|
||||
type SafetensorsFile struct {
|
||||
arrays C.mlx_map_string_to_array
|
||||
metadata C.mlx_map_string_to_string
|
||||
}
|
||||
|
||||
func loadSafetensorsStream() C.mlx_stream {
|
||||
if runtime.GOOS == "darwin" {
|
||||
return C.mlx_default_cpu_stream_new()
|
||||
}
|
||||
return C.mlx_default_gpu_stream_new()
|
||||
}
|
||||
|
||||
// LoadSafetensorsNative loads a safetensors file using MLX's native loader.
|
||||
func LoadSafetensorsNative(path string) (*SafetensorsFile, error) {
|
||||
var arrays C.mlx_map_string_to_array
|
||||
var metadata C.mlx_map_string_to_string
|
||||
|
||||
cPath := C.CString(path)
|
||||
defer C.free(unsafe.Pointer(cPath))
|
||||
|
||||
stream := loadSafetensorsStream()
|
||||
defer C.mlx_stream_free(stream)
|
||||
|
||||
if C.mlx_load_safetensors(&arrays, &metadata, cPath, stream) != 0 {
|
||||
return nil, fmt.Errorf("failed to load safetensors: %s", path)
|
||||
}
|
||||
|
||||
return &SafetensorsFile{arrays: arrays, metadata: metadata}, nil
|
||||
}
|
||||
|
||||
// Get retrieves a tensor by name.
|
||||
func (s *SafetensorsFile) Get(name string) *Array {
|
||||
cName := C.CString(name)
|
||||
defer C.free(unsafe.Pointer(cName))
|
||||
|
||||
value := C.mlx_array_new()
|
||||
if C.mlx_map_string_to_array_get(&value, s.arrays, cName) != 0 {
|
||||
return nil
|
||||
}
|
||||
if value.ctx == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
arr := New(name)
|
||||
arr.ctx = value
|
||||
return arr
|
||||
}
|
||||
|
||||
// GetMetadata retrieves a metadata value by key.
|
||||
func (s *SafetensorsFile) GetMetadata(key string) string {
|
||||
cKey := C.CString(key)
|
||||
defer C.free(unsafe.Pointer(cKey))
|
||||
|
||||
var cValue *C.char
|
||||
if C.mlx_map_string_to_string_get(&cValue, s.metadata, cKey) != 0 {
|
||||
return ""
|
||||
}
|
||||
return C.GoString(cValue)
|
||||
}
|
||||
|
||||
// Free releases the loaded safetensors maps.
|
||||
func (s *SafetensorsFile) Free() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
C.mlx_map_string_to_array_free(s.arrays)
|
||||
C.mlx_map_string_to_string_free(s.metadata)
|
||||
}
|
||||
|
||||
func Load(path string) iter.Seq2[string, *Array] {
|
||||
return func(yield func(string, *Array) bool) {
|
||||
sf, err := LoadSafetensorsNative(path)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer sf.Free()
|
||||
|
||||
it := C.mlx_map_string_to_array_iterator_new(sf.arrays)
|
||||
defer C.mlx_map_string_to_array_iterator_free(it)
|
||||
|
||||
for {
|
||||
var key *C.char
|
||||
value := C.mlx_array_new()
|
||||
if C.mlx_map_string_to_array_iterator_next(&key, &value, it) != 0 {
|
||||
break
|
||||
}
|
||||
|
||||
name := C.GoString(key)
|
||||
arr := New(name)
|
||||
arr.ctx = value
|
||||
if !yield(name, arr) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SaveSafetensors saves arrays to a safetensors file without metadata.
|
||||
func SaveSafetensors(path string, arrays map[string]*Array) error {
|
||||
return SaveSafetensorsWithMetadata(path, arrays, nil)
|
||||
}
|
||||
|
||||
// SaveSafetensorsWithMetadata saves arrays to a safetensors file with metadata.
|
||||
func SaveSafetensorsWithMetadata(path string, arrays map[string]*Array, metadata map[string]string) error {
|
||||
cPath := C.CString(path)
|
||||
defer C.free(unsafe.Pointer(cPath))
|
||||
|
||||
cArrays := C.mlx_map_string_to_array_new()
|
||||
defer C.mlx_map_string_to_array_free(cArrays)
|
||||
|
||||
arrayNames := make([]string, 0, len(arrays))
|
||||
for name, arr := range arrays {
|
||||
if arr == nil {
|
||||
continue
|
||||
}
|
||||
arrayNames = append(arrayNames, name)
|
||||
}
|
||||
sort.Strings(arrayNames)
|
||||
|
||||
for _, name := range arrayNames {
|
||||
arr := arrays[name]
|
||||
cName := C.CString(name)
|
||||
C.mlx_map_string_to_array_insert(cArrays, cName, arr.ctx)
|
||||
C.free(unsafe.Pointer(cName))
|
||||
}
|
||||
|
||||
cMetadata := C.mlx_map_string_to_string_new()
|
||||
defer C.mlx_map_string_to_string_free(cMetadata)
|
||||
|
||||
metadataKeys := make([]string, 0, len(metadata))
|
||||
for key := range metadata {
|
||||
metadataKeys = append(metadataKeys, key)
|
||||
}
|
||||
sort.Strings(metadataKeys)
|
||||
|
||||
for _, key := range metadataKeys {
|
||||
value := metadata[key]
|
||||
cKey := C.CString(key)
|
||||
cValue := C.CString(value)
|
||||
C.mlx_map_string_to_string_insert(cMetadata, cKey, cValue)
|
||||
C.free(unsafe.Pointer(cKey))
|
||||
C.free(unsafe.Pointer(cValue))
|
||||
}
|
||||
|
||||
if C.mlx_save_safetensors(cPath, cArrays, cMetadata) != 0 {
|
||||
return fmt.Errorf("failed to save safetensors: %s", path)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
89
x/mlxrunner/mlx/memory.go
Normal file
89
x/mlxrunner/mlx/memory.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package mlx
|
||||
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
func (b Byte) String() string {
|
||||
return strconv.FormatInt(int64(b), 10) + " B"
|
||||
}
|
||||
|
||||
func (b KibiByte) String() string {
|
||||
return strconv.FormatFloat(float64(b)/(1<<10), 'f', 2, 64) + " KiB"
|
||||
}
|
||||
|
||||
func (b MebiByte) String() string {
|
||||
return strconv.FormatFloat(float64(b)/(1<<(2*10)), 'f', 2, 64) + " MiB"
|
||||
}
|
||||
|
||||
func (b GibiByte) String() string {
|
||||
return strconv.FormatFloat(float64(b)/(1<<(3*10)), 'f', 2, 64) + " GiB"
|
||||
}
|
||||
|
||||
func (b TebiByte) String() string {
|
||||
return strconv.FormatFloat(float64(b)/(1<<(4*10)), 'f', 2, 64) + " TiB"
|
||||
}
|
||||
|
||||
func PrettyBytes(n int) fmt.Stringer {
|
||||
switch {
|
||||
case n < 1<<10:
|
||||
return Byte(n)
|
||||
case n < 1<<(2*10):
|
||||
return KibiByte(n)
|
||||
case n < 1<<(3*10):
|
||||
return MebiByte(n)
|
||||
case n < 1<<(4*10):
|
||||
return GibiByte(n)
|
||||
default:
|
||||
return TebiByte(n)
|
||||
}
|
||||
}
|
||||
|
||||
func ActiveMemory() int {
|
||||
var active C.size_t
|
||||
C.mlx_get_active_memory(&active)
|
||||
return int(active)
|
||||
}
|
||||
|
||||
func CacheMemory() int {
|
||||
var cache C.size_t
|
||||
C.mlx_get_cache_memory(&cache)
|
||||
return int(cache)
|
||||
}
|
||||
|
||||
func PeakMemory() int {
|
||||
var peak C.size_t
|
||||
C.mlx_get_peak_memory(&peak)
|
||||
return int(peak)
|
||||
}
|
||||
|
||||
func ResetPeakMemory() {
|
||||
C.mlx_reset_peak_memory()
|
||||
}
|
||||
|
||||
type Memory struct{}
|
||||
|
||||
func (Memory) LogValue() slog.Value {
|
||||
return slog.GroupValue(
|
||||
slog.Any("active", PrettyBytes(ActiveMemory())),
|
||||
slog.Any("cache", PrettyBytes(CacheMemory())),
|
||||
slog.Any("peak", PrettyBytes(PeakMemory())),
|
||||
)
|
||||
}
|
||||
|
||||
type (
|
||||
Byte int
|
||||
KibiByte int
|
||||
MebiByte int
|
||||
GibiByte int
|
||||
TebiByte int
|
||||
)
|
||||
|
||||
func ClearCache() {
|
||||
C.mlx_clear_cache()
|
||||
}
|
||||
107
x/mlxrunner/mlx/mlx.go
Normal file
107
x/mlxrunner/mlx/mlx.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package mlx
|
||||
|
||||
//go:generate go run generator/main.go -output=. ./include/mlx/c/*.h
|
||||
|
||||
// #cgo CXXFLAGS: -std=c++17
|
||||
// #cgo CPPFLAGS: -I${SRCDIR}/include
|
||||
// #cgo LDFLAGS: -lstdc++
|
||||
// #cgo darwin LDFLAGS: -framework Foundation -framework Metal -framework Accelerate
|
||||
// #include "generated.h"
|
||||
// #include <string.h>
|
||||
//
|
||||
// static __thread char _mlx_last_error_msg[1024] = {0};
|
||||
// static __thread int _mlx_last_error_flag = 0;
|
||||
//
|
||||
// static void _mlx_capture_error_handler(const char* msg, void* data) {
|
||||
// (void)data;
|
||||
// strncpy(_mlx_last_error_msg, msg, sizeof(_mlx_last_error_msg) - 1);
|
||||
// _mlx_last_error_msg[sizeof(_mlx_last_error_msg) - 1] = '\0';
|
||||
// _mlx_last_error_flag = 1;
|
||||
// }
|
||||
//
|
||||
// static void mlx_install_capture_handler(void) {
|
||||
// if (mlx_set_error_handler_) {
|
||||
// mlx_set_error_handler_(_mlx_capture_error_handler, NULL, NULL);
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// static void mlx_clear_last_error(void) {
|
||||
// _mlx_last_error_flag = 0;
|
||||
// _mlx_last_error_msg[0] = '\0';
|
||||
// }
|
||||
//
|
||||
// static const char* mlx_get_last_error(void) {
|
||||
// return _mlx_last_error_flag ? _mlx_last_error_msg : "";
|
||||
// }
|
||||
import "C"
|
||||
|
||||
import "runtime"
|
||||
|
||||
func init() {
|
||||
// Replace the default exit(-1) error handler with one that captures
|
||||
// the error message so we can surface it in Go.
|
||||
C.mlx_install_capture_handler()
|
||||
}
|
||||
|
||||
// Version returns the MLX core library version string.
|
||||
func Version() string {
|
||||
str := C.mlx_string_new()
|
||||
defer C.mlx_string_free(str)
|
||||
C.mlx_version(&str)
|
||||
return C.GoString(C.mlx_string_data(str))
|
||||
}
|
||||
|
||||
// mlxCheck locks the goroutine to its OS thread, clears the captured error
|
||||
// state, calls fn, and panics with the captured message if fn returns non-zero.
|
||||
// The thread lock ensures the thread-local error state is read from the same
|
||||
// thread that executed the call.
|
||||
func mlxCheck(fallback string, fn func() C.int) {
|
||||
runtime.LockOSThread()
|
||||
defer runtime.UnlockOSThread()
|
||||
|
||||
C.mlx_clear_last_error()
|
||||
if fn() != 0 {
|
||||
msg := C.GoString(C.mlx_get_last_error())
|
||||
if msg == "" {
|
||||
msg = fallback
|
||||
}
|
||||
panic("mlx: " + msg)
|
||||
}
|
||||
}
|
||||
|
||||
func doEval(outputs []*Array, async bool) {
|
||||
if len(outputs) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
vector := C.mlx_vector_array_new()
|
||||
defer C.mlx_vector_array_free(vector)
|
||||
|
||||
for _, output := range outputs {
|
||||
if output != nil && output.Valid() {
|
||||
C.mlx_vector_array_append_value(vector, output.ctx)
|
||||
}
|
||||
}
|
||||
|
||||
mlxCheck("eval failed", func() C.int {
|
||||
if async {
|
||||
return C.mlx_async_eval(vector)
|
||||
}
|
||||
return C.mlx_eval(vector)
|
||||
})
|
||||
}
|
||||
|
||||
func AsyncEval(outputs ...*Array) {
|
||||
doEval(outputs, true)
|
||||
}
|
||||
|
||||
func Eval(outputs ...*Array) {
|
||||
doEval(outputs, false)
|
||||
}
|
||||
|
||||
// MetalIsAvailable returns true if a Metal GPU is available.
|
||||
func MetalIsAvailable() bool {
|
||||
var available C._Bool
|
||||
C.mlx_metal_is_available(&available)
|
||||
return bool(available)
|
||||
}
|
||||
36
x/mlxrunner/mlx/nn.go
Normal file
36
x/mlxrunner/mlx/nn.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package mlx
|
||||
|
||||
type Linear struct {
|
||||
Weight *Array `weight:"weight"`
|
||||
Bias *Array `weight:"bias"`
|
||||
}
|
||||
|
||||
// Forward computes the linear transformation: x @ Weight.T + Bias
|
||||
func (m *Linear) Forward(x *Array) *Array {
|
||||
w := m.Weight.Transpose(1, 0)
|
||||
if m.Bias.Valid() {
|
||||
return m.Bias.Addmm(x, w, 1.0, 1.0)
|
||||
}
|
||||
|
||||
return x.Matmul(w)
|
||||
}
|
||||
|
||||
func (m *Linear) Gather(x, lhs, rhs *Array, sorted bool) *Array {
|
||||
w := m.Weight.Transpose(0, 2, 1)
|
||||
// TODO: bias
|
||||
return x.GatherMM(w, lhs, rhs, sorted)
|
||||
}
|
||||
|
||||
type Embedding struct {
|
||||
Weight *Array `weight:"weight"`
|
||||
}
|
||||
|
||||
func (e *Embedding) Forward(indices *Array) *Array {
|
||||
return e.Weight.TakeAxis(indices, 0)
|
||||
}
|
||||
|
||||
func (e *Embedding) AsLinear() Linear {
|
||||
return Linear{
|
||||
Weight: e.Weight,
|
||||
}
|
||||
}
|
||||
300
x/mlxrunner/mlx/ops.go
Normal file
300
x/mlxrunner/mlx/ops.go
Normal file
@@ -0,0 +1,300 @@
|
||||
package mlx
|
||||
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
func (t *Array) Abs() *Array {
|
||||
out := New("ABS")
|
||||
C.mlx_abs(&out.ctx, t.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Add(other *Array) *Array {
|
||||
out := New("ADD")
|
||||
C.mlx_add(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Addmm(a, b *Array, alpha, beta float32) *Array {
|
||||
out := New("ADDMM")
|
||||
C.mlx_addmm(&out.ctx, t.ctx, a.ctx, b.ctx, C.float(alpha), C.float(beta), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Argmax(axis int, keepDims bool) *Array {
|
||||
out := New("ARGMAX")
|
||||
C.mlx_argmax_axis(&out.ctx, t.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) ArgpartitionAxis(kth int, axis int) *Array {
|
||||
out := New("ARGPARTITION")
|
||||
C.mlx_argpartition_axis(&out.ctx, t.ctx, C.int(kth), C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) ArgsortAxis(axis int) *Array {
|
||||
out := New("ARGSORT_AXIS")
|
||||
C.mlx_argsort_axis(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) AsType(dtype DType) *Array {
|
||||
out := New("AS_TYPE")
|
||||
C.mlx_astype(&out.ctx, t.ctx, C.mlx_dtype(dtype), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) AsStrided(shape []int, strides []int, offset int) *Array {
|
||||
cShape := make([]C.int, len(shape))
|
||||
for i, s := range shape {
|
||||
cShape[i] = C.int(s)
|
||||
}
|
||||
|
||||
cStrides := make([]C.int64_t, len(strides))
|
||||
for i, s := range strides {
|
||||
cStrides[i] = C.int64_t(s)
|
||||
}
|
||||
|
||||
out := New("AS_STRIDED")
|
||||
C.mlx_as_strided(
|
||||
&out.ctx, t.ctx,
|
||||
unsafe.SliceData(cShape), C.size_t(len(shape)),
|
||||
unsafe.SliceData(cStrides), C.size_t(len(strides)),
|
||||
C.size_t(offset),
|
||||
DefaultStream().ctx,
|
||||
)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Concatenate(axis int, others ...*Array) *Array {
|
||||
if len(others) == 0 {
|
||||
return t.Clone()
|
||||
}
|
||||
|
||||
vector := C.mlx_vector_array_new()
|
||||
defer C.mlx_vector_array_free(vector)
|
||||
|
||||
s := append([]*Array{t}, others...)
|
||||
for _, other := range s {
|
||||
C.mlx_vector_array_append_value(vector, other.ctx)
|
||||
}
|
||||
|
||||
out := New("CONCATENATE")
|
||||
C.mlx_concatenate_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Cumsum(axis int, reverse, inclusive bool) *Array {
|
||||
out := New("CUMSUM")
|
||||
C.mlx_cumsum(&out.ctx, t.ctx, C.int(axis), C.bool(reverse), C.bool(inclusive), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Divide(other *Array) *Array {
|
||||
out := New("DIVIDE")
|
||||
C.mlx_divide(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) ExpandDims(axis int) *Array {
|
||||
out := New("EXPAND_DIMS")
|
||||
C.mlx_expand_dims(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Flatten(startAxis, endAxis int) *Array {
|
||||
out := New("FLATTEN")
|
||||
C.mlx_flatten(&out.ctx, t.ctx, C.int(startAxis), C.int(endAxis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) FloorDivide(other *Array) *Array {
|
||||
out := New("FLOOR_DIVIDE")
|
||||
C.mlx_floor_divide(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) GatherMM(other, lhs, rhs *Array, sorted bool) *Array {
|
||||
if lhs == nil {
|
||||
lhs = New("")
|
||||
}
|
||||
if rhs == nil {
|
||||
rhs = New("")
|
||||
}
|
||||
out := New("GATHER_MM")
|
||||
C.mlx_gather_mm(&out.ctx, t.ctx, other.ctx, lhs.ctx, rhs.ctx, C.bool(sorted), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) LogsumexpAxis(axis int, keepDims bool) *Array {
|
||||
out := New("LOGSUMEXP_AXIS")
|
||||
C.mlx_logsumexp_axis(&out.ctx, t.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Equal(other *Array) *Array {
|
||||
out := New("EQUAL")
|
||||
C.mlx_equal(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Greater(other *Array) *Array {
|
||||
out := New("GREATER")
|
||||
C.mlx_greater(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Less(other *Array) *Array {
|
||||
out := New("LESS")
|
||||
C.mlx_less(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) LessEqual(other *Array) *Array {
|
||||
out := New("LESS_EQUAL")
|
||||
C.mlx_less_equal(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) MaxAxis(axis int, keepDims bool) *Array {
|
||||
out := New("MAX_AXIS")
|
||||
C.mlx_max_axis(&out.ctx, t.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Matmul(other *Array) *Array {
|
||||
out := New("MATMUL")
|
||||
C.mlx_matmul(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Multiply(other *Array) *Array {
|
||||
out := New("MULTIPLY")
|
||||
C.mlx_multiply(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Negative() *Array {
|
||||
out := New("NEGATIVE")
|
||||
C.mlx_negative(&out.ctx, t.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Power(exponent *Array) *Array {
|
||||
out := New("POWER")
|
||||
C.mlx_power(&out.ctx, t.ctx, exponent.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) PutAlongAxis(indices, values *Array, axis int) *Array {
|
||||
out := New("PUT_ALONG_AXIS")
|
||||
C.mlx_put_along_axis(&out.ctx, t.ctx, indices.ctx, values.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) ScatterAddAxis(indices, values *Array, axis int) *Array {
|
||||
out := New("SCATTER_ADD_AXIS")
|
||||
C.mlx_scatter_add_axis(&out.ctx, t.ctx, indices.ctx, values.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Reshape(axes ...int) *Array {
|
||||
cAxes := make([]C.int, len(axes))
|
||||
for i := range axes {
|
||||
cAxes[i] = C.int(axes[i])
|
||||
}
|
||||
|
||||
out := New("RESHAPE")
|
||||
C.mlx_reshape(&out.ctx, t.ctx, unsafe.SliceData(cAxes), C.size_t(len(cAxes)), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Sigmoid() *Array {
|
||||
out := New("SIGMOID")
|
||||
C.mlx_sigmoid(&out.ctx, t.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Sqrt() *Array {
|
||||
out := New("SQRT")
|
||||
C.mlx_sqrt(&out.ctx, t.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Squeeze(axis int) *Array {
|
||||
out := New("SQUEEZE")
|
||||
C.mlx_squeeze_axis(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) StackAxis(axis int, others ...*Array) *Array {
|
||||
vectorData := make([]C.mlx_array, len(others)+1)
|
||||
vectorData[0] = t.ctx
|
||||
for i := range others {
|
||||
vectorData[i+1] = others[i].ctx
|
||||
}
|
||||
|
||||
vector := C.mlx_vector_array_new_data(unsafe.SliceData(vectorData), C.size_t(len(vectorData)))
|
||||
defer C.mlx_vector_array_free(vector)
|
||||
|
||||
out := New("STACK_AXIS")
|
||||
C.mlx_stack_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Subtract(other *Array) *Array {
|
||||
out := New("SUBTRACT")
|
||||
C.mlx_subtract(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) SumAxis(axis int, keepDims bool) *Array {
|
||||
out := New("SUM_AXIS")
|
||||
C.mlx_sum_axis(&out.ctx, t.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) TakeAxis(indices *Array, axis int) *Array {
|
||||
out := New("TAKE_AXIS")
|
||||
C.mlx_take_axis(&out.ctx, t.ctx, indices.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) TakeAlongAxis(indices *Array, axis int) *Array {
|
||||
out := New("TAKE_ALONG_AXIS")
|
||||
C.mlx_take_along_axis(&out.ctx, t.ctx, indices.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Tanh() *Array {
|
||||
out := New("TANH")
|
||||
C.mlx_tanh(&out.ctx, t.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Transpose(axes ...int) *Array {
|
||||
cAxes := make([]C.int, len(axes))
|
||||
for i, axis := range axes {
|
||||
cAxes[i] = C.int(axis)
|
||||
}
|
||||
|
||||
out := New("TRANSPOSE")
|
||||
C.mlx_transpose_axes(&out.ctx, t.ctx, unsafe.SliceData(cAxes), C.size_t(len(cAxes)), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func Zeros(dtype DType, shape ...int) *Array {
|
||||
cAxes := make([]C.int, len(shape))
|
||||
for i := range shape {
|
||||
cAxes[i] = C.int(shape[i])
|
||||
}
|
||||
|
||||
t := New("ZEROS")
|
||||
C.mlx_zeros(&t.ctx, unsafe.SliceData(cAxes), C.size_t(len(cAxes)), C.mlx_dtype(dtype), DefaultStream().ctx)
|
||||
return t
|
||||
}
|
||||
666
x/mlxrunner/mlx/ops_extra.go
Normal file
666
x/mlxrunner/mlx/ops_extra.go
Normal file
@@ -0,0 +1,666 @@
|
||||
package mlx
|
||||
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// Quantization operations
|
||||
|
||||
func Quantize(w *Array, groupSize, bits int, mode string) (weights, scales, biases *Array) {
|
||||
cMode := C.CString(mode)
|
||||
defer C.free(unsafe.Pointer(cMode))
|
||||
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
|
||||
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
|
||||
res := C.mlx_vector_array_new()
|
||||
defer C.mlx_vector_array_free(res)
|
||||
var globalScale C.mlx_array
|
||||
C.mlx_quantize(&res, w.ctx, optGroupSize, optBits, cMode, globalScale, DefaultStream().ctx)
|
||||
|
||||
vecSize := int(C.mlx_vector_array_size(res))
|
||||
w0 := New("QUANTIZE_W")
|
||||
C.mlx_vector_array_get(&w0.ctx, res, 0)
|
||||
w1 := New("QUANTIZE_S")
|
||||
C.mlx_vector_array_get(&w1.ctx, res, 1)
|
||||
if vecSize >= 3 {
|
||||
w2 := New("QUANTIZE_B")
|
||||
C.mlx_vector_array_get(&w2.ctx, res, 2)
|
||||
return w0, w1, w2
|
||||
}
|
||||
return w0, w1, nil
|
||||
}
|
||||
|
||||
func FromFP8(x *Array, dtype DType) *Array {
|
||||
out := New("FROM_FP8")
|
||||
C.mlx_from_fp8(&out.ctx, x.ctx, C.mlx_dtype(dtype), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func ToFP8(x *Array) *Array {
|
||||
out := New("TO_FP8")
|
||||
C.mlx_to_fp8(&out.ctx, x.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func Dequantize(w, scales, biases *Array, groupSize, bits int, mode string) *Array {
|
||||
cMode := C.CString(mode)
|
||||
defer C.free(unsafe.Pointer(cMode))
|
||||
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
|
||||
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
|
||||
optDtype := C.mlx_optional_dtype{has_value: false}
|
||||
|
||||
var b C.mlx_array
|
||||
if biases != nil {
|
||||
b = biases.ctx
|
||||
}
|
||||
|
||||
out := New("DEQUANTIZE")
|
||||
var globalScale C.mlx_array
|
||||
C.mlx_dequantize(&out.ctx, w.ctx, scales.ctx, b, optGroupSize, optBits, cMode, globalScale, optDtype, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func QuantizedMatmul(x, w, scales, biases *Array, transpose bool, groupSize, bits int, mode string) *Array {
|
||||
cMode := C.CString(mode)
|
||||
defer C.free(unsafe.Pointer(cMode))
|
||||
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
|
||||
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
|
||||
|
||||
var b C.mlx_array
|
||||
if biases != nil {
|
||||
b = biases.ctx
|
||||
}
|
||||
|
||||
out := New("QUANTIZED_MATMUL")
|
||||
C.mlx_quantized_matmul(&out.ctx, x.ctx, w.ctx, scales.ctx, b, C.bool(transpose), optGroupSize, optBits, cMode, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func GatherQMM(x, w, scales *Array, biases, lhsIndices, rhsIndices *Array, transpose bool, groupSize, bits int, mode string, sortedIndices bool) *Array {
|
||||
cMode := C.CString(mode)
|
||||
defer C.free(unsafe.Pointer(cMode))
|
||||
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
|
||||
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
|
||||
|
||||
var b, lhs, rhs C.mlx_array
|
||||
if biases != nil {
|
||||
b = biases.ctx
|
||||
}
|
||||
if lhsIndices != nil {
|
||||
lhs = lhsIndices.ctx
|
||||
}
|
||||
if rhsIndices != nil {
|
||||
rhs = rhsIndices.ctx
|
||||
}
|
||||
|
||||
out := New("GATHER_QMM")
|
||||
C.mlx_gather_qmm(&out.ctx, x.ctx, w.ctx, scales.ctx, b, lhs, rhs, C.bool(transpose), optGroupSize, optBits, cMode, C.bool(sortedIndices), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Missing tensor ops
|
||||
|
||||
func Tile(a *Array, reps []int32) *Array {
|
||||
cReps := make([]C.int, len(reps))
|
||||
for i, r := range reps {
|
||||
cReps[i] = C.int(r)
|
||||
}
|
||||
out := New("TILE")
|
||||
C.mlx_tile(&out.ctx, a.ctx, unsafe.SliceData(cReps), C.size_t(len(reps)), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func Tri(n, m int32, k int) *Array {
|
||||
out := New("TRI")
|
||||
C.mlx_tri(&out.ctx, C.int(n), C.int(m), C.int(k), C.mlx_dtype(DTypeFloat32), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func Where(condition, a, b *Array) *Array {
|
||||
out := New("WHERE")
|
||||
C.mlx_where(&out.ctx, condition.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func Conv1d(x, weight *Array, bias *Array, stride, padding, dilation, groups int32) *Array {
|
||||
out := New("CONV1D")
|
||||
C.mlx_conv1d(
|
||||
&out.ctx,
|
||||
x.ctx,
|
||||
weight.ctx,
|
||||
C.int(stride),
|
||||
C.int(padding),
|
||||
C.int(dilation),
|
||||
C.int(groups),
|
||||
DefaultStream().ctx,
|
||||
)
|
||||
if bias != nil && bias.Valid() {
|
||||
out = Add(out, bias)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func Contiguous(a *Array, allowColMajor bool) *Array {
|
||||
out := New("CONTIGUOUS")
|
||||
C.mlx_contiguous(&out.ctx, a.ctx, C.bool(allowColMajor), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Conv2d performs 2D convolution: x [N,H,W,C_in], weight [C_out,kH,kW,C_in].
|
||||
// MLX uses NHWC layout.
|
||||
func Conv2d(x, weight *Array, strideH, strideW, padH, padW, dilationH, dilationW, groups int32) *Array {
|
||||
out := New("CONV2D")
|
||||
C.mlx_conv2d(
|
||||
&out.ctx,
|
||||
x.ctx,
|
||||
weight.ctx,
|
||||
C.int(strideH), C.int(strideW),
|
||||
C.int(padH), C.int(padW),
|
||||
C.int(dilationH), C.int(dilationW),
|
||||
C.int(groups),
|
||||
DefaultStream().ctx,
|
||||
)
|
||||
return out
|
||||
}
|
||||
|
||||
// Pad pads array a along the given axes with specified low/high pad sizes.
|
||||
// mode should be "constant", "edge", or "reflect".
|
||||
func Pad(a *Array, axes []int, lowPad, highPad []int, padValue *Array, mode string) *Array {
|
||||
cAxes := make([]C.int, len(axes))
|
||||
cLow := make([]C.int, len(lowPad))
|
||||
cHigh := make([]C.int, len(highPad))
|
||||
for i := range axes {
|
||||
cAxes[i] = C.int(axes[i])
|
||||
cLow[i] = C.int(lowPad[i])
|
||||
cHigh[i] = C.int(highPad[i])
|
||||
}
|
||||
cMode := C.CString(mode)
|
||||
defer C.free(unsafe.Pointer(cMode))
|
||||
out := New("PAD")
|
||||
C.mlx_pad(
|
||||
&out.ctx,
|
||||
a.ctx,
|
||||
unsafe.SliceData(cAxes), C.size_t(len(cAxes)),
|
||||
unsafe.SliceData(cLow), C.size_t(len(cLow)),
|
||||
unsafe.SliceData(cHigh), C.size_t(len(cHigh)),
|
||||
padValue.ctx,
|
||||
cMode,
|
||||
DefaultStream().ctx,
|
||||
)
|
||||
return out
|
||||
}
|
||||
|
||||
// PadConstant pads with zeros along the given axes.
|
||||
func PadConstant(a *Array, axes []int, lowPad, highPad []int) *Array {
|
||||
zero := NewScalarArray(float32(0))
|
||||
return Pad(a, axes, lowPad, highPad, zero, "constant")
|
||||
}
|
||||
|
||||
func DepthwiseConv1d(x, weight *Array, bias *Array) *Array {
|
||||
groups := int32(x.Dim(x.NumDims() - 1))
|
||||
return Conv1d(x, weight, bias, 1, 0, 1, groups)
|
||||
}
|
||||
|
||||
// Maximum returns element-wise maximum of two arrays.
|
||||
func Maximum(a, b *Array) *Array {
|
||||
out := New("MAXIMUM")
|
||||
C.mlx_maximum(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Minimum returns element-wise minimum of two arrays.
|
||||
func Minimum(a, b *Array) *Array {
|
||||
out := New("MINIMUM")
|
||||
C.mlx_minimum(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Softplus computes log(1 + exp(x)) using logaddexp for numerical stability.
|
||||
func Softplus(a *Array) *Array {
|
||||
return Logaddexp(a, Zeros(a.DType(), a.Dims()...))
|
||||
}
|
||||
|
||||
// ReLU computes max(0, x).
|
||||
func ReLU(a *Array) *Array {
|
||||
return Maximum(a, NewScalarArray(float32(0)))
|
||||
}
|
||||
|
||||
// GLU applies Gated Linear Unit: splits x along last dim into two halves,
|
||||
// returns first * sigmoid(second).
|
||||
func GLU(a *Array) *Array {
|
||||
lastDim := a.NumDims() - 1
|
||||
halfSize := a.Dim(lastDim) / 2
|
||||
first := SliceStartStop(a,
|
||||
make([]int32, lastDim+1), // all zeros for start
|
||||
appendDims(a, lastDim, int32(halfSize)),
|
||||
)
|
||||
second := SliceStartStop(a,
|
||||
appendDimsStart(a, lastDim, int32(halfSize)),
|
||||
appendDims(a, lastDim, int32(a.Dim(lastDim))),
|
||||
)
|
||||
return first.Multiply(second.Sigmoid())
|
||||
}
|
||||
|
||||
// helper: builds stop array for SliceStartStop where the target axis = val
|
||||
func appendDims(a *Array, targetAxis int, val int32) []int32 {
|
||||
n := a.NumDims()
|
||||
out := make([]int32, n)
|
||||
for i := range n {
|
||||
if i == targetAxis {
|
||||
out[i] = val
|
||||
} else {
|
||||
out[i] = int32(a.Dim(i))
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// helper: builds start array for SliceStartStop where the target axis = val
|
||||
func appendDimsStart(a *Array, targetAxis int, val int32) []int32 {
|
||||
n := a.NumDims()
|
||||
out := make([]int32, n)
|
||||
for i := range n {
|
||||
if i == targetAxis {
|
||||
out[i] = val
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// Clamp clamps array values to [min, max].
|
||||
func Clamp(a *Array, minVal, maxVal float32) *Array {
|
||||
return Minimum(Maximum(a, NewScalarArray(minVal)), NewScalarArray(maxVal))
|
||||
}
|
||||
|
||||
// Convenience wrappers (function-style for the model code)
|
||||
|
||||
func Stack(arrays []*Array, axis int) *Array {
|
||||
vectorData := make([]C.mlx_array, len(arrays))
|
||||
for i := range arrays {
|
||||
vectorData[i] = arrays[i].ctx
|
||||
}
|
||||
vector := C.mlx_vector_array_new_data(unsafe.SliceData(vectorData), C.size_t(len(vectorData)))
|
||||
defer C.mlx_vector_array_free(vector)
|
||||
|
||||
out := New("STACK")
|
||||
C.mlx_stack_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func Neg(a *Array) *Array {
|
||||
return a.Negative()
|
||||
}
|
||||
|
||||
func Sum(a *Array, axis int, keepDims bool) *Array {
|
||||
return a.SumAxis(axis, keepDims)
|
||||
}
|
||||
|
||||
func Argsort(a *Array, axis int) *Array {
|
||||
return a.ArgsortAxis(axis)
|
||||
}
|
||||
|
||||
func Take(a *Array, indices *Array, axis int) *Array {
|
||||
return a.TakeAxis(indices, axis)
|
||||
}
|
||||
|
||||
func RSqrt(a *Array) *Array {
|
||||
out := New("RSQRT")
|
||||
C.mlx_rsqrt(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func Mean(a *Array, axis int, keepDims bool) *Array {
|
||||
out := New("MEAN_AXIS")
|
||||
C.mlx_mean_axis(&out.ctx, a.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func Argpartition(a *Array, kth int, axis int) *Array {
|
||||
return a.ArgpartitionAxis(kth, axis)
|
||||
}
|
||||
|
||||
func TakeAlongAxis(a, indices *Array, axis int) *Array {
|
||||
return a.TakeAlongAxis(indices, axis)
|
||||
}
|
||||
|
||||
// Function-style wrappers matching imagegen API
|
||||
|
||||
func Add(a, b *Array) *Array {
|
||||
return a.Add(b)
|
||||
}
|
||||
|
||||
func Sub(a, b *Array) *Array {
|
||||
return a.Subtract(b)
|
||||
}
|
||||
|
||||
func Mul(a, b *Array) *Array {
|
||||
return a.Multiply(b)
|
||||
}
|
||||
|
||||
func Div(a, b *Array) *Array {
|
||||
return a.Divide(b)
|
||||
}
|
||||
|
||||
func Matmul(a, b *Array) *Array {
|
||||
return a.Matmul(b)
|
||||
}
|
||||
|
||||
func Reshape(a *Array, shape ...int32) *Array {
|
||||
axes := make([]int, len(shape))
|
||||
for i, s := range shape {
|
||||
axes[i] = int(s)
|
||||
}
|
||||
return a.Reshape(axes...)
|
||||
}
|
||||
|
||||
func Transpose(a *Array, axes ...int) *Array {
|
||||
return a.Transpose(axes...)
|
||||
}
|
||||
|
||||
func ExpandDims(a *Array, axis int) *Array {
|
||||
return a.ExpandDims(axis)
|
||||
}
|
||||
|
||||
func Squeeze(a *Array, axis int) *Array {
|
||||
return a.Squeeze(axis)
|
||||
}
|
||||
|
||||
func Flatten(a *Array) *Array {
|
||||
return a.Flatten(0, -1)
|
||||
}
|
||||
|
||||
func Concatenate(arrays []*Array, axis int) *Array {
|
||||
if len(arrays) == 0 {
|
||||
return nil
|
||||
}
|
||||
if len(arrays) == 1 {
|
||||
return arrays[0].Clone()
|
||||
}
|
||||
return arrays[0].Concatenate(axis, arrays[1:]...)
|
||||
}
|
||||
|
||||
func SliceStartStop(a *Array, start, stop []int32) *Array {
|
||||
n := len(start)
|
||||
cStart := make([]C.int, n)
|
||||
cStop := make([]C.int, n)
|
||||
cStrides := make([]C.int, n)
|
||||
for i := 0; i < n; i++ {
|
||||
cStart[i] = C.int(start[i])
|
||||
cStop[i] = C.int(stop[i])
|
||||
cStrides[i] = 1
|
||||
}
|
||||
out := New("SLICE")
|
||||
C.mlx_slice(&out.ctx, a.ctx, unsafe.SliceData(cStart), C.size_t(n), unsafe.SliceData(cStop), C.size_t(n), unsafe.SliceData(cStrides), C.size_t(n), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func GatherMM(a, b *Array, lhsIndices, rhsIndices *Array, sortedIndices bool) *Array {
|
||||
if lhsIndices == nil {
|
||||
lhsIndices = New("")
|
||||
}
|
||||
if rhsIndices == nil {
|
||||
rhsIndices = New("")
|
||||
}
|
||||
return a.GatherMM(b, lhsIndices, rhsIndices, sortedIndices)
|
||||
}
|
||||
|
||||
// RoPEWithBase applies rotary position embeddings to x. offsets is an
|
||||
// int32 array of shape [B] giving each batch row's starting position;
|
||||
// the kernel applies positions offsets[b] + 0..T-1 per row.
|
||||
func RoPEWithBase(x *Array, dims int, traditional bool, base, scale float32, offsets *Array) *Array {
|
||||
return RoPEWithFreqs(x, dims, traditional, base, scale, offsets, nil)
|
||||
}
|
||||
|
||||
// RoPEWithFreqs applies RoPE with optional custom frequencies.
|
||||
// When freqs is non-nil, it is used instead of computing from base.
|
||||
// Note: MLX takes reciprocal(freqs) internally to get inv_freq, so pass
|
||||
// the actual frequencies (base^(2i/dim)), not the inverse frequencies.
|
||||
func RoPEWithFreqs(x *Array, dims int, traditional bool, base, scale float32, offsets *Array, freqs *Array) *Array {
|
||||
var freqsCtx C.mlx_array
|
||||
var optBase C.mlx_optional_float
|
||||
if freqs != nil {
|
||||
freqsCtx = freqs.ctx
|
||||
optBase = C.mlx_optional_float{has_value: C.bool(false)}
|
||||
} else {
|
||||
empty := New("")
|
||||
freqsCtx = empty.ctx
|
||||
optBase = C.mlx_optional_float{
|
||||
value: C.float(base),
|
||||
has_value: C.bool(func() bool { return base != 0 }()),
|
||||
}
|
||||
}
|
||||
out := New("FAST_ROPE")
|
||||
C.mlx_fast_rope_dynamic(
|
||||
&out.ctx,
|
||||
x.ctx,
|
||||
C.int(dims),
|
||||
C.bool(traditional),
|
||||
optBase,
|
||||
C.float(scale),
|
||||
offsets.ctx,
|
||||
freqsCtx,
|
||||
DefaultStream().ctx,
|
||||
)
|
||||
return out
|
||||
}
|
||||
|
||||
func Sigmoid(a *Array) *Array {
|
||||
return a.Sigmoid()
|
||||
}
|
||||
|
||||
func Exp(a *Array) *Array {
|
||||
out := New("EXP")
|
||||
C.mlx_exp(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func Log(a *Array) *Array {
|
||||
out := New("LOG")
|
||||
C.mlx_log(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func Sin(a *Array) *Array {
|
||||
out := New("SIN")
|
||||
C.mlx_sin(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func Cos(a *Array) *Array {
|
||||
out := New("COS")
|
||||
C.mlx_cos(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func Clip(a, aMin, aMax *Array) *Array {
|
||||
out := New("CLIP")
|
||||
C.mlx_clip(&out.ctx, a.ctx, aMin.ctx, aMax.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func Logaddexp(a, b *Array) *Array {
|
||||
out := New("LOGADDEXP")
|
||||
C.mlx_logaddexp(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func SoftmaxAxis(a *Array, axis int, precise bool) *Array {
|
||||
out := New("SOFTMAX_AXIS")
|
||||
C.mlx_softmax_axis(&out.ctx, a.ctx, C.int(axis), C.bool(precise), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func LayerNormFn(x, weight, bias *Array, eps float32) *Array {
|
||||
out := New("FAST_LAYERNORM")
|
||||
var w, b C.mlx_array
|
||||
if weight != nil {
|
||||
w = weight.ctx
|
||||
}
|
||||
if bias != nil {
|
||||
b = bias.ctx
|
||||
}
|
||||
C.mlx_fast_layer_norm(&out.ctx, x.ctx, w, b, C.float(eps), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func RMSNormFn(x, weight *Array, eps float32) *Array {
|
||||
out := New("FAST_RMSNORM")
|
||||
var w C.mlx_array
|
||||
if weight != nil {
|
||||
w = weight.ctx
|
||||
}
|
||||
C.mlx_fast_rms_norm(&out.ctx, x.ctx, w, C.float(eps), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func AddMM(c, a, b *Array, alpha, beta float32) *Array {
|
||||
return c.Addmm(a, b, alpha, beta)
|
||||
}
|
||||
|
||||
// Scalar helpers
|
||||
|
||||
// scalarWithDtype creates a scalar array matching the dtype of a.
|
||||
// Matching dtype is important for graph fusion and avoiding implicit casts.
|
||||
func scalarWithDtype(s float32, a *Array) C.mlx_array {
|
||||
f32 := C.mlx_array_new_float(C.float(s))
|
||||
dtype := a.DType()
|
||||
if dtype == DTypeFloat32 {
|
||||
return f32
|
||||
}
|
||||
casted := C.mlx_array_new()
|
||||
C.mlx_astype(&casted, f32, C.mlx_dtype(dtype), DefaultStream().ctx)
|
||||
C.mlx_array_free(f32)
|
||||
return casted
|
||||
}
|
||||
|
||||
func AddScalar(a *Array, s float32) *Array {
|
||||
scalar := scalarWithDtype(s, a)
|
||||
out := New("ADD_SCALAR")
|
||||
C.mlx_add(&out.ctx, a.ctx, scalar, DefaultStream().ctx)
|
||||
C.mlx_array_free(scalar)
|
||||
return out
|
||||
}
|
||||
|
||||
func MulScalar(a *Array, s float32) *Array {
|
||||
scalar := scalarWithDtype(s, a)
|
||||
out := New("MUL_SCALAR")
|
||||
C.mlx_multiply(&out.ctx, a.ctx, scalar, DefaultStream().ctx)
|
||||
C.mlx_array_free(scalar)
|
||||
return out
|
||||
}
|
||||
|
||||
func DivScalar(a *Array, s float32) *Array {
|
||||
scalar := scalarWithDtype(s, a)
|
||||
out := New("DIV_SCALAR")
|
||||
C.mlx_divide(&out.ctx, a.ctx, scalar, DefaultStream().ctx)
|
||||
C.mlx_array_free(scalar)
|
||||
return out
|
||||
}
|
||||
|
||||
func FloorDivideScalar(a *Array, s int32) *Array {
|
||||
scalar := FromValue(int(s))
|
||||
return a.FloorDivide(scalar)
|
||||
}
|
||||
|
||||
// Array constructors
|
||||
|
||||
func NewArrayInt32(data []int32, shape []int32) *Array {
|
||||
cShape := make([]C.int, len(shape))
|
||||
for i, s := range shape {
|
||||
cShape[i] = C.int(s)
|
||||
}
|
||||
out := New("NEW_ARRAY_INT32")
|
||||
out.ctx = C.mlx_array_new_data(unsafe.Pointer(&data[0]), unsafe.SliceData(cShape), C.int(len(shape)), C.mlx_dtype(DTypeInt32))
|
||||
return out
|
||||
}
|
||||
|
||||
func NewScalarArray(value float32) *Array {
|
||||
out := New("SCALAR")
|
||||
out.ctx = C.mlx_array_new_float32(C.float(value))
|
||||
return out
|
||||
}
|
||||
|
||||
func ZerosF32(shape []int32) *Array {
|
||||
return Zeros(DTypeFloat32, func() []int {
|
||||
ints := make([]int, len(shape))
|
||||
for i, s := range shape {
|
||||
ints[i] = int(s)
|
||||
}
|
||||
return ints
|
||||
}()...)
|
||||
}
|
||||
|
||||
// Utility
|
||||
|
||||
func Collect(v any) []*Array {
|
||||
var arrays []*Array
|
||||
seen := make(map[uintptr]bool)
|
||||
collect(reflect.ValueOf(v), &arrays, seen)
|
||||
return arrays
|
||||
}
|
||||
|
||||
func collect(v reflect.Value, arrays *[]*Array, seen map[uintptr]bool) {
|
||||
if !v.IsValid() {
|
||||
return
|
||||
}
|
||||
|
||||
if v.Kind() == reflect.Ptr {
|
||||
if v.IsNil() {
|
||||
return
|
||||
}
|
||||
ptr := v.Pointer()
|
||||
if seen[ptr] {
|
||||
return
|
||||
}
|
||||
seen[ptr] = true
|
||||
|
||||
if arr, ok := v.Interface().(*Array); ok {
|
||||
if arr != nil && arr.Valid() {
|
||||
*arrays = append(*arrays, arr)
|
||||
}
|
||||
return
|
||||
}
|
||||
collect(v.Elem(), arrays, seen)
|
||||
return
|
||||
}
|
||||
|
||||
switch v.Kind() {
|
||||
case reflect.Struct:
|
||||
// Check if this struct IS an Array (not a pointer to one)
|
||||
if arr, ok := v.Addr().Interface().(*Array); ok {
|
||||
if arr != nil && arr.Valid() {
|
||||
*arrays = append(*arrays, arr)
|
||||
}
|
||||
return
|
||||
}
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
field := v.Field(i)
|
||||
if field.CanInterface() {
|
||||
collect(field, arrays, seen)
|
||||
}
|
||||
}
|
||||
case reflect.Slice:
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
collect(v.Index(i), arrays, seen)
|
||||
}
|
||||
case reflect.Map:
|
||||
for _, key := range v.MapKeys() {
|
||||
collect(v.MapIndex(key), arrays, seen)
|
||||
}
|
||||
case reflect.Interface:
|
||||
if !v.IsNil() {
|
||||
collect(v.Elem(), arrays, seen)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func EnableCompile() {
|
||||
C.mlx_enable_compile()
|
||||
}
|
||||
|
||||
func DisableCompile() {
|
||||
C.mlx_disable_compile()
|
||||
}
|
||||
44
x/mlxrunner/mlx/random.go
Normal file
44
x/mlxrunner/mlx/random.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package mlx
|
||||
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
|
||||
import "unsafe"
|
||||
|
||||
func RandomKey(seed uint64) *Array {
|
||||
out := New("RANDOM_KEY")
|
||||
C.mlx_random_key(&out.ctx, C.uint64_t(seed))
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) Categorical(axis int) *Array {
|
||||
return t.CategoricalWithKey(axis, nil)
|
||||
}
|
||||
|
||||
func (t *Array) CategoricalWithKey(axis int, key *Array) *Array {
|
||||
if key == nil {
|
||||
key = New("")
|
||||
}
|
||||
out := New("")
|
||||
C.mlx_random_categorical(&out.ctx, t.ctx, C.int(axis), key.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
func Bernoulli(p *Array) *Array {
|
||||
return BernoulliWithKey(p, nil)
|
||||
}
|
||||
|
||||
func BernoulliWithKey(p *Array, key *Array) *Array {
|
||||
dims := p.Dims()
|
||||
shape := make([]C.int, len(dims))
|
||||
for i, d := range dims {
|
||||
shape[i] = C.int(d)
|
||||
}
|
||||
|
||||
if key == nil {
|
||||
key = New("")
|
||||
}
|
||||
out := New("BERNOULLI")
|
||||
C.mlx_random_bernoulli(&out.ctx, p.ctx, unsafe.SliceData(shape), C.size_t(len(shape)), key.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
100
x/mlxrunner/mlx/slice.go
Normal file
100
x/mlxrunner/mlx/slice.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package mlx
|
||||
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"math"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// End is a sentinel value meaning "to the end of the dimension",
|
||||
// equivalent to an omitted stop in Python (e.g. a[i:]).
|
||||
const End = math.MaxInt32
|
||||
|
||||
type slice struct {
|
||||
args []int
|
||||
}
|
||||
|
||||
func Slice(args ...int) slice {
|
||||
return slice{args: args}
|
||||
}
|
||||
|
||||
func resolve(val, dim int) C.int {
|
||||
if val == End {
|
||||
return C.int(dim)
|
||||
}
|
||||
if val < 0 {
|
||||
return C.int(dim + val)
|
||||
}
|
||||
return C.int(val)
|
||||
}
|
||||
|
||||
func makeSlices(dims []int, slices ...slice) (starts, stops, strides []C.int) {
|
||||
if len(slices) != len(dims) {
|
||||
panic("number of slice arguments must match number of tensor dimensions")
|
||||
}
|
||||
|
||||
args := [3][]C.int{
|
||||
make([]C.int, len(slices)),
|
||||
make([]C.int, len(slices)),
|
||||
make([]C.int, len(slices)),
|
||||
}
|
||||
|
||||
for i, s := range slices {
|
||||
dim := dims[i]
|
||||
switch len(s.args) {
|
||||
case 0:
|
||||
// slice[:]
|
||||
args[0][i] = C.int(0)
|
||||
args[1][i] = C.int(dim)
|
||||
args[2][i] = C.int(1)
|
||||
case 1:
|
||||
// slice[i]
|
||||
start := resolve(s.args[0], dim)
|
||||
args[0][i] = start
|
||||
args[1][i] = start + 1
|
||||
args[2][i] = C.int(1)
|
||||
case 2:
|
||||
// slice[i:j]
|
||||
args[0][i] = resolve(s.args[0], dim)
|
||||
args[1][i] = resolve(s.args[1], dim)
|
||||
args[2][i] = C.int(1)
|
||||
case 3:
|
||||
// slice[i:j:k]
|
||||
args[0][i] = resolve(s.args[0], dim)
|
||||
args[1][i] = resolve(s.args[1], dim)
|
||||
args[2][i] = C.int(s.args[2])
|
||||
default:
|
||||
panic("invalid slice arguments")
|
||||
}
|
||||
}
|
||||
|
||||
return args[0], args[1], args[2]
|
||||
}
|
||||
|
||||
func (t *Array) Slice(slices ...slice) *Array {
|
||||
starts, stops, strides := makeSlices(t.Dims(), slices...)
|
||||
out := New("SLICE")
|
||||
C.mlx_slice(
|
||||
&out.ctx, t.ctx,
|
||||
unsafe.SliceData(starts), C.size_t(len(starts)),
|
||||
unsafe.SliceData(stops), C.size_t(len(stops)),
|
||||
unsafe.SliceData(strides), C.size_t(len(strides)),
|
||||
DefaultStream().ctx,
|
||||
)
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *Array) SliceUpdate(other *Array, slices ...slice) *Array {
|
||||
starts, stops, strides := makeSlices(t.Dims(), slices...)
|
||||
out := New("SLICE_UPDATE")
|
||||
C.mlx_slice_update(
|
||||
&out.ctx, t.ctx, other.ctx,
|
||||
unsafe.SliceData(starts), C.size_t(len(starts)),
|
||||
unsafe.SliceData(stops), C.size_t(len(stops)),
|
||||
unsafe.SliceData(strides), C.size_t(len(strides)),
|
||||
DefaultStream().ctx,
|
||||
)
|
||||
return out
|
||||
}
|
||||
79
x/mlxrunner/mlx/stream.go
Normal file
79
x/mlxrunner/mlx/stream.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package mlx
|
||||
|
||||
// #include "generated.h"
|
||||
import "C"
|
||||
|
||||
import "log/slog"
|
||||
|
||||
type Device struct {
|
||||
ctx C.mlx_device
|
||||
}
|
||||
|
||||
func (d Device) LogValue() slog.Value {
|
||||
str := C.mlx_string_new()
|
||||
defer C.mlx_string_free(str)
|
||||
C.mlx_device_tostring(&str, d.ctx)
|
||||
return slog.StringValue(C.GoString(C.mlx_string_data(str)))
|
||||
}
|
||||
|
||||
var (
|
||||
defaultDevice Device
|
||||
defaultDeviceSet bool
|
||||
defaultStream Stream
|
||||
defaultStreamSet bool
|
||||
)
|
||||
|
||||
func resetDefaultStreamCache() {
|
||||
defaultDeviceSet = false
|
||||
defaultStreamSet = false
|
||||
}
|
||||
|
||||
func DefaultDevice() Device {
|
||||
if !defaultDeviceSet {
|
||||
d := C.mlx_device_new()
|
||||
C.mlx_get_default_device(&d)
|
||||
defaultDevice = Device{d}
|
||||
defaultDeviceSet = true
|
||||
}
|
||||
|
||||
return defaultDevice
|
||||
}
|
||||
|
||||
// GPUIsAvailable returns true if a GPU device is available.
|
||||
func GPUIsAvailable() bool {
|
||||
dev := C.mlx_device_new_type(C.MLX_GPU, 0)
|
||||
defer C.mlx_device_free(dev)
|
||||
var avail C.bool
|
||||
C.mlx_device_is_available(&avail, dev)
|
||||
return bool(avail)
|
||||
}
|
||||
|
||||
// SetDefaultDeviceGPU sets the default MLX device to GPU.
|
||||
func SetDefaultDeviceGPU() {
|
||||
dev := C.mlx_device_new_type(C.MLX_GPU, 0)
|
||||
C.mlx_set_default_device(dev)
|
||||
C.mlx_device_free(dev)
|
||||
resetDefaultStreamCache()
|
||||
}
|
||||
|
||||
type Stream struct {
|
||||
ctx C.mlx_stream
|
||||
}
|
||||
|
||||
func (s Stream) LogValue() slog.Value {
|
||||
str := C.mlx_string_new()
|
||||
defer C.mlx_string_free(str)
|
||||
C.mlx_stream_tostring(&str, s.ctx)
|
||||
return slog.StringValue(C.GoString(C.mlx_string_data(str)))
|
||||
}
|
||||
|
||||
func DefaultStream() Stream {
|
||||
if !defaultStreamSet {
|
||||
s := C.mlx_stream_new()
|
||||
C.mlx_get_default_stream(&s, DefaultDevice().ctx)
|
||||
defaultStream = Stream{s}
|
||||
defaultStreamSet = true
|
||||
}
|
||||
|
||||
return defaultStream
|
||||
}
|
||||
104
x/mlxrunner/mlx/thread_test.go
Normal file
104
x/mlxrunner/mlx/thread_test.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package mlx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"runtime"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/internal/mlxthread"
|
||||
)
|
||||
|
||||
func skipIfNoMLX(t *testing.T) {
|
||||
t.Helper()
|
||||
if err := CheckInit(); err != nil {
|
||||
t.Skipf("MLX not available: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func startMLXThread(t *testing.T) *mlxthread.Thread {
|
||||
t.Helper()
|
||||
|
||||
thread, err := mlxthread.Start("mlx-test", func() error {
|
||||
if err := CheckInit(); err != nil {
|
||||
return err
|
||||
}
|
||||
if GPUIsAvailable() {
|
||||
SetDefaultDeviceGPU()
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Skipf("MLX not available: %v", err)
|
||||
}
|
||||
|
||||
return thread
|
||||
}
|
||||
|
||||
func stopMLXThread(t *testing.T, thread *mlxthread.Thread) {
|
||||
t.Helper()
|
||||
|
||||
if err := thread.Stop(context.Background(), func() {
|
||||
Sweep()
|
||||
ClearCache()
|
||||
resetDefaultStreamCache()
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func withMLXThread(t *testing.T, fn func()) {
|
||||
t.Helper()
|
||||
|
||||
thread := startMLXThread(t)
|
||||
defer stopMLXThread(t, thread)
|
||||
|
||||
if err := thread.Do(context.Background(), func() error {
|
||||
fn()
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestThreadedMLXOperations(t *testing.T) {
|
||||
thread := startMLXThread(t)
|
||||
defer stopMLXThread(t, thread)
|
||||
|
||||
oldProcs := runtime.GOMAXPROCS(8)
|
||||
defer runtime.GOMAXPROCS(oldProcs)
|
||||
|
||||
const goroutines = 8
|
||||
const iterations = 8
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errCh := make(chan error, goroutines)
|
||||
for range goroutines {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
for range iterations {
|
||||
if err := thread.Do(context.Background(), func() error {
|
||||
a := FromValues([]float32{1, 2, 3, 4}, 2, 2)
|
||||
b := Matmul(a, a)
|
||||
AsyncEval(b)
|
||||
Eval(b)
|
||||
Sweep()
|
||||
ClearCache()
|
||||
return nil
|
||||
}); err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errCh)
|
||||
|
||||
for err := range errCh {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
202
x/mlxrunner/model/base/base.go
Normal file
202
x/mlxrunner/model/base/base.go
Normal file
@@ -0,0 +1,202 @@
|
||||
package base
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/batch"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model"
|
||||
"github.com/ollama/ollama/x/tokenizer"
|
||||
)
|
||||
|
||||
// Model is the interface that model implementations must satisfy.
|
||||
type Model interface {
|
||||
Forward(b *batch.Batch, cache []cache.Cache) *mlx.Array
|
||||
Unembed(x *mlx.Array) *mlx.Array
|
||||
NumLayers() int
|
||||
Tokenizer() *tokenizer.Tokenizer
|
||||
MaxContextLength() int
|
||||
|
||||
// LoadWeights receives all tensors loaded from the manifest and assigns
|
||||
// them to model fields. Model-specific logic (MLA absorption, expert
|
||||
// stacking, quantized layer creation) happens here.
|
||||
LoadWeights(tensors map[string]*mlx.Array) error
|
||||
}
|
||||
|
||||
// DraftModel is an auxiliary model stored alongside a target model.
|
||||
type DraftModel interface {
|
||||
LoadWeights(tensors map[string]*mlx.Array) error
|
||||
}
|
||||
|
||||
// MTPDefaults holds model-provided draft-token defaults for speculative
|
||||
// decoding. Environment settings in the runner may override these values.
|
||||
type MTPDefaults struct {
|
||||
InitialDraftTokens int
|
||||
MaxDraftTokens int
|
||||
Enabled bool
|
||||
}
|
||||
|
||||
// MTPDefaultsProvider lets a model provide MTP policy defaults from its own
|
||||
// config without teaching the runner model-specific shape heuristics.
|
||||
type MTPDefaultsProvider interface {
|
||||
MTPDraftDefaults(sample bool) MTPDefaults
|
||||
}
|
||||
|
||||
// MTPDraftModel is a draft model capable of Gemma-style multi-token
|
||||
// prediction from target token embeddings, target hidden states, and target KV.
|
||||
type MTPDraftModel interface {
|
||||
Draft(inputEmbeds *mlx.Array, position int32, caches []cache.Cache) (logits, hidden *mlx.Array)
|
||||
}
|
||||
|
||||
// MTPEmbeddingModel exposes the target token embedding path used by MTP drafts.
|
||||
type MTPEmbeddingModel interface {
|
||||
TokenEmbeddings(inputIDs *mlx.Array) *mlx.Array
|
||||
}
|
||||
|
||||
// DFlashTargetModel exposes target-layer hidden states for DFlash drafts.
|
||||
type DFlashTargetModel interface {
|
||||
ForwardDFlash(b *batch.Batch, caches []cache.Cache, layerIDs []int) (hidden, targetHidden *mlx.Array)
|
||||
}
|
||||
|
||||
// DFlashDraftModel is a block-diffusion speculative draft model.
|
||||
type DFlashDraftModel interface {
|
||||
DraftModel
|
||||
|
||||
TargetLayerIDs() []int
|
||||
BlockSize() int
|
||||
MaskTokenID() int32
|
||||
NewCaches() []cache.Cache
|
||||
AppendContext(targetHidden *mlx.Array, caches []cache.Cache)
|
||||
Draft(inputIDs *mlx.Array, caches []cache.Cache) *mlx.Array
|
||||
}
|
||||
|
||||
var (
|
||||
mu sync.Mutex
|
||||
registry = make(map[string]func(root *model.Root) (Model, error))
|
||||
draftRegistry = make(map[string]func(root *model.Root, target Model) (DraftModel, error))
|
||||
)
|
||||
|
||||
// Register registers a model constructor by architecture name.
|
||||
// Called from init() in model packages. Panics on duplicate registration.
|
||||
func Register(arch string, fn func(root *model.Root) (Model, error)) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
if _, exists := registry[arch]; exists {
|
||||
panic(fmt.Sprintf("model architecture %q already registered", arch))
|
||||
}
|
||||
registry[arch] = fn
|
||||
}
|
||||
|
||||
// RegisterDraft registers a draft model constructor by architecture name.
|
||||
func RegisterDraft(arch string, fn func(root *model.Root, target Model) (DraftModel, error)) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
if _, exists := draftRegistry[arch]; exists {
|
||||
panic(fmt.Sprintf("draft model architecture %q already registered", arch))
|
||||
}
|
||||
draftRegistry[arch] = fn
|
||||
}
|
||||
|
||||
// New reads config.json from the manifest, detects the architecture, looks up
|
||||
// the registered constructor, and calls it to create the model (with config
|
||||
// parsed and struct created, but weights not yet loaded).
|
||||
func New(root *model.Root) (Model, error) {
|
||||
configData, err := root.Manifest.ReadConfig("config.json")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read config.json: %w", err)
|
||||
}
|
||||
|
||||
var archConfig struct {
|
||||
Architectures []string `json:"architectures"`
|
||||
}
|
||||
if err := json.Unmarshal(configData, &archConfig); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse config.json: %w", err)
|
||||
}
|
||||
|
||||
if len(archConfig.Architectures) == 0 {
|
||||
return nil, fmt.Errorf("no architectures found in config.json")
|
||||
}
|
||||
|
||||
arch := archConfig.Architectures[0]
|
||||
slog.Info("Model architecture", "arch", arch)
|
||||
|
||||
mu.Lock()
|
||||
fn, ok := registry[arch]
|
||||
mu.Unlock()
|
||||
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unsupported architecture: %s", arch)
|
||||
}
|
||||
|
||||
return fn(root)
|
||||
}
|
||||
|
||||
// NewDraft constructs the draft model described by the manifest config, if any.
|
||||
func NewDraft(root *model.Root, target Model) (DraftModel, error) {
|
||||
if root == nil || root.Draft == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
configPath := root.Draft.Config
|
||||
if configPath == "" {
|
||||
configPath = "draft/config.json"
|
||||
}
|
||||
configData, err := root.Manifest.ReadConfig(configPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read %s: %w", configPath, err)
|
||||
}
|
||||
|
||||
var archConfig struct {
|
||||
Architectures []string `json:"architectures"`
|
||||
ModelType string `json:"model_type"`
|
||||
}
|
||||
if err := json.Unmarshal(configData, &archConfig); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse %s: %w", configPath, err)
|
||||
}
|
||||
|
||||
arch := root.Draft.Architecture
|
||||
if arch == "" && len(archConfig.Architectures) > 0 {
|
||||
arch = archConfig.Architectures[0]
|
||||
}
|
||||
if arch == "" {
|
||||
arch = archConfig.ModelType
|
||||
}
|
||||
if arch == "" {
|
||||
return nil, fmt.Errorf("no draft architecture found in %s", configPath)
|
||||
}
|
||||
slog.Info("Draft model architecture", "arch", arch)
|
||||
|
||||
mu.Lock()
|
||||
fn, ok := draftRegistry[arch]
|
||||
mu.Unlock()
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unsupported draft architecture: %s", arch)
|
||||
}
|
||||
|
||||
return fn(root, target)
|
||||
}
|
||||
|
||||
// Weights returns a function that loads model weights, then pins all
|
||||
// arrays reachable from the model struct and sweeps everything else.
|
||||
func Weights(m Model) func(map[string]*mlx.Array) error {
|
||||
return func(tensors map[string]*mlx.Array) error {
|
||||
if err := m.LoadWeights(tensors); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
collected := mlx.Collect(m)
|
||||
for _, arr := range collected {
|
||||
mlx.Pin(arr)
|
||||
}
|
||||
mlx.Sweep()
|
||||
mlx.Eval(collected...)
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
42
x/mlxrunner/model/embedding.go
Normal file
42
x/mlxrunner/model/embedding.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/models/nn"
|
||||
)
|
||||
|
||||
// MakeEmbeddingLayer constructs an embedding layer from a tensor map.
|
||||
//
|
||||
// For quantized tensors (path.weight + path.weight_scale), it returns a
|
||||
// QuantizedEmbedding using the same quant metadata path that linear layers use.
|
||||
// For non-quantized tensors, it returns a standard dense embedding.
|
||||
func MakeEmbeddingLayer(
|
||||
tensors map[string]*mlx.Array,
|
||||
path string,
|
||||
defaultGroupSize, defaultBits int,
|
||||
defaultMode string,
|
||||
tensorQuant map[string]*TensorQuantInfo,
|
||||
) nn.EmbeddingLayer {
|
||||
w := tensors[path+".weight"]
|
||||
if w == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
scales := tensors[path+".weight_scale"]
|
||||
if scales != nil {
|
||||
qbiases := tensors[path+".weight_qbias"]
|
||||
groupSize, bits, mode := ResolveLinearQuantParams(
|
||||
defaultGroupSize,
|
||||
defaultBits,
|
||||
defaultMode,
|
||||
tensorQuant,
|
||||
path+".weight",
|
||||
w,
|
||||
scales,
|
||||
)
|
||||
|
||||
return nn.NewQuantizedEmbedding(w, scales, qbiases, groupSize, bits, mode)
|
||||
}
|
||||
|
||||
return nn.NewEmbedding(w)
|
||||
}
|
||||
78
x/mlxrunner/model/embedding_test.go
Normal file
78
x/mlxrunner/model/embedding_test.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/models/nn"
|
||||
)
|
||||
|
||||
func skipIfNoMLX(t *testing.T) {
|
||||
t.Helper()
|
||||
if err := mlx.CheckInit(); err != nil {
|
||||
t.Skipf("MLX not available: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMakeEmbeddingLayerDense(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
weight := mlx.FromValues([]float32{
|
||||
1, 2, 3, 4,
|
||||
5, 6, 7, 8,
|
||||
}, 2, 4).AsType(mlx.DTypeBFloat16)
|
||||
|
||||
emb := MakeEmbeddingLayer(map[string]*mlx.Array{
|
||||
"model.embed_tokens.weight": weight,
|
||||
}, "model.embed_tokens", 0, 0, "", nil)
|
||||
|
||||
dense, ok := emb.(*nn.Embedding)
|
||||
if !ok {
|
||||
t.Fatalf("embedding type = %T, want *nn.Embedding", emb)
|
||||
}
|
||||
if dense.Weight.DType() != mlx.DTypeBFloat16 {
|
||||
t.Fatalf("embedding dtype = %v, want %v", dense.Weight.DType(), mlx.DTypeBFloat16)
|
||||
}
|
||||
if _, ok := emb.AsLinear().(*nn.Linear); !ok {
|
||||
t.Fatalf("AsLinear type = %T, want *nn.Linear", emb.AsLinear())
|
||||
}
|
||||
}
|
||||
|
||||
func TestMakeEmbeddingLayerQuantized(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
denseWeight := mlx.FromValues(func() []float32 {
|
||||
out := make([]float32, 2*64)
|
||||
for i := range out {
|
||||
out[i] = float32(i%17) / 8
|
||||
}
|
||||
return out
|
||||
}(), 2, 64).AsType(mlx.DTypeBFloat16)
|
||||
|
||||
qw, scales, qbiases := mlx.Quantize(denseWeight, 64, 4, "affine")
|
||||
mlx.Eval(qw, scales, qbiases)
|
||||
|
||||
emb := MakeEmbeddingLayer(map[string]*mlx.Array{
|
||||
"model.embed_tokens.weight": qw,
|
||||
"model.embed_tokens.weight_scale": scales,
|
||||
"model.embed_tokens.weight_qbias": qbiases,
|
||||
}, "model.embed_tokens", 64, 4, "affine", nil)
|
||||
|
||||
qemb, ok := emb.(*nn.QuantizedEmbedding)
|
||||
if !ok {
|
||||
t.Fatalf("embedding type = %T, want *nn.QuantizedEmbedding", emb)
|
||||
}
|
||||
if qemb.GroupSize != 64 || qemb.Bits != 4 || qemb.Mode != "affine" {
|
||||
t.Fatalf("quant params = (%d, %d, %q), want (64, 4, %q)", qemb.GroupSize, qemb.Bits, qemb.Mode, "affine")
|
||||
}
|
||||
|
||||
indices := mlx.FromValues([]int32{1, 0}, 2)
|
||||
out := emb.Forward(indices)
|
||||
mlx.Eval(out)
|
||||
if dims := out.Dims(); len(dims) != 2 || dims[0] != 2 || dims[1] != 64 {
|
||||
t.Fatalf("embedding output dims = %v, want [2 64]", dims)
|
||||
}
|
||||
if _, ok := emb.AsLinear().(*nn.QuantizedLinear); !ok {
|
||||
t.Fatalf("AsLinear type = %T, want *nn.QuantizedLinear", emb.AsLinear())
|
||||
}
|
||||
}
|
||||
99
x/mlxrunner/model/linear.go
Normal file
99
x/mlxrunner/model/linear.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/models/nn"
|
||||
)
|
||||
|
||||
// LinearFactory builds linear layers using shared tensor maps and quant defaults.
|
||||
type LinearFactory struct {
|
||||
tensors map[string]*mlx.Array
|
||||
defaultGroupSize int
|
||||
defaultBits int
|
||||
defaultMode string
|
||||
tensorQuant map[string]*TensorQuantInfo
|
||||
}
|
||||
|
||||
// NewLinearFactory creates a reusable constructor for model linear layers.
|
||||
func NewLinearFactory(
|
||||
tensors map[string]*mlx.Array,
|
||||
defaultGroupSize, defaultBits int,
|
||||
defaultMode string,
|
||||
tensorQuant map[string]*TensorQuantInfo,
|
||||
) LinearFactory {
|
||||
return LinearFactory{
|
||||
tensors: tensors,
|
||||
defaultGroupSize: defaultGroupSize,
|
||||
defaultBits: defaultBits,
|
||||
defaultMode: defaultMode,
|
||||
tensorQuant: tensorQuant,
|
||||
}
|
||||
}
|
||||
|
||||
// Make constructs a linear layer at path.
|
||||
func (f LinearFactory) Make(path string) nn.LinearLayer {
|
||||
return MakeLinearLayer(
|
||||
f.tensors,
|
||||
path,
|
||||
f.defaultGroupSize,
|
||||
f.defaultBits,
|
||||
f.defaultMode,
|
||||
f.tensorQuant,
|
||||
)
|
||||
}
|
||||
|
||||
// MakeLinearLayer constructs a linear layer from a tensor map.
|
||||
//
|
||||
// For quantized tensors (path.weight + path.weight_scale), it resolves per-tensor
|
||||
// quant params via TensorQuant metadata (with shape-based affine fallback).
|
||||
// For non-quantized tensors, it returns a standard nn.Linear.
|
||||
func MakeLinearLayer(
|
||||
tensors map[string]*mlx.Array,
|
||||
path string,
|
||||
defaultGroupSize, defaultBits int,
|
||||
defaultMode string,
|
||||
tensorQuant map[string]*TensorQuantInfo,
|
||||
) nn.LinearLayer {
|
||||
w := tensors[path+".weight"]
|
||||
if w == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
scales := tensors[path+".weight_scale"]
|
||||
if scales != nil {
|
||||
qbiases := tensors[path+".weight_qbias"]
|
||||
bias := tensors[path+".bias"]
|
||||
|
||||
groupSize, bits, mode := ResolveLinearQuantParams(
|
||||
defaultGroupSize,
|
||||
defaultBits,
|
||||
defaultMode,
|
||||
tensorQuant,
|
||||
path+".weight",
|
||||
w,
|
||||
scales,
|
||||
)
|
||||
|
||||
// Check for per-tensor global scale (NVIDIA double-scale nvfp4).
|
||||
// NVIDIA ModelOpt stores this as "weight_scale_2"; our import
|
||||
// pipeline maps it to "weight.global_scale".
|
||||
globalScale := tensors[path+".weight.global_scale"]
|
||||
if globalScale == nil {
|
||||
globalScale = tensors[path+".weight_scale_2"]
|
||||
}
|
||||
|
||||
return &nn.QuantizedLinear{
|
||||
Weight: w,
|
||||
Scales: scales,
|
||||
QBiases: qbiases,
|
||||
Bias: bias,
|
||||
GlobalScale: globalScale,
|
||||
GroupSize: groupSize,
|
||||
Bits: bits,
|
||||
Mode: mode,
|
||||
}
|
||||
}
|
||||
|
||||
bias := tensors[path+".bias"]
|
||||
return nn.NewLinear(w, bias)
|
||||
}
|
||||
132
x/mlxrunner/model/quant.go
Normal file
132
x/mlxrunner/model/quant.go
Normal file
@@ -0,0 +1,132 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
// QuantizationParams returns default groupSize, bits, and mode for a quantization type.
|
||||
func QuantizationParams(quantization string) (groupSize, bits int, mode string) {
|
||||
switch strings.ToUpper(quantization) {
|
||||
case "NVFP4":
|
||||
return 16, 4, "nvfp4"
|
||||
case "MXFP4":
|
||||
return 32, 4, "mxfp4"
|
||||
case "FP4", "Q4", "INT4":
|
||||
return 64, 4, "affine"
|
||||
case "MXFP8":
|
||||
return 32, 8, "mxfp8"
|
||||
case "FP8", "Q8", "INT8":
|
||||
return 64, 8, "affine"
|
||||
case "":
|
||||
return 0, 0, ""
|
||||
default:
|
||||
return 32, 8, "affine"
|
||||
}
|
||||
}
|
||||
|
||||
// TensorQuantParams resolves quant params for a tensor using per-tensor metadata
|
||||
// when available, otherwise falling back to the provided model defaults.
|
||||
func TensorQuantParams(
|
||||
defaultGroupSize, defaultBits int,
|
||||
defaultMode string,
|
||||
tensorQuant map[string]*TensorQuantInfo,
|
||||
tensorName string,
|
||||
) (groupSize, bits int, mode string, fromTensor bool) {
|
||||
if tensorQuant != nil {
|
||||
if tq := tensorQuant[tensorName]; tq != nil {
|
||||
groupSize, bits, mode = QuantizationParams(tq.QuantType)
|
||||
if tq.GroupSize > 0 {
|
||||
groupSize = tq.GroupSize
|
||||
}
|
||||
return groupSize, bits, mode, true
|
||||
}
|
||||
}
|
||||
return defaultGroupSize, defaultBits, defaultMode, false
|
||||
}
|
||||
|
||||
// ResolveLinearQuantParams resolves quantization params for a quantized linear
|
||||
// tensor, preferring per-tensor metadata and falling back to shape-based
|
||||
// inference for affine packed tensors.
|
||||
func ResolveLinearQuantParams(
|
||||
defaultGroupSize, defaultBits int,
|
||||
defaultMode string,
|
||||
tensorQuant map[string]*TensorQuantInfo,
|
||||
tensorName string,
|
||||
weight, scales *mlx.Array,
|
||||
) (groupSize, bits int, mode string) {
|
||||
groupSize, bits, mode, fromTensor := TensorQuantParams(
|
||||
defaultGroupSize,
|
||||
defaultBits,
|
||||
defaultMode,
|
||||
tensorQuant,
|
||||
tensorName,
|
||||
)
|
||||
|
||||
if mode == "affine" {
|
||||
if inferredGroupSize, inferredBits, ok := InferAffineQuantParamsFromShapes(weight, scales, bits); ok {
|
||||
if !fromTensor || groupSize == 0 || bits == 0 {
|
||||
groupSize = inferredGroupSize
|
||||
bits = inferredBits
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return groupSize, bits, mode
|
||||
}
|
||||
|
||||
// InferAffineQuantParamsFromShapes infers (groupSize,bits) for affine quantized
|
||||
// tensors from packed weight and scale shapes.
|
||||
func InferAffineQuantParamsFromShapes(weight, scales *mlx.Array, hintBits int) (groupSize, bits int, ok bool) {
|
||||
if weight == nil || scales == nil {
|
||||
return 0, 0, false
|
||||
}
|
||||
|
||||
weightShape := weight.Dims()
|
||||
scaleShape := scales.Dims()
|
||||
if len(weightShape) == 0 || len(scaleShape) == 0 {
|
||||
return 0, 0, false
|
||||
}
|
||||
|
||||
weightCols := weightShape[len(weightShape)-1]
|
||||
scalesCols := scaleShape[len(scaleShape)-1]
|
||||
if weightCols <= 0 || scalesCols <= 0 {
|
||||
return 0, 0, false
|
||||
}
|
||||
|
||||
groupSize4 := weightCols * 8 / scalesCols
|
||||
groupSize8 := weightCols * 4 / scalesCols
|
||||
|
||||
switch {
|
||||
case groupSize4 == 32:
|
||||
return 32, 4, true
|
||||
case groupSize8 == 64:
|
||||
return 64, 8, true
|
||||
case groupSize4 == 64 && groupSize8 == 32:
|
||||
if hintBits == 8 {
|
||||
return 32, 8, true
|
||||
}
|
||||
if hintBits == 4 {
|
||||
return 64, 4, true
|
||||
}
|
||||
}
|
||||
|
||||
if isCommonGroupSize(groupSize4) && !isCommonGroupSize(groupSize8) {
|
||||
return groupSize4, 4, true
|
||||
}
|
||||
if isCommonGroupSize(groupSize8) && !isCommonGroupSize(groupSize4) {
|
||||
return groupSize8, 8, true
|
||||
}
|
||||
|
||||
return 0, 0, false
|
||||
}
|
||||
|
||||
func isCommonGroupSize(v int) bool {
|
||||
switch v {
|
||||
case 16, 32, 64, 128:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
299
x/mlxrunner/model/root.go
Normal file
299
x/mlxrunner/model/root.go
Normal file
@@ -0,0 +1,299 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
modeltypes "github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/x/imagegen/manifest"
|
||||
)
|
||||
|
||||
// TensorQuantInfo describes per-tensor quantization metadata.
|
||||
type TensorQuantInfo struct {
|
||||
QuantType string
|
||||
GroupSize int
|
||||
}
|
||||
|
||||
// Root wraps a ModelManifest with pre-scanned quantization metadata.
|
||||
type Root struct {
|
||||
Manifest *manifest.ModelManifest
|
||||
Draft *modeltypes.Draft
|
||||
|
||||
// Backwards-compatible model-level quant metadata (first tensor blob).
|
||||
quantType string
|
||||
groupSize int
|
||||
|
||||
// Per-tensor quantization metadata.
|
||||
tensorQuant map[string]*TensorQuantInfo
|
||||
}
|
||||
|
||||
// Open loads a manifest for the given model name and scans tensor blobs for
|
||||
// quantization metadata.
|
||||
func Open(modelName string) (*Root, error) {
|
||||
m, err := manifest.LoadManifest(modelName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
root := &Root{
|
||||
Manifest: m,
|
||||
tensorQuant: make(map[string]*TensorQuantInfo),
|
||||
}
|
||||
root.Draft = readDraftConfig(m)
|
||||
|
||||
for _, layer := range m.GetTensorLayers("") {
|
||||
blobPath := m.BlobPath(layer.Digest)
|
||||
|
||||
infos, blobQuantType, blobGroupSize, err := readBlobTensorQuantInfo(blobPath)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for name, info := range infos {
|
||||
root.tensorQuant[name] = info
|
||||
}
|
||||
|
||||
if root.quantType == "" && blobQuantType != "" {
|
||||
root.quantType = strings.ToUpper(blobQuantType)
|
||||
root.groupSize = blobGroupSize
|
||||
if root.groupSize == 0 {
|
||||
root.groupSize = defaultGroupSize(root.quantType)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return root, nil
|
||||
}
|
||||
|
||||
func readDraftConfig(m *manifest.ModelManifest) *modeltypes.Draft {
|
||||
if m == nil || m.Manifest == nil || m.Manifest.Config.Digest == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(m.BlobPath(m.Manifest.Config.Digest))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var cfg modeltypes.ConfigV2
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return nil
|
||||
}
|
||||
if cfg.Draft != nil {
|
||||
return cfg.Draft
|
||||
}
|
||||
|
||||
if m.GetConfigLayer("draft/config.json") != nil {
|
||||
return &modeltypes.Draft{
|
||||
ModelFormat: "safetensors",
|
||||
TensorPrefix: "draft.",
|
||||
Config: "draft/config.json",
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close is a no-op for now (future: release resources).
|
||||
func (r *Root) Close() {}
|
||||
|
||||
// QuantType returns the quantization type detected from the first tensor blob metadata.
|
||||
func (r *Root) QuantType() string { return r.quantType }
|
||||
|
||||
// GroupSize returns the quantization group size detected from the first tensor blob metadata.
|
||||
func (r *Root) GroupSize() int { return r.groupSize }
|
||||
|
||||
// TensorQuant returns per-tensor quantization metadata if available.
|
||||
func (r *Root) TensorQuant(name string) *TensorQuantInfo {
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
return r.tensorQuant[name]
|
||||
}
|
||||
|
||||
// AllTensorQuant returns a copy of the per-tensor quantization metadata.
|
||||
func (r *Root) AllTensorQuant() map[string]*TensorQuantInfo {
|
||||
out := make(map[string]*TensorQuantInfo, len(r.tensorQuant))
|
||||
for k, v := range r.tensorQuant {
|
||||
if v == nil {
|
||||
continue
|
||||
}
|
||||
copy := *v
|
||||
out[k] = ©
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func defaultGroupSize(quantType string) int {
|
||||
groupSize, _, _ := QuantizationParams(quantType)
|
||||
return groupSize
|
||||
}
|
||||
|
||||
func readBlobTensorQuantInfo(path string) (map[string]*TensorQuantInfo, string, int, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, "", 0, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var headerSize uint64
|
||||
if err := binary.Read(f, binary.LittleEndian, &headerSize); err != nil {
|
||||
return nil, "", 0, err
|
||||
}
|
||||
if headerSize > 100*1024*1024 {
|
||||
return nil, "", 0, fmt.Errorf("header too large: %d", headerSize)
|
||||
}
|
||||
|
||||
data := make([]byte, headerSize)
|
||||
if _, err := io.ReadFull(f, data); err != nil {
|
||||
return nil, "", 0, err
|
||||
}
|
||||
|
||||
var header map[string]json.RawMessage
|
||||
if err := json.Unmarshal(data, &header); err != nil {
|
||||
return nil, "", 0, err
|
||||
}
|
||||
|
||||
globalQuantType, globalGroupSize := parseGlobalQuantMetadata(header)
|
||||
globalQuantType = strings.ToUpper(globalQuantType)
|
||||
|
||||
// Parse full metadata for per-tensor quant info
|
||||
var metaMap map[string]string
|
||||
if metaRaw, ok := header["__metadata__"]; ok {
|
||||
json.Unmarshal(metaRaw, &metaMap)
|
||||
}
|
||||
|
||||
mainNames := mainTensorNames(header)
|
||||
infos := make(map[string]*TensorQuantInfo)
|
||||
for _, name := range mainNames {
|
||||
if _, ok := header[name+".scale"]; !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
quantType := globalQuantType
|
||||
groupSize := globalGroupSize
|
||||
|
||||
// Check per-tensor metadata (e.g. from packed expert blobs with mixed precision)
|
||||
if metaMap != nil {
|
||||
if qt, ok := metaMap[name+".quant_type"]; ok && qt != "" {
|
||||
quantType = strings.ToUpper(qt)
|
||||
}
|
||||
if gs, ok := metaMap[name+".group_size"]; ok && gs != "" {
|
||||
if v, err := strconv.Atoi(gs); err == nil {
|
||||
groupSize = v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inferredType, inferredGroup := inferQuantTypeFromShapes(header, name, quantType)
|
||||
if quantType == "" {
|
||||
quantType = inferredType
|
||||
}
|
||||
if groupSize == 0 {
|
||||
groupSize = inferredGroup
|
||||
}
|
||||
if quantType == "" {
|
||||
continue
|
||||
}
|
||||
if groupSize == 0 {
|
||||
groupSize = defaultGroupSize(quantType)
|
||||
}
|
||||
|
||||
infos[name] = &TensorQuantInfo{QuantType: quantType, GroupSize: groupSize}
|
||||
}
|
||||
|
||||
return infos, globalQuantType, globalGroupSize, nil
|
||||
}
|
||||
|
||||
func parseGlobalQuantMetadata(header map[string]json.RawMessage) (quantType string, groupSize int) {
|
||||
metaRaw, ok := header["__metadata__"]
|
||||
if !ok {
|
||||
return "", 0
|
||||
}
|
||||
|
||||
var meta map[string]string
|
||||
if err := json.Unmarshal(metaRaw, &meta); err != nil {
|
||||
return "", 0
|
||||
}
|
||||
|
||||
quantType = meta["quant_type"]
|
||||
if gs := meta["group_size"]; gs != "" {
|
||||
groupSize, _ = strconv.Atoi(gs)
|
||||
}
|
||||
return quantType, groupSize
|
||||
}
|
||||
|
||||
func mainTensorNames(header map[string]json.RawMessage) []string {
|
||||
names := make([]string, 0, len(header))
|
||||
for name := range header {
|
||||
if name == "__metadata__" || strings.HasSuffix(name, ".scale") || strings.HasSuffix(name, ".bias") {
|
||||
continue
|
||||
}
|
||||
names = append(names, name)
|
||||
}
|
||||
sort.Strings(names)
|
||||
return names
|
||||
}
|
||||
|
||||
func inferQuantTypeFromShapes(header map[string]json.RawMessage, tensorName string, hintQuantType string) (string, int) {
|
||||
type tensorShape struct {
|
||||
Shape []int64 `json:"shape"`
|
||||
}
|
||||
|
||||
mainRaw, ok := header[tensorName]
|
||||
if !ok {
|
||||
return "", 0
|
||||
}
|
||||
scaleRaw, ok := header[tensorName+".scale"]
|
||||
if !ok {
|
||||
return "", 0
|
||||
}
|
||||
|
||||
var mainInfo tensorShape
|
||||
if err := json.Unmarshal(mainRaw, &mainInfo); err != nil || len(mainInfo.Shape) == 0 {
|
||||
return "", 0
|
||||
}
|
||||
|
||||
var scaleInfo tensorShape
|
||||
if err := json.Unmarshal(scaleRaw, &scaleInfo); err != nil || len(scaleInfo.Shape) == 0 {
|
||||
return "", 0
|
||||
}
|
||||
|
||||
weightCols := int(mainInfo.Shape[len(mainInfo.Shape)-1])
|
||||
scalesCols := int(scaleInfo.Shape[len(scaleInfo.Shape)-1])
|
||||
if weightCols <= 0 || scalesCols <= 0 {
|
||||
return "", 0
|
||||
}
|
||||
|
||||
groupSize4 := weightCols * 8 / scalesCols
|
||||
groupSize8 := weightCols * 4 / scalesCols
|
||||
|
||||
switch {
|
||||
case groupSize4 == 32:
|
||||
return "INT4", 32
|
||||
case groupSize8 == 64:
|
||||
return "INT8", 64
|
||||
case groupSize4 == 64 && groupSize8 == 32:
|
||||
h := strings.ToUpper(hintQuantType)
|
||||
if strings.Contains(h, "8") {
|
||||
return "INT8", 32
|
||||
}
|
||||
if strings.Contains(h, "4") {
|
||||
return "INT4", 64
|
||||
}
|
||||
}
|
||||
|
||||
if isCommonGroupSize(groupSize4) && !isCommonGroupSize(groupSize8) {
|
||||
return "INT4", groupSize4
|
||||
}
|
||||
if isCommonGroupSize(groupSize8) && !isCommonGroupSize(groupSize4) {
|
||||
return "INT8", groupSize8
|
||||
}
|
||||
|
||||
return "", 0
|
||||
}
|
||||
952
x/mlxrunner/mtp.go
Normal file
952
x/mlxrunner/mtp.go
Normal file
@@ -0,0 +1,952 @@
|
||||
package mlxrunner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/batch"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||
sampler "github.com/ollama/ollama/x/mlxrunner/sample"
|
||||
)
|
||||
|
||||
const (
|
||||
mtpDefaultInitialDraftTokens = 4
|
||||
mtpDefaultMaxDraftTokens = 16
|
||||
)
|
||||
|
||||
type mtpDraftSchedule string
|
||||
|
||||
const (
|
||||
mtpDraftScheduleHeuristic mtpDraftSchedule = "heuristic"
|
||||
mtpDraftScheduleConstant mtpDraftSchedule = "constant"
|
||||
)
|
||||
|
||||
type mtpStats struct {
|
||||
iterations int
|
||||
drafted int
|
||||
accepted int
|
||||
mismatches int
|
||||
allAccepted int
|
||||
batched int
|
||||
serial int
|
||||
compared int
|
||||
batchSerialMismatches int
|
||||
maxDraft int
|
||||
targetDuration time.Duration
|
||||
draftDuration time.Duration
|
||||
validateDuration time.Duration
|
||||
}
|
||||
|
||||
type mtpOptions struct {
|
||||
initialDraftTokens int
|
||||
maxDraftTokens int
|
||||
draftSchedule mtpDraftSchedule
|
||||
serialValidate bool
|
||||
compareSerialValidate bool
|
||||
}
|
||||
|
||||
func (r *Runner) mtpDefaults(sample bool) base.MTPDefaults {
|
||||
defaults := base.MTPDefaults{
|
||||
InitialDraftTokens: mtpDefaultInitialDraftTokens,
|
||||
MaxDraftTokens: mtpDefaultMaxDraftTokens,
|
||||
Enabled: true,
|
||||
}
|
||||
if p, ok := r.Model.(base.MTPDefaultsProvider); ok {
|
||||
defaults = p.MTPDraftDefaults(sample)
|
||||
}
|
||||
if defaults.InitialDraftTokens <= 0 {
|
||||
defaults.InitialDraftTokens = mtpDefaultInitialDraftTokens
|
||||
}
|
||||
if defaults.MaxDraftTokens <= 0 {
|
||||
defaults.MaxDraftTokens = mtpDefaultMaxDraftTokens
|
||||
}
|
||||
return defaults
|
||||
}
|
||||
|
||||
func (r *Runner) loadMTPOptions(sample bool) mtpOptions {
|
||||
defaults := r.mtpDefaults(sample)
|
||||
|
||||
opts := mtpOptions{
|
||||
initialDraftTokens: defaults.InitialDraftTokens,
|
||||
maxDraftTokens: defaults.MaxDraftTokens,
|
||||
draftSchedule: mtpDraftScheduleConstant,
|
||||
}
|
||||
if v := positiveEnvInt("OLLAMA_MLX_MTP_MAX_DRAFT_TOKENS"); v > 0 {
|
||||
opts.maxDraftTokens = v
|
||||
}
|
||||
if v := positiveEnvInt("OLLAMA_MLX_MTP_INITIAL_DRAFT_TOKENS"); v > 0 {
|
||||
opts.initialDraftTokens = v
|
||||
}
|
||||
if opts.initialDraftTokens > opts.maxDraftTokens {
|
||||
opts.initialDraftTokens = opts.maxDraftTokens
|
||||
}
|
||||
if b, err := strconv.ParseBool(os.Getenv("OLLAMA_MLX_MTP_SERIAL_VALIDATE")); err == nil {
|
||||
opts.serialValidate = b
|
||||
}
|
||||
if b, err := strconv.ParseBool(os.Getenv("OLLAMA_MLX_MTP_COMPARE_SERIAL_VALIDATE")); err == nil {
|
||||
opts.compareSerialValidate = b
|
||||
}
|
||||
switch schedule := strings.ToLower(strings.TrimSpace(os.Getenv("OLLAMA_MLX_MTP_DRAFT_SCHEDULE"))); schedule {
|
||||
case "", string(mtpDraftScheduleConstant):
|
||||
opts.draftSchedule = mtpDraftScheduleConstant
|
||||
case string(mtpDraftScheduleHeuristic):
|
||||
opts.draftSchedule = mtpDraftScheduleHeuristic
|
||||
default:
|
||||
slog.Warn("invalid MTP env setting", "key", "OLLAMA_MLX_MTP_DRAFT_SCHEDULE", "value", schedule)
|
||||
}
|
||||
return opts
|
||||
}
|
||||
|
||||
func positiveEnvInt(key string) int {
|
||||
raw := os.Getenv(key)
|
||||
if raw == "" {
|
||||
return 0
|
||||
}
|
||||
v, err := strconv.Atoi(raw)
|
||||
if err != nil || v <= 0 {
|
||||
slog.Warn("invalid MTP env setting", "key", key, "value", raw)
|
||||
return 0
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func (r *Runner) useGreedyMTP(opts sampler.Options) bool {
|
||||
if r.Draft == nil {
|
||||
return false
|
||||
}
|
||||
if _, ok := r.Draft.(base.MTPDraftModel); !ok {
|
||||
return false
|
||||
}
|
||||
if _, ok := r.Model.(base.MTPEmbeddingModel); !ok {
|
||||
return false
|
||||
}
|
||||
if !r.mtpDefaults(false).Enabled {
|
||||
return false
|
||||
}
|
||||
if opts.Logprobs || opts.TopLogprobs > 0 {
|
||||
return false
|
||||
}
|
||||
if opts.Temperature != 0 {
|
||||
return false
|
||||
}
|
||||
repeatPenaltyNeutral := opts.RepeatPenalty <= 0 || opts.RepeatPenalty == 1
|
||||
topPNeutral := opts.TopP <= 0 || opts.TopP >= 1
|
||||
topKNeutral := opts.TopK <= 0
|
||||
return repeatPenaltyNeutral && opts.PresencePenalty == 0 && opts.FrequencyPenalty == 0 && topPNeutral && topKNeutral && opts.MinP == 0
|
||||
}
|
||||
|
||||
func (r *Runner) useSampleMTP(opts sampler.Options) bool {
|
||||
if serial, err := strconv.ParseBool(os.Getenv("OLLAMA_MLX_MTP_SERIAL_VALIDATE")); err == nil && serial {
|
||||
return false
|
||||
}
|
||||
if compare, err := strconv.ParseBool(os.Getenv("OLLAMA_MLX_MTP_COMPARE_SERIAL_VALIDATE")); err == nil && compare {
|
||||
return false
|
||||
}
|
||||
if r.Draft == nil {
|
||||
return false
|
||||
}
|
||||
if _, ok := r.Draft.(base.MTPDraftModel); !ok {
|
||||
return false
|
||||
}
|
||||
if _, ok := r.Model.(base.MTPEmbeddingModel); !ok {
|
||||
return false
|
||||
}
|
||||
if !r.mtpDefaults(true).Enabled {
|
||||
return false
|
||||
}
|
||||
if opts.Logprobs || opts.TopLogprobs > 0 {
|
||||
return false
|
||||
}
|
||||
return opts.Temperature != 0
|
||||
}
|
||||
|
||||
func (r *Runner) runGreedyMTPDecode(ctx context.Context, request Request, session *cacheSession, caches []cache.Cache, seed []int32, position *int, started time.Time) error {
|
||||
targetEmbeddings := r.Model.(base.MTPEmbeddingModel)
|
||||
draft := r.Draft.(base.MTPDraftModel)
|
||||
mtpOpts := r.loadMTPOptions(false)
|
||||
stats := mtpStats{maxDraft: mtpOpts.initialDraftTokens}
|
||||
draftLimit := mtpOpts.initialDraftTokens
|
||||
slog.Info("MTP greedy decode enabled", "initial_draft_tokens", mtpOpts.initialDraftTokens, "max_draft_tokens", mtpOpts.maxDraftTokens, "draft_schedule", mtpOpts.draftSchedule, "serial_validate", mtpOpts.serialValidate, "compare_serial_validate", mtpOpts.compareSerialValidate)
|
||||
|
||||
targetForward := func(token *mlx.Array) *mlx.Array {
|
||||
fwd := r.Model.Forward(&batch.Batch{
|
||||
InputIDs: token,
|
||||
SeqOffsets: []int32{int32(*position)},
|
||||
SeqQueryLens: []int32{int32(token.Dim(1))},
|
||||
}, caches)
|
||||
*position += token.Dim(1)
|
||||
return fwd
|
||||
}
|
||||
|
||||
hidden := targetForward(mlx.FromValues(seed, 1, len(seed)))
|
||||
current := sampler.Result{Token: greedyTokenFromLogits(r.lastLogits(hidden))}
|
||||
mlx.Pin(current.Arrays()...)
|
||||
mlx.Sweep()
|
||||
mlx.AsyncEval(current.Arrays()...)
|
||||
defer func() {
|
||||
mlx.Unpin(current.Arrays()...)
|
||||
}()
|
||||
|
||||
dec := decoder{tokenizer: r.Tokenizer}
|
||||
final := CompletionResponse{Done: true, PromptEvalCount: len(request.Tokens), DoneReason: 1}
|
||||
now := started
|
||||
|
||||
generated := 0
|
||||
for generated < request.Options.NumPredict {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t0 := time.Now()
|
||||
hidden = targetForward(current.Token.ExpandDims(-1))
|
||||
baseLogits := r.lastLogits(hidden)
|
||||
stats.targetDuration += time.Since(t0)
|
||||
|
||||
if generated == 0 {
|
||||
mlx.Eval(current.Arrays()...)
|
||||
final.PromptEvalDuration = time.Since(now)
|
||||
now = time.Now()
|
||||
}
|
||||
|
||||
done, err := r.emitMTPToken(ctx, request, session, &dec, current, &final)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !done {
|
||||
generated++
|
||||
}
|
||||
if done || generated >= request.Options.NumPredict {
|
||||
break
|
||||
}
|
||||
|
||||
stats.iterations++
|
||||
maxDraft := min(draftLimit, request.Options.NumPredict-generated)
|
||||
t0 = time.Now()
|
||||
draftTokens := r.generateMTPDrafts(draft, targetEmbeddings, current.Token, hidden, caches, int32(*position-1), maxDraft)
|
||||
draftCount := 0
|
||||
if draftTokens != nil {
|
||||
draftCount = draftTokens.Dim(1)
|
||||
mlx.Pin(baseLogits, draftTokens)
|
||||
mlx.Eval(draftTokens)
|
||||
mlx.Sweep()
|
||||
}
|
||||
stats.draftDuration += time.Since(t0)
|
||||
stats.drafted += draftCount
|
||||
var next sampler.Result
|
||||
if draftCount == 0 {
|
||||
next = sampler.Result{Token: greedyTokenFromLogits(baseLogits)}
|
||||
} else {
|
||||
var accepted int
|
||||
t0 = time.Now()
|
||||
next, accepted, done, err = r.acceptMTPDrafts(ctx, request, session, &dec, caches, position, baseLogits, draftTokens, &final, &generated, &stats, mtpOpts)
|
||||
stats.validateDuration += time.Since(t0)
|
||||
mlx.Unpin(baseLogits, draftTokens)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
stats.accepted += accepted
|
||||
switch {
|
||||
case mtpOpts.draftSchedule == mtpDraftScheduleConstant:
|
||||
case accepted == draftCount:
|
||||
stats.allAccepted++
|
||||
draftLimit = min(mtpOpts.maxDraftTokens, draftLimit+2)
|
||||
default:
|
||||
stats.mismatches++
|
||||
draftLimit = max(1, draftLimit-1)
|
||||
}
|
||||
if mtpOpts.draftSchedule == mtpDraftScheduleConstant {
|
||||
if accepted == draftCount {
|
||||
stats.allAccepted++
|
||||
} else {
|
||||
stats.mismatches++
|
||||
}
|
||||
}
|
||||
stats.maxDraft = max(stats.maxDraft, draftLimit)
|
||||
if next.Token == nil {
|
||||
mlx.Sweep()
|
||||
}
|
||||
if done || generated >= request.Options.NumPredict {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
mlx.Pin(next.Arrays()...)
|
||||
old := current
|
||||
current = next
|
||||
mlx.Unpin(old.Arrays()...)
|
||||
mlx.Sweep()
|
||||
mlx.AsyncEval(current.Arrays()...)
|
||||
|
||||
if generated%256 == 0 {
|
||||
mlx.ClearCache()
|
||||
}
|
||||
}
|
||||
|
||||
final.EvalCount = generated
|
||||
final.EvalDuration = time.Since(now)
|
||||
acceptance := 0.0
|
||||
if stats.drafted > 0 {
|
||||
acceptance = float64(stats.accepted) / float64(stats.drafted)
|
||||
}
|
||||
avgDraft := 0.0
|
||||
avgAccepted := 0.0
|
||||
if stats.iterations > 0 {
|
||||
avgDraft = float64(stats.drafted) / float64(stats.iterations)
|
||||
avgAccepted = float64(stats.accepted) / float64(stats.iterations)
|
||||
}
|
||||
slog.Info("MTP decode stats", "generated", generated, "drafted", stats.drafted, "accepted", stats.accepted, "acceptance", acceptance, "iterations", stats.iterations, "avg_draft", avgDraft, "avg_accepted", avgAccepted, "batched", stats.batched, "serial", stats.serial, "compared", stats.compared, "batch_serial_mismatches", stats.batchSerialMismatches, "mismatches", stats.mismatches, "all_accepted", stats.allAccepted, "max_draft", stats.maxDraft, "draft_schedule", mtpOpts.draftSchedule, "target_duration", stats.targetDuration, "draft_duration", stats.draftDuration, "validate_duration", stats.validateDuration)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case request.Responses <- final:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Runner) runSampleMTPDecode(ctx context.Context, request Request, session *cacheSession, caches []cache.Cache, seed []int32, position *int, started time.Time) error {
|
||||
targetEmbeddings := r.Model.(base.MTPEmbeddingModel)
|
||||
draft := r.Draft.(base.MTPDraftModel)
|
||||
mtpOpts := r.loadMTPOptions(true)
|
||||
stats := mtpStats{maxDraft: mtpOpts.initialDraftTokens}
|
||||
draftLimit := mtpOpts.initialDraftTokens
|
||||
slog.Info("MTP sample decode enabled", "initial_draft_tokens", mtpOpts.initialDraftTokens, "max_draft_tokens", mtpOpts.maxDraftTokens, "draft_schedule", mtpOpts.draftSchedule, "serial_validate", mtpOpts.serialValidate)
|
||||
|
||||
targetForward := func(token *mlx.Array) *mlx.Array {
|
||||
fwd := r.Model.Forward(&batch.Batch{
|
||||
InputIDs: token,
|
||||
SeqOffsets: []int32{int32(*position)},
|
||||
SeqQueryLens: []int32{int32(token.Dim(1))},
|
||||
}, caches)
|
||||
*position += token.Dim(1)
|
||||
return fwd
|
||||
}
|
||||
|
||||
hidden := targetForward(mlx.FromValues(seed, 1, len(seed)))
|
||||
current := r.Sampler.Sample([]int{pipelineSlot}, r.lastLogits(hidden))
|
||||
mlx.Pin(current.Arrays()...)
|
||||
mlx.Sweep()
|
||||
mlx.AsyncEval(current.Arrays()...)
|
||||
defer func() {
|
||||
mlx.Unpin(current.Arrays()...)
|
||||
}()
|
||||
|
||||
dec := decoder{tokenizer: r.Tokenizer}
|
||||
final := CompletionResponse{Done: true, PromptEvalCount: len(request.Tokens), DoneReason: 1}
|
||||
now := started
|
||||
|
||||
generated := 0
|
||||
for generated < request.Options.NumPredict {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t0 := time.Now()
|
||||
hidden = targetForward(mtpTokenInput(current.Token))
|
||||
baseLogits := r.lastLogits(hidden)
|
||||
stats.targetDuration += time.Since(t0)
|
||||
|
||||
if generated == 0 {
|
||||
mlx.Eval(current.Arrays()...)
|
||||
final.PromptEvalDuration = time.Since(now)
|
||||
now = time.Now()
|
||||
}
|
||||
|
||||
done, err := r.emitMTPToken(ctx, request, session, &dec, current, &final)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !done {
|
||||
generated++
|
||||
}
|
||||
if done || generated >= request.Options.NumPredict {
|
||||
break
|
||||
}
|
||||
|
||||
stats.iterations++
|
||||
maxDraft := min(draftLimit, request.Options.NumPredict-generated)
|
||||
t0 = time.Now()
|
||||
candidates := r.generateMTPDraftCandidates(draft, targetEmbeddings, current.Token, hidden, caches, int32(*position-1), maxDraft)
|
||||
draftCount := 0
|
||||
var candidateArrays []*mlx.Array
|
||||
if candidates != nil {
|
||||
draftCount = candidates.tokens.Dim(1)
|
||||
candidateArrays = append([]*mlx.Array{baseLogits}, candidates.Arrays()...)
|
||||
mlx.Pin(candidateArrays...)
|
||||
mlx.Sweep()
|
||||
}
|
||||
stats.draftDuration += time.Since(t0)
|
||||
stats.drafted += draftCount
|
||||
|
||||
var next sampler.Result
|
||||
if draftCount == 0 {
|
||||
next = r.Sampler.Sample([]int{pipelineSlot}, baseLogits)
|
||||
} else {
|
||||
var accepted int
|
||||
t0 = time.Now()
|
||||
next, accepted, done, err = r.acceptSampleMTPDrafts(ctx, request, session, &dec, caches, position, baseLogits, candidates, &final, &generated, &stats)
|
||||
stats.validateDuration += time.Since(t0)
|
||||
mlx.Unpin(candidateArrays...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
stats.accepted += accepted
|
||||
switch {
|
||||
case mtpOpts.draftSchedule == mtpDraftScheduleConstant:
|
||||
case accepted == draftCount:
|
||||
stats.allAccepted++
|
||||
draftLimit = min(mtpOpts.maxDraftTokens, draftLimit+2)
|
||||
default:
|
||||
stats.mismatches++
|
||||
draftLimit = max(1, draftLimit-1)
|
||||
}
|
||||
if mtpOpts.draftSchedule == mtpDraftScheduleConstant {
|
||||
if accepted == draftCount {
|
||||
stats.allAccepted++
|
||||
} else {
|
||||
stats.mismatches++
|
||||
}
|
||||
}
|
||||
stats.maxDraft = max(stats.maxDraft, draftLimit)
|
||||
if next.Token == nil {
|
||||
mlx.Sweep()
|
||||
}
|
||||
if done || generated >= request.Options.NumPredict {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
mlx.Pin(next.Arrays()...)
|
||||
old := current
|
||||
current = next
|
||||
mlx.Unpin(old.Arrays()...)
|
||||
mlx.Sweep()
|
||||
mlx.AsyncEval(current.Arrays()...)
|
||||
|
||||
if generated%256 == 0 {
|
||||
mlx.ClearCache()
|
||||
}
|
||||
}
|
||||
|
||||
final.EvalCount = generated
|
||||
final.EvalDuration = time.Since(now)
|
||||
acceptance := 0.0
|
||||
if stats.drafted > 0 {
|
||||
acceptance = float64(stats.accepted) / float64(stats.drafted)
|
||||
}
|
||||
avgDraft := 0.0
|
||||
avgAccepted := 0.0
|
||||
if stats.iterations > 0 {
|
||||
avgDraft = float64(stats.drafted) / float64(stats.iterations)
|
||||
avgAccepted = float64(stats.accepted) / float64(stats.iterations)
|
||||
}
|
||||
slog.Info("MTP decode stats", "mode", "sample", "generated", generated, "drafted", stats.drafted, "accepted", stats.accepted, "acceptance", acceptance, "iterations", stats.iterations, "avg_draft", avgDraft, "avg_accepted", avgAccepted, "batched", stats.batched, "serial", stats.serial, "mismatches", stats.mismatches, "all_accepted", stats.allAccepted, "max_draft", stats.maxDraft, "draft_schedule", mtpOpts.draftSchedule, "target_duration", stats.targetDuration, "draft_duration", stats.draftDuration, "validate_duration", stats.validateDuration)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case request.Responses <- final:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
type mtpDraftCandidates struct {
|
||||
tokens *mlx.Array
|
||||
// dist is the proposal distribution used to sample each drafted token.
|
||||
dist sampler.Distribution
|
||||
}
|
||||
|
||||
func (c *mtpDraftCandidates) Arrays() []*mlx.Array {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
return append([]*mlx.Array{c.tokens}, c.dist.Arrays()...)
|
||||
}
|
||||
|
||||
func (r *Runner) generateMTPDrafts(draft base.MTPDraftModel, target base.MTPEmbeddingModel, token *mlx.Array, hidden *mlx.Array, caches []cache.Cache, position int32, maxDraft int) *mlx.Array {
|
||||
if maxDraft <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
lastToken := token.ExpandDims(-1)
|
||||
lastHidden := hidden
|
||||
draftTokens := make([]*mlx.Array, 0, maxDraft)
|
||||
|
||||
// Gemma4 assistant MTP is trained as "single-position" drafting:
|
||||
// keep the RoPE/cache position anchored at the last target-seen token
|
||||
// while the proposed token and projected hidden state advance.
|
||||
for range maxDraft {
|
||||
tokenEmbedding := target.TokenEmbeddings(lastToken)
|
||||
inputs := tokenEmbedding.Concatenate(-1, lastHidden)
|
||||
logits, projected := draft.Draft(inputs, position, caches)
|
||||
stepLogits := r.lastLogitsFromLogits(logits)
|
||||
nextToken := greedyTokenFromLogits(stepLogits)
|
||||
|
||||
lastToken = nextToken.ExpandDims(-1)
|
||||
lastHidden = projected
|
||||
draftTokens = append(draftTokens, lastToken)
|
||||
}
|
||||
if len(draftTokens) == 0 {
|
||||
return nil
|
||||
}
|
||||
return mlx.Concatenate(draftTokens, 1)
|
||||
}
|
||||
|
||||
func (r *Runner) generateMTPDraftCandidates(draft base.MTPDraftModel, target base.MTPEmbeddingModel, token *mlx.Array, hidden *mlx.Array, caches []cache.Cache, position int32, maxDraft int) *mtpDraftCandidates {
|
||||
if maxDraft <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
lastToken := mtpTokenInput(token)
|
||||
lastHidden := hidden
|
||||
draftTokens := make([]*mlx.Array, 0, maxDraft)
|
||||
draftDists := make([]sampler.Distribution, 0, maxDraft)
|
||||
var prefix *mlx.Array
|
||||
|
||||
// Gemma4 assistant MTP is trained as "single-position" drafting:
|
||||
// keep the RoPE/cache position anchored at the last target-seen token
|
||||
// while the proposed token and projected hidden state advance.
|
||||
for range maxDraft {
|
||||
tokenEmbedding := target.TokenEmbeddings(lastToken)
|
||||
inputs := tokenEmbedding.Concatenate(-1, lastHidden)
|
||||
logits, projected := draft.Draft(inputs, position, caches)
|
||||
stepLogits := r.lastLogitsFromLogits(logits)
|
||||
dist := r.Sampler.Distribution(pipelineSlot, stepLogits, prefix)
|
||||
nextToken := r.Sampler.SampleDistribution(pipelineSlot, dist)
|
||||
|
||||
lastToken = mtpTokenInput(nextToken)
|
||||
lastHidden = projected
|
||||
draftTokens = append(draftTokens, lastToken)
|
||||
draftDists = append(draftDists, dist)
|
||||
if prefix == nil {
|
||||
prefix = lastToken
|
||||
} else {
|
||||
prefix = prefix.Concatenate(1, lastToken)
|
||||
}
|
||||
}
|
||||
if len(draftTokens) == 0 {
|
||||
return nil
|
||||
}
|
||||
return &mtpDraftCandidates{
|
||||
tokens: mlx.Concatenate(draftTokens, 1),
|
||||
dist: sampler.ConcatenateDistributions(draftDists),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Runner) acceptMTPDrafts(ctx context.Context, request Request, session *cacheSession, dec *decoder, caches []cache.Cache, position *int, baseLogits *mlx.Array, draftTokens *mlx.Array, final *CompletionResponse, generated *int, stats *mtpStats, opts mtpOptions) (sampler.Result, int, bool, error) {
|
||||
if opts.serialValidate {
|
||||
stats.serial++
|
||||
return r.acceptMTPDraftsSerial(ctx, request, session, dec, caches, position, baseLogits, draftTokens, final, generated)
|
||||
}
|
||||
|
||||
specCaches, spec, ok := cache.BeginSpeculation(caches)
|
||||
if ok {
|
||||
stats.batched++
|
||||
return r.acceptMTPDraftsBatched(ctx, request, session, dec, caches, specCaches, spec, position, baseLogits, draftTokens, final, generated, stats, opts)
|
||||
}
|
||||
|
||||
stats.serial++
|
||||
return r.acceptMTPDraftsSerial(ctx, request, session, dec, caches, position, baseLogits, draftTokens, final, generated)
|
||||
}
|
||||
|
||||
func (r *Runner) acceptMTPDraftsBatched(ctx context.Context, request Request, session *cacheSession, dec *decoder, liveCaches []cache.Cache, caches []cache.Cache, spec *cache.Speculation, position *int, baseLogits *mlx.Array, draftTokens *mlx.Array, final *CompletionResponse, generated *int, stats *mtpStats, opts mtpOptions) (sampler.Result, int, bool, error) {
|
||||
before := *position
|
||||
draftCount := draftTokens.Dim(1)
|
||||
hiddenSeq := r.Model.Forward(&batch.Batch{
|
||||
InputIDs: draftTokens,
|
||||
SeqOffsets: []int32{int32(before)},
|
||||
SeqQueryLens: []int32{int32(draftCount)},
|
||||
}, caches)
|
||||
|
||||
accepted := 0
|
||||
var next sampler.Result
|
||||
done := false
|
||||
|
||||
selectedTokens := r.mtpValidationTokens(baseLogits, hiddenSeq)
|
||||
mlx.Eval(draftTokens, selectedTokens)
|
||||
draftIDs := draftTokens.Ints()
|
||||
selectedIDs := selectedTokens.Ints()
|
||||
if len(selectedIDs) < draftCount+1 {
|
||||
return sampler.Result{}, accepted, false, fmt.Errorf("mtp validation produced %d tokens for %d draft tokens", len(selectedIDs), draftCount)
|
||||
}
|
||||
|
||||
for i, id := range draftIDs {
|
||||
if selectedIDs[i] != id {
|
||||
next = sampler.Result{Token: mtpTokenAt(selectedTokens, i)}
|
||||
break
|
||||
}
|
||||
accepted++
|
||||
if r.Tokenizer.IsEOS(int32(id)) {
|
||||
done = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if opts.compareSerialValidate {
|
||||
spec.Commit(0)
|
||||
r.compareMTPBatchedWithSerial(ctx, liveCaches, before, baseLogits, hiddenSeq, draftIDs, selectedIDs, accepted, draftCount, stats)
|
||||
}
|
||||
spec.Commit(accepted)
|
||||
*position = before + accepted
|
||||
|
||||
for _, id := range draftIDs[:accepted] {
|
||||
if *generated >= request.Options.NumPredict {
|
||||
done = true
|
||||
break
|
||||
}
|
||||
res := sampler.Result{Token: mlx.FromValues([]int32{int32(id)}, 1)}
|
||||
var err error
|
||||
done, err = r.emitMTPToken(ctx, request, session, dec, res, final)
|
||||
if err != nil {
|
||||
return sampler.Result{}, accepted, done, err
|
||||
}
|
||||
if !done {
|
||||
(*generated)++
|
||||
}
|
||||
if done {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if done || *generated >= request.Options.NumPredict {
|
||||
return sampler.Result{}, accepted, true, nil
|
||||
}
|
||||
if next.Token == nil {
|
||||
next = sampler.Result{Token: mtpTokenAt(selectedTokens, draftCount)}
|
||||
}
|
||||
return next, accepted, false, nil
|
||||
}
|
||||
|
||||
func (r *Runner) acceptSampleMTPDrafts(ctx context.Context, request Request, session *cacheSession, dec *decoder, caches []cache.Cache, position *int, baseLogits *mlx.Array, candidates *mtpDraftCandidates, final *CompletionResponse, generated *int, stats *mtpStats) (sampler.Result, int, bool, error) {
|
||||
specCaches, spec, ok := cache.BeginSpeculation(caches)
|
||||
if !ok {
|
||||
stats.serial++
|
||||
return r.Sampler.Sample([]int{pipelineSlot}, baseLogits), 0, false, nil
|
||||
}
|
||||
stats.batched++
|
||||
|
||||
before := *position
|
||||
draftCount := candidates.tokens.Dim(1)
|
||||
hiddenSeq := r.Model.Forward(&batch.Batch{
|
||||
InputIDs: candidates.tokens,
|
||||
SeqOffsets: []int32{int32(before)},
|
||||
SeqQueryLens: []int32{int32(draftCount)},
|
||||
}, specCaches)
|
||||
|
||||
targetDist := r.Sampler.Distribution(pipelineSlot, r.mtpValidationLogits(baseLogits, hiddenSeq), candidates.tokens)
|
||||
draftDist := candidates.dist
|
||||
acceptedMask := r.mtpSampleAcceptedMask(targetDist.SliceRows(0, draftCount), draftDist, candidates.tokens)
|
||||
mlx.Eval(candidates.tokens, acceptedMask)
|
||||
|
||||
draftIDs := candidates.tokens.Ints()
|
||||
acceptedFlags := acceptedMask.Ints()
|
||||
accepted := 0
|
||||
for _, ok := range acceptedFlags {
|
||||
if ok == 0 {
|
||||
break
|
||||
}
|
||||
accepted++
|
||||
}
|
||||
if accepted > draftCount {
|
||||
return sampler.Result{}, 0, false, fmt.Errorf("mtp sample validation accepted %d tokens for %d draft tokens", accepted, draftCount)
|
||||
}
|
||||
|
||||
commitIDs := make([]int32, 0, accepted+1)
|
||||
done := false
|
||||
for i, id := range draftIDs[:accepted] {
|
||||
commitIDs = append(commitIDs, int32(id))
|
||||
if r.Tokenizer.IsEOS(int32(id)) {
|
||||
done = true
|
||||
accepted = i + 1
|
||||
commitIDs = commitIDs[:accepted]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
spec.Commit(accepted)
|
||||
*position = before + accepted
|
||||
|
||||
for _, id := range draftIDs[:accepted] {
|
||||
if *generated >= request.Options.NumPredict {
|
||||
done = true
|
||||
break
|
||||
}
|
||||
res := sampler.Result{Token: mlx.FromValues([]int32{int32(id)}, 1)}
|
||||
var err error
|
||||
done, err = r.emitMTPToken(ctx, request, session, dec, res, final)
|
||||
if err != nil {
|
||||
return sampler.Result{}, accepted, done, err
|
||||
}
|
||||
if !done {
|
||||
(*generated)++
|
||||
}
|
||||
if done {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if done || *generated >= request.Options.NumPredict {
|
||||
r.Sampler.Commit(pipelineSlot, commitIDs)
|
||||
return sampler.Result{}, accepted, true, nil
|
||||
}
|
||||
|
||||
var nextToken *mlx.Array
|
||||
if accepted == draftCount {
|
||||
nextToken = r.mtpSampleTokenAt(targetDist, draftCount)
|
||||
} else {
|
||||
nextToken = r.mtpSampleResidualToken(targetDist, draftDist, accepted)
|
||||
}
|
||||
mlx.Eval(nextToken)
|
||||
nextID := int32(tokenID(nextToken))
|
||||
commitIDs = append(commitIDs, nextID)
|
||||
r.Sampler.Commit(pipelineSlot, commitIDs)
|
||||
|
||||
return sampler.Result{Token: nextToken}, accepted, false, nil
|
||||
}
|
||||
|
||||
func (r *Runner) mtpSampleAcceptedMask(targetDist, draftDist sampler.Distribution, draftTokens *mlx.Array) *mlx.Array {
|
||||
p := targetDist.Prob(draftTokens)
|
||||
q := draftDist.Prob(draftTokens)
|
||||
acceptP := mlx.Minimum(p.Divide(q), mlx.FromValue(float32(1)))
|
||||
return r.Sampler.Bernoulli(pipelineSlot, acceptP).AsType(mlx.DTypeInt32)
|
||||
}
|
||||
|
||||
func (r *Runner) mtpSampleTokenAt(dist sampler.Distribution, index int) *mlx.Array {
|
||||
return mtpTokenVector(r.Sampler.SampleDistribution(pipelineSlot, dist.SliceRows(index, index+1)))
|
||||
}
|
||||
|
||||
func (r *Runner) mtpSampleResidualToken(targetDist, draftDist sampler.Distribution, index int) *mlx.Array {
|
||||
residual := targetDist.SliceRows(index, index+1).ResidualAgainst(draftDist.SliceRows(index, index+1))
|
||||
return mtpTokenVector(r.Sampler.SampleDistribution(pipelineSlot, residual))
|
||||
}
|
||||
|
||||
func mtpTokenInput(token *mlx.Array) *mlx.Array {
|
||||
switch token.NumDims() {
|
||||
case 0:
|
||||
return token.Reshape(1, 1)
|
||||
case 1:
|
||||
return token.ExpandDims(-1)
|
||||
case 2:
|
||||
return token
|
||||
default:
|
||||
panic(fmt.Sprintf("mtp token must be rank 0, 1, or 2, got rank %d", token.NumDims()))
|
||||
}
|
||||
}
|
||||
|
||||
func mtpTokenVector(token *mlx.Array) *mlx.Array {
|
||||
switch token.NumDims() {
|
||||
case 0:
|
||||
return token.Reshape(1)
|
||||
case 1:
|
||||
return token
|
||||
default:
|
||||
panic(fmt.Sprintf("mtp sampled token must be rank 0 or 1, got rank %d", token.NumDims()))
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Runner) compareMTPBatchedWithSerial(ctx context.Context, caches []cache.Cache, before int, baseLogits, hiddenSeq *mlx.Array, draftIDs, selectedIDs []int, accepted, draftCount int, stats *mtpStats) {
|
||||
serialCaches, ok := cache.BeginIsolatedSpeculation(caches)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
compareCount := accepted + 1
|
||||
if accepted == draftCount {
|
||||
// Include the target bonus token when every draft was accepted.
|
||||
compareCount = draftCount + 1
|
||||
}
|
||||
|
||||
serialLogits := baseLogits
|
||||
for i := range compareCount {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return
|
||||
}
|
||||
if i >= len(selectedIDs) {
|
||||
return
|
||||
}
|
||||
|
||||
batchedLogits := baseLogits
|
||||
if i > 0 {
|
||||
batchedLogits = r.targetLogitsAt(hiddenSeq, i-1)
|
||||
}
|
||||
|
||||
batchedToken := greedyTokenFromLogits(batchedLogits)
|
||||
serialToken := greedyTokenFromLogits(serialLogits)
|
||||
mlx.Eval(batchedToken, serialToken)
|
||||
|
||||
batchedID := tokenID(batchedToken)
|
||||
vectorizedID := selectedIDs[i]
|
||||
serialID := tokenID(serialToken)
|
||||
stats.compared++
|
||||
if vectorizedID != serialID {
|
||||
firstMismatch := stats.batchSerialMismatches == 0
|
||||
stats.batchSerialMismatches++
|
||||
if !firstMismatch {
|
||||
return
|
||||
}
|
||||
|
||||
draftID := -1
|
||||
if i < draftCount {
|
||||
draftID = draftIDs[i]
|
||||
}
|
||||
batchedTop := top2FromLogits(batchedLogits)
|
||||
serialTop := top2FromLogits(serialLogits)
|
||||
slog.Warn("MTP batched validation differs from serial validation",
|
||||
"position", before+i,
|
||||
"draft", draftID,
|
||||
"batched", vectorizedID,
|
||||
"batched_slice", batchedID,
|
||||
"serial", serialID,
|
||||
"batched_slice_top1", batchedTop.firstToken,
|
||||
"batched_slice_top2", batchedTop.secondToken,
|
||||
"batched_slice_margin", batchedTop.margin,
|
||||
"serial_top1", serialTop.firstToken,
|
||||
"serial_top2", serialTop.secondToken,
|
||||
"serial_margin", serialTop.margin,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
if i >= draftCount || i >= accepted {
|
||||
return
|
||||
}
|
||||
|
||||
hidden := r.Model.Forward(&batch.Batch{
|
||||
InputIDs: mlx.FromValues([]int32{int32(draftIDs[i])}, 1, 1),
|
||||
SeqOffsets: []int32{int32(before + i)},
|
||||
SeqQueryLens: []int32{1},
|
||||
}, serialCaches)
|
||||
serialLogits = r.lastLogits(hidden)
|
||||
}
|
||||
}
|
||||
|
||||
type mtpTop2 struct {
|
||||
firstToken int
|
||||
secondToken int
|
||||
margin float64
|
||||
}
|
||||
|
||||
func top2FromLogits(logits *mlx.Array) mtpTop2 {
|
||||
indices := logits.Negative().ArgsortAxis(-1).Slice(mlx.Slice(), mlx.Slice(0, 2))
|
||||
indices32 := indices.AsType(mlx.DTypeInt32)
|
||||
values := logits.TakeAlongAxis(indices, -1).AsType(mlx.DTypeFloat32)
|
||||
mlx.Eval(indices32, values)
|
||||
|
||||
tokenIDs := indices32.Ints()
|
||||
logitValues := values.Floats()
|
||||
if len(tokenIDs) < 2 || len(logitValues) < 2 {
|
||||
return mtpTop2{}
|
||||
}
|
||||
return mtpTop2{
|
||||
firstToken: tokenIDs[0],
|
||||
secondToken: tokenIDs[1],
|
||||
margin: float64(logitValues[0] - logitValues[1]),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Runner) acceptMTPDraftsSerial(ctx context.Context, request Request, session *cacheSession, dec *decoder, caches []cache.Cache, position *int, baseLogits *mlx.Array, draftTokens *mlx.Array, final *CompletionResponse, generated *int) (sampler.Result, int, bool, error) {
|
||||
logits := baseLogits
|
||||
accepted := 0
|
||||
draftIDs := draftTokens.Ints()
|
||||
|
||||
for _, id := range draftIDs {
|
||||
selected := greedyTokenFromLogits(logits)
|
||||
mlx.Eval(selected)
|
||||
selectedID := tokenID(selected)
|
||||
if selectedID != id {
|
||||
return sampler.Result{Token: mlx.FromValues([]int32{int32(selectedID)}, 1)}, accepted, false, nil
|
||||
}
|
||||
|
||||
hidden := r.Model.Forward(&batch.Batch{
|
||||
InputIDs: mlx.FromValues([]int32{int32(id)}, 1, 1),
|
||||
SeqOffsets: []int32{int32(*position)},
|
||||
SeqQueryLens: []int32{1},
|
||||
}, caches)
|
||||
(*position)++
|
||||
accepted++
|
||||
|
||||
res := sampler.Result{Token: mlx.FromValues([]int32{int32(id)}, 1)}
|
||||
done, err := r.emitMTPToken(ctx, request, session, dec, res, final)
|
||||
if err != nil {
|
||||
return sampler.Result{}, accepted, done, err
|
||||
}
|
||||
if !done {
|
||||
(*generated)++
|
||||
}
|
||||
if done || *generated >= request.Options.NumPredict {
|
||||
return sampler.Result{}, accepted, true, nil
|
||||
}
|
||||
|
||||
logits = r.lastLogits(hidden)
|
||||
}
|
||||
|
||||
return sampler.Result{Token: greedyTokenFromLogits(logits)}, accepted, false, nil
|
||||
}
|
||||
|
||||
func (r *Runner) emitMTPToken(ctx context.Context, request Request, session *cacheSession, dec *decoder, res sampler.Result, final *CompletionResponse) (bool, error) {
|
||||
output := int32(tokenID(res.Token))
|
||||
session.outputs = append(session.outputs, output)
|
||||
|
||||
if r.Tokenizer.IsEOS(output) {
|
||||
final.DoneReason = 0
|
||||
return true, nil
|
||||
}
|
||||
|
||||
if resp, ok := dec.decode(res); ok {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false, ctx.Err()
|
||||
case request.Responses <- resp:
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (r *Runner) lastLogits(hidden *mlx.Array) *mlx.Array {
|
||||
logits := r.Model.Unembed(hidden)
|
||||
return r.lastLogitsFromLogits(logits)
|
||||
}
|
||||
|
||||
func (r *Runner) targetLogitsAt(hiddenSeq *mlx.Array, index int) *mlx.Array {
|
||||
hidden := hiddenSeq.Slice(mlx.Slice(), mlx.Slice(index), mlx.Slice())
|
||||
return r.lastLogits(hidden)
|
||||
}
|
||||
|
||||
func (r *Runner) mtpValidationTokens(baseLogits, hiddenSeq *mlx.Array) *mlx.Array {
|
||||
return greedyTokenFromLogits(r.mtpValidationLogits(baseLogits, hiddenSeq))
|
||||
}
|
||||
|
||||
func (r *Runner) mtpValidationLogits(baseLogits, hiddenSeq *mlx.Array) *mlx.Array {
|
||||
seqLogits := r.Model.Unembed(hiddenSeq)
|
||||
return baseLogits.ExpandDims(1).Concatenate(1, seqLogits)
|
||||
}
|
||||
|
||||
func mtpTokenAt(tokens *mlx.Array, index int) *mlx.Array {
|
||||
return tokens.Slice(mlx.Slice(), mlx.Slice(index)).Squeeze(0)
|
||||
}
|
||||
|
||||
func (r *Runner) lastLogitsFromLogits(logits *mlx.Array) *mlx.Array {
|
||||
return logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)
|
||||
}
|
||||
|
||||
func greedyTokenFromLogits(logits *mlx.Array) *mlx.Array {
|
||||
return logits.Argmax(-1, false).AsType(mlx.DTypeInt32)
|
||||
}
|
||||
|
||||
func tokenID(token *mlx.Array) int {
|
||||
if token == nil {
|
||||
return -1
|
||||
}
|
||||
if token.DType() == mlx.DTypeInt32 {
|
||||
ids := token.Ints()
|
||||
if len(ids) > 0 {
|
||||
return ids[0]
|
||||
}
|
||||
}
|
||||
return token.Int()
|
||||
}
|
||||
434
x/mlxrunner/pipeline.go
Normal file
434
x/mlxrunner/pipeline.go
Normal file
@@ -0,0 +1,434 @@
|
||||
package mlxrunner
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/x/mlxrunner/batch"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||
sampler "github.com/ollama/ollama/x/mlxrunner/sample"
|
||||
"github.com/ollama/ollama/x/tokenizer"
|
||||
)
|
||||
|
||||
func prefillChunkSize() int {
|
||||
return 2 << 10
|
||||
}
|
||||
|
||||
// Prepare tokenizes the prompt and validates it against the model's
|
||||
// context length. It is safe to call from any goroutine. On success it
|
||||
// populates request.Tokens and adjusts request.Options.NumPredict.
|
||||
func (r *Runner) Prepare(request *Request) error {
|
||||
if r.Model == nil {
|
||||
return errors.New("model not loaded")
|
||||
}
|
||||
|
||||
tokens := r.Tokenizer.Encode(request.Prompt, r.Tokenizer.AddBOS())
|
||||
if len(tokens) == 0 {
|
||||
return errors.New("empty prompt")
|
||||
}
|
||||
|
||||
if len(tokens) >= r.contextLength {
|
||||
return fmt.Errorf("input length (%d tokens) exceeds the model's maximum context length (%d tokens)", len(tokens), r.contextLength)
|
||||
}
|
||||
|
||||
// Cap generation to stay within the model's context length
|
||||
maxGenerate := r.contextLength - len(tokens)
|
||||
if request.Options.NumPredict <= 0 {
|
||||
request.Options.NumPredict = maxGenerate
|
||||
} else {
|
||||
request.Options.NumPredict = min(request.Options.NumPredict, maxGenerate)
|
||||
}
|
||||
|
||||
request.Tokens = tokens
|
||||
return nil
|
||||
}
|
||||
|
||||
// The runner serializes requests today so we just use a fixed slot ID.
|
||||
const pipelineSlot = 0
|
||||
|
||||
func (r *Runner) TextGenerationPipeline(ctx context.Context, request Request) error {
|
||||
mlx.ResetPeakMemory()
|
||||
var sample, nextSample sampler.Result
|
||||
|
||||
defer func() {
|
||||
r.Sampler.Remove(pipelineSlot)
|
||||
mlx.Unpin(sample.Arrays()...)
|
||||
mlx.Unpin(nextSample.Arrays()...)
|
||||
mlx.Sweep()
|
||||
mlx.ClearCache()
|
||||
|
||||
if slog.Default().Enabled(context.TODO(), logutil.LevelTrace) {
|
||||
mlx.LogArrays()
|
||||
r.cache.dumpTree()
|
||||
}
|
||||
slog.Info("peak memory", "size", mlx.PrettyBytes(mlx.PeakMemory()))
|
||||
}()
|
||||
|
||||
inputs := request.Tokens
|
||||
|
||||
session := r.cache.begin(r.Model, inputs)
|
||||
defer session.close()
|
||||
|
||||
caches := session.caches
|
||||
tokens := session.remaining
|
||||
prefillChunk := prefillChunkSize()
|
||||
dflashMode, dflashDisabledReason := r.dflashGate(request.SamplerOpts)
|
||||
dflashEnabled := dflashMode.enabled()
|
||||
var dflashDraft base.DFlashDraftModel
|
||||
var dflashTarget base.DFlashTargetModel
|
||||
var dflashCaches []cache.Cache
|
||||
var dflashSession *cacheSession
|
||||
if dflashEnabled {
|
||||
dflashDraft = r.Draft.(base.DFlashDraftModel)
|
||||
dflashTarget = r.Model.(base.DFlashTargetModel)
|
||||
targetCachedPrefix := len(inputs) - len(tokens)
|
||||
dflashSession = r.dflashCache.beginWithFactoryLimit(inputs, dflashDraft.NewCaches, "DFlash draft", targetCachedPrefix, false)
|
||||
dflashCaches = dflashSession.caches
|
||||
defer func() {
|
||||
dflashSession.outputs = append([]int32(nil), session.outputs...)
|
||||
dflashSession.close()
|
||||
}()
|
||||
} else if _, ok := r.Draft.(base.DFlashDraftModel); ok {
|
||||
slog.Info("DFlash decode disabled",
|
||||
"reason", dflashDisabledReason,
|
||||
"temperature", request.SamplerOpts.Temperature,
|
||||
"top_p", request.SamplerOpts.TopP,
|
||||
"top_k", request.SamplerOpts.TopK,
|
||||
"min_p", request.SamplerOpts.MinP,
|
||||
"repeat_penalty", request.SamplerOpts.RepeatPenalty,
|
||||
"presence_penalty", request.SamplerOpts.PresencePenalty,
|
||||
"frequency_penalty", request.SamplerOpts.FrequencyPenalty,
|
||||
"logprobs", request.SamplerOpts.Logprobs,
|
||||
"top_logprobs", request.SamplerOpts.TopLogprobs,
|
||||
)
|
||||
}
|
||||
|
||||
requestPipelineSnapshots := func(s *cacheSession) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
// Request periodic snapshots during prefill and near the end of the
|
||||
// prompt so that long prompts can be partially restored and
|
||||
// thinking/generation can be retried without full reprocessing.
|
||||
const snapshotInterval = 8192
|
||||
for offset := snapshotInterval; offset < len(inputs); offset += snapshotInterval {
|
||||
s.requestSnapshot(offset)
|
||||
}
|
||||
|
||||
const preThinking = 4
|
||||
if end := len(inputs) - preThinking; end > 0 {
|
||||
s.requestSnapshot(end)
|
||||
}
|
||||
}
|
||||
requestPipelineSnapshots(session)
|
||||
requestPipelineSnapshots(dflashSession)
|
||||
|
||||
nextSnapshotOffset := func() int {
|
||||
next := session.nextPendingSnapshot()
|
||||
if dflashSession != nil {
|
||||
if offset := dflashSession.nextPendingSnapshot(); offset > 0 && (next == 0 || offset < next) {
|
||||
next = offset
|
||||
}
|
||||
}
|
||||
return next
|
||||
}
|
||||
|
||||
snapshotReadySessions := func(position int) {
|
||||
if snapOffset := session.nextPendingSnapshot(); snapOffset > 0 && position >= snapOffset {
|
||||
session.snapshot()
|
||||
}
|
||||
if dflashSession != nil {
|
||||
if snapOffset := dflashSession.nextPendingSnapshot(); snapOffset > 0 && position >= snapOffset {
|
||||
dflashSession.snapshot()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
materializeCaches := func(cacheSets ...[]cache.Cache) {
|
||||
if len(cacheSets) == 0 {
|
||||
cacheSets = [][]cache.Cache{caches}
|
||||
}
|
||||
state := make([]*mlx.Array, 0, 2*len(caches))
|
||||
for _, set := range cacheSets {
|
||||
for _, c := range set {
|
||||
if c == nil {
|
||||
continue
|
||||
}
|
||||
state = append(state, c.State()...)
|
||||
}
|
||||
}
|
||||
if len(state) == 0 {
|
||||
return
|
||||
}
|
||||
mlx.Eval(state...)
|
||||
}
|
||||
|
||||
if dflashEnabled {
|
||||
targetCachedPrefix := len(inputs) - len(tokens)
|
||||
dflashCachedPrefix := len(inputs) - len(dflashSession.remaining)
|
||||
if targetCachedPrefix > dflashCachedPrefix {
|
||||
t0 := time.Now()
|
||||
rebuildCaches := newDFlashTargetCaches(r.Model)
|
||||
rebuildProcessed := 0
|
||||
for targetCachedPrefix-rebuildProcessed > 0 {
|
||||
if err := ctx.Err(); err != nil {
|
||||
freeCacheSet(rebuildCaches)
|
||||
return err
|
||||
}
|
||||
n := min(prefillChunk, targetCachedPrefix-rebuildProcessed)
|
||||
if snapOffset := dflashSession.nextPendingSnapshot(); snapOffset > rebuildProcessed && snapOffset < rebuildProcessed+n {
|
||||
n = snapOffset - rebuildProcessed
|
||||
}
|
||||
start, end := rebuildProcessed, rebuildProcessed+n
|
||||
b := &batch.Batch{
|
||||
InputIDs: mlx.FromValues(inputs[start:end], 1, n),
|
||||
SeqOffsets: []int32{int32(start)},
|
||||
SeqQueryLens: []int32{int32(n)},
|
||||
}
|
||||
_, targetHidden := dflashTarget.ForwardDFlash(b, rebuildCaches, dflashDraft.TargetLayerIDs())
|
||||
if end > dflashCachedPrefix {
|
||||
appendHidden := targetHidden
|
||||
if start < dflashCachedPrefix {
|
||||
appendHidden = targetHidden.Slice(mlx.Slice(), mlx.Slice(dflashCachedPrefix-start, n), mlx.Slice())
|
||||
}
|
||||
dflashDraft.AppendContext(appendHidden, dflashCaches)
|
||||
}
|
||||
mlx.Sweep()
|
||||
materializeCaches(rebuildCaches, dflashCaches)
|
||||
rebuildProcessed = end
|
||||
if snapOffset := dflashSession.nextPendingSnapshot(); snapOffset > 0 && rebuildProcessed >= snapOffset {
|
||||
dflashSession.snapshot()
|
||||
}
|
||||
mlx.ClearCache()
|
||||
}
|
||||
freeCacheSet(rebuildCaches)
|
||||
slog.Info("DFlash draft cache rebuild",
|
||||
"target_cached", targetCachedPrefix,
|
||||
"draft_cached", dflashCachedPrefix,
|
||||
"rebuilt", targetCachedPrefix-dflashCachedPrefix,
|
||||
"draft_offset", r.dflashCache.minCacheOffset(),
|
||||
"duration", time.Since(t0),
|
||||
)
|
||||
} else {
|
||||
slog.Info("DFlash draft cache restored",
|
||||
"target_cached", targetCachedPrefix,
|
||||
"draft_cached", dflashCachedPrefix,
|
||||
"draft_offset", r.dflashCache.minCacheOffset(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
total, processed := len(tokens), 0
|
||||
position := len(inputs) - len(tokens)
|
||||
for total-processed > 1 {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
n := min(prefillChunk, total-processed-1)
|
||||
|
||||
// If there's a pending snapshot, split the batch so we can
|
||||
// capture it at the exact offset.
|
||||
if snapOffset := nextSnapshotOffset(); snapOffset > 0 {
|
||||
tokensUntilSnapshot := snapOffset - position
|
||||
if tokensUntilSnapshot > 0 && tokensUntilSnapshot < n {
|
||||
n = tokensUntilSnapshot
|
||||
}
|
||||
}
|
||||
|
||||
b := &batch.Batch{
|
||||
InputIDs: mlx.FromValues(tokens[processed:processed+n], 1, n),
|
||||
SeqOffsets: []int32{int32(position)},
|
||||
SeqQueryLens: []int32{int32(n)},
|
||||
}
|
||||
if dflashEnabled {
|
||||
_, targetHidden := dflashTarget.ForwardDFlash(b, caches, dflashDraft.TargetLayerIDs())
|
||||
dflashDraft.AppendContext(targetHidden, dflashCaches)
|
||||
} else {
|
||||
r.Model.Forward(b, caches)
|
||||
}
|
||||
mlx.Sweep()
|
||||
if dflashEnabled {
|
||||
materializeCaches(caches, dflashCaches)
|
||||
} else {
|
||||
materializeCaches()
|
||||
}
|
||||
processed += n
|
||||
position += n
|
||||
slog.Info("Prompt processing progress", "processed", processed, "total", total)
|
||||
logutil.TraceContext(ctx, "mlx prompt forward", "processed", processed, "total", total, "tokens", n, "memory", mlx.Memory{})
|
||||
|
||||
// Create snapshot if we've reached a pending offset.
|
||||
snapshotReadySessions(position)
|
||||
|
||||
mlx.ClearCache()
|
||||
}
|
||||
|
||||
// Register the sampler after prefill completes.
|
||||
r.Sampler.Add(pipelineSlot, request.SamplerOpts, inputs)
|
||||
if dflashMode == dflashDecodeGreedy {
|
||||
return r.runGreedyDFlashDecode(ctx, request, session, caches, dflashCaches, tokens[processed:], &position, now)
|
||||
}
|
||||
if dflashMode == dflashDecodeSample {
|
||||
return r.runSampleDFlashDecode(ctx, request, session, caches, dflashCaches, tokens[processed:], &position, now)
|
||||
}
|
||||
if r.useGreedyMTP(request.SamplerOpts) {
|
||||
return r.runGreedyMTPDecode(ctx, request, session, caches, tokens[processed:], &position, now)
|
||||
}
|
||||
if r.useSampleMTP(request.SamplerOpts) {
|
||||
return r.runSampleMTPDecode(ctx, request, session, caches, tokens[processed:], &position, now)
|
||||
}
|
||||
|
||||
step := func(token *mlx.Array) sampler.Result {
|
||||
fwd := r.Model.Forward(&batch.Batch{
|
||||
InputIDs: token,
|
||||
SeqOffsets: []int32{int32(position)},
|
||||
SeqQueryLens: []int32{int32(token.Dim(1))},
|
||||
}, caches)
|
||||
position += token.Dim(1)
|
||||
logits := r.Model.Unembed(fwd)
|
||||
logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)
|
||||
|
||||
sample := r.Sampler.Sample([]int{pipelineSlot}, logits)
|
||||
mlx.Pin(sample.Arrays()...)
|
||||
mlx.Sweep()
|
||||
mlx.AsyncEval(sample.Arrays()...)
|
||||
return sample
|
||||
}
|
||||
|
||||
sample = step(mlx.FromValues(tokens[processed:], 1, total-processed))
|
||||
logutil.TraceContext(ctx, "mlx decode seed", "tokens", total-processed, "memory", mlx.Memory{})
|
||||
|
||||
dec := decoder{
|
||||
tokenizer: r.Tokenizer,
|
||||
wantLogprobs: request.SamplerOpts.Logprobs,
|
||||
wantTopLogprobs: request.SamplerOpts.TopLogprobs,
|
||||
}
|
||||
|
||||
final := CompletionResponse{Done: true, PromptEvalCount: len(inputs), EvalCount: request.Options.NumPredict, DoneReason: 1}
|
||||
for i := range request.Options.NumPredict {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
nextSample = step(sample.Token.ExpandDims(-1))
|
||||
|
||||
if i == 0 {
|
||||
mlx.Eval(sample.Arrays()...)
|
||||
final.PromptEvalDuration = time.Since(now)
|
||||
now = time.Now()
|
||||
}
|
||||
|
||||
output := int32(sample.Token.Int())
|
||||
session.outputs = append(session.outputs, output)
|
||||
if i == 0 {
|
||||
logutil.TraceContext(ctx, "mlx decode first token", "memory", mlx.Memory{})
|
||||
}
|
||||
|
||||
if r.Tokenizer.IsEOS(output) {
|
||||
final.DoneReason = 0
|
||||
final.EvalCount = i
|
||||
break
|
||||
}
|
||||
|
||||
if resp, ok := dec.decode(sample); ok {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case request.Responses <- resp:
|
||||
}
|
||||
}
|
||||
|
||||
mlx.Unpin(sample.Arrays()...)
|
||||
sample, nextSample = nextSample, sampler.Result{}
|
||||
|
||||
if i%256 == 0 {
|
||||
mlx.ClearCache()
|
||||
}
|
||||
}
|
||||
|
||||
final.EvalDuration = time.Since(now)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case request.Responses <- final:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// decoder serializes sampled tokens into response chunks, holding bytes
|
||||
// whose UTF-8 sequence hasn't completed yet and the logprobs that belong
|
||||
// with those bytes so Content and Logprobs stay aligned when a chunk does
|
||||
// flush.
|
||||
type decoder struct {
|
||||
tokenizer *tokenizer.Tokenizer
|
||||
buf bytes.Buffer
|
||||
logprobs []llm.Logprob
|
||||
wantLogprobs bool
|
||||
wantTopLogprobs int
|
||||
}
|
||||
|
||||
func (d *decoder) decode(res sampler.Result) (CompletionResponse, bool) {
|
||||
output := int32(res.Token.Int())
|
||||
d.buf.WriteString(d.tokenizer.Decode([]int32{output}))
|
||||
d.logprobs = append(d.logprobs, buildLogprob(res, d.wantLogprobs, d.wantTopLogprobs, d.tokenizer.Decode)...)
|
||||
|
||||
content := flushValidUTF8Prefix(&d.buf)
|
||||
if content == "" {
|
||||
return CompletionResponse{}, false
|
||||
}
|
||||
resp := CompletionResponse{Content: content, Logprobs: d.logprobs}
|
||||
d.logprobs = nil
|
||||
return resp, true
|
||||
}
|
||||
|
||||
// buildLogprob converts the sampler's logprob tensors into the wire-format
|
||||
// llm.Logprob entries the caller wants. The sampler populates its logprob
|
||||
// tensors whenever any registered slot requested them, so the caller must
|
||||
// gate emission on its own request config (wantLogprobs / wantTopLogprobs)
|
||||
// rather than on whether the tensors happen to be non-nil.
|
||||
func buildLogprob(sample sampler.Result, wantLogprobs bool, wantTopLogprobs int, decode func([]int32) string) []llm.Logprob {
|
||||
if !wantLogprobs || sample.Logprob == nil {
|
||||
return nil
|
||||
}
|
||||
tok := func(id int32) string { return decode([]int32{id}) }
|
||||
|
||||
out := llm.Logprob{
|
||||
TokenLogprob: llm.TokenLogprob{
|
||||
Token: tok(int32(sample.Token.Int())),
|
||||
Logprob: float64(sample.Logprob.Floats()[0]),
|
||||
},
|
||||
}
|
||||
|
||||
if wantTopLogprobs > 0 && sample.TopTokens != nil {
|
||||
ids := sample.TopTokens.Ints()
|
||||
vals := sample.TopLogprobs.Floats()
|
||||
pairs := make([]llm.TokenLogprob, len(ids))
|
||||
for i, id := range ids {
|
||||
pairs[i] = llm.TokenLogprob{
|
||||
Token: tok(int32(id)),
|
||||
Logprob: float64(vals[i]),
|
||||
}
|
||||
}
|
||||
// The sampler emits the top maxK across registered slots via
|
||||
// Argpartition, which leaves entries unsorted.
|
||||
sort.Slice(pairs, func(i, j int) bool {
|
||||
return pairs[i].Logprob > pairs[j].Logprob
|
||||
})
|
||||
if wantTopLogprobs < len(pairs) {
|
||||
pairs = pairs[:wantTopLogprobs]
|
||||
}
|
||||
out.TopLogprobs = pairs
|
||||
}
|
||||
return []llm.Logprob{out}
|
||||
}
|
||||
210
x/mlxrunner/runner.go
Normal file
210
x/mlxrunner/runner.go
Normal file
@@ -0,0 +1,210 @@
|
||||
package mlxrunner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/x/internal/mlxthread"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
||||
"github.com/ollama/ollama/x/mlxrunner/sample"
|
||||
"github.com/ollama/ollama/x/tokenizer"
|
||||
)
|
||||
|
||||
// Request is a short-lived struct that carries a completion request through
|
||||
// a channel from the HTTP handler to the runner goroutine. The ctx field
|
||||
// must travel with the request so that cancellation propagates across the
|
||||
// channel boundary.
|
||||
type Request struct {
|
||||
CompletionRequest
|
||||
Responses chan CompletionResponse
|
||||
Pipeline func(context.Context, Request) error
|
||||
|
||||
Ctx context.Context //nolint:containedctx
|
||||
Tokens []int32
|
||||
SamplerOpts sample.Options
|
||||
}
|
||||
|
||||
type Runner struct {
|
||||
Model base.Model
|
||||
Draft base.DraftModel
|
||||
Tokenizer *tokenizer.Tokenizer
|
||||
Requests chan Request
|
||||
Sampler *sample.Sampler
|
||||
cache kvCache
|
||||
dflashCache kvCache
|
||||
contextLength int
|
||||
mlxThread *mlxthread.Thread
|
||||
}
|
||||
|
||||
func (r *Runner) Load(modelName string) error {
|
||||
root, err := model.Open(modelName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer root.Close()
|
||||
|
||||
m, err := base.New(root)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Load all tensor blobs from manifest
|
||||
tensors, err := loadTensorsFromManifest(root)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Assign weights to model (model-specific logic). Target and draft weights
|
||||
// must be loaded before sweeping so tensors from a combined manifest are
|
||||
// not discarded before the draft model can retain them.
|
||||
if err := m.LoadWeights(tensors); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r.Draft = nil
|
||||
draft, err := base.NewDraft(root, m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if draft != nil {
|
||||
if err := draft.LoadWeights(tensors); err != nil {
|
||||
return err
|
||||
}
|
||||
r.Draft = draft
|
||||
}
|
||||
|
||||
collected := mlx.Collect(m)
|
||||
if draft != nil {
|
||||
draftArrays := mlx.Collect(draft)
|
||||
collected = append(collected, draftArrays...)
|
||||
if root.Draft != nil {
|
||||
slog.Info("Loaded draft model", "tensor_prefix", root.Draft.TensorPrefix, "config", root.Draft.Config, "arrays", len(draftArrays))
|
||||
} else {
|
||||
slog.Info("Loaded draft model", "arrays", len(draftArrays))
|
||||
}
|
||||
}
|
||||
for _, arr := range collected {
|
||||
mlx.Pin(arr)
|
||||
}
|
||||
mlx.Sweep()
|
||||
mlx.Eval(collected...)
|
||||
|
||||
r.Model = m
|
||||
r.Tokenizer = m.Tokenizer()
|
||||
r.contextLength = m.MaxContextLength()
|
||||
r.Sampler = sample.New(r.contextLength)
|
||||
|
||||
mlx.EnableCompile()
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadTensorsFromManifest loads all tensor blobs from the manifest into a
|
||||
// flat map, deduplicating by digest and remapping safetensors key suffixes.
|
||||
//
|
||||
// Uses a two-phase approach: first loads all raw tensors, then remaps
|
||||
// .bias → _qbias with complete knowledge of which base names have .scale
|
||||
// entries. This avoids a race condition where Go map iteration order could
|
||||
// cause .bias to be processed before .scale within the same blob.
|
||||
func loadTensorsFromManifest(root *model.Root) (map[string]*mlx.Array, error) {
|
||||
// Phase 1: Load all tensors raw from all blobs
|
||||
rawTensors := make(map[string]*mlx.Array)
|
||||
seen := make(map[string]bool)
|
||||
for _, layer := range root.Manifest.GetTensorLayers("") {
|
||||
if seen[layer.Digest] {
|
||||
continue
|
||||
}
|
||||
seen[layer.Digest] = true
|
||||
blobPath := root.Manifest.BlobPath(layer.Digest)
|
||||
for name, arr := range mlx.Load(blobPath) {
|
||||
rawTensors[name] = arr
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 2: Identify all base names that have .scale tensors and remap them
|
||||
scaleBaseNames := make(map[string]bool)
|
||||
allTensors := make(map[string]*mlx.Array, len(rawTensors))
|
||||
for name, arr := range rawTensors {
|
||||
if strings.HasSuffix(name, ".scale") {
|
||||
baseName := strings.TrimSuffix(name, ".scale")
|
||||
allTensors[baseName+"_scale"] = arr
|
||||
scaleBaseNames[baseName] = true
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 3: Process remaining tensors with complete scale knowledge
|
||||
for name, arr := range rawTensors {
|
||||
if strings.HasSuffix(name, ".scale") {
|
||||
continue // already handled
|
||||
}
|
||||
if strings.HasSuffix(name, ".bias") && !strings.HasSuffix(name, ".weight_qbias") {
|
||||
baseName := strings.TrimSuffix(name, ".bias")
|
||||
if scaleBaseNames[baseName] {
|
||||
allTensors[baseName+"_qbias"] = arr
|
||||
} else {
|
||||
allTensors[name] = arr
|
||||
}
|
||||
} else {
|
||||
allTensors[name] = arr
|
||||
}
|
||||
}
|
||||
|
||||
slog.Info("Loaded tensors from manifest", "count", len(allTensors))
|
||||
return allTensors, nil
|
||||
}
|
||||
|
||||
func (r *Runner) Run(host, port string, mux http.Handler) error {
|
||||
g, ctx := errgroup.WithContext(context.Background())
|
||||
|
||||
g.Go(func() error {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case request := <-r.Requests:
|
||||
err := r.runRequest(request)
|
||||
if err != nil {
|
||||
slog.Info("Request terminated", "error", err)
|
||||
var statusErr api.StatusError
|
||||
if !errors.As(err, &statusErr) {
|
||||
statusErr = api.StatusError{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
ErrorMessage: err.Error(),
|
||||
}
|
||||
}
|
||||
select {
|
||||
case request.Responses <- CompletionResponse{Error: &statusErr}:
|
||||
case <-request.Ctx.Done():
|
||||
}
|
||||
}
|
||||
|
||||
close(request.Responses)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
g.Go(func() error {
|
||||
slog.Info("Starting HTTP server", "host", host, "port", port)
|
||||
return http.ListenAndServe(net.JoinHostPort(host, port), mux)
|
||||
})
|
||||
|
||||
return g.Wait()
|
||||
}
|
||||
|
||||
func (r *Runner) runRequest(request Request) error {
|
||||
if r.mlxThread == nil {
|
||||
return request.Pipeline(request.Ctx, request)
|
||||
}
|
||||
|
||||
return r.mlxThread.Do(request.Ctx, func() error {
|
||||
return request.Pipeline(request.Ctx, request)
|
||||
})
|
||||
}
|
||||
300
x/mlxrunner/sample/logprob_test.go
Normal file
300
x/mlxrunner/sample/logprob_test.go
Normal file
@@ -0,0 +1,300 @@
|
||||
//go:build mlx
|
||||
|
||||
package sample
|
||||
|
||||
import (
|
||||
"math"
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
// logprobEntry is the (token id, logprob) pair returned by the sampler's
|
||||
// top-K extraction, used after the test-side descending sort.
|
||||
type logprobEntry struct {
|
||||
id int
|
||||
logprob float64
|
||||
}
|
||||
|
||||
// runSampleLogprobs drives Sample on a fresh Sampler configured for logprobs
|
||||
// and returns the greedily-sampled token id, its logprob, and the top-K
|
||||
// entries sorted descending by logprob. Logits must be a [vocab]-shaped
|
||||
// slice; the helper reshapes it to [1, vocab] before calling the sampler.
|
||||
func runSampleLogprobs(t *testing.T, logits []float32, topK int) (int, float64, []logprobEntry) {
|
||||
t.Helper()
|
||||
|
||||
s := New(128)
|
||||
defer func() {
|
||||
s.Free()
|
||||
mlx.Sweep()
|
||||
}()
|
||||
s.Add(0, Options{Logprobs: true, TopLogprobs: topK}, nil)
|
||||
|
||||
tensor := mlx.FromValues(logits, 1, len(logits))
|
||||
res := s.Sample([]int{0}, tensor)
|
||||
|
||||
mlx.Pin(res.Arrays()...)
|
||||
defer mlx.Unpin(res.Arrays()...)
|
||||
mlx.Sweep()
|
||||
mlx.Eval(res.Arrays()...)
|
||||
|
||||
selected := res.Token.Int()
|
||||
selLP := float64(res.Logprob.Floats()[0])
|
||||
|
||||
var top []logprobEntry
|
||||
if topK > 0 && res.TopTokens != nil {
|
||||
ids := res.TopTokens.Ints()
|
||||
vals := res.TopLogprobs.Floats()
|
||||
top = make([]logprobEntry, len(ids))
|
||||
for i, id := range ids {
|
||||
top[i] = logprobEntry{id: id, logprob: float64(vals[i])}
|
||||
}
|
||||
sort.Slice(top, func(i, j int) bool { return top[i].logprob > top[j].logprob })
|
||||
}
|
||||
return selected, selLP, top
|
||||
}
|
||||
|
||||
func TestSampleLogprobsBasic(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
logits []float32
|
||||
topK int
|
||||
wantSelectedID int
|
||||
wantTopLen int
|
||||
}{
|
||||
{
|
||||
name: "single token without top logprobs",
|
||||
logits: []float32{1.0, 0.5, 0.3, 0.1},
|
||||
topK: 0,
|
||||
wantSelectedID: 0,
|
||||
wantTopLen: 0,
|
||||
},
|
||||
{
|
||||
name: "single token with top logprobs",
|
||||
logits: []float32{1.0, 0.5, 0.3, 0.1},
|
||||
topK: 3,
|
||||
wantSelectedID: 0,
|
||||
wantTopLen: 3,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
selected, _, top := runSampleLogprobs(t, tt.logits, tt.topK)
|
||||
if selected != tt.wantSelectedID {
|
||||
t.Errorf("selected = %d, want %d", selected, tt.wantSelectedID)
|
||||
}
|
||||
if len(top) != tt.wantTopLen {
|
||||
t.Errorf("top-K length = %d, want %d", len(top), tt.wantTopLen)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSampleLogprobsNumericalStability(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
logits := []float32{1000.0, 999.0, 998.0}
|
||||
_, selLP, top := runSampleLogprobs(t, logits, 3)
|
||||
|
||||
if math.IsInf(selLP, 0) || math.IsNaN(selLP) {
|
||||
t.Errorf("selected logprob is not finite: %f", selLP)
|
||||
}
|
||||
for i, e := range top {
|
||||
if math.IsInf(e.logprob, 0) || math.IsNaN(e.logprob) {
|
||||
t.Errorf("top[%d] logprob is not finite: %f", i, e.logprob)
|
||||
}
|
||||
}
|
||||
for i := 1; i < len(top); i++ {
|
||||
if top[i].logprob > top[i-1].logprob {
|
||||
t.Errorf("top logprobs not descending: %f > %f", top[i].logprob, top[i-1].logprob)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSampleLogprobsProbabilityCorrectness(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
logits []float32
|
||||
}{
|
||||
{"uniform", []float32{1.0, 1.0, 1.0, 1.0}},
|
||||
{"different", []float32{2.0, 1.0, 0.5, 0.1}},
|
||||
{"negative", []float32{-1.0, -2.0, -3.0, -4.0}},
|
||||
{"mixed", []float32{5.0, -5.0, 0.0, 2.5}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
selected, selLP, top := runSampleLogprobs(t, tt.logits, len(tt.logits))
|
||||
|
||||
if selLP > 0 {
|
||||
t.Errorf("selected logprob should be <= 0, got %f", selLP)
|
||||
}
|
||||
for i, e := range top {
|
||||
if e.logprob > 0 {
|
||||
t.Errorf("top[%d] logprob should be <= 0, got %f", i, e.logprob)
|
||||
}
|
||||
}
|
||||
|
||||
if tt.name == "uniform" {
|
||||
want := 1.0 / float64(len(tt.logits))
|
||||
got := math.Exp(selLP)
|
||||
if math.Abs(got-want) > 1e-6 {
|
||||
t.Errorf("uniform logits: selected prob = %f, want %f", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
for i := 1; i < len(top); i++ {
|
||||
if top[i].logprob > top[i-1].logprob {
|
||||
t.Errorf("top logprobs not descending at %d: %f > %f",
|
||||
i, top[i].logprob, top[i-1].logprob)
|
||||
}
|
||||
}
|
||||
|
||||
found := false
|
||||
for _, e := range top {
|
||||
if e.id == selected {
|
||||
found = true
|
||||
if math.Abs(e.logprob-selLP) > 1e-6 {
|
||||
t.Errorf("selected logprob mismatch: selLP=%f top=%f", selLP, e.logprob)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("selected token %d not present in top-K", selected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSampleLogprobsSoftmaxCorrectness(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
logits []float32
|
||||
}{
|
||||
{"small vocabulary", []float32{1.0, 2.0, 3.0}},
|
||||
{"large differences", []float32{10.0, 0.0, -10.0}},
|
||||
{"all equal", []float32{5.0, 5.0, 5.0, 5.0, 5.0}},
|
||||
{"very large values", []float32{500.0, 499.0, 498.0}},
|
||||
{"very small values", []float32{-500.0, -499.0, -498.0}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, _, top := runSampleLogprobs(t, tt.logits, len(tt.logits))
|
||||
if len(top) != len(tt.logits) {
|
||||
t.Fatalf("top-K length = %d, want %d", len(top), len(tt.logits))
|
||||
}
|
||||
|
||||
var sum float64
|
||||
for _, e := range top {
|
||||
p := math.Exp(e.logprob)
|
||||
if p < 0 || p > 1 {
|
||||
t.Errorf("token %d: probability %f out of [0,1]", e.id, p)
|
||||
}
|
||||
sum += p
|
||||
}
|
||||
|
||||
if math.Abs(sum-1.0) > 1e-5 {
|
||||
t.Errorf("probabilities sum = %f, want 1.0", sum)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSampleLogprobsSelectedTokenCorrectness(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
logits := []float32{3.0, 1.0, 2.0, 0.5}
|
||||
|
||||
maxIdx := 0
|
||||
for i, v := range logits[1:] {
|
||||
if v > logits[maxIdx] {
|
||||
maxIdx = i + 1
|
||||
}
|
||||
}
|
||||
|
||||
selected, selLP, top := runSampleLogprobs(t, logits, len(logits))
|
||||
|
||||
if selected != maxIdx {
|
||||
t.Errorf("selected = %d, want argmax %d", selected, maxIdx)
|
||||
}
|
||||
|
||||
if top[0].id != maxIdx {
|
||||
t.Errorf("top[0].id = %d, want argmax %d", top[0].id, maxIdx)
|
||||
}
|
||||
if math.Abs(top[0].logprob-selLP) > 1e-6 {
|
||||
t.Errorf("top[0].logprob = %f, want selected %f", top[0].logprob, selLP)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBatchedLogprobsPerRow verifies that per-row logprobs in a batched
|
||||
// sample call match the per-slot reference. The numerically-stable softmax
|
||||
// must reduce along the last axis only, not over the whole batch.
|
||||
func TestBatchedLogprobsPerRow(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
rowA := []float32{2, 1, 0}
|
||||
rowB := []float32{0, 5, 0}
|
||||
|
||||
_, wantA, _ := runSampleLogprobs(t, rowA, 0)
|
||||
_, wantB, _ := runSampleLogprobs(t, rowB, 0)
|
||||
|
||||
s := New(128)
|
||||
t.Cleanup(func() {
|
||||
s.Free()
|
||||
mlx.Sweep()
|
||||
})
|
||||
s.Add(1, Options{Logprobs: true}, nil)
|
||||
s.Add(2, Options{Logprobs: true}, nil)
|
||||
|
||||
logits := mlx.FromValues(append(append([]float32{}, rowA...), rowB...), 2, 3)
|
||||
res := s.Sample([]int{1, 2}, logits)
|
||||
mlx.Pin(res.Arrays()...)
|
||||
t.Cleanup(func() { mlx.Unpin(res.Arrays()...) })
|
||||
mlx.Eval(res.Arrays()...)
|
||||
|
||||
got := res.Logprob.Floats()
|
||||
if len(got) != 2 {
|
||||
t.Fatalf("Logprob length = %d, want 2", len(got))
|
||||
}
|
||||
if math.Abs(float64(got[0])-wantA) > 1e-5 {
|
||||
t.Errorf("row 0 logprob = %f, want %f (per-slot reference)", got[0], wantA)
|
||||
}
|
||||
if math.Abs(float64(got[1])-wantB) > 1e-5 {
|
||||
t.Errorf("row 1 logprob = %f, want %f (per-slot reference)", got[1], wantB)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSampleLogprobsTopKOrdering(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
// Logits chosen so argmax order differs from index order.
|
||||
logits := []float32{2.0, 5.0, 1.0, 4.0, 3.0}
|
||||
wantOrder := []int{1, 3, 4, 0, 2}
|
||||
|
||||
_, _, top := runSampleLogprobs(t, logits, len(logits))
|
||||
|
||||
if len(top) != len(wantOrder) {
|
||||
t.Fatalf("top-K length = %d, want %d", len(top), len(wantOrder))
|
||||
}
|
||||
for i, e := range top {
|
||||
if e.id != wantOrder[i] {
|
||||
t.Errorf("top[%d].id = %d, want %d", i, e.id, wantOrder[i])
|
||||
}
|
||||
}
|
||||
for i := 1; i < len(top); i++ {
|
||||
if top[i].logprob > top[i-1].logprob {
|
||||
t.Errorf("top[%d].logprob (%f) > top[%d].logprob (%f)",
|
||||
i, top[i].logprob, i-1, top[i-1].logprob)
|
||||
}
|
||||
}
|
||||
}
|
||||
897
x/mlxrunner/sample/sample.go
Normal file
897
x/mlxrunner/sample/sample.go
Normal file
@@ -0,0 +1,897 @@
|
||||
package sample
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
Temperature float32
|
||||
TopP float32
|
||||
MinP float32
|
||||
TopK int
|
||||
RepeatLastN int
|
||||
RepeatPenalty float32
|
||||
PresencePenalty float32
|
||||
FrequencyPenalty float32
|
||||
Seed int
|
||||
UseSeed bool
|
||||
|
||||
// Logprobs causes Sample to populate Result.Logprob with the selected
|
||||
// token's log-probability. TopLogprobs (when > 0) adds top-K pairs.
|
||||
Logprobs bool
|
||||
TopLogprobs int
|
||||
}
|
||||
|
||||
// Result bundles the outputs of one decode step. Logprob/TopTokens/
|
||||
// TopLogprobs are populated whenever any registered slot has Logprobs
|
||||
// (respectively TopLogprobs>0). Consumers need to filter by their
|
||||
// per-slot Options.
|
||||
type Result struct {
|
||||
Token *mlx.Array // sampled token ids, shape [B]
|
||||
Logprob *mlx.Array // sampled-token logprobs, shape [B,1]; nil unless any registered slot has Logprobs
|
||||
TopTokens *mlx.Array // top-K token ids, shape [B,maxK]; nil unless any registered slot has TopLogprobs>0
|
||||
TopLogprobs *mlx.Array // top-K logprobs, shape [B,maxK]; same
|
||||
}
|
||||
|
||||
// Arrays returns the tensor fields as a slice so callers can drive the mlx
|
||||
// lifecycle verbs (Pin, Unpin, Eval, AsyncEval) over the whole group. Unset
|
||||
// fields stay nil; the mlx helpers skip them.
|
||||
func (r Result) Arrays() []*mlx.Array {
|
||||
return []*mlx.Array{r.Token, r.Logprob, r.TopTokens, r.TopLogprobs}
|
||||
}
|
||||
|
||||
// Distribution is the filtered probability distribution used by the sampler.
|
||||
// When IDs is nil, Probs is dense over the vocabulary. When IDs is set, Probs
|
||||
// is sparse over the token ids in IDs, preserving GPU residency for the
|
||||
// top-k-first path used by normal and speculative sampling.
|
||||
type Distribution struct {
|
||||
IDs *mlx.Array // sparse token ids, shape [B,K]; nil for dense distributions
|
||||
Probs *mlx.Array // probabilities, shape [B,K] or [B,V]
|
||||
}
|
||||
|
||||
// Arrays returns the tensor fields for mlx lifecycle management.
|
||||
func (d Distribution) Arrays() []*mlx.Array {
|
||||
return []*mlx.Array{d.IDs, d.Probs}
|
||||
}
|
||||
|
||||
// Rows returns the number of rows in the distribution.
|
||||
func (d Distribution) Rows() int {
|
||||
if d.Probs == nil {
|
||||
return 0
|
||||
}
|
||||
return d.Probs.Dim(0)
|
||||
}
|
||||
|
||||
// SliceRows returns a row slice while preserving sparse/dense layout.
|
||||
func (d Distribution) SliceRows(start, stop int) Distribution {
|
||||
out := Distribution{Probs: d.Probs.Slice(mlx.Slice(start, stop), mlx.Slice())}
|
||||
if d.IDs != nil {
|
||||
out.IDs = d.IDs.Slice(mlx.Slice(start, stop), mlx.Slice())
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// SampleWithKey draws one token per row using key when supplied.
|
||||
func (d Distribution) SampleWithKey(key *mlx.Array) *mlx.Array {
|
||||
choice := logitsFromProbs(d.Probs).CategoricalWithKey(-1, key).AsType(mlx.DTypeInt32)
|
||||
if d.IDs == nil {
|
||||
return choice
|
||||
}
|
||||
return d.IDs.TakeAlongAxis(choice.ExpandDims(-1), -1).Squeeze(-1).AsType(mlx.DTypeInt32)
|
||||
}
|
||||
|
||||
// Prob returns the probability assigned to one token per row.
|
||||
func (d Distribution) Prob(tokens *mlx.Array) *mlx.Array {
|
||||
switch tokens.NumDims() {
|
||||
case 2:
|
||||
if tokens.Dim(0) == 1 {
|
||||
tokens = tokens.Squeeze(0)
|
||||
} else if tokens.Dim(1) == 1 {
|
||||
tokens = tokens.Squeeze(1)
|
||||
}
|
||||
case 0:
|
||||
tokens = tokens.Reshape(1)
|
||||
}
|
||||
return d.ProbsForIDs(tokens.ExpandDims(-1)).Squeeze(-1)
|
||||
}
|
||||
|
||||
// ProbsForIDs returns probabilities for each requested token id. ids must be
|
||||
// rank-2 [B,N], matching the distribution rows.
|
||||
func (d Distribution) ProbsForIDs(ids *mlx.Array) *mlx.Array {
|
||||
if d.IDs == nil {
|
||||
return d.Probs.TakeAlongAxis(ids, -1)
|
||||
}
|
||||
eq := d.IDs.ExpandDims(-1).Equal(ids.ExpandDims(1))
|
||||
values := mlx.Where(eq, d.Probs.ExpandDims(-1), mlx.FromValue(float32(0)))
|
||||
return values.SumAxis(1, false)
|
||||
}
|
||||
|
||||
// ResidualAgainst returns the Leviathan/Chen rejection distribution
|
||||
// proportional to max(target - draft, 0). Sparse target distributions stay
|
||||
// sparse over the target support; tokens outside target support have zero mass.
|
||||
func (d Distribution) ResidualAgainst(draft Distribution) Distribution {
|
||||
if d.IDs != nil {
|
||||
diff := d.Probs.Subtract(draft.ProbsForIDs(d.IDs))
|
||||
return Distribution{IDs: d.IDs, Probs: normalizeProbs(mlx.Maximum(diff, mlx.FromValue(float32(0))))}
|
||||
}
|
||||
if draft.IDs != nil {
|
||||
panic("sample.Distribution.ResidualAgainst: dense target with sparse draft is unsupported")
|
||||
}
|
||||
diff := d.Probs.Subtract(draft.Probs)
|
||||
return Distribution{Probs: normalizeProbs(mlx.Maximum(diff, mlx.FromValue(float32(0))))}
|
||||
}
|
||||
|
||||
// LogProbs returns dense log-probabilities, scattering sparse distributions
|
||||
// into a full-vocabulary tensor when needed.
|
||||
func (d Distribution) LogProbs(vocab int) *mlx.Array {
|
||||
logProbs := logitsFromProbs(d.Probs)
|
||||
if d.IDs == nil {
|
||||
return logProbs
|
||||
}
|
||||
out := mlx.AddScalar(mlx.Zeros(mlx.DTypeFloat32, d.Probs.Dim(0), vocab), float32(math.Inf(-1)))
|
||||
return out.PutAlongAxis(d.IDs, logProbs, -1)
|
||||
}
|
||||
|
||||
// ConcatenateDistributions concatenates distribution rows. All inputs must use
|
||||
// the same sparse/dense layout.
|
||||
func ConcatenateDistributions(dists []Distribution) Distribution {
|
||||
if len(dists) == 0 {
|
||||
return Distribution{}
|
||||
}
|
||||
probs := make([]*mlx.Array, 0, len(dists))
|
||||
ids := make([]*mlx.Array, 0, len(dists))
|
||||
sparse := dists[0].IDs != nil
|
||||
for _, d := range dists {
|
||||
if (d.IDs != nil) != sparse {
|
||||
panic("sample.ConcatenateDistributions: mixed sparse and dense distributions")
|
||||
}
|
||||
probs = append(probs, d.Probs)
|
||||
if sparse {
|
||||
ids = append(ids, d.IDs)
|
||||
}
|
||||
}
|
||||
out := Distribution{Probs: mlx.Concatenate(probs, 0)}
|
||||
if sparse {
|
||||
out.IDs = mlx.Concatenate(ids, 0)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// Sampler is a batched, slot-based sampler. Sequences are registered with
|
||||
// Add and released with Remove. Each Sample call takes a subset of
|
||||
// registered slots (in any order) with their [B,V] logits, samples one
|
||||
// token per row, and appends it to that slot's ring-buffer history. Slots
|
||||
// not named in a given call are untouched.
|
||||
type Sampler struct {
|
||||
slots []*slotState
|
||||
byID map[int]*slotState
|
||||
|
||||
// history is the pooled ring-buffer storage, [B, W] int32. Row i
|
||||
// belongs to slots[i]; W is max(RepeatLastN) across penalty slots.
|
||||
// Allocated on the first penalty slot, rebuilt only in Add/Remove.
|
||||
history *mlx.Array
|
||||
|
||||
// allSameOpts: every registered slot shares Options. When true the
|
||||
// canonical shared value is s.slots[0].opts.
|
||||
allSameOpts bool
|
||||
|
||||
// anyLogprobs / maxTopLogprobs: compute-for-all output config.
|
||||
// Sample populates Logprob (and Top* when maxTopLogprobs>0) whenever
|
||||
// any registered slot requests them, even if that slot isn't in the
|
||||
// current call.
|
||||
anyLogprobs bool
|
||||
maxTopLogprobs int
|
||||
|
||||
// numCtx is the runner's context window; normalize uses it to
|
||||
// resolve the repeat_last_n == -1 sentinel.
|
||||
numCtx int
|
||||
}
|
||||
|
||||
type slotState struct {
|
||||
opts Options
|
||||
historyLen int
|
||||
randomCounter uint64
|
||||
}
|
||||
|
||||
type slotCtx struct {
|
||||
opts Options
|
||||
history *mlx.Array // 2D [B, W] when penalties are configured; nil otherwise
|
||||
}
|
||||
|
||||
// New constructs an empty sampler with no registered slots. numCtx is
|
||||
// the runner's context window and must be positive.
|
||||
func New(numCtx int) *Sampler {
|
||||
return &Sampler{
|
||||
byID: make(map[int]*slotState),
|
||||
allSameOpts: true,
|
||||
numCtx: numCtx,
|
||||
}
|
||||
}
|
||||
|
||||
// historyWidth returns the column count of the pooled history tensor,
|
||||
// or 0 when no penalty slot has forced it to be allocated.
|
||||
func (s *Sampler) historyWidth() int {
|
||||
if s.history == nil {
|
||||
return 0
|
||||
}
|
||||
return s.history.Dim(1)
|
||||
}
|
||||
|
||||
func (o Options) usesHistory() bool {
|
||||
// RepeatLastN == 0 disables the penalty ring per the repeat_last_n API
|
||||
// contract (0 = disabled), overriding any penalty coefficients.
|
||||
if o.RepeatLastN == 0 {
|
||||
return false
|
||||
}
|
||||
return o.RepeatPenalty != 1 || o.PresencePenalty != 0 || o.FrequencyPenalty != 0
|
||||
}
|
||||
|
||||
func (o Options) normalize(numCtx int) Options {
|
||||
if o.RepeatPenalty <= 0 {
|
||||
o.RepeatPenalty = 1
|
||||
}
|
||||
// Resolve the repeat_last_n == -1 sentinel ("-1 = num_ctx") against
|
||||
// the caller's context window.
|
||||
if o.RepeatLastN < 0 {
|
||||
o.RepeatLastN = numCtx
|
||||
}
|
||||
if !o.usesHistory() {
|
||||
// Zero the ring capacity so slots that differ only in a spurious
|
||||
// RepeatLastN still batch together and don't inflate pool width.
|
||||
o.RepeatLastN = 0
|
||||
}
|
||||
if o.Seed < 0 {
|
||||
o.UseSeed = false
|
||||
}
|
||||
if !o.UseSeed {
|
||||
// Keep unseeded callers on the same batching path even when a
|
||||
// meaningless Seed value is present in an Options literal.
|
||||
o.Seed = 0
|
||||
}
|
||||
return o
|
||||
}
|
||||
|
||||
// Add registers a sequence under seqID. The last RepeatLastN entries of
|
||||
// priorTokens seed the ring buffer.
|
||||
func (s *Sampler) Add(seqID int, opts Options, priorTokens []int32) {
|
||||
if _, dup := s.byID[seqID]; dup {
|
||||
panic(fmt.Sprintf("sample.Sampler.Add: seqID %d already registered", seqID))
|
||||
}
|
||||
|
||||
opts = opts.normalize(s.numCtx)
|
||||
slot := &slotState{
|
||||
opts: opts,
|
||||
}
|
||||
|
||||
// Grow the pool to hold this slot's row. The pool is lazy — the first
|
||||
// penalty slot allocates it — and thereafter every registered slot
|
||||
// gets a row (rows for non-penalty slots are zero and never read).
|
||||
// Invariant: s.history is pinned whenever non-nil.
|
||||
if s.history != nil || opts.usesHistory() {
|
||||
targetWidth := max(opts.RepeatLastN, s.historyWidth())
|
||||
newRow := makeHistoryRow(priorTokens, opts.RepeatLastN, targetWidth)
|
||||
|
||||
var pool *mlx.Array
|
||||
switch {
|
||||
case s.history == nil && len(s.slots) == 0:
|
||||
pool = newRow
|
||||
case s.history == nil:
|
||||
// First penalty slot with non-penalty slots already registered;
|
||||
// seed zero rows so s.slots and pool row indices stay aligned.
|
||||
zeros := mlx.Zeros(mlx.DTypeInt32, len(s.slots), targetWidth)
|
||||
pool = zeros.Concatenate(0, newRow)
|
||||
case targetWidth > s.historyWidth():
|
||||
pad := mlx.Zeros(mlx.DTypeInt32, s.history.Dim(0), targetWidth-s.historyWidth())
|
||||
pool = s.history.Concatenate(1, pad).Concatenate(0, newRow)
|
||||
default:
|
||||
pool = s.history.Concatenate(0, newRow)
|
||||
}
|
||||
|
||||
mlx.Pin(pool)
|
||||
mlx.Unpin(s.history)
|
||||
s.history = pool
|
||||
|
||||
if opts.usesHistory() {
|
||||
// Cap on seed so the next write's ring position
|
||||
// (historyLen % RepeatLastN) lands at 0, overwriting the
|
||||
// oldest entry when the ring was filled from priors.
|
||||
slot.historyLen = min(len(priorTokens), opts.RepeatLastN)
|
||||
}
|
||||
}
|
||||
|
||||
s.slots = append(s.slots, slot)
|
||||
s.byID[seqID] = slot
|
||||
s.recomputeInvariants()
|
||||
}
|
||||
|
||||
// makeHistoryRow builds a [1, width] int32 row with the last repeatLastN
|
||||
// entries of priorTokens packed into [0, min(len, repeatLastN)), zeros
|
||||
// elsewhere.
|
||||
func makeHistoryRow(priorTokens []int32, repeatLastN, width int) *mlx.Array {
|
||||
take := min(len(priorTokens), repeatLastN)
|
||||
if take <= 0 {
|
||||
return mlx.Zeros(mlx.DTypeInt32, 1, width)
|
||||
}
|
||||
row := make([]int32, width)
|
||||
copy(row, priorTokens[len(priorTokens)-take:])
|
||||
return mlx.NewArrayInt32(row, []int32{1, int32(width)})
|
||||
}
|
||||
|
||||
// recomputeInvariants refreshes allSameOpts and anyLogprobs/maxTopLogprobs
|
||||
// from s.slots. Called at the end of Add and Remove.
|
||||
func (s *Sampler) recomputeInvariants() {
|
||||
if len(s.slots) == 0 {
|
||||
s.allSameOpts = true
|
||||
s.anyLogprobs = false
|
||||
s.maxTopLogprobs = 0
|
||||
return
|
||||
}
|
||||
first := s.slots[0].opts
|
||||
s.allSameOpts = true
|
||||
s.anyLogprobs = false
|
||||
s.maxTopLogprobs = 0
|
||||
for _, slot := range s.slots {
|
||||
if slot.opts != first {
|
||||
s.allSameOpts = false
|
||||
}
|
||||
if slot.opts.Logprobs {
|
||||
s.anyLogprobs = true
|
||||
if slot.opts.TopLogprobs > s.maxTopLogprobs {
|
||||
s.maxTopLogprobs = slot.opts.TopLogprobs
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Remove releases the slot. The pool tensor is rebuilt to drop the row.
|
||||
func (s *Sampler) Remove(seqID int) {
|
||||
slot, ok := s.byID[seqID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
delete(s.byID, seqID)
|
||||
|
||||
row := slices.Index(s.slots, slot)
|
||||
s.slots = slices.Delete(s.slots, row, row+1)
|
||||
s.recomputeInvariants()
|
||||
|
||||
if s.history == nil {
|
||||
return
|
||||
}
|
||||
|
||||
n := s.history.Dim(0)
|
||||
var newHistory *mlx.Array
|
||||
switch {
|
||||
case n == 1:
|
||||
newHistory = nil
|
||||
case row == 0:
|
||||
newHistory = s.history.Slice(mlx.Slice(1, n), mlx.Slice())
|
||||
case row == n-1:
|
||||
newHistory = s.history.Slice(mlx.Slice(0, row), mlx.Slice())
|
||||
default:
|
||||
before := s.history.Slice(mlx.Slice(0, row), mlx.Slice())
|
||||
after := s.history.Slice(mlx.Slice(row+1, n), mlx.Slice())
|
||||
newHistory = before.Concatenate(0, after)
|
||||
}
|
||||
|
||||
mlx.Pin(newHistory)
|
||||
mlx.Unpin(s.history)
|
||||
s.history = newHistory
|
||||
}
|
||||
|
||||
// Free releases the pooled history tensor and resets the sampler to the
|
||||
// New-equivalent state so it may be reused.
|
||||
func (s *Sampler) Free() {
|
||||
mlx.Unpin(s.history)
|
||||
*s = Sampler{
|
||||
byID: make(map[int]*slotState),
|
||||
allSameOpts: true,
|
||||
numCtx: s.numCtx,
|
||||
}
|
||||
}
|
||||
|
||||
// Sample draws one token per row of logits ([B,V]); seqIDs[i] names the
|
||||
// slot whose logits live at row i. Each sampled token is appended to its
|
||||
// slot's ring. Slots not named in seqIDs are untouched.
|
||||
func (s *Sampler) Sample(seqIDs []int, logits *mlx.Array) Result {
|
||||
if len(seqIDs) == 0 {
|
||||
return Result{}
|
||||
}
|
||||
|
||||
slots := make([]*slotState, len(seqIDs))
|
||||
for i, id := range seqIDs {
|
||||
slot, ok := s.byID[id]
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("sample.Sampler.Sample: seqID %d not registered", id))
|
||||
}
|
||||
slots[i] = slot
|
||||
}
|
||||
|
||||
var token *mlx.Array
|
||||
if opts0, ok := s.canBatch(slots); ok {
|
||||
token = s.sampleTokensUniform(slots, opts0, logits)
|
||||
} else {
|
||||
token = s.sampleTokensSerial(slots, logits)
|
||||
}
|
||||
|
||||
res := Result{Token: token}
|
||||
if s.anyLogprobs {
|
||||
// Log-softmax over original logits so every row holds a truthful
|
||||
// value (compute-for-all; consumers filter per-slot). Subtract
|
||||
// max first for numerical stability in the logsumexp.
|
||||
lp := logits.AsType(mlx.DTypeFloat32)
|
||||
lp = lp.Subtract(lp.MaxAxis(-1, true))
|
||||
lp = lp.Subtract(lp.LogsumexpAxis(-1, true))
|
||||
res.Logprob = lp.TakeAlongAxis(token.ExpandDims(-1), -1)
|
||||
if s.maxTopLogprobs > 0 {
|
||||
k := s.maxTopLogprobs
|
||||
if vocab := lp.Dim(lp.NumDims() - 1); k > vocab {
|
||||
k = vocab
|
||||
}
|
||||
// Argpartition on the negated values places the K largest
|
||||
// (unsorted) in positions [0:K].
|
||||
idx := lp.Negative().ArgpartitionAxis(k-1, -1).Slice(mlx.Slice(), mlx.Slice(0, k))
|
||||
res.TopTokens = idx.AsType(mlx.DTypeInt32)
|
||||
res.TopLogprobs = lp.TakeAlongAxis(idx, -1)
|
||||
}
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
// Distribution applies this slot's sampling transforms to logits without
|
||||
// mutating sampler state. Row i is built as if draftTokens[:i] had already
|
||||
// been appended to the slot history. logits must be [R,V] or [1,R,V].
|
||||
func (s *Sampler) Distribution(seqID int, logits *mlx.Array, draftTokens *mlx.Array) Distribution {
|
||||
slot, logits, draftTokens := s.speculativeInputs("Distribution", seqID, logits, draftTokens)
|
||||
rows := logits.Dim(0)
|
||||
|
||||
var hist *mlx.Array
|
||||
if slot.opts.usesHistory() {
|
||||
if s.history == nil {
|
||||
panic(fmt.Sprintf("sample.Sampler.Distribution: seqID %d has no history", seqID))
|
||||
}
|
||||
if slot.historyLen < slot.opts.RepeatLastN {
|
||||
return s.speculativeDistributionSerial(slot, logits, draftTokens)
|
||||
}
|
||||
hist = s.speculativeHistory(slot, draftTokens, rows)
|
||||
}
|
||||
|
||||
return slot.distribution(&slotCtx{opts: slot.opts, history: hist}, logits)
|
||||
}
|
||||
|
||||
// SpeculativeScores applies this slot's sampling transforms to logits without
|
||||
// mutating sampler state and returns dense log-probability scores for sampled
|
||||
// decoding. Greedy decoding returns the penalty-adjusted logits.
|
||||
func (s *Sampler) SpeculativeScores(seqID int, logits *mlx.Array, draftTokens *mlx.Array) *mlx.Array {
|
||||
slot, logits, draftTokens := s.speculativeInputs("SpeculativeScores", seqID, logits, draftTokens)
|
||||
rows := logits.Dim(0)
|
||||
|
||||
var hist *mlx.Array
|
||||
if slot.opts.usesHistory() {
|
||||
if s.history == nil {
|
||||
panic(fmt.Sprintf("sample.Sampler.SpeculativeScores: seqID %d has no history", seqID))
|
||||
}
|
||||
if slot.historyLen < slot.opts.RepeatLastN {
|
||||
return s.speculativeScoresSerial(slot, logits, draftTokens)
|
||||
}
|
||||
hist = s.speculativeHistory(slot, draftTokens, rows)
|
||||
}
|
||||
|
||||
return slot.speculativeScores(&slotCtx{opts: slot.opts, history: hist}, logits)
|
||||
}
|
||||
|
||||
// SampleDistribution draws from a precomputed distribution while advancing
|
||||
// seqID's deterministic RNG stream when a seed is configured.
|
||||
func (s *Sampler) SampleDistribution(seqID int, dist Distribution) *mlx.Array {
|
||||
slot := s.mustSlot("SampleDistribution", seqID)
|
||||
return dist.SampleWithKey(slot.nextRandomKey())
|
||||
}
|
||||
|
||||
// Bernoulli samples boolean outcomes while advancing seqID's deterministic RNG
|
||||
// stream when a seed is configured.
|
||||
func (s *Sampler) Bernoulli(seqID int, p *mlx.Array) *mlx.Array {
|
||||
slot := s.mustSlot("Bernoulli", seqID)
|
||||
return mlx.BernoulliWithKey(p, slot.nextRandomKey())
|
||||
}
|
||||
|
||||
func (s *Sampler) mustSlot(caller string, seqID int) *slotState {
|
||||
slot, ok := s.byID[seqID]
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("sample.Sampler.%s: seqID %d not registered", caller, seqID))
|
||||
}
|
||||
return slot
|
||||
}
|
||||
|
||||
func (s *Sampler) speculativeInputs(caller string, seqID int, logits *mlx.Array, draftTokens *mlx.Array) (*slotState, *mlx.Array, *mlx.Array) {
|
||||
slot := s.mustSlot(caller, seqID)
|
||||
|
||||
if logits.NumDims() == 3 {
|
||||
if logits.Dim(0) != 1 {
|
||||
panic(fmt.Sprintf("sample.Sampler.%s: only batch size 1 is supported", caller))
|
||||
}
|
||||
logits = logits.Squeeze(0)
|
||||
}
|
||||
if logits.NumDims() != 2 {
|
||||
panic(fmt.Sprintf("sample.Sampler.%s: logits must be rank 2 or 3, got rank %d", caller, logits.NumDims()))
|
||||
}
|
||||
|
||||
if draftTokens != nil && draftTokens.NumDims() == 1 {
|
||||
draftTokens = draftTokens.ExpandDims(0)
|
||||
}
|
||||
return slot, logits, draftTokens
|
||||
}
|
||||
|
||||
// Commit appends already-selected tokens to seqID's repeat-penalty history.
|
||||
// It is used after speculative sampling once the accepted continuation is
|
||||
// known. Normal Sample calls continue to mutate history themselves.
|
||||
func (s *Sampler) Commit(seqID int, tokens []int32) {
|
||||
if len(tokens) == 0 {
|
||||
return
|
||||
}
|
||||
slot, ok := s.byID[seqID]
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("sample.Sampler.Commit: seqID %d not registered", seqID))
|
||||
}
|
||||
if !slot.opts.usesHistory() {
|
||||
return
|
||||
}
|
||||
if s.history == nil {
|
||||
panic(fmt.Sprintf("sample.Sampler.Commit: seqID %d has no history", seqID))
|
||||
}
|
||||
|
||||
row := slices.Index(s.slots, slot)
|
||||
width := s.historyWidth()
|
||||
take := min(len(tokens), slot.opts.RepeatLastN)
|
||||
startLen := slot.historyLen + len(tokens) - take
|
||||
writeTokens := tokens[len(tokens)-take:]
|
||||
flatOffsets := make([]int32, take)
|
||||
for i := range take {
|
||||
ringPos := (startLen + i) % slot.opts.RepeatLastN
|
||||
flatOffsets[i] = int32(row*width + ringPos)
|
||||
}
|
||||
|
||||
flatIdx := mlx.NewArrayInt32(flatOffsets, []int32{int32(take), 1})
|
||||
values := mlx.NewArrayInt32(writeTokens, []int32{int32(take), 1})
|
||||
flatHist := s.history.Reshape(s.history.Dim(0)*width, 1)
|
||||
s.history.Set(flatHist.PutAlongAxis(flatIdx, values, 0).Reshape(s.history.Dim(0), width))
|
||||
slot.historyLen += len(tokens)
|
||||
}
|
||||
|
||||
func (s *Sampler) speculativeDistributionSerial(slot *slotState, logits *mlx.Array, draftTokens *mlx.Array) Distribution {
|
||||
rows := logits.Dim(0)
|
||||
draftCount := 0
|
||||
if draftTokens != nil {
|
||||
draftCount = draftTokens.Dim(1)
|
||||
}
|
||||
row := slices.Index(s.slots, slot)
|
||||
baseFill := min(slot.historyLen, slot.opts.RepeatLastN)
|
||||
var base *mlx.Array
|
||||
if baseFill > 0 {
|
||||
base = s.history.Slice(mlx.Slice(row, row+1), mlx.Slice(0, baseFill))
|
||||
}
|
||||
|
||||
dists := make([]Distribution, 0, rows)
|
||||
for i := range rows {
|
||||
rowLogits := logits.Slice(mlx.Slice(i, i+1), mlx.Slice())
|
||||
hist := base
|
||||
prefixLen := min(i, draftCount)
|
||||
if prefixLen > 0 {
|
||||
prefix := draftTokens.Slice(mlx.Slice(), mlx.Slice(0, prefixLen))
|
||||
if hist == nil {
|
||||
hist = prefix
|
||||
} else {
|
||||
hist = hist.Concatenate(1, prefix)
|
||||
}
|
||||
if hist.Dim(1) > slot.opts.RepeatLastN {
|
||||
hist = hist.Slice(mlx.Slice(), mlx.Slice(hist.Dim(1)-slot.opts.RepeatLastN, mlx.End))
|
||||
}
|
||||
}
|
||||
dists = append(dists, slot.distribution(&slotCtx{opts: slot.opts, history: hist}, rowLogits))
|
||||
}
|
||||
return ConcatenateDistributions(dists)
|
||||
}
|
||||
|
||||
func (s *Sampler) speculativeScoresSerial(slot *slotState, logits *mlx.Array, draftTokens *mlx.Array) *mlx.Array {
|
||||
return s.speculativeDistributionSerial(slot, logits, draftTokens).LogProbs(logits.Dim(logits.NumDims() - 1))
|
||||
}
|
||||
|
||||
func (s *Sampler) speculativeHistory(slot *slotState, draftTokens *mlx.Array, rows int) *mlx.Array {
|
||||
row := slices.Index(s.slots, slot)
|
||||
width := slot.opts.RepeatLastN
|
||||
base := s.history.Slice(mlx.Slice(row, row+1), mlx.Slice(0, width))
|
||||
base = mlx.Tile(base, []int32{int32(rows), 1})
|
||||
next := slot.historyLen % width
|
||||
draftCount := 0
|
||||
if draftTokens != nil {
|
||||
draftCount = draftTokens.Dim(1)
|
||||
}
|
||||
if draftCount == 0 {
|
||||
return base
|
||||
}
|
||||
|
||||
sourceIdx := make([]int32, rows*width)
|
||||
writeMask := make([]bool, rows*width)
|
||||
for i := range rows {
|
||||
prefixLen := min(i, draftCount)
|
||||
for j := range prefixLen {
|
||||
pos := (next + j) % width
|
||||
sourceIdx[i*width+pos] = int32(j)
|
||||
writeMask[i*width+pos] = true
|
||||
}
|
||||
}
|
||||
|
||||
draftRows := mlx.Tile(draftTokens, []int32{int32(rows), 1})
|
||||
idx := mlx.NewArrayInt32(sourceIdx, []int32{int32(rows), int32(width)})
|
||||
mask := mlx.FromValues(writeMask, rows, width)
|
||||
values := draftRows.TakeAlongAxis(idx, 1)
|
||||
return mlx.Where(mask, values, base)
|
||||
}
|
||||
|
||||
func (slot *slotState) speculativeScores(ctx *slotCtx, logits *mlx.Array) *mlx.Array {
|
||||
if slot.opts.Temperature == 0 {
|
||||
return slot.baseScores(ctx, logits)
|
||||
}
|
||||
return slot.distribution(ctx, logits).LogProbs(logits.Dim(logits.NumDims() - 1))
|
||||
}
|
||||
|
||||
// canBatch reports whether the call can take the uniform batched path.
|
||||
// All slots must share Options; when penalties are active the call must
|
||||
// additionally cover every registered slot in registration order with a
|
||||
// full ring, because the uniform path indexes the pool positionally.
|
||||
func (s *Sampler) canBatch(slots []*slotState) (Options, bool) {
|
||||
if !s.allSameOpts {
|
||||
return Options{}, false
|
||||
}
|
||||
// slots is non-empty (Sample guards) and every slot is registered,
|
||||
// so s.slots[0].opts is the canonical shared value.
|
||||
shared := s.slots[0].opts
|
||||
// TODO(pdevine): Before using multi-slot batching with seeded stochastic sampling,
|
||||
// make sure each row gets its own per-slot random key instead of sharing
|
||||
// slots[0]'s key through one batched categorical op.
|
||||
if !shared.usesHistory() {
|
||||
return shared, true
|
||||
}
|
||||
if len(slots) != len(s.slots) {
|
||||
return Options{}, false
|
||||
}
|
||||
for i, slot := range slots {
|
||||
if s.slots[i] != slot || slot.historyLen < shared.RepeatLastN {
|
||||
return Options{}, false
|
||||
}
|
||||
}
|
||||
return shared, true
|
||||
}
|
||||
|
||||
// sampleTokensUniform runs one fused sampling pass over the whole batch.
|
||||
// Reached only when canBatch is true, which lets the pool be used in place
|
||||
// with a single PutAlongAxis write-back and no gather.
|
||||
func (s *Sampler) sampleTokensUniform(slots []*slotState, opts Options, logits *mlx.Array) *mlx.Array {
|
||||
B := len(slots)
|
||||
|
||||
var hist *mlx.Array
|
||||
if opts.usesHistory() {
|
||||
hist = s.history
|
||||
if s.historyWidth() > opts.RepeatLastN {
|
||||
hist = hist.Slice(mlx.Slice(), mlx.Slice(0, opts.RepeatLastN))
|
||||
}
|
||||
}
|
||||
|
||||
ctx := &slotCtx{opts: opts, history: hist}
|
||||
token := slots[0].sample(ctx, logits)
|
||||
if opts.UseSeed && opts.Temperature != 0 {
|
||||
// TODO: This only keeps counters aligned; it does not give each slot
|
||||
// an independent key for the batched draw.
|
||||
for _, slot := range slots[1:] {
|
||||
slot.randomCounter++
|
||||
}
|
||||
}
|
||||
|
||||
if !opts.usesHistory() {
|
||||
return token
|
||||
}
|
||||
|
||||
writeIdxData := make([]int32, B)
|
||||
for i, slot := range slots {
|
||||
writeIdxData[i] = int32(slot.historyLen % opts.RepeatLastN)
|
||||
slot.historyLen++
|
||||
}
|
||||
writeIdx := mlx.NewArrayInt32(writeIdxData, []int32{int32(B), 1})
|
||||
|
||||
s.history.Set(s.history.PutAlongAxis(writeIdx, token.ExpandDims(-1), 1))
|
||||
return token
|
||||
}
|
||||
|
||||
// sampleTokensSerial samples each slot against its own row of logits.
|
||||
func (s *Sampler) sampleTokensSerial(slots []*slotState, logits *mlx.Array) *mlx.Array {
|
||||
perSlotTokens := make([]*mlx.Array, len(slots))
|
||||
|
||||
rowOf := make(map[*slotState]int, len(s.slots))
|
||||
for i, slot := range s.slots {
|
||||
rowOf[slot] = i
|
||||
}
|
||||
|
||||
for i, slot := range slots {
|
||||
row := logits.Slice(mlx.Slice(i, i+1), mlx.Slice())
|
||||
|
||||
var hist *mlx.Array
|
||||
if slot.opts.usesHistory() && slot.historyLen > 0 && s.history != nil {
|
||||
poolRow := rowOf[slot]
|
||||
fill := min(slot.historyLen, slot.opts.RepeatLastN)
|
||||
hist = s.history.Slice(
|
||||
mlx.Slice(poolRow, poolRow+1),
|
||||
mlx.Slice(0, fill),
|
||||
)
|
||||
}
|
||||
|
||||
ctx := &slotCtx{opts: slot.opts, history: hist}
|
||||
perSlotTokens[i] = slot.sample(ctx, row)
|
||||
}
|
||||
|
||||
token := mlx.Concatenate(perSlotTokens, 0)
|
||||
|
||||
if s.history != nil {
|
||||
// For each writing slot collect its flat (row-major) pool offset
|
||||
// and the call-order position of its token. One PutAlongAxis on a
|
||||
// flat view of the pool scatters all writes in a single op.
|
||||
flatOffsets := make([]int32, 0, len(slots))
|
||||
tokenPos := make([]int32, 0, len(slots))
|
||||
for i, slot := range slots {
|
||||
if !slot.opts.usesHistory() {
|
||||
continue
|
||||
}
|
||||
ringPos := slot.historyLen % slot.opts.RepeatLastN
|
||||
flatOffsets = append(flatOffsets, int32(rowOf[slot]*s.historyWidth()+ringPos))
|
||||
tokenPos = append(tokenPos, int32(i))
|
||||
slot.historyLen++
|
||||
}
|
||||
|
||||
if len(flatOffsets) > 0 {
|
||||
m := len(flatOffsets)
|
||||
flatIdx := mlx.NewArrayInt32(flatOffsets, []int32{int32(m), 1})
|
||||
writingTokens := token
|
||||
if m != len(slots) {
|
||||
tokenPosIdx := mlx.NewArrayInt32(tokenPos, []int32{int32(m)})
|
||||
writingTokens = token.TakeAxis(tokenPosIdx, 0)
|
||||
}
|
||||
flatHist := s.history.Reshape(s.history.Dim(0)*s.historyWidth(), 1)
|
||||
s.history.Set(flatHist.PutAlongAxis(flatIdx, writingTokens.ExpandDims(-1), 0).Reshape(s.history.Dim(0), s.historyWidth()))
|
||||
}
|
||||
}
|
||||
return token
|
||||
}
|
||||
|
||||
func (slot *slotState) sample(ctx *slotCtx, logits *mlx.Array) *mlx.Array {
|
||||
if slot.opts.Temperature == 0 {
|
||||
return slot.baseScores(ctx, logits).Argmax(-1, false).AsType(mlx.DTypeInt32)
|
||||
}
|
||||
return slot.distribution(ctx, logits).SampleWithKey(slot.nextRandomKey())
|
||||
}
|
||||
|
||||
func (slot *slotState) nextRandomKey() *mlx.Array {
|
||||
if !slot.opts.UseSeed {
|
||||
return nil
|
||||
}
|
||||
seed := mixSeed(uint64(slot.opts.Seed), slot.randomCounter)
|
||||
slot.randomCounter++
|
||||
return mlx.RandomKey(seed)
|
||||
}
|
||||
|
||||
const (
|
||||
// SplitMix64 constants used to decorrelate nearby (seed, counter) pairs.
|
||||
splitMix64Weyl = 0x9e3779b97f4a7c15
|
||||
splitMix64Mul1 = 0xbf58476d1ce4e5b9
|
||||
splitMix64Mul2 = 0x94d049bb133111eb
|
||||
splitMix64Shift1 = 30
|
||||
splitMix64Shift2 = 27
|
||||
splitMix64FinalShift = 31
|
||||
)
|
||||
|
||||
func mixSeed(seed, counter uint64) uint64 {
|
||||
z := seed + splitMix64Weyl*(counter+1)
|
||||
z = (z ^ (z >> splitMix64Shift1)) * splitMix64Mul1
|
||||
z = (z ^ (z >> splitMix64Shift2)) * splitMix64Mul2
|
||||
return z ^ (z >> splitMix64FinalShift)
|
||||
}
|
||||
|
||||
func (slot *slotState) baseScores(ctx *slotCtx, logits *mlx.Array) *mlx.Array {
|
||||
scores := logits
|
||||
if slot.opts.usesHistory() {
|
||||
scores = penalty(ctx, scores)
|
||||
}
|
||||
return scores
|
||||
}
|
||||
|
||||
func (slot *slotState) distribution(ctx *slotCtx, logits *mlx.Array) Distribution {
|
||||
scores := slot.baseScores(ctx, logits)
|
||||
if slot.opts.Temperature <= 0 {
|
||||
ids := scores.Argmax(-1, false).AsType(mlx.DTypeInt32).ExpandDims(-1)
|
||||
probs := mlx.AddScalar(ids.AsType(mlx.DTypeFloat32).Multiply(mlx.FromValue(float32(0))), 1)
|
||||
return Distribution{IDs: ids, Probs: probs}
|
||||
}
|
||||
|
||||
vocab := scores.Dim(scores.NumDims() - 1)
|
||||
if slot.opts.TopK > 0 && slot.opts.TopK < vocab {
|
||||
return sparseDistribution(ctx.opts, scores)
|
||||
}
|
||||
return denseDistribution(ctx.opts, scores)
|
||||
}
|
||||
|
||||
func sparseDistribution(opts Options, scores *mlx.Array) Distribution {
|
||||
ids := scores.Negative().ArgpartitionAxis(opts.TopK-1, -1).Slice(mlx.Slice(), mlx.Slice(0, opts.TopK)).AsType(mlx.DTypeInt32)
|
||||
topScores := scores.TakeAlongAxis(ids, -1).AsType(mlx.DTypeFloat32)
|
||||
probs := mlx.SoftmaxAxis(mlx.DivScalar(topScores, opts.Temperature), -1, true)
|
||||
probs = applyTopPProbs(probs, opts.TopP)
|
||||
probs = applyMinPProbs(probs, opts.MinP)
|
||||
return Distribution{IDs: ids, Probs: normalizeProbs(probs)}
|
||||
}
|
||||
|
||||
func denseDistribution(opts Options, scores *mlx.Array) Distribution {
|
||||
probs := mlx.SoftmaxAxis(mlx.DivScalar(scores.AsType(mlx.DTypeFloat32), opts.Temperature), -1, true)
|
||||
probs = applyTopPProbs(probs, opts.TopP)
|
||||
probs = applyMinPProbs(probs, opts.MinP)
|
||||
return Distribution{Probs: normalizeProbs(probs)}
|
||||
}
|
||||
|
||||
func applyTopPProbs(probs *mlx.Array, topP float32) *mlx.Array {
|
||||
if topP <= 0 || topP >= 1 {
|
||||
return probs
|
||||
}
|
||||
order := probs.Negative().ArgsortAxis(-1)
|
||||
sorted := probs.TakeAlongAxis(order, -1)
|
||||
prevCumProbs := sorted.Cumsum(-1, false, true).Subtract(sorted)
|
||||
keep := prevCumProbs.Less(mlx.FromValue(topP))
|
||||
filtered := mlx.Where(keep, sorted, mlx.FromValue(float32(0)))
|
||||
return mlx.Zeros(probs.DType(), probs.Dims()...).PutAlongAxis(order, filtered, -1)
|
||||
}
|
||||
|
||||
func applyMinPProbs(probs *mlx.Array, minP float32) *mlx.Array {
|
||||
if minP <= 0 || minP > 1 {
|
||||
return probs
|
||||
}
|
||||
threshold := mlx.MulScalar(probs.MaxAxis(-1, true), minP)
|
||||
return mlx.Where(probs.Less(threshold), mlx.FromValue(float32(0)), probs)
|
||||
}
|
||||
|
||||
func normalizeProbs(probs *mlx.Array) *mlx.Array {
|
||||
sum := mlx.Maximum(probs.SumAxis(-1, true), mlx.FromValue(float32(1e-20)))
|
||||
return probs.Divide(sum)
|
||||
}
|
||||
|
||||
func logitsFromProbs(probs *mlx.Array) *mlx.Array {
|
||||
positive := mlx.Maximum(probs, mlx.FromValue(float32(1e-20)))
|
||||
logits := mlx.Log(positive)
|
||||
return mlx.Where(probs.LessEqual(mlx.FromValue(float32(0))), mlx.FromValue(float32(math.Inf(-1))), logits)
|
||||
}
|
||||
|
||||
func penalty(ctx *slotCtx, scores *mlx.Array) *mlx.Array {
|
||||
tokenIndices := ctx.history
|
||||
if tokenIndices == nil {
|
||||
return scores
|
||||
}
|
||||
|
||||
if ctx.opts.RepeatPenalty != 1 || ctx.opts.PresencePenalty != 0 {
|
||||
adjusted := scores.TakeAlongAxis(tokenIndices, -1)
|
||||
if ctx.opts.RepeatPenalty != 1 {
|
||||
factor := mlx.Where(
|
||||
adjusted.Less(mlx.FromValue(float32(0))),
|
||||
mlx.FromValue(ctx.opts.RepeatPenalty),
|
||||
mlx.FromValue(1/ctx.opts.RepeatPenalty),
|
||||
)
|
||||
adjusted = adjusted.Multiply(factor)
|
||||
}
|
||||
if ctx.opts.PresencePenalty != 0 {
|
||||
adjusted = mlx.AddScalar(adjusted, -ctx.opts.PresencePenalty)
|
||||
}
|
||||
scores = scores.PutAlongAxis(tokenIndices, adjusted, -1)
|
||||
}
|
||||
|
||||
if ctx.opts.FrequencyPenalty != 0 {
|
||||
scores = scores.ScatterAddAxis(tokenIndices, mlx.FromValue(-ctx.opts.FrequencyPenalty), -1)
|
||||
}
|
||||
|
||||
return scores
|
||||
}
|
||||
492
x/mlxrunner/sample/sample_test.go
Normal file
492
x/mlxrunner/sample/sample_test.go
Normal file
@@ -0,0 +1,492 @@
|
||||
//go:build mlx
|
||||
|
||||
package sample
|
||||
|
||||
import (
|
||||
"math"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
func skipIfNoMLX(t *testing.T) {
|
||||
t.Helper()
|
||||
if err := mlx.CheckInit(); err != nil {
|
||||
t.Skipf("MLX not available: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// slotLogits builds a [1, V] logits tensor for a single-slot Sample call.
|
||||
func slotLogits(values []float32) *mlx.Array {
|
||||
return mlx.FromValues(values, 1, len(values))
|
||||
}
|
||||
|
||||
// batchLogits stacks per-row float32 slices of equal length into a [B, V]
|
||||
// logits tensor.
|
||||
func batchLogits(rows ...[]float32) *mlx.Array {
|
||||
v := len(rows[0])
|
||||
flat := make([]float32, 0, len(rows)*v)
|
||||
for _, r := range rows {
|
||||
if len(r) != v {
|
||||
panic("batchLogits: rows must share vocab size")
|
||||
}
|
||||
flat = append(flat, r...)
|
||||
}
|
||||
return mlx.FromValues(flat, len(rows), v)
|
||||
}
|
||||
|
||||
// sampleOne runs Sample on a freshly-added single slot and returns the
|
||||
// sampled token id. Used both for the single-slot options table and as the
|
||||
// reference oracle for the batched-equivalence test.
|
||||
func sampleOne(t *testing.T, opts Options, priorTokens []int32, values []float32) int {
|
||||
t.Helper()
|
||||
s := New(128)
|
||||
t.Cleanup(func() {
|
||||
s.Free()
|
||||
mlx.Sweep()
|
||||
})
|
||||
s.Add(0, opts, priorTokens)
|
||||
|
||||
got := s.Sample([]int{0}, slotLogits(values)).Token
|
||||
mlx.Eval(got)
|
||||
return got.Int()
|
||||
}
|
||||
|
||||
// logOf returns log(p) as a float32 so tests can build logits that softmax to
|
||||
// a chosen probability distribution.
|
||||
func logOf(p float64) float32 { return float32(math.Log(p)) }
|
||||
|
||||
// TestSampleSingleSlotOptions pins the per-slot behavior of each Options
|
||||
// knob against a concrete expected token. Expected values are worked out by
|
||||
// hand from the math of each transform, not from a second call into the
|
||||
// sampler — so a regression in any single transform shows up here.
|
||||
func TestSampleSingleSlotOptions(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
opts Options
|
||||
priors []int32
|
||||
logits []float32
|
||||
want int
|
||||
}{
|
||||
{
|
||||
name: "presence penalty",
|
||||
opts: Options{RepeatLastN: 1, PresencePenalty: 6},
|
||||
priors: []int32{1},
|
||||
logits: []float32{0, 5, 4},
|
||||
want: 2, // token 1: 5 - 6 = -1, argmax shifts to 2
|
||||
},
|
||||
{
|
||||
name: "repeat penalty on positive logits",
|
||||
opts: Options{RepeatLastN: 1, RepeatPenalty: 2},
|
||||
priors: []int32{1},
|
||||
logits: []float32{0, 5, 4},
|
||||
want: 2, // token 1 positive → divided: 5/2 = 2.5, argmax shifts to 2
|
||||
},
|
||||
{
|
||||
name: "repeat penalty on negative logits",
|
||||
opts: Options{RepeatLastN: 1, RepeatPenalty: 4},
|
||||
priors: []int32{1},
|
||||
logits: []float32{-5, -1, -3},
|
||||
want: 2, // token 1 negative → multiplied: -1*4 = -4, argmax shifts to 2
|
||||
},
|
||||
{
|
||||
name: "frequency penalty",
|
||||
opts: Options{RepeatLastN: 4, FrequencyPenalty: 2},
|
||||
priors: []int32{1, 1},
|
||||
logits: []float32{0, 5, 4},
|
||||
want: 2, // 5 - 2*count(1)=2*2=4 → 1, argmax shifts to 2
|
||||
},
|
||||
{
|
||||
name: "top-k",
|
||||
opts: Options{Temperature: 1, TopK: 1},
|
||||
logits: []float32{1, 5, 4},
|
||||
want: 1, // only argmax survives → deterministic even with temperature
|
||||
},
|
||||
{
|
||||
name: "top-p",
|
||||
opts: Options{Temperature: 1, TopP: 0.4},
|
||||
logits: []float32{logOf(0.5), logOf(0.3), logOf(0.2)},
|
||||
want: 0, // exclusive cumsum below 0.4 keeps only token 0
|
||||
},
|
||||
{
|
||||
name: "min-p",
|
||||
opts: Options{Temperature: 1, MinP: 0.7},
|
||||
logits: []float32{logOf(0.5), logOf(0.3), logOf(0.2)},
|
||||
want: 0, // threshold 0.5*0.7=0.35 drops all but the top token
|
||||
},
|
||||
{
|
||||
name: "RepeatLastN=0 disables penalties",
|
||||
opts: Options{RepeatLastN: 0, RepeatPenalty: 2, PresencePenalty: 10},
|
||||
priors: []int32{1},
|
||||
logits: []float32{0, 5, 4},
|
||||
want: 1, // 0 = disabled per API contract, argmax unchanged
|
||||
},
|
||||
{
|
||||
name: "RepeatLastN=-1 resolves to num_ctx",
|
||||
opts: Options{RepeatLastN: -1, PresencePenalty: 6},
|
||||
priors: []int32{1},
|
||||
logits: []float32{0, 5, 4},
|
||||
want: 2, // -1 → num_ctx (128); penalty applies, argmax shifts
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if got := sampleOne(t, tc.opts, tc.priors, tc.logits); got != tc.want {
|
||||
t.Errorf("got %d, want %d", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDistributionAppliesTopKBeforeTopP(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
s := New(128)
|
||||
t.Cleanup(func() {
|
||||
s.Free()
|
||||
mlx.Sweep()
|
||||
})
|
||||
s.Add(0, Options{Temperature: 1, TopK: 2, TopP: 0.7}, nil)
|
||||
|
||||
dist := s.Distribution(0, slotLogits([]float32{logOf(0.6), logOf(0.2), logOf(0.2)}), nil)
|
||||
mlx.Eval(dist.Arrays()...)
|
||||
|
||||
ids := dist.IDs.Ints()
|
||||
probs := dist.Probs.Floats()
|
||||
if len(ids) != 2 || len(probs) != 2 {
|
||||
t.Fatalf("support = ids %v probs %v, want 2 sparse entries", ids, probs)
|
||||
}
|
||||
|
||||
foundTop := false
|
||||
for i, id := range ids {
|
||||
switch id {
|
||||
case 0:
|
||||
foundTop = true
|
||||
if math.Abs(float64(probs[i]-1)) > 1e-5 {
|
||||
t.Fatalf("top token prob = %v, want 1; ids=%v probs=%v", probs[i], ids, probs)
|
||||
}
|
||||
default:
|
||||
if math.Abs(float64(probs[i])) > 1e-5 {
|
||||
t.Fatalf("non-top token %d prob = %v, want 0; ids=%v probs=%v", id, probs[i], ids, probs)
|
||||
}
|
||||
}
|
||||
}
|
||||
if !foundTop {
|
||||
t.Fatalf("top-k support %v did not include token 0", ids)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDistributionResidualUsesTargetSupport(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
target := Distribution{
|
||||
IDs: mlx.NewArrayInt32([]int32{2, 5}, []int32{1, 2}),
|
||||
Probs: mlx.FromValues([]float32{0.7, 0.3}, 1, 2),
|
||||
}
|
||||
draft := Distribution{
|
||||
IDs: mlx.NewArrayInt32([]int32{2, 4}, []int32{1, 2}),
|
||||
Probs: mlx.FromValues([]float32{0.2, 0.8}, 1, 2),
|
||||
}
|
||||
|
||||
residual := target.ResidualAgainst(draft)
|
||||
mlx.Eval(residual.Arrays()...)
|
||||
|
||||
ids := residual.IDs.Ints()
|
||||
probs := residual.Probs.Floats()
|
||||
want := map[int]float64{2: 0.625, 5: 0.375}
|
||||
if len(ids) != 2 || len(probs) != 2 {
|
||||
t.Fatalf("residual = ids %v probs %v, want 2 sparse entries", ids, probs)
|
||||
}
|
||||
for i, id := range ids {
|
||||
w, ok := want[id]
|
||||
if !ok {
|
||||
t.Fatalf("residual includes token %d outside target support: ids=%v probs=%v", id, ids, probs)
|
||||
}
|
||||
if math.Abs(float64(probs[i])-w) > 1e-5 {
|
||||
t.Fatalf("residual token %d prob = %v, want %v; ids=%v probs=%v", id, probs[i], w, ids, probs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSeededSamplingIsReproducible(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
seededSequence := func(seed int) []int {
|
||||
s := New(128)
|
||||
t.Cleanup(func() {
|
||||
s.Free()
|
||||
mlx.Sweep()
|
||||
})
|
||||
s.Add(0, Options{Temperature: 1, TopK: 4, Seed: seed, UseSeed: true}, nil)
|
||||
|
||||
logits := slotLogits([]float32{0, 0, 0, 0})
|
||||
out := make([]int, 32)
|
||||
for i := range out {
|
||||
token := s.Sample([]int{0}, logits).Token
|
||||
mlx.Eval(token)
|
||||
out[i] = token.Int()
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
a := seededSequence(1234)
|
||||
b := seededSequence(1234)
|
||||
if !slices.Equal(a, b) {
|
||||
t.Fatalf("same seed produced different sequences:\n%v\n%v", a, b)
|
||||
}
|
||||
|
||||
c := seededSequence(5678)
|
||||
if slices.Equal(a, c) {
|
||||
t.Fatalf("different seeds produced the same sequence: %v", a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSeededBernoulliIsReproducible(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
seededMask := func() []int {
|
||||
s := New(128)
|
||||
t.Cleanup(func() {
|
||||
s.Free()
|
||||
mlx.Sweep()
|
||||
})
|
||||
s.Add(0, Options{Seed: 99, UseSeed: true}, nil)
|
||||
|
||||
mask := s.Bernoulli(0, mlx.FromValues([]float32{0.5, 0.5, 0.5, 0.5, 0.5, 0.5}, 6)).AsType(mlx.DTypeInt32)
|
||||
mlx.Eval(mask)
|
||||
return mask.Ints()
|
||||
}
|
||||
|
||||
a := seededMask()
|
||||
b := seededMask()
|
||||
if !slices.Equal(a, b) {
|
||||
t.Fatalf("same seed produced different bernoulli masks:\n%v\n%v", a, b)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSampleHistoryWindow verifies that penalty history respects the
|
||||
// RepeatLastN window: priors longer than RepeatLastN are trimmed on Add,
|
||||
// and once the ring wraps, tokens that rotate out no longer contribute
|
||||
// to penalties.
|
||||
func TestSampleHistoryWindow(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
s := New(128)
|
||||
t.Cleanup(func() {
|
||||
s.Free()
|
||||
mlx.Sweep()
|
||||
})
|
||||
|
||||
// RepeatLastN=2 with priors {1, 2, 3}: makeHistoryRow keeps only
|
||||
// {2, 3}. Token 1 was trimmed — its penalty is NOT active.
|
||||
s.Add(0, Options{RepeatLastN: 2, PresencePenalty: 10}, []int32{1, 2, 3})
|
||||
|
||||
// Step 1: logits favor token 1 (trimmed). If the trim were broken it
|
||||
// would be penalized and the argmax would move.
|
||||
step1 := s.Sample([]int{0}, slotLogits([]float32{0, 5, 0, 0, 0})).Token
|
||||
mlx.Eval(step1)
|
||||
if got := step1.Int(); got != 1 {
|
||||
t.Fatalf("step 1 = %d, want 1 (token 1 trimmed from priors)", got)
|
||||
}
|
||||
// After step 1 the ring holds {1, 3}; token 2 has rotated out.
|
||||
|
||||
// Step 2: logits favor token 2 (rotated out). If the ring wrap were
|
||||
// wrong, token 2 would still be penalized.
|
||||
step2 := s.Sample([]int{0}, slotLogits([]float32{0, 0, 5, 0, 0})).Token
|
||||
mlx.Eval(step2)
|
||||
if got := step2.Int(); got != 2 {
|
||||
t.Fatalf("step 2 = %d, want 2 (token 2 rotated out of ring)", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpeculativeScoresUsesDraftHistoryWithoutCommit(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
s := New(128)
|
||||
t.Cleanup(func() {
|
||||
s.Free()
|
||||
mlx.Sweep()
|
||||
})
|
||||
|
||||
s.Add(0, Options{RepeatLastN: 2, RepeatPenalty: 10}, []int32{1, 2})
|
||||
draftTokens := mlx.NewArrayInt32([]int32{3, 4}, []int32{1, 2})
|
||||
scores := s.SpeculativeScores(0, batchLogits(
|
||||
[]float32{0, 9, 9, 8, 0}, // history {1,2}; token 3 wins
|
||||
[]float32{0, 0, 9, 9, 8}, // history {2,3}; token 4 wins
|
||||
[]float32{0, 0, 9, 9, 8}, // history {3,4}; token 2 wins
|
||||
), draftTokens)
|
||||
tokens := scores.Argmax(-1, false).AsType(mlx.DTypeInt32)
|
||||
mlx.Eval(tokens)
|
||||
|
||||
if got, want := tokens.Ints(), []int{3, 4, 2}; len(got) != len(want) {
|
||||
t.Fatalf("tokens = %v, want %v", got, want)
|
||||
} else {
|
||||
for i := range want {
|
||||
if got[i] != want[i] {
|
||||
t.Fatalf("tokens = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
if s.byID[0].historyLen != 2 {
|
||||
t.Fatalf("historyLen = %d, want 2", s.byID[0].historyLen)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCommitBatchesRingWrites(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
s := New(128)
|
||||
t.Cleanup(func() {
|
||||
s.Free()
|
||||
mlx.Sweep()
|
||||
})
|
||||
|
||||
s.Add(0, Options{RepeatLastN: 4, RepeatPenalty: 1.1}, []int32{10, 11, 12})
|
||||
s.Commit(0, []int32{20, 21, 22})
|
||||
s.Commit(0, []int32{30, 31, 32, 33, 34})
|
||||
mlx.Eval(s.history)
|
||||
|
||||
got := s.history.Ints()
|
||||
want := []int{32, 33, 34, 31}
|
||||
for i := range want {
|
||||
if got[i] != want[i] {
|
||||
t.Fatalf("history = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
if s.byID[0].historyLen != 11 {
|
||||
t.Fatalf("historyLen = %d, want 11", s.byID[0].historyLen)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBatchSamplingPreservesPerSlotBehavior is the core equivalence test:
|
||||
// for every representative dispatch branch (uniform, serial on mixed opts,
|
||||
// serial on partial ring, subset/out-of-order), a batched Sample call must
|
||||
// produce the same token per row as running the same slot alone.
|
||||
func TestBatchSamplingPreservesPerSlotBehavior(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
type slot struct {
|
||||
id int
|
||||
opts Options
|
||||
priors []int32
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
slots []slot
|
||||
sample []int
|
||||
rows [][]float32
|
||||
}{
|
||||
{
|
||||
name: "uniform",
|
||||
slots: []slot{
|
||||
{10, Options{RepeatLastN: 2, PresencePenalty: 5}, []int32{1, 2}},
|
||||
{20, Options{RepeatLastN: 2, PresencePenalty: 5}, []int32{0, 2}},
|
||||
},
|
||||
sample: []int{10, 20},
|
||||
rows: [][]float32{{0, 5, 4}, {3, 0, 0}},
|
||||
},
|
||||
{
|
||||
name: "serial — mixed opts",
|
||||
slots: []slot{
|
||||
{1, Options{RepeatLastN: 1, RepeatPenalty: 2}, []int32{1}},
|
||||
{2, Options{Temperature: 1, TopK: 1}, nil},
|
||||
},
|
||||
sample: []int{1, 2},
|
||||
rows: [][]float32{{0, 5, 4, 1}, {2, 1, 5, 3}},
|
||||
},
|
||||
{
|
||||
name: "serial — partial ring",
|
||||
slots: []slot{
|
||||
{1, Options{RepeatLastN: 4, PresencePenalty: 5}, []int32{1, 1, 1, 1}},
|
||||
{2, Options{RepeatLastN: 4, PresencePenalty: 5}, []int32{2}},
|
||||
},
|
||||
sample: []int{1, 2},
|
||||
rows: [][]float32{{0, 5, 4}, {0, 4, 5}},
|
||||
},
|
||||
{
|
||||
name: "subset out-of-order",
|
||||
slots: []slot{
|
||||
{10, Options{RepeatLastN: 2, PresencePenalty: 10}, []int32{1, 1}},
|
||||
{20, Options{RepeatLastN: 2, PresencePenalty: 10}, []int32{2, 2}},
|
||||
{30, Options{RepeatLastN: 2, PresencePenalty: 10}, []int32{3, 3}},
|
||||
},
|
||||
sample: []int{30, 10},
|
||||
rows: [][]float32{{5, 5, 5, 0, 5, 5}, {5, 0, 5, 5, 0, 5}},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Per-slot reference for each sampled seq.
|
||||
want := make([]int, len(tc.sample))
|
||||
for i, id := range tc.sample {
|
||||
var spec slot
|
||||
for _, s := range tc.slots {
|
||||
if s.id == id {
|
||||
spec = s
|
||||
break
|
||||
}
|
||||
}
|
||||
want[i] = sampleOne(t, spec.opts, spec.priors, tc.rows[i])
|
||||
}
|
||||
|
||||
// Batched call.
|
||||
s := New(128)
|
||||
t.Cleanup(func() {
|
||||
s.Free()
|
||||
mlx.Sweep()
|
||||
})
|
||||
for _, spec := range tc.slots {
|
||||
s.Add(spec.id, spec.opts, spec.priors)
|
||||
}
|
||||
res := s.Sample(tc.sample, batchLogits(tc.rows...))
|
||||
mlx.Eval(res.Token)
|
||||
got := res.Token.Ints()
|
||||
|
||||
for i, id := range tc.sample {
|
||||
if got[i] != want[i] {
|
||||
t.Errorf("seq %d: batched = %d, per-slot = %d", id, got[i], want[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRemoveDoesNotLeakHistory: after Remove, a newly-added slot at the
|
||||
// recycled row must start from its own priors only — no carryover from
|
||||
// the removed slot's history.
|
||||
func TestRemoveDoesNotLeakHistory(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
opts := Options{RepeatLastN: 1, PresencePenalty: 10}
|
||||
s := New(128)
|
||||
t.Cleanup(func() {
|
||||
s.Free()
|
||||
mlx.Sweep()
|
||||
})
|
||||
s.Add(1, opts, []int32{1})
|
||||
s.Add(2, opts, []int32{2})
|
||||
s.Remove(1)
|
||||
s.Add(3, opts, []int32{0})
|
||||
|
||||
// Slot 2 retains history {2}; slot 3 retains history {0}. With
|
||||
// equal logits and PresencePenalty=10 the argmax drops to the first
|
||||
// unpenalized token.
|
||||
res := s.Sample([]int{2, 3}, batchLogits(
|
||||
[]float32{3, 3, 0},
|
||||
[]float32{3, 3, 0},
|
||||
))
|
||||
mlx.Eval(res.Token)
|
||||
tokens := res.Token.Ints()
|
||||
if tokens[0] != 0 {
|
||||
t.Errorf("slot 2 = %d, want 0 (token 2 penalized)", tokens[0])
|
||||
}
|
||||
if tokens[1] != 1 {
|
||||
t.Errorf("slot 3 = %d, want 1 (token 0 penalized, no slot-1 carryover)", tokens[1])
|
||||
}
|
||||
}
|
||||
247
x/mlxrunner/server.go
Normal file
247
x/mlxrunner/server.go
Normal file
@@ -0,0 +1,247 @@
|
||||
package mlxrunner
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/x/internal/mlxthread"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/sample"
|
||||
)
|
||||
|
||||
func Execute(args []string) error {
|
||||
slog.SetDefault(logutil.NewLogger(os.Stderr, envconfig.LogLevel()))
|
||||
|
||||
var (
|
||||
modelName string
|
||||
port int
|
||||
)
|
||||
|
||||
flagSet := flag.NewFlagSet("mlxrunner", flag.ExitOnError)
|
||||
flagSet.StringVar(&modelName, "model", "", "Model name")
|
||||
flagSet.IntVar(&port, "port", 0, "Port to listen on")
|
||||
_ = flagSet.Bool("verbose", false, "Enable debug logging")
|
||||
flagSet.Parse(args)
|
||||
|
||||
worker, err := mlxthread.Start("mlxrunner", func() error {
|
||||
if err := mlx.CheckInit(); err != nil {
|
||||
return fmt.Errorf("MLX not available: %w", err)
|
||||
}
|
||||
|
||||
if mlx.GPUIsAvailable() {
|
||||
mlx.SetDefaultDeviceGPU()
|
||||
slog.Info("MLX engine initialized", "MLX version", mlx.Version(), "device", "gpu")
|
||||
} else {
|
||||
slog.Info("MLX engine initialized", "MLX version", mlx.Version(), "device", "cpu")
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer worker.Stop(context.Background(), func() {
|
||||
mlx.Sweep()
|
||||
mlx.ClearCache()
|
||||
})
|
||||
runnerCtx, cancelRunner := context.WithCancel(context.Background())
|
||||
defer cancelRunner()
|
||||
|
||||
runner := Runner{
|
||||
Requests: make(chan Request),
|
||||
mlxThread: worker,
|
||||
}
|
||||
|
||||
if err := worker.Do(context.Background(), func() error {
|
||||
return runner.Load(modelName)
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
readMemory := func() (uint64, error) {
|
||||
return uint64(mlx.ActiveMemory() + mlx.CacheMemory()), nil
|
||||
}
|
||||
initialMemory, err := mlxthread.Call(context.Background(), worker, readMemory)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
memoryCache := newStatusMemoryCache(
|
||||
runnerCtx,
|
||||
initialMemory,
|
||||
time.Now(),
|
||||
statusMemoryRefreshWait,
|
||||
func() (uint64, error) {
|
||||
return mlxthread.Call(runnerCtx, worker, readMemory)
|
||||
},
|
||||
)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("GET /v1/status", func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := json.NewEncoder(w).Encode(statusResponse{
|
||||
Status: 0,
|
||||
Progress: 100,
|
||||
ContextLength: runner.contextLength,
|
||||
Memory: memoryCache.Memory(),
|
||||
}); err != nil {
|
||||
slog.Error("Failed to encode response", "error", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
mux.HandleFunc("/v1/models", func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case "POST":
|
||||
fallthrough
|
||||
case "GET":
|
||||
if err := json.NewEncoder(w).Encode(map[string]any{
|
||||
"Success": true,
|
||||
}); err != nil {
|
||||
slog.Error("Failed to encode response", "error", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
case "DELETE":
|
||||
// TODO: cleanup model and cache
|
||||
}
|
||||
})
|
||||
|
||||
mux.HandleFunc("POST /v1/completions", func(w http.ResponseWriter, r *http.Request) {
|
||||
request := Request{Responses: make(chan CompletionResponse)}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&request.CompletionRequest); err != nil {
|
||||
slog.Error("Failed to decode request", "error", err)
|
||||
http.Error(w, "Bad Request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
request.Pipeline = runner.TextGenerationPipeline
|
||||
request.SamplerOpts = sample.Options{
|
||||
Temperature: request.Options.Temperature,
|
||||
TopP: request.Options.TopP,
|
||||
MinP: request.Options.MinP,
|
||||
TopK: request.Options.TopK,
|
||||
RepeatLastN: request.Options.RepeatLastN,
|
||||
RepeatPenalty: request.Options.RepeatPenalty,
|
||||
PresencePenalty: request.Options.PresencePenalty,
|
||||
FrequencyPenalty: request.Options.FrequencyPenalty,
|
||||
Seed: request.Options.Seed,
|
||||
UseSeed: request.Options.Seed >= 0,
|
||||
Logprobs: request.Logprobs,
|
||||
TopLogprobs: request.TopLogprobs,
|
||||
}
|
||||
|
||||
if err := runner.Prepare(&request); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var cancel context.CancelFunc
|
||||
request.Ctx, cancel = context.WithCancel(r.Context())
|
||||
defer cancel()
|
||||
|
||||
select {
|
||||
case <-r.Context().Done():
|
||||
return
|
||||
case runner.Requests <- request:
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/jsonl")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
enc := json.NewEncoder(w)
|
||||
for {
|
||||
select {
|
||||
case <-r.Context().Done():
|
||||
return
|
||||
case response, ok := <-request.Responses:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if err := enc.Encode(response); err != nil {
|
||||
slog.Error("Failed to encode response", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if f, ok := w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
mux.HandleFunc("POST /v1/tokenize", func(w http.ResponseWriter, r *http.Request) {
|
||||
var b bytes.Buffer
|
||||
if _, err := io.Copy(&b, r.Body); err != nil {
|
||||
slog.Error("Failed to read request body", "error", err)
|
||||
http.Error(w, "Bad Request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
tokens := runner.Tokenizer.Encode(b.String(), runner.Tokenizer.AddBOS())
|
||||
|
||||
if err := json.NewEncoder(w).Encode(tokens); err != nil {
|
||||
slog.Error("Failed to encode response", "error", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
for source, target := range map[string]string{
|
||||
"GET /health": "/v1/status",
|
||||
"POST /load": "/v1/models",
|
||||
"POST /completion": "/v1/completions",
|
||||
} {
|
||||
mux.Handle(source, http.RedirectHandler(target, http.StatusPermanentRedirect))
|
||||
}
|
||||
|
||||
return runner.Run("127.0.0.1", strconv.Itoa(port), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
recorder := &statusRecorder{ResponseWriter: w, code: http.StatusOK}
|
||||
t := time.Now()
|
||||
mux.ServeHTTP(recorder, r)
|
||||
|
||||
var level slog.Level
|
||||
switch {
|
||||
case recorder.code >= 500:
|
||||
level = slog.LevelError
|
||||
case recorder.code >= 400:
|
||||
level = slog.LevelWarn
|
||||
case recorder.code >= 300:
|
||||
return
|
||||
}
|
||||
|
||||
slog.Log(r.Context(), level, "ServeHTTP", "method", r.Method, "path", r.URL.Path, "took", time.Since(t), "status", recorder.Status())
|
||||
}))
|
||||
}
|
||||
|
||||
type statusRecorder struct {
|
||||
http.ResponseWriter
|
||||
code int
|
||||
}
|
||||
|
||||
func (w *statusRecorder) WriteHeader(code int) {
|
||||
w.code = code
|
||||
w.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
func (w *statusRecorder) Status() string {
|
||||
return strconv.Itoa(w.code) + " " + http.StatusText(w.code)
|
||||
}
|
||||
|
||||
func (w *statusRecorder) Flush() {
|
||||
if f, ok := w.ResponseWriter.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
}
|
||||
109
x/mlxrunner/status_memory.go
Normal file
109
x/mlxrunner/status_memory.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package mlxrunner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const statusMemoryRefreshWait = 50 * time.Millisecond
|
||||
|
||||
type statusMemoryRefreshFunc func() (uint64, error)
|
||||
|
||||
// statusMemoryCache keeps health checks from depending synchronously on the
|
||||
// serialized MLX worker while still refreshing memory telemetry opportunistically.
|
||||
type statusMemoryCache struct {
|
||||
done <-chan struct{}
|
||||
wait time.Duration
|
||||
refresh statusMemoryRefreshFunc
|
||||
|
||||
mu sync.Mutex
|
||||
memory uint64
|
||||
refreshedAt time.Time
|
||||
inFlight chan struct{}
|
||||
}
|
||||
|
||||
func newStatusMemoryCache(ctx context.Context, memory uint64, refreshedAt time.Time, wait time.Duration, refresh statusMemoryRefreshFunc) *statusMemoryCache {
|
||||
return &statusMemoryCache{
|
||||
done: ctx.Done(),
|
||||
wait: wait,
|
||||
refresh: refresh,
|
||||
memory: memory,
|
||||
refreshedAt: refreshedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *statusMemoryCache) Memory() uint64 {
|
||||
done := c.startRefresh()
|
||||
if c.wait <= 0 {
|
||||
<-done
|
||||
memory, _ := c.snapshot()
|
||||
return memory
|
||||
}
|
||||
|
||||
timer := time.NewTimer(c.wait)
|
||||
defer timer.Stop()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-timer.C:
|
||||
memory, refreshedAt := c.snapshot()
|
||||
if refreshedAt.IsZero() {
|
||||
slog.Debug("using cached MLX memory status before first refresh")
|
||||
} else {
|
||||
slog.Debug("using cached MLX memory status", "stale", time.Since(refreshedAt))
|
||||
}
|
||||
return memory
|
||||
case <-c.done:
|
||||
}
|
||||
|
||||
memory, _ := c.snapshot()
|
||||
return memory
|
||||
}
|
||||
|
||||
func (c *statusMemoryCache) startRefresh() chan struct{} {
|
||||
c.mu.Lock()
|
||||
if c.inFlight != nil {
|
||||
done := c.inFlight
|
||||
c.mu.Unlock()
|
||||
return done
|
||||
}
|
||||
|
||||
refreshDone := make(chan struct{})
|
||||
c.inFlight = refreshDone
|
||||
refresh := c.refresh
|
||||
lifecycleDone := c.done
|
||||
c.mu.Unlock()
|
||||
|
||||
go func() {
|
||||
memory, err := refresh()
|
||||
now := time.Now()
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
defer close(refreshDone)
|
||||
|
||||
if err != nil {
|
||||
select {
|
||||
case <-lifecycleDone:
|
||||
default:
|
||||
slog.Debug("failed to refresh MLX memory status", "error", err)
|
||||
}
|
||||
c.inFlight = nil
|
||||
return
|
||||
}
|
||||
|
||||
c.memory = memory
|
||||
c.refreshedAt = now
|
||||
c.inFlight = nil
|
||||
}()
|
||||
|
||||
return refreshDone
|
||||
}
|
||||
|
||||
func (c *statusMemoryCache) snapshot() (uint64, time.Time) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.memory, c.refreshedAt
|
||||
}
|
||||
246
x/mlxrunner/status_memory_test.go
Normal file
246
x/mlxrunner/status_memory_test.go
Normal file
@@ -0,0 +1,246 @@
|
||||
package mlxrunner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestStatusMemoryCacheWaitsForFastRefresh(t *testing.T) {
|
||||
var calls atomic.Int32
|
||||
cache := newStatusMemoryCache(context.Background(), 7, time.Now().Add(-time.Minute), time.Second, func() (uint64, error) {
|
||||
calls.Add(1)
|
||||
return 42, nil
|
||||
})
|
||||
|
||||
if got := cache.Memory(); got != 42 {
|
||||
t.Fatalf("got memory %d, want 42", got)
|
||||
}
|
||||
if got := calls.Load(); got != 1 {
|
||||
t.Fatalf("refresh calls = %d, want 1", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatusMemoryCacheSupportsBlockingWait(t *testing.T) {
|
||||
cache := newStatusMemoryCache(context.Background(), 7, time.Now().Add(-time.Minute), 0, func() (uint64, error) {
|
||||
return 42, nil
|
||||
})
|
||||
|
||||
if got := cache.Memory(); got != 42 {
|
||||
t.Fatalf("got memory %d, want 42", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatusMemoryCacheReturnsCachedValueAndRefreshesLater(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
started := make(chan struct{})
|
||||
release := make(chan struct{})
|
||||
var calls atomic.Int32
|
||||
|
||||
cache := newStatusMemoryCache(ctx, 7, time.Now().Add(-time.Minute), time.Millisecond, func() (uint64, error) {
|
||||
if calls.Add(1) == 1 {
|
||||
close(started)
|
||||
}
|
||||
select {
|
||||
case <-release:
|
||||
return 42, nil
|
||||
case <-ctx.Done():
|
||||
return 0, ctx.Err()
|
||||
}
|
||||
})
|
||||
|
||||
start := time.Now()
|
||||
if got := cache.Memory(); got != 7 {
|
||||
t.Fatalf("got memory %d, want cached value 7", got)
|
||||
}
|
||||
if elapsed := time.Since(start); elapsed > time.Second {
|
||||
t.Fatalf("cached memory lookup took too long: %s", elapsed)
|
||||
}
|
||||
|
||||
waitForRefreshStart(t, started)
|
||||
close(release)
|
||||
waitForCachedMemory(t, cache, 42)
|
||||
|
||||
if got := calls.Load(); got != 1 {
|
||||
t.Fatalf("refresh calls = %d, want 1", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatusMemoryCacheReturnsCachedValueBeforeFirstRefresh(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
started := make(chan struct{})
|
||||
release := make(chan struct{})
|
||||
cache := newStatusMemoryCache(ctx, 7, time.Time{}, time.Millisecond, func() (uint64, error) {
|
||||
close(started)
|
||||
select {
|
||||
case <-release:
|
||||
return 42, nil
|
||||
case <-ctx.Done():
|
||||
return 0, ctx.Err()
|
||||
}
|
||||
})
|
||||
|
||||
if got := cache.Memory(); got != 7 {
|
||||
t.Fatalf("got memory %d, want cached value 7", got)
|
||||
}
|
||||
|
||||
waitForRefreshStart(t, started)
|
||||
close(release)
|
||||
waitForCachedMemory(t, cache, 42)
|
||||
}
|
||||
|
||||
func TestStatusMemoryCacheKeepsCachedValueWhenRefreshFails(t *testing.T) {
|
||||
var calls atomic.Int32
|
||||
cache := newStatusMemoryCache(context.Background(), 7, time.Now().Add(-time.Minute), time.Second, func() (uint64, error) {
|
||||
calls.Add(1)
|
||||
return 0, errors.New("refresh failed")
|
||||
})
|
||||
|
||||
if got := cache.Memory(); got != 7 {
|
||||
t.Fatalf("got memory %d, want cached value 7", got)
|
||||
}
|
||||
if got := calls.Load(); got != 1 {
|
||||
t.Fatalf("refresh calls = %d, want 1", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatusMemoryCacheReturnsCachedValueWhenContextDone(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
started := make(chan struct{})
|
||||
release := make(chan struct{})
|
||||
cache := newStatusMemoryCache(ctx, 7, time.Now().Add(-time.Minute), time.Second, func() (uint64, error) {
|
||||
close(started)
|
||||
<-release
|
||||
return 0, ctx.Err()
|
||||
})
|
||||
|
||||
cancel()
|
||||
if got := cache.Memory(); got != 7 {
|
||||
t.Fatalf("got memory %d, want cached value 7", got)
|
||||
}
|
||||
|
||||
waitForRefreshStart(t, started)
|
||||
close(release)
|
||||
waitForInflightRefresh(t, cache)
|
||||
}
|
||||
|
||||
func TestStatusMemoryCacheAllowsRefreshAfterFailure(t *testing.T) {
|
||||
var calls atomic.Int32
|
||||
cache := newStatusMemoryCache(context.Background(), 7, time.Now().Add(-time.Minute), time.Second, func() (uint64, error) {
|
||||
if calls.Add(1) == 1 {
|
||||
return 0, errors.New("refresh failed")
|
||||
}
|
||||
return 42, nil
|
||||
})
|
||||
|
||||
if got := cache.Memory(); got != 7 {
|
||||
t.Fatalf("got memory %d, want cached value 7", got)
|
||||
}
|
||||
if got := cache.Memory(); got != 42 {
|
||||
t.Fatalf("got memory %d after retry, want 42", got)
|
||||
}
|
||||
if got := calls.Load(); got != 2 {
|
||||
t.Fatalf("refresh calls = %d, want 2", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatusMemoryCacheAllowsOneInflightRefresh(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
started := make(chan struct{})
|
||||
release := make(chan struct{})
|
||||
var calls atomic.Int32
|
||||
|
||||
cache := newStatusMemoryCache(ctx, 11, time.Now().Add(-time.Minute), time.Millisecond, func() (uint64, error) {
|
||||
if calls.Add(1) == 1 {
|
||||
close(started)
|
||||
}
|
||||
select {
|
||||
case <-release:
|
||||
return 99, nil
|
||||
case <-ctx.Done():
|
||||
return 0, ctx.Err()
|
||||
}
|
||||
})
|
||||
|
||||
const goroutines = 8
|
||||
var wg sync.WaitGroup
|
||||
errCh := make(chan string, goroutines)
|
||||
for range goroutines {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if got := cache.Memory(); got != 11 {
|
||||
errCh <- "got non-cached memory value"
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errCh)
|
||||
for err := range errCh {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
waitForRefreshStart(t, started)
|
||||
if got := calls.Load(); got != 1 {
|
||||
t.Fatalf("refresh calls = %d, want 1", got)
|
||||
}
|
||||
|
||||
close(release)
|
||||
waitForCachedMemory(t, cache, 99)
|
||||
}
|
||||
|
||||
func waitForRefreshStart(t *testing.T, started <-chan struct{}) {
|
||||
t.Helper()
|
||||
select {
|
||||
case <-started:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timeout waiting for refresh to start")
|
||||
}
|
||||
}
|
||||
|
||||
func waitForCachedMemory(t *testing.T, cache *statusMemoryCache, want uint64) {
|
||||
t.Helper()
|
||||
deadline := time.After(time.Second)
|
||||
for {
|
||||
got, _ := cache.snapshot()
|
||||
if got == want {
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case <-deadline:
|
||||
t.Fatalf("cached memory = %d, want %d", got, want)
|
||||
case <-time.After(time.Millisecond):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func waitForInflightRefresh(t *testing.T, cache *statusMemoryCache) {
|
||||
t.Helper()
|
||||
deadline := time.After(time.Second)
|
||||
for {
|
||||
cache.mu.Lock()
|
||||
inFlight := cache.inFlight
|
||||
cache.mu.Unlock()
|
||||
if inFlight == nil {
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case <-deadline:
|
||||
t.Fatal("timeout waiting for refresh to finish")
|
||||
case <-time.After(time.Millisecond):
|
||||
}
|
||||
}
|
||||
}
|
||||
47
x/mlxrunner/utf8_buffer.go
Normal file
47
x/mlxrunner/utf8_buffer.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package mlxrunner
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
// flushValidUTF8Prefix returns and consumes the longest valid UTF-8 prefix
|
||||
// currently buffered, leaving any incomplete trailing bytes in place.
|
||||
func flushValidUTF8Prefix(b *bytes.Buffer) string {
|
||||
data := b.Bytes()
|
||||
if len(data) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
prefix := validUTF8PrefixLen(data)
|
||||
if prefix == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
text := string(data[:prefix])
|
||||
b.Next(prefix)
|
||||
return text
|
||||
}
|
||||
|
||||
func validUTF8PrefixLen(data []byte) int {
|
||||
i := 0
|
||||
prefix := 0
|
||||
for i < len(data) {
|
||||
r, size := utf8.DecodeRune(data[i:])
|
||||
if r == utf8.RuneError && size == 1 {
|
||||
if !utf8.FullRune(data[i:]) {
|
||||
break
|
||||
}
|
||||
|
||||
// Invalid UTF-8 byte; consume one byte to guarantee forward progress.
|
||||
i++
|
||||
prefix = i
|
||||
continue
|
||||
}
|
||||
|
||||
i += size
|
||||
prefix = i
|
||||
}
|
||||
|
||||
return prefix
|
||||
}
|
||||
46
x/mlxrunner/utf8_buffer_test.go
Normal file
46
x/mlxrunner/utf8_buffer_test.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package mlxrunner
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestFlushValidUTF8Prefix_PreservesIncompleteRune(t *testing.T) {
|
||||
var b bytes.Buffer
|
||||
|
||||
b.Write([]byte{0xE3, 0x81})
|
||||
if got := flushValidUTF8Prefix(&b); got != "" {
|
||||
t.Fatalf("first flush = %q, want empty", got)
|
||||
}
|
||||
|
||||
b.Write([]byte{0x93, 0xE3})
|
||||
if got := flushValidUTF8Prefix(&b); got != "こ" {
|
||||
t.Fatalf("second flush = %q, want %q", got, "こ")
|
||||
}
|
||||
|
||||
if got := b.Bytes(); !bytes.Equal(got, []byte{0xE3}) {
|
||||
t.Fatalf("buffer after second flush = %v, want %v", got, []byte{0xE3})
|
||||
}
|
||||
|
||||
b.Write([]byte{0x82, 0x93})
|
||||
if got := flushValidUTF8Prefix(&b); got != "ん" {
|
||||
t.Fatalf("third flush = %q, want %q", got, "ん")
|
||||
}
|
||||
|
||||
if b.Len() != 0 {
|
||||
t.Fatalf("buffer not empty after third flush: %d", b.Len())
|
||||
}
|
||||
}
|
||||
|
||||
func TestFlushValidUTF8Prefix_ValidText(t *testing.T) {
|
||||
var b bytes.Buffer
|
||||
b.WriteString("hello 世界")
|
||||
|
||||
if got := flushValidUTF8Prefix(&b); got != "hello 世界" {
|
||||
t.Fatalf("flush = %q, want %q", got, "hello 世界")
|
||||
}
|
||||
|
||||
if b.Len() != 0 {
|
||||
t.Fatalf("buffer not empty after flush: %d", b.Len())
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user