898 lines
29 KiB
Go
898 lines
29 KiB
Go
package sample
|
|
|
|
import (
|
|
"fmt"
|
|
"math"
|
|
"slices"
|
|
|
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
|
)
|
|
|
|
type Options struct {
|
|
Temperature float32
|
|
TopP float32
|
|
MinP float32
|
|
TopK int
|
|
RepeatLastN int
|
|
RepeatPenalty float32
|
|
PresencePenalty float32
|
|
FrequencyPenalty float32
|
|
Seed int
|
|
UseSeed bool
|
|
|
|
// Logprobs causes Sample to populate Result.Logprob with the selected
|
|
// token's log-probability. TopLogprobs (when > 0) adds top-K pairs.
|
|
Logprobs bool
|
|
TopLogprobs int
|
|
}
|
|
|
|
// Result bundles the outputs of one decode step. Logprob/TopTokens/
|
|
// TopLogprobs are populated whenever any registered slot has Logprobs
|
|
// (respectively TopLogprobs>0). Consumers need to filter by their
|
|
// per-slot Options.
|
|
type Result struct {
|
|
Token *mlx.Array // sampled token ids, shape [B]
|
|
Logprob *mlx.Array // sampled-token logprobs, shape [B,1]; nil unless any registered slot has Logprobs
|
|
TopTokens *mlx.Array // top-K token ids, shape [B,maxK]; nil unless any registered slot has TopLogprobs>0
|
|
TopLogprobs *mlx.Array // top-K logprobs, shape [B,maxK]; same
|
|
}
|
|
|
|
// Arrays returns the tensor fields as a slice so callers can drive the mlx
|
|
// lifecycle verbs (Pin, Unpin, Eval, AsyncEval) over the whole group. Unset
|
|
// fields stay nil; the mlx helpers skip them.
|
|
func (r Result) Arrays() []*mlx.Array {
|
|
return []*mlx.Array{r.Token, r.Logprob, r.TopTokens, r.TopLogprobs}
|
|
}
|
|
|
|
// Distribution is the filtered probability distribution used by the sampler.
|
|
// When IDs is nil, Probs is dense over the vocabulary. When IDs is set, Probs
|
|
// is sparse over the token ids in IDs, preserving GPU residency for the
|
|
// top-k-first path used by normal and speculative sampling.
|
|
type Distribution struct {
|
|
IDs *mlx.Array // sparse token ids, shape [B,K]; nil for dense distributions
|
|
Probs *mlx.Array // probabilities, shape [B,K] or [B,V]
|
|
}
|
|
|
|
// Arrays returns the tensor fields for mlx lifecycle management.
|
|
func (d Distribution) Arrays() []*mlx.Array {
|
|
return []*mlx.Array{d.IDs, d.Probs}
|
|
}
|
|
|
|
// Rows returns the number of rows in the distribution.
|
|
func (d Distribution) Rows() int {
|
|
if d.Probs == nil {
|
|
return 0
|
|
}
|
|
return d.Probs.Dim(0)
|
|
}
|
|
|
|
// SliceRows returns a row slice while preserving sparse/dense layout.
|
|
func (d Distribution) SliceRows(start, stop int) Distribution {
|
|
out := Distribution{Probs: d.Probs.Slice(mlx.Slice(start, stop), mlx.Slice())}
|
|
if d.IDs != nil {
|
|
out.IDs = d.IDs.Slice(mlx.Slice(start, stop), mlx.Slice())
|
|
}
|
|
return out
|
|
}
|
|
|
|
// SampleWithKey draws one token per row using key when supplied.
|
|
func (d Distribution) SampleWithKey(key *mlx.Array) *mlx.Array {
|
|
choice := logitsFromProbs(d.Probs).CategoricalWithKey(-1, key).AsType(mlx.DTypeInt32)
|
|
if d.IDs == nil {
|
|
return choice
|
|
}
|
|
return d.IDs.TakeAlongAxis(choice.ExpandDims(-1), -1).Squeeze(-1).AsType(mlx.DTypeInt32)
|
|
}
|
|
|
|
// Prob returns the probability assigned to one token per row.
|
|
func (d Distribution) Prob(tokens *mlx.Array) *mlx.Array {
|
|
switch tokens.NumDims() {
|
|
case 2:
|
|
if tokens.Dim(0) == 1 {
|
|
tokens = tokens.Squeeze(0)
|
|
} else if tokens.Dim(1) == 1 {
|
|
tokens = tokens.Squeeze(1)
|
|
}
|
|
case 0:
|
|
tokens = tokens.Reshape(1)
|
|
}
|
|
return d.ProbsForIDs(tokens.ExpandDims(-1)).Squeeze(-1)
|
|
}
|
|
|
|
// ProbsForIDs returns probabilities for each requested token id. ids must be
|
|
// rank-2 [B,N], matching the distribution rows.
|
|
func (d Distribution) ProbsForIDs(ids *mlx.Array) *mlx.Array {
|
|
if d.IDs == nil {
|
|
return d.Probs.TakeAlongAxis(ids, -1)
|
|
}
|
|
eq := d.IDs.ExpandDims(-1).Equal(ids.ExpandDims(1))
|
|
values := mlx.Where(eq, d.Probs.ExpandDims(-1), mlx.FromValue(float32(0)))
|
|
return values.SumAxis(1, false)
|
|
}
|
|
|
|
// ResidualAgainst returns the Leviathan/Chen rejection distribution
|
|
// proportional to max(target - draft, 0). Sparse target distributions stay
|
|
// sparse over the target support; tokens outside target support have zero mass.
|
|
func (d Distribution) ResidualAgainst(draft Distribution) Distribution {
|
|
if d.IDs != nil {
|
|
diff := d.Probs.Subtract(draft.ProbsForIDs(d.IDs))
|
|
return Distribution{IDs: d.IDs, Probs: normalizeProbs(mlx.Maximum(diff, mlx.FromValue(float32(0))))}
|
|
}
|
|
if draft.IDs != nil {
|
|
panic("sample.Distribution.ResidualAgainst: dense target with sparse draft is unsupported")
|
|
}
|
|
diff := d.Probs.Subtract(draft.Probs)
|
|
return Distribution{Probs: normalizeProbs(mlx.Maximum(diff, mlx.FromValue(float32(0))))}
|
|
}
|
|
|
|
// LogProbs returns dense log-probabilities, scattering sparse distributions
|
|
// into a full-vocabulary tensor when needed.
|
|
func (d Distribution) LogProbs(vocab int) *mlx.Array {
|
|
logProbs := logitsFromProbs(d.Probs)
|
|
if d.IDs == nil {
|
|
return logProbs
|
|
}
|
|
out := mlx.AddScalar(mlx.Zeros(mlx.DTypeFloat32, d.Probs.Dim(0), vocab), float32(math.Inf(-1)))
|
|
return out.PutAlongAxis(d.IDs, logProbs, -1)
|
|
}
|
|
|
|
// ConcatenateDistributions concatenates distribution rows. All inputs must use
|
|
// the same sparse/dense layout.
|
|
func ConcatenateDistributions(dists []Distribution) Distribution {
|
|
if len(dists) == 0 {
|
|
return Distribution{}
|
|
}
|
|
probs := make([]*mlx.Array, 0, len(dists))
|
|
ids := make([]*mlx.Array, 0, len(dists))
|
|
sparse := dists[0].IDs != nil
|
|
for _, d := range dists {
|
|
if (d.IDs != nil) != sparse {
|
|
panic("sample.ConcatenateDistributions: mixed sparse and dense distributions")
|
|
}
|
|
probs = append(probs, d.Probs)
|
|
if sparse {
|
|
ids = append(ids, d.IDs)
|
|
}
|
|
}
|
|
out := Distribution{Probs: mlx.Concatenate(probs, 0)}
|
|
if sparse {
|
|
out.IDs = mlx.Concatenate(ids, 0)
|
|
}
|
|
return out
|
|
}
|
|
|
|
// Sampler is a batched, slot-based sampler. Sequences are registered with
|
|
// Add and released with Remove. Each Sample call takes a subset of
|
|
// registered slots (in any order) with their [B,V] logits, samples one
|
|
// token per row, and appends it to that slot's ring-buffer history. Slots
|
|
// not named in a given call are untouched.
|
|
type Sampler struct {
|
|
slots []*slotState
|
|
byID map[int]*slotState
|
|
|
|
// history is the pooled ring-buffer storage, [B, W] int32. Row i
|
|
// belongs to slots[i]; W is max(RepeatLastN) across penalty slots.
|
|
// Allocated on the first penalty slot, rebuilt only in Add/Remove.
|
|
history *mlx.Array
|
|
|
|
// allSameOpts: every registered slot shares Options. When true the
|
|
// canonical shared value is s.slots[0].opts.
|
|
allSameOpts bool
|
|
|
|
// anyLogprobs / maxTopLogprobs: compute-for-all output config.
|
|
// Sample populates Logprob (and Top* when maxTopLogprobs>0) whenever
|
|
// any registered slot requests them, even if that slot isn't in the
|
|
// current call.
|
|
anyLogprobs bool
|
|
maxTopLogprobs int
|
|
|
|
// numCtx is the runner's context window; normalize uses it to
|
|
// resolve the repeat_last_n == -1 sentinel.
|
|
numCtx int
|
|
}
|
|
|
|
type slotState struct {
|
|
opts Options
|
|
historyLen int
|
|
randomCounter uint64
|
|
}
|
|
|
|
type slotCtx struct {
|
|
opts Options
|
|
history *mlx.Array // 2D [B, W] when penalties are configured; nil otherwise
|
|
}
|
|
|
|
// New constructs an empty sampler with no registered slots. numCtx is
|
|
// the runner's context window and must be positive.
|
|
func New(numCtx int) *Sampler {
|
|
return &Sampler{
|
|
byID: make(map[int]*slotState),
|
|
allSameOpts: true,
|
|
numCtx: numCtx,
|
|
}
|
|
}
|
|
|
|
// historyWidth returns the column count of the pooled history tensor,
|
|
// or 0 when no penalty slot has forced it to be allocated.
|
|
func (s *Sampler) historyWidth() int {
|
|
if s.history == nil {
|
|
return 0
|
|
}
|
|
return s.history.Dim(1)
|
|
}
|
|
|
|
func (o Options) usesHistory() bool {
|
|
// RepeatLastN == 0 disables the penalty ring per the repeat_last_n API
|
|
// contract (0 = disabled), overriding any penalty coefficients.
|
|
if o.RepeatLastN == 0 {
|
|
return false
|
|
}
|
|
return o.RepeatPenalty != 1 || o.PresencePenalty != 0 || o.FrequencyPenalty != 0
|
|
}
|
|
|
|
func (o Options) normalize(numCtx int) Options {
|
|
if o.RepeatPenalty <= 0 {
|
|
o.RepeatPenalty = 1
|
|
}
|
|
// Resolve the repeat_last_n == -1 sentinel ("-1 = num_ctx") against
|
|
// the caller's context window.
|
|
if o.RepeatLastN < 0 {
|
|
o.RepeatLastN = numCtx
|
|
}
|
|
if !o.usesHistory() {
|
|
// Zero the ring capacity so slots that differ only in a spurious
|
|
// RepeatLastN still batch together and don't inflate pool width.
|
|
o.RepeatLastN = 0
|
|
}
|
|
if o.Seed < 0 {
|
|
o.UseSeed = false
|
|
}
|
|
if !o.UseSeed {
|
|
// Keep unseeded callers on the same batching path even when a
|
|
// meaningless Seed value is present in an Options literal.
|
|
o.Seed = 0
|
|
}
|
|
return o
|
|
}
|
|
|
|
// Add registers a sequence under seqID. The last RepeatLastN entries of
|
|
// priorTokens seed the ring buffer.
|
|
func (s *Sampler) Add(seqID int, opts Options, priorTokens []int32) {
|
|
if _, dup := s.byID[seqID]; dup {
|
|
panic(fmt.Sprintf("sample.Sampler.Add: seqID %d already registered", seqID))
|
|
}
|
|
|
|
opts = opts.normalize(s.numCtx)
|
|
slot := &slotState{
|
|
opts: opts,
|
|
}
|
|
|
|
// Grow the pool to hold this slot's row. The pool is lazy — the first
|
|
// penalty slot allocates it — and thereafter every registered slot
|
|
// gets a row (rows for non-penalty slots are zero and never read).
|
|
// Invariant: s.history is pinned whenever non-nil.
|
|
if s.history != nil || opts.usesHistory() {
|
|
targetWidth := max(opts.RepeatLastN, s.historyWidth())
|
|
newRow := makeHistoryRow(priorTokens, opts.RepeatLastN, targetWidth)
|
|
|
|
var pool *mlx.Array
|
|
switch {
|
|
case s.history == nil && len(s.slots) == 0:
|
|
pool = newRow
|
|
case s.history == nil:
|
|
// First penalty slot with non-penalty slots already registered;
|
|
// seed zero rows so s.slots and pool row indices stay aligned.
|
|
zeros := mlx.Zeros(mlx.DTypeInt32, len(s.slots), targetWidth)
|
|
pool = zeros.Concatenate(0, newRow)
|
|
case targetWidth > s.historyWidth():
|
|
pad := mlx.Zeros(mlx.DTypeInt32, s.history.Dim(0), targetWidth-s.historyWidth())
|
|
pool = s.history.Concatenate(1, pad).Concatenate(0, newRow)
|
|
default:
|
|
pool = s.history.Concatenate(0, newRow)
|
|
}
|
|
|
|
mlx.Pin(pool)
|
|
mlx.Unpin(s.history)
|
|
s.history = pool
|
|
|
|
if opts.usesHistory() {
|
|
// Cap on seed so the next write's ring position
|
|
// (historyLen % RepeatLastN) lands at 0, overwriting the
|
|
// oldest entry when the ring was filled from priors.
|
|
slot.historyLen = min(len(priorTokens), opts.RepeatLastN)
|
|
}
|
|
}
|
|
|
|
s.slots = append(s.slots, slot)
|
|
s.byID[seqID] = slot
|
|
s.recomputeInvariants()
|
|
}
|
|
|
|
// makeHistoryRow builds a [1, width] int32 row with the last repeatLastN
|
|
// entries of priorTokens packed into [0, min(len, repeatLastN)), zeros
|
|
// elsewhere.
|
|
func makeHistoryRow(priorTokens []int32, repeatLastN, width int) *mlx.Array {
|
|
take := min(len(priorTokens), repeatLastN)
|
|
if take <= 0 {
|
|
return mlx.Zeros(mlx.DTypeInt32, 1, width)
|
|
}
|
|
row := make([]int32, width)
|
|
copy(row, priorTokens[len(priorTokens)-take:])
|
|
return mlx.NewArrayInt32(row, []int32{1, int32(width)})
|
|
}
|
|
|
|
// recomputeInvariants refreshes allSameOpts and anyLogprobs/maxTopLogprobs
|
|
// from s.slots. Called at the end of Add and Remove.
|
|
func (s *Sampler) recomputeInvariants() {
|
|
if len(s.slots) == 0 {
|
|
s.allSameOpts = true
|
|
s.anyLogprobs = false
|
|
s.maxTopLogprobs = 0
|
|
return
|
|
}
|
|
first := s.slots[0].opts
|
|
s.allSameOpts = true
|
|
s.anyLogprobs = false
|
|
s.maxTopLogprobs = 0
|
|
for _, slot := range s.slots {
|
|
if slot.opts != first {
|
|
s.allSameOpts = false
|
|
}
|
|
if slot.opts.Logprobs {
|
|
s.anyLogprobs = true
|
|
if slot.opts.TopLogprobs > s.maxTopLogprobs {
|
|
s.maxTopLogprobs = slot.opts.TopLogprobs
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Remove releases the slot. The pool tensor is rebuilt to drop the row.
|
|
func (s *Sampler) Remove(seqID int) {
|
|
slot, ok := s.byID[seqID]
|
|
if !ok {
|
|
return
|
|
}
|
|
delete(s.byID, seqID)
|
|
|
|
row := slices.Index(s.slots, slot)
|
|
s.slots = slices.Delete(s.slots, row, row+1)
|
|
s.recomputeInvariants()
|
|
|
|
if s.history == nil {
|
|
return
|
|
}
|
|
|
|
n := s.history.Dim(0)
|
|
var newHistory *mlx.Array
|
|
switch {
|
|
case n == 1:
|
|
newHistory = nil
|
|
case row == 0:
|
|
newHistory = s.history.Slice(mlx.Slice(1, n), mlx.Slice())
|
|
case row == n-1:
|
|
newHistory = s.history.Slice(mlx.Slice(0, row), mlx.Slice())
|
|
default:
|
|
before := s.history.Slice(mlx.Slice(0, row), mlx.Slice())
|
|
after := s.history.Slice(mlx.Slice(row+1, n), mlx.Slice())
|
|
newHistory = before.Concatenate(0, after)
|
|
}
|
|
|
|
mlx.Pin(newHistory)
|
|
mlx.Unpin(s.history)
|
|
s.history = newHistory
|
|
}
|
|
|
|
// Free releases the pooled history tensor and resets the sampler to the
|
|
// New-equivalent state so it may be reused.
|
|
func (s *Sampler) Free() {
|
|
mlx.Unpin(s.history)
|
|
*s = Sampler{
|
|
byID: make(map[int]*slotState),
|
|
allSameOpts: true,
|
|
numCtx: s.numCtx,
|
|
}
|
|
}
|
|
|
|
// Sample draws one token per row of logits ([B,V]); seqIDs[i] names the
|
|
// slot whose logits live at row i. Each sampled token is appended to its
|
|
// slot's ring. Slots not named in seqIDs are untouched.
|
|
func (s *Sampler) Sample(seqIDs []int, logits *mlx.Array) Result {
|
|
if len(seqIDs) == 0 {
|
|
return Result{}
|
|
}
|
|
|
|
slots := make([]*slotState, len(seqIDs))
|
|
for i, id := range seqIDs {
|
|
slot, ok := s.byID[id]
|
|
if !ok {
|
|
panic(fmt.Sprintf("sample.Sampler.Sample: seqID %d not registered", id))
|
|
}
|
|
slots[i] = slot
|
|
}
|
|
|
|
var token *mlx.Array
|
|
if opts0, ok := s.canBatch(slots); ok {
|
|
token = s.sampleTokensUniform(slots, opts0, logits)
|
|
} else {
|
|
token = s.sampleTokensSerial(slots, logits)
|
|
}
|
|
|
|
res := Result{Token: token}
|
|
if s.anyLogprobs {
|
|
// Log-softmax over original logits so every row holds a truthful
|
|
// value (compute-for-all; consumers filter per-slot). Subtract
|
|
// max first for numerical stability in the logsumexp.
|
|
lp := logits.AsType(mlx.DTypeFloat32)
|
|
lp = lp.Subtract(lp.MaxAxis(-1, true))
|
|
lp = lp.Subtract(lp.LogsumexpAxis(-1, true))
|
|
res.Logprob = lp.TakeAlongAxis(token.ExpandDims(-1), -1)
|
|
if s.maxTopLogprobs > 0 {
|
|
k := s.maxTopLogprobs
|
|
if vocab := lp.Dim(lp.NumDims() - 1); k > vocab {
|
|
k = vocab
|
|
}
|
|
// Argpartition on the negated values places the K largest
|
|
// (unsorted) in positions [0:K].
|
|
idx := lp.Negative().ArgpartitionAxis(k-1, -1).Slice(mlx.Slice(), mlx.Slice(0, k))
|
|
res.TopTokens = idx.AsType(mlx.DTypeInt32)
|
|
res.TopLogprobs = lp.TakeAlongAxis(idx, -1)
|
|
}
|
|
}
|
|
return res
|
|
}
|
|
|
|
// Distribution applies this slot's sampling transforms to logits without
|
|
// mutating sampler state. Row i is built as if draftTokens[:i] had already
|
|
// been appended to the slot history. logits must be [R,V] or [1,R,V].
|
|
func (s *Sampler) Distribution(seqID int, logits *mlx.Array, draftTokens *mlx.Array) Distribution {
|
|
slot, logits, draftTokens := s.speculativeInputs("Distribution", seqID, logits, draftTokens)
|
|
rows := logits.Dim(0)
|
|
|
|
var hist *mlx.Array
|
|
if slot.opts.usesHistory() {
|
|
if s.history == nil {
|
|
panic(fmt.Sprintf("sample.Sampler.Distribution: seqID %d has no history", seqID))
|
|
}
|
|
if slot.historyLen < slot.opts.RepeatLastN {
|
|
return s.speculativeDistributionSerial(slot, logits, draftTokens)
|
|
}
|
|
hist = s.speculativeHistory(slot, draftTokens, rows)
|
|
}
|
|
|
|
return slot.distribution(&slotCtx{opts: slot.opts, history: hist}, logits)
|
|
}
|
|
|
|
// SpeculativeScores applies this slot's sampling transforms to logits without
|
|
// mutating sampler state and returns dense log-probability scores for sampled
|
|
// decoding. Greedy decoding returns the penalty-adjusted logits.
|
|
func (s *Sampler) SpeculativeScores(seqID int, logits *mlx.Array, draftTokens *mlx.Array) *mlx.Array {
|
|
slot, logits, draftTokens := s.speculativeInputs("SpeculativeScores", seqID, logits, draftTokens)
|
|
rows := logits.Dim(0)
|
|
|
|
var hist *mlx.Array
|
|
if slot.opts.usesHistory() {
|
|
if s.history == nil {
|
|
panic(fmt.Sprintf("sample.Sampler.SpeculativeScores: seqID %d has no history", seqID))
|
|
}
|
|
if slot.historyLen < slot.opts.RepeatLastN {
|
|
return s.speculativeScoresSerial(slot, logits, draftTokens)
|
|
}
|
|
hist = s.speculativeHistory(slot, draftTokens, rows)
|
|
}
|
|
|
|
return slot.speculativeScores(&slotCtx{opts: slot.opts, history: hist}, logits)
|
|
}
|
|
|
|
// SampleDistribution draws from a precomputed distribution while advancing
|
|
// seqID's deterministic RNG stream when a seed is configured.
|
|
func (s *Sampler) SampleDistribution(seqID int, dist Distribution) *mlx.Array {
|
|
slot := s.mustSlot("SampleDistribution", seqID)
|
|
return dist.SampleWithKey(slot.nextRandomKey())
|
|
}
|
|
|
|
// Bernoulli samples boolean outcomes while advancing seqID's deterministic RNG
|
|
// stream when a seed is configured.
|
|
func (s *Sampler) Bernoulli(seqID int, p *mlx.Array) *mlx.Array {
|
|
slot := s.mustSlot("Bernoulli", seqID)
|
|
return mlx.BernoulliWithKey(p, slot.nextRandomKey())
|
|
}
|
|
|
|
func (s *Sampler) mustSlot(caller string, seqID int) *slotState {
|
|
slot, ok := s.byID[seqID]
|
|
if !ok {
|
|
panic(fmt.Sprintf("sample.Sampler.%s: seqID %d not registered", caller, seqID))
|
|
}
|
|
return slot
|
|
}
|
|
|
|
func (s *Sampler) speculativeInputs(caller string, seqID int, logits *mlx.Array, draftTokens *mlx.Array) (*slotState, *mlx.Array, *mlx.Array) {
|
|
slot := s.mustSlot(caller, seqID)
|
|
|
|
if logits.NumDims() == 3 {
|
|
if logits.Dim(0) != 1 {
|
|
panic(fmt.Sprintf("sample.Sampler.%s: only batch size 1 is supported", caller))
|
|
}
|
|
logits = logits.Squeeze(0)
|
|
}
|
|
if logits.NumDims() != 2 {
|
|
panic(fmt.Sprintf("sample.Sampler.%s: logits must be rank 2 or 3, got rank %d", caller, logits.NumDims()))
|
|
}
|
|
|
|
if draftTokens != nil && draftTokens.NumDims() == 1 {
|
|
draftTokens = draftTokens.ExpandDims(0)
|
|
}
|
|
return slot, logits, draftTokens
|
|
}
|
|
|
|
// Commit appends already-selected tokens to seqID's repeat-penalty history.
|
|
// It is used after speculative sampling once the accepted continuation is
|
|
// known. Normal Sample calls continue to mutate history themselves.
|
|
func (s *Sampler) Commit(seqID int, tokens []int32) {
|
|
if len(tokens) == 0 {
|
|
return
|
|
}
|
|
slot, ok := s.byID[seqID]
|
|
if !ok {
|
|
panic(fmt.Sprintf("sample.Sampler.Commit: seqID %d not registered", seqID))
|
|
}
|
|
if !slot.opts.usesHistory() {
|
|
return
|
|
}
|
|
if s.history == nil {
|
|
panic(fmt.Sprintf("sample.Sampler.Commit: seqID %d has no history", seqID))
|
|
}
|
|
|
|
row := slices.Index(s.slots, slot)
|
|
width := s.historyWidth()
|
|
take := min(len(tokens), slot.opts.RepeatLastN)
|
|
startLen := slot.historyLen + len(tokens) - take
|
|
writeTokens := tokens[len(tokens)-take:]
|
|
flatOffsets := make([]int32, take)
|
|
for i := range take {
|
|
ringPos := (startLen + i) % slot.opts.RepeatLastN
|
|
flatOffsets[i] = int32(row*width + ringPos)
|
|
}
|
|
|
|
flatIdx := mlx.NewArrayInt32(flatOffsets, []int32{int32(take), 1})
|
|
values := mlx.NewArrayInt32(writeTokens, []int32{int32(take), 1})
|
|
flatHist := s.history.Reshape(s.history.Dim(0)*width, 1)
|
|
s.history.Set(flatHist.PutAlongAxis(flatIdx, values, 0).Reshape(s.history.Dim(0), width))
|
|
slot.historyLen += len(tokens)
|
|
}
|
|
|
|
func (s *Sampler) speculativeDistributionSerial(slot *slotState, logits *mlx.Array, draftTokens *mlx.Array) Distribution {
|
|
rows := logits.Dim(0)
|
|
draftCount := 0
|
|
if draftTokens != nil {
|
|
draftCount = draftTokens.Dim(1)
|
|
}
|
|
row := slices.Index(s.slots, slot)
|
|
baseFill := min(slot.historyLen, slot.opts.RepeatLastN)
|
|
var base *mlx.Array
|
|
if baseFill > 0 {
|
|
base = s.history.Slice(mlx.Slice(row, row+1), mlx.Slice(0, baseFill))
|
|
}
|
|
|
|
dists := make([]Distribution, 0, rows)
|
|
for i := range rows {
|
|
rowLogits := logits.Slice(mlx.Slice(i, i+1), mlx.Slice())
|
|
hist := base
|
|
prefixLen := min(i, draftCount)
|
|
if prefixLen > 0 {
|
|
prefix := draftTokens.Slice(mlx.Slice(), mlx.Slice(0, prefixLen))
|
|
if hist == nil {
|
|
hist = prefix
|
|
} else {
|
|
hist = hist.Concatenate(1, prefix)
|
|
}
|
|
if hist.Dim(1) > slot.opts.RepeatLastN {
|
|
hist = hist.Slice(mlx.Slice(), mlx.Slice(hist.Dim(1)-slot.opts.RepeatLastN, mlx.End))
|
|
}
|
|
}
|
|
dists = append(dists, slot.distribution(&slotCtx{opts: slot.opts, history: hist}, rowLogits))
|
|
}
|
|
return ConcatenateDistributions(dists)
|
|
}
|
|
|
|
func (s *Sampler) speculativeScoresSerial(slot *slotState, logits *mlx.Array, draftTokens *mlx.Array) *mlx.Array {
|
|
return s.speculativeDistributionSerial(slot, logits, draftTokens).LogProbs(logits.Dim(logits.NumDims() - 1))
|
|
}
|
|
|
|
func (s *Sampler) speculativeHistory(slot *slotState, draftTokens *mlx.Array, rows int) *mlx.Array {
|
|
row := slices.Index(s.slots, slot)
|
|
width := slot.opts.RepeatLastN
|
|
base := s.history.Slice(mlx.Slice(row, row+1), mlx.Slice(0, width))
|
|
base = mlx.Tile(base, []int32{int32(rows), 1})
|
|
next := slot.historyLen % width
|
|
draftCount := 0
|
|
if draftTokens != nil {
|
|
draftCount = draftTokens.Dim(1)
|
|
}
|
|
if draftCount == 0 {
|
|
return base
|
|
}
|
|
|
|
sourceIdx := make([]int32, rows*width)
|
|
writeMask := make([]bool, rows*width)
|
|
for i := range rows {
|
|
prefixLen := min(i, draftCount)
|
|
for j := range prefixLen {
|
|
pos := (next + j) % width
|
|
sourceIdx[i*width+pos] = int32(j)
|
|
writeMask[i*width+pos] = true
|
|
}
|
|
}
|
|
|
|
draftRows := mlx.Tile(draftTokens, []int32{int32(rows), 1})
|
|
idx := mlx.NewArrayInt32(sourceIdx, []int32{int32(rows), int32(width)})
|
|
mask := mlx.FromValues(writeMask, rows, width)
|
|
values := draftRows.TakeAlongAxis(idx, 1)
|
|
return mlx.Where(mask, values, base)
|
|
}
|
|
|
|
func (slot *slotState) speculativeScores(ctx *slotCtx, logits *mlx.Array) *mlx.Array {
|
|
if slot.opts.Temperature == 0 {
|
|
return slot.baseScores(ctx, logits)
|
|
}
|
|
return slot.distribution(ctx, logits).LogProbs(logits.Dim(logits.NumDims() - 1))
|
|
}
|
|
|
|
// canBatch reports whether the call can take the uniform batched path.
|
|
// All slots must share Options; when penalties are active the call must
|
|
// additionally cover every registered slot in registration order with a
|
|
// full ring, because the uniform path indexes the pool positionally.
|
|
func (s *Sampler) canBatch(slots []*slotState) (Options, bool) {
|
|
if !s.allSameOpts {
|
|
return Options{}, false
|
|
}
|
|
// slots is non-empty (Sample guards) and every slot is registered,
|
|
// so s.slots[0].opts is the canonical shared value.
|
|
shared := s.slots[0].opts
|
|
// TODO(pdevine): Before using multi-slot batching with seeded stochastic sampling,
|
|
// make sure each row gets its own per-slot random key instead of sharing
|
|
// slots[0]'s key through one batched categorical op.
|
|
if !shared.usesHistory() {
|
|
return shared, true
|
|
}
|
|
if len(slots) != len(s.slots) {
|
|
return Options{}, false
|
|
}
|
|
for i, slot := range slots {
|
|
if s.slots[i] != slot || slot.historyLen < shared.RepeatLastN {
|
|
return Options{}, false
|
|
}
|
|
}
|
|
return shared, true
|
|
}
|
|
|
|
// sampleTokensUniform runs one fused sampling pass over the whole batch.
|
|
// Reached only when canBatch is true, which lets the pool be used in place
|
|
// with a single PutAlongAxis write-back and no gather.
|
|
func (s *Sampler) sampleTokensUniform(slots []*slotState, opts Options, logits *mlx.Array) *mlx.Array {
|
|
B := len(slots)
|
|
|
|
var hist *mlx.Array
|
|
if opts.usesHistory() {
|
|
hist = s.history
|
|
if s.historyWidth() > opts.RepeatLastN {
|
|
hist = hist.Slice(mlx.Slice(), mlx.Slice(0, opts.RepeatLastN))
|
|
}
|
|
}
|
|
|
|
ctx := &slotCtx{opts: opts, history: hist}
|
|
token := slots[0].sample(ctx, logits)
|
|
if opts.UseSeed && opts.Temperature != 0 {
|
|
// TODO: This only keeps counters aligned; it does not give each slot
|
|
// an independent key for the batched draw.
|
|
for _, slot := range slots[1:] {
|
|
slot.randomCounter++
|
|
}
|
|
}
|
|
|
|
if !opts.usesHistory() {
|
|
return token
|
|
}
|
|
|
|
writeIdxData := make([]int32, B)
|
|
for i, slot := range slots {
|
|
writeIdxData[i] = int32(slot.historyLen % opts.RepeatLastN)
|
|
slot.historyLen++
|
|
}
|
|
writeIdx := mlx.NewArrayInt32(writeIdxData, []int32{int32(B), 1})
|
|
|
|
s.history.Set(s.history.PutAlongAxis(writeIdx, token.ExpandDims(-1), 1))
|
|
return token
|
|
}
|
|
|
|
// sampleTokensSerial samples each slot against its own row of logits.
|
|
func (s *Sampler) sampleTokensSerial(slots []*slotState, logits *mlx.Array) *mlx.Array {
|
|
perSlotTokens := make([]*mlx.Array, len(slots))
|
|
|
|
rowOf := make(map[*slotState]int, len(s.slots))
|
|
for i, slot := range s.slots {
|
|
rowOf[slot] = i
|
|
}
|
|
|
|
for i, slot := range slots {
|
|
row := logits.Slice(mlx.Slice(i, i+1), mlx.Slice())
|
|
|
|
var hist *mlx.Array
|
|
if slot.opts.usesHistory() && slot.historyLen > 0 && s.history != nil {
|
|
poolRow := rowOf[slot]
|
|
fill := min(slot.historyLen, slot.opts.RepeatLastN)
|
|
hist = s.history.Slice(
|
|
mlx.Slice(poolRow, poolRow+1),
|
|
mlx.Slice(0, fill),
|
|
)
|
|
}
|
|
|
|
ctx := &slotCtx{opts: slot.opts, history: hist}
|
|
perSlotTokens[i] = slot.sample(ctx, row)
|
|
}
|
|
|
|
token := mlx.Concatenate(perSlotTokens, 0)
|
|
|
|
if s.history != nil {
|
|
// For each writing slot collect its flat (row-major) pool offset
|
|
// and the call-order position of its token. One PutAlongAxis on a
|
|
// flat view of the pool scatters all writes in a single op.
|
|
flatOffsets := make([]int32, 0, len(slots))
|
|
tokenPos := make([]int32, 0, len(slots))
|
|
for i, slot := range slots {
|
|
if !slot.opts.usesHistory() {
|
|
continue
|
|
}
|
|
ringPos := slot.historyLen % slot.opts.RepeatLastN
|
|
flatOffsets = append(flatOffsets, int32(rowOf[slot]*s.historyWidth()+ringPos))
|
|
tokenPos = append(tokenPos, int32(i))
|
|
slot.historyLen++
|
|
}
|
|
|
|
if len(flatOffsets) > 0 {
|
|
m := len(flatOffsets)
|
|
flatIdx := mlx.NewArrayInt32(flatOffsets, []int32{int32(m), 1})
|
|
writingTokens := token
|
|
if m != len(slots) {
|
|
tokenPosIdx := mlx.NewArrayInt32(tokenPos, []int32{int32(m)})
|
|
writingTokens = token.TakeAxis(tokenPosIdx, 0)
|
|
}
|
|
flatHist := s.history.Reshape(s.history.Dim(0)*s.historyWidth(), 1)
|
|
s.history.Set(flatHist.PutAlongAxis(flatIdx, writingTokens.ExpandDims(-1), 0).Reshape(s.history.Dim(0), s.historyWidth()))
|
|
}
|
|
}
|
|
return token
|
|
}
|
|
|
|
func (slot *slotState) sample(ctx *slotCtx, logits *mlx.Array) *mlx.Array {
|
|
if slot.opts.Temperature == 0 {
|
|
return slot.baseScores(ctx, logits).Argmax(-1, false).AsType(mlx.DTypeInt32)
|
|
}
|
|
return slot.distribution(ctx, logits).SampleWithKey(slot.nextRandomKey())
|
|
}
|
|
|
|
func (slot *slotState) nextRandomKey() *mlx.Array {
|
|
if !slot.opts.UseSeed {
|
|
return nil
|
|
}
|
|
seed := mixSeed(uint64(slot.opts.Seed), slot.randomCounter)
|
|
slot.randomCounter++
|
|
return mlx.RandomKey(seed)
|
|
}
|
|
|
|
const (
|
|
// SplitMix64 constants used to decorrelate nearby (seed, counter) pairs.
|
|
splitMix64Weyl = 0x9e3779b97f4a7c15
|
|
splitMix64Mul1 = 0xbf58476d1ce4e5b9
|
|
splitMix64Mul2 = 0x94d049bb133111eb
|
|
splitMix64Shift1 = 30
|
|
splitMix64Shift2 = 27
|
|
splitMix64FinalShift = 31
|
|
)
|
|
|
|
func mixSeed(seed, counter uint64) uint64 {
|
|
z := seed + splitMix64Weyl*(counter+1)
|
|
z = (z ^ (z >> splitMix64Shift1)) * splitMix64Mul1
|
|
z = (z ^ (z >> splitMix64Shift2)) * splitMix64Mul2
|
|
return z ^ (z >> splitMix64FinalShift)
|
|
}
|
|
|
|
func (slot *slotState) baseScores(ctx *slotCtx, logits *mlx.Array) *mlx.Array {
|
|
scores := logits
|
|
if slot.opts.usesHistory() {
|
|
scores = penalty(ctx, scores)
|
|
}
|
|
return scores
|
|
}
|
|
|
|
func (slot *slotState) distribution(ctx *slotCtx, logits *mlx.Array) Distribution {
|
|
scores := slot.baseScores(ctx, logits)
|
|
if slot.opts.Temperature <= 0 {
|
|
ids := scores.Argmax(-1, false).AsType(mlx.DTypeInt32).ExpandDims(-1)
|
|
probs := mlx.AddScalar(ids.AsType(mlx.DTypeFloat32).Multiply(mlx.FromValue(float32(0))), 1)
|
|
return Distribution{IDs: ids, Probs: probs}
|
|
}
|
|
|
|
vocab := scores.Dim(scores.NumDims() - 1)
|
|
if slot.opts.TopK > 0 && slot.opts.TopK < vocab {
|
|
return sparseDistribution(ctx.opts, scores)
|
|
}
|
|
return denseDistribution(ctx.opts, scores)
|
|
}
|
|
|
|
func sparseDistribution(opts Options, scores *mlx.Array) Distribution {
|
|
ids := scores.Negative().ArgpartitionAxis(opts.TopK-1, -1).Slice(mlx.Slice(), mlx.Slice(0, opts.TopK)).AsType(mlx.DTypeInt32)
|
|
topScores := scores.TakeAlongAxis(ids, -1).AsType(mlx.DTypeFloat32)
|
|
probs := mlx.SoftmaxAxis(mlx.DivScalar(topScores, opts.Temperature), -1, true)
|
|
probs = applyTopPProbs(probs, opts.TopP)
|
|
probs = applyMinPProbs(probs, opts.MinP)
|
|
return Distribution{IDs: ids, Probs: normalizeProbs(probs)}
|
|
}
|
|
|
|
func denseDistribution(opts Options, scores *mlx.Array) Distribution {
|
|
probs := mlx.SoftmaxAxis(mlx.DivScalar(scores.AsType(mlx.DTypeFloat32), opts.Temperature), -1, true)
|
|
probs = applyTopPProbs(probs, opts.TopP)
|
|
probs = applyMinPProbs(probs, opts.MinP)
|
|
return Distribution{Probs: normalizeProbs(probs)}
|
|
}
|
|
|
|
func applyTopPProbs(probs *mlx.Array, topP float32) *mlx.Array {
|
|
if topP <= 0 || topP >= 1 {
|
|
return probs
|
|
}
|
|
order := probs.Negative().ArgsortAxis(-1)
|
|
sorted := probs.TakeAlongAxis(order, -1)
|
|
prevCumProbs := sorted.Cumsum(-1, false, true).Subtract(sorted)
|
|
keep := prevCumProbs.Less(mlx.FromValue(topP))
|
|
filtered := mlx.Where(keep, sorted, mlx.FromValue(float32(0)))
|
|
return mlx.Zeros(probs.DType(), probs.Dims()...).PutAlongAxis(order, filtered, -1)
|
|
}
|
|
|
|
func applyMinPProbs(probs *mlx.Array, minP float32) *mlx.Array {
|
|
if minP <= 0 || minP > 1 {
|
|
return probs
|
|
}
|
|
threshold := mlx.MulScalar(probs.MaxAxis(-1, true), minP)
|
|
return mlx.Where(probs.Less(threshold), mlx.FromValue(float32(0)), probs)
|
|
}
|
|
|
|
func normalizeProbs(probs *mlx.Array) *mlx.Array {
|
|
sum := mlx.Maximum(probs.SumAxis(-1, true), mlx.FromValue(float32(1e-20)))
|
|
return probs.Divide(sum)
|
|
}
|
|
|
|
func logitsFromProbs(probs *mlx.Array) *mlx.Array {
|
|
positive := mlx.Maximum(probs, mlx.FromValue(float32(1e-20)))
|
|
logits := mlx.Log(positive)
|
|
return mlx.Where(probs.LessEqual(mlx.FromValue(float32(0))), mlx.FromValue(float32(math.Inf(-1))), logits)
|
|
}
|
|
|
|
func penalty(ctx *slotCtx, scores *mlx.Array) *mlx.Array {
|
|
tokenIndices := ctx.history
|
|
if tokenIndices == nil {
|
|
return scores
|
|
}
|
|
|
|
if ctx.opts.RepeatPenalty != 1 || ctx.opts.PresencePenalty != 0 {
|
|
adjusted := scores.TakeAlongAxis(tokenIndices, -1)
|
|
if ctx.opts.RepeatPenalty != 1 {
|
|
factor := mlx.Where(
|
|
adjusted.Less(mlx.FromValue(float32(0))),
|
|
mlx.FromValue(ctx.opts.RepeatPenalty),
|
|
mlx.FromValue(1/ctx.opts.RepeatPenalty),
|
|
)
|
|
adjusted = adjusted.Multiply(factor)
|
|
}
|
|
if ctx.opts.PresencePenalty != 0 {
|
|
adjusted = mlx.AddScalar(adjusted, -ctx.opts.PresencePenalty)
|
|
}
|
|
scores = scores.PutAlongAxis(tokenIndices, adjusted, -1)
|
|
}
|
|
|
|
if ctx.opts.FrequencyPenalty != 0 {
|
|
scores = scores.ScatterAddAxis(tokenIndices, mlx.FromValue(-ctx.opts.FrequencyPenalty), -1)
|
|
}
|
|
|
|
return scores
|
|
}
|