package mlxrunner import ( "context" "fmt" "log/slog" "time" "github.com/ollama/ollama/x/mlxrunner/batch" "github.com/ollama/ollama/x/mlxrunner/cache" "github.com/ollama/ollama/x/mlxrunner/mlx" "github.com/ollama/ollama/x/mlxrunner/model/base" sampler "github.com/ollama/ollama/x/mlxrunner/sample" ) type dflashStats struct { iterations int drafted int accepted int mismatches int allAccepted int batched int serial int targetDuration time.Duration draftDuration time.Duration validateDuration time.Duration } type dflashDecodeMode string const ( dflashDecodeDisabled dflashDecodeMode = "" dflashDecodeGreedy dflashDecodeMode = "greedy" dflashDecodeSample dflashDecodeMode = "sample" ) func (m dflashDecodeMode) enabled() bool { return m != dflashDecodeDisabled } func newDFlashTargetCaches(m base.Model) []cache.Cache { if cacheFactory, ok := m.(interface{ NewCaches() []cache.Cache }); ok { return cacheFactory.NewCaches() } caches := make([]cache.Cache, m.NumLayers()) for i := range caches { caches[i] = cache.NewKVCache() } return caches } func freeCacheSet(caches []cache.Cache) { for _, c := range caches { if c != nil { c.Free() } } } func (r *Runner) dflashGate(opts sampler.Options) (dflashDecodeMode, string) { if r.Draft == nil { return dflashDecodeDisabled, "no_draft" } if _, ok := r.Draft.(base.DFlashDraftModel); !ok { return dflashDecodeDisabled, "draft_not_dflash" } if _, ok := r.Model.(base.DFlashTargetModel); !ok { return dflashDecodeDisabled, "target_not_dflash" } if _, ok := r.Model.(base.MTPEmbeddingModel); !ok { return dflashDecodeDisabled, "target_embeddings_missing" } if opts.Logprobs || opts.TopLogprobs > 0 { return dflashDecodeDisabled, "logprobs_requested" } if opts.Temperature > 0 || dflashUsesSamplerHistory(opts) { return dflashDecodeSample, "" } return dflashDecodeGreedy, "" } func dflashUsesSamplerHistory(opts sampler.Options) bool { if opts.RepeatLastN == 0 { return false } repeatPenalty := opts.RepeatPenalty if repeatPenalty <= 0 { repeatPenalty = 1 } return repeatPenalty != 1 || opts.PresencePenalty != 0 || opts.FrequencyPenalty != 0 } func (r *Runner) runGreedyDFlashDecode(ctx context.Context, request Request, session *cacheSession, targetCaches []cache.Cache, draftCaches []cache.Cache, seed []int32, position *int, started time.Time) error { target := r.Model.(base.DFlashTargetModel) draft := r.Draft.(base.DFlashDraftModel) stats := dflashStats{} slog.Info("DFlash greedy decode enabled", "block_size", draft.BlockSize(), "target_layers", draft.TargetLayerIDs()) targetForward := func(token *mlx.Array) (*mlx.Array, *mlx.Array) { hidden, targetHidden := target.ForwardDFlash(&batch.Batch{ InputIDs: token, SeqOffsets: []int32{int32(*position)}, SeqQueryLens: []int32{int32(token.Dim(1))}, }, targetCaches, draft.TargetLayerIDs()) *position += token.Dim(1) return hidden, targetHidden } t0 := time.Now() hidden, targetHidden := targetForward(mlx.FromValues(seed, 1, len(seed))) draft.AppendContext(targetHidden, draftCaches) current := sampler.Result{Token: greedyTokenFromLogits(r.lastLogits(hidden))} mlx.Pin(current.Arrays()...) mlx.Sweep() mlx.AsyncEval(current.Arrays()...) stats.targetDuration += time.Since(t0) defer func() { mlx.Unpin(current.Arrays()...) }() dec := decoder{tokenizer: r.Tokenizer} final := CompletionResponse{Done: true, PromptEvalCount: len(request.Tokens), DoneReason: 1} now := started generated := 0 for generated < request.Options.NumPredict { if err := ctx.Err(); err != nil { return err } if generated == 0 { mlx.Eval(current.Arrays()...) final.PromptEvalDuration = time.Since(now) now = time.Now() } done, err := r.emitMTPToken(ctx, request, session, &dec, current, &final) if err != nil { return err } if !done { generated++ } if done || generated >= request.Options.NumPredict { break } draftCount := min(draft.BlockSize()-1, request.Options.NumPredict-generated) if draftCount <= 0 { t0 = time.Now() hidden, targetHidden := targetForward(mtpTokenInput(current.Token)) draft.AppendContext(targetHidden, draftCaches) stats.targetDuration += time.Since(t0) next := sampler.Result{Token: greedyTokenFromLogits(r.lastLogits(hidden))} mlx.Pin(next.Arrays()...) old := current current = next mlx.Unpin(old.Arrays()...) mlx.Sweep() mlx.AsyncEval(current.Arrays()...) continue } stats.iterations++ t0 = time.Now() draftTokens := r.generateDFlashDrafts(draft, current.Token, draftCaches, draftCount) mlx.Pin(draftTokens) mlx.Eval(draftTokens) stats.draftDuration += time.Since(t0) stats.drafted += draftCount t0 = time.Now() next, accepted, done, err := r.acceptDFlashDrafts(ctx, request, session, &dec, target, draft, targetCaches, draftCaches, position, current, draftTokens, &final, &generated, &stats) stats.validateDuration += time.Since(t0) mlx.Unpin(draftTokens) if err != nil { return err } stats.accepted += accepted if accepted == draftCount { stats.allAccepted++ } else { stats.mismatches++ } if done || generated >= request.Options.NumPredict { break } mlx.Pin(next.Arrays()...) old := current current = next mlx.Unpin(old.Arrays()...) mlx.Sweep() mlx.AsyncEval(current.Arrays()...) if generated%256 == 0 { mlx.ClearCache() } } final.EvalCount = generated final.EvalDuration = time.Since(now) acceptance := 0.0 if stats.drafted > 0 { acceptance = float64(stats.accepted) / float64(stats.drafted) } avgDraft := 0.0 avgAccepted := 0.0 if stats.iterations > 0 { avgDraft = float64(stats.drafted) / float64(stats.iterations) avgAccepted = float64(stats.accepted) / float64(stats.iterations) } slog.Info("DFlash decode stats", "mode", "greedy", "generated", generated, "drafted", stats.drafted, "accepted", stats.accepted, "acceptance", acceptance, "iterations", stats.iterations, "avg_draft", avgDraft, "avg_accepted", avgAccepted, "batched", stats.batched, "serial", stats.serial, "mismatches", stats.mismatches, "all_accepted", stats.allAccepted, "max_draft", draft.BlockSize()-1, "block_size", draft.BlockSize(), "target_layers", draft.TargetLayerIDs(), "target_duration", stats.targetDuration, "draft_duration", stats.draftDuration, "validate_duration", stats.validateDuration) select { case <-ctx.Done(): return ctx.Err() case request.Responses <- final: return nil } } func (r *Runner) runSampleDFlashDecode(ctx context.Context, request Request, session *cacheSession, targetCaches []cache.Cache, draftCaches []cache.Cache, seed []int32, position *int, started time.Time) error { target := r.Model.(base.DFlashTargetModel) draft := r.Draft.(base.DFlashDraftModel) stats := dflashStats{} slog.Info("DFlash sample decode enabled", "block_size", draft.BlockSize(), "target_layers", draft.TargetLayerIDs(), "temperature", request.SamplerOpts.Temperature, "top_p", request.SamplerOpts.TopP, "top_k", request.SamplerOpts.TopK, "min_p", request.SamplerOpts.MinP, "repeat_penalty", request.SamplerOpts.RepeatPenalty, "presence_penalty", request.SamplerOpts.PresencePenalty, "frequency_penalty", request.SamplerOpts.FrequencyPenalty, ) targetForward := func(token *mlx.Array) (*mlx.Array, *mlx.Array) { hidden, targetHidden := target.ForwardDFlash(&batch.Batch{ InputIDs: token, SeqOffsets: []int32{int32(*position)}, SeqQueryLens: []int32{int32(token.Dim(1))}, }, targetCaches, draft.TargetLayerIDs()) *position += token.Dim(1) return hidden, targetHidden } t0 := time.Now() hidden, targetHidden := targetForward(mlx.FromValues(seed, 1, len(seed))) draft.AppendContext(targetHidden, draftCaches) current := r.Sampler.Sample([]int{pipelineSlot}, r.lastLogits(hidden)) mlx.Pin(current.Arrays()...) mlx.Sweep() mlx.AsyncEval(current.Arrays()...) stats.targetDuration += time.Since(t0) defer func() { mlx.Unpin(current.Arrays()...) }() dec := decoder{tokenizer: r.Tokenizer} final := CompletionResponse{Done: true, PromptEvalCount: len(request.Tokens), DoneReason: 1} now := started generated := 0 for generated < request.Options.NumPredict { if err := ctx.Err(); err != nil { return err } if generated == 0 { mlx.Eval(current.Arrays()...) final.PromptEvalDuration = time.Since(now) now = time.Now() } done, err := r.emitMTPToken(ctx, request, session, &dec, current, &final) if err != nil { return err } if !done { generated++ } if done || generated >= request.Options.NumPredict { break } draftCount := min(draft.BlockSize()-1, request.Options.NumPredict-generated) if draftCount <= 0 { t0 = time.Now() hidden, targetHidden := targetForward(mtpTokenInput(current.Token)) draft.AppendContext(targetHidden, draftCaches) stats.targetDuration += time.Since(t0) next := r.Sampler.Sample([]int{pipelineSlot}, r.lastLogits(hidden)) mlx.Pin(next.Arrays()...) old := current current = next mlx.Unpin(old.Arrays()...) mlx.Sweep() mlx.AsyncEval(current.Arrays()...) continue } stats.iterations++ t0 = time.Now() candidates := r.generateDFlashDraftCandidates(draft, current.Token, draftCaches, draftCount) var candidateArrays []*mlx.Array if candidates != nil { draftCount = candidates.tokens.Dim(1) candidateArrays = candidates.Arrays() mlx.Pin(candidateArrays...) mlx.Sweep() } stats.draftDuration += time.Since(t0) stats.drafted += draftCount var next sampler.Result if draftCount == 0 { t0 = time.Now() hidden, targetHidden := targetForward(mtpTokenInput(current.Token)) draft.AppendContext(targetHidden, draftCaches) stats.targetDuration += time.Since(t0) next = r.Sampler.Sample([]int{pipelineSlot}, r.lastLogits(hidden)) } else { var accepted int t0 = time.Now() next, accepted, done, err = r.acceptSampleDFlashDrafts(ctx, request, session, &dec, target, draft, targetCaches, draftCaches, position, current, candidates, &final, &generated, &stats) stats.validateDuration += time.Since(t0) mlx.Unpin(candidateArrays...) if err != nil { return err } stats.accepted += accepted if accepted == draftCount { stats.allAccepted++ } else { stats.mismatches++ } if next.Token == nil { mlx.Sweep() } if done || generated >= request.Options.NumPredict { break } } mlx.Pin(next.Arrays()...) old := current current = next mlx.Unpin(old.Arrays()...) mlx.Sweep() mlx.AsyncEval(current.Arrays()...) if generated%256 == 0 { mlx.ClearCache() } } final.EvalCount = generated final.EvalDuration = time.Since(now) acceptance := 0.0 if stats.drafted > 0 { acceptance = float64(stats.accepted) / float64(stats.drafted) } avgDraft := 0.0 avgAccepted := 0.0 if stats.iterations > 0 { avgDraft = float64(stats.drafted) / float64(stats.iterations) avgAccepted = float64(stats.accepted) / float64(stats.iterations) } slog.Info("DFlash decode stats", "mode", "sample", "generated", generated, "drafted", stats.drafted, "accepted", stats.accepted, "acceptance", acceptance, "iterations", stats.iterations, "avg_draft", avgDraft, "avg_accepted", avgAccepted, "batched", stats.batched, "serial", stats.serial, "mismatches", stats.mismatches, "all_accepted", stats.allAccepted, "max_draft", draft.BlockSize()-1, "block_size", draft.BlockSize(), "target_layers", draft.TargetLayerIDs(), "target_duration", stats.targetDuration, "draft_duration", stats.draftDuration, "validate_duration", stats.validateDuration) select { case <-ctx.Done(): return ctx.Err() case request.Responses <- final: return nil } } func (r *Runner) dflashDraftLogits(draft base.DFlashDraftModel, current *mlx.Array, caches []cache.Cache, draftCount int) *mlx.Array { blockLen := draftCount + 1 values := make([]int32, blockLen) values[0] = int32(tokenID(current)) for i := 1; i < blockLen; i++ { values[i] = draft.MaskTokenID() } block := mlx.FromValues(values, 1, blockLen) logits := draft.Draft(block, caches) return logits.Slice(mlx.Slice(), mlx.Slice(1, blockLen), mlx.Slice()) } func (r *Runner) generateDFlashDrafts(draft base.DFlashDraftModel, current *mlx.Array, caches []cache.Cache, draftCount int) *mlx.Array { logits := r.dflashDraftLogits(draft, current, caches, draftCount) return logits.Argmax(-1, false).AsType(mlx.DTypeInt32) } type dflashDraftCandidates struct { tokens *mlx.Array dist sampler.Distribution } func (c *dflashDraftCandidates) Arrays() []*mlx.Array { if c == nil { return nil } return append([]*mlx.Array{c.tokens}, c.dist.Arrays()...) } func (r *Runner) generateDFlashDraftCandidates(draft base.DFlashDraftModel, current *mlx.Array, caches []cache.Cache, draftCount int) *dflashDraftCandidates { if draftCount <= 0 { return nil } logits := r.dflashDraftLogits(draft, current, caches, draftCount) draftTokens := make([]*mlx.Array, 0, draftCount) draftDists := make([]sampler.Distribution, 0, draftCount) var prefix *mlx.Array for i := range draftCount { rows := logits.Slice(mlx.Slice(), mlx.Slice(0, i+1), mlx.Slice()) dist := r.Sampler.Distribution(pipelineSlot, rows, prefix).SliceRows(i, i+1) nextToken := mtpTokenVector(r.Sampler.SampleDistribution(pipelineSlot, dist)) nextInput := mtpTokenInput(nextToken) draftTokens = append(draftTokens, nextInput) draftDists = append(draftDists, dist) if prefix == nil { prefix = nextInput } else { prefix = prefix.Concatenate(1, nextInput) } } if len(draftTokens) == 0 { return nil } return &dflashDraftCandidates{ tokens: mlx.Concatenate(draftTokens, 1), dist: sampler.ConcatenateDistributions(draftDists), } } func (r *Runner) acceptDFlashDrafts(ctx context.Context, request Request, session *cacheSession, dec *decoder, target base.DFlashTargetModel, draft base.DFlashDraftModel, targetCaches []cache.Cache, draftCaches []cache.Cache, position *int, current sampler.Result, draftTokens *mlx.Array, final *CompletionResponse, generated *int, stats *dflashStats) (sampler.Result, int, bool, error) { specCaches, spec, ok := cache.BeginSpeculation(targetCaches) if !ok { stats.serial++ return r.acceptDFlashDraftsSerial(ctx, request, session, dec, target, draft, targetCaches, draftCaches, position, current, draftTokens, final, generated) } stats.batched++ return r.acceptDFlashDraftsBatched(ctx, request, session, dec, target, draft, specCaches, spec, draftCaches, position, current, draftTokens, final, generated) } func (r *Runner) acceptDFlashDraftsBatched(ctx context.Context, request Request, session *cacheSession, dec *decoder, target base.DFlashTargetModel, draft base.DFlashDraftModel, specCaches []cache.Cache, spec *cache.Speculation, draftCaches []cache.Cache, position *int, current sampler.Result, draftTokens *mlx.Array, final *CompletionResponse, generated *int) (sampler.Result, int, bool, error) { before := *position draftCount := draftTokens.Dim(1) verifyInput := mtpTokenInput(current.Token).Concatenate(1, draftTokens) hiddenSeq, targetHiddenSeq := target.ForwardDFlash(&batch.Batch{ InputIDs: verifyInput, SeqOffsets: []int32{int32(before)}, SeqQueryLens: []int32{int32(verifyInput.Dim(1))}, }, specCaches, draft.TargetLayerIDs()) selectedTokens := r.Model.Unembed(hiddenSeq).Argmax(-1, false).AsType(mlx.DTypeInt32) mlx.Eval(draftTokens, selectedTokens) draftIDs := draftTokens.Ints() selectedIDs := selectedTokens.Ints() if len(selectedIDs) < draftCount+1 { spec.Commit(0) return sampler.Result{}, 0, false, fmt.Errorf("dflash validation produced %d tokens for %d draft tokens", len(selectedIDs), draftCount) } accepted := 0 for i, id := range draftIDs { if selectedIDs[i] != id { break } accepted++ if r.Tokenizer.IsEOS(int32(id)) { break } } commitN := accepted + 1 spec.Commit(0) done := false for _, id := range draftIDs[:accepted] { if *generated >= request.Options.NumPredict { done = true break } res := sampler.Result{Token: mlx.FromValues([]int32{int32(id)}, 1)} var err error done, err = r.emitMTPToken(ctx, request, session, dec, res, final) if err != nil { return sampler.Result{}, accepted, done, err } if !done { (*generated)++ } if done { break } } spec.Commit(commitN) *position = before + commitN draft.AppendContext(targetHiddenSeq.Slice(mlx.Slice(), mlx.Slice(0, commitN), mlx.Slice()), draftCaches) if done || *generated >= request.Options.NumPredict { return sampler.Result{}, accepted, true, nil } nextIndex := accepted if nextIndex >= len(selectedIDs) { nextIndex = len(selectedIDs) - 1 } return sampler.Result{Token: mlx.FromValues([]int32{int32(selectedIDs[nextIndex])}, 1)}, accepted, false, nil } func (r *Runner) acceptDFlashDraftsSerial(ctx context.Context, request Request, session *cacheSession, dec *decoder, target base.DFlashTargetModel, draft base.DFlashDraftModel, targetCaches []cache.Cache, draftCaches []cache.Cache, position *int, current sampler.Result, draftTokens *mlx.Array, final *CompletionResponse, generated *int) (sampler.Result, int, bool, error) { targetForward := func(token *mlx.Array) *mlx.Array { hidden, targetHidden := target.ForwardDFlash(&batch.Batch{ InputIDs: token, SeqOffsets: []int32{int32(*position)}, SeqQueryLens: []int32{int32(token.Dim(1))}, }, targetCaches, draft.TargetLayerIDs()) *position += token.Dim(1) draft.AppendContext(targetHidden, draftCaches) return r.lastLogits(hidden) } logits := targetForward(mtpTokenInput(current.Token)) accepted := 0 for _, id := range draftTokens.Ints() { selected := greedyTokenFromLogits(logits) mlx.Eval(selected) selectedID := tokenID(selected) if selectedID != id { return sampler.Result{Token: mlx.FromValues([]int32{int32(selectedID)}, 1)}, accepted, false, nil } res := sampler.Result{Token: mlx.FromValues([]int32{int32(id)}, 1)} done, err := r.emitMTPToken(ctx, request, session, dec, res, final) if err != nil { return sampler.Result{}, accepted, done, err } accepted++ if !done { (*generated)++ } if done || *generated >= request.Options.NumPredict { return sampler.Result{}, accepted, true, nil } logits = targetForward(mtpTokenInput(res.Token)) } return sampler.Result{Token: greedyTokenFromLogits(logits)}, accepted, false, nil } func (r *Runner) acceptSampleDFlashDrafts(ctx context.Context, request Request, session *cacheSession, dec *decoder, target base.DFlashTargetModel, draft base.DFlashDraftModel, targetCaches []cache.Cache, draftCaches []cache.Cache, position *int, current sampler.Result, candidates *dflashDraftCandidates, final *CompletionResponse, generated *int, stats *dflashStats) (sampler.Result, int, bool, error) { specCaches, spec, ok := cache.BeginSpeculation(targetCaches) if !ok { stats.serial++ return r.acceptSampleDFlashDraftsSerial(ctx, request, session, dec, target, draft, targetCaches, draftCaches, position, current, candidates, final, generated) } stats.batched++ return r.acceptSampleDFlashDraftsBatched(ctx, request, session, dec, target, draft, specCaches, spec, draftCaches, position, current, candidates, final, generated) } func (r *Runner) acceptSampleDFlashDraftsBatched(ctx context.Context, request Request, session *cacheSession, dec *decoder, target base.DFlashTargetModel, draft base.DFlashDraftModel, specCaches []cache.Cache, spec *cache.Speculation, draftCaches []cache.Cache, position *int, current sampler.Result, candidates *dflashDraftCandidates, final *CompletionResponse, generated *int) (sampler.Result, int, bool, error) { before := *position draftCount := candidates.tokens.Dim(1) verifyInput := mtpTokenInput(current.Token).Concatenate(1, candidates.tokens) hiddenSeq, targetHiddenSeq := target.ForwardDFlash(&batch.Batch{ InputIDs: verifyInput, SeqOffsets: []int32{int32(before)}, SeqQueryLens: []int32{int32(verifyInput.Dim(1))}, }, specCaches, draft.TargetLayerIDs()) targetDist := r.Sampler.Distribution(pipelineSlot, r.Model.Unembed(hiddenSeq), candidates.tokens) draftDist := candidates.dist acceptedMask := r.mtpSampleAcceptedMask(targetDist.SliceRows(0, draftCount), draftDist, candidates.tokens) mlx.Eval(candidates.tokens, acceptedMask) draftIDs := candidates.tokens.Ints() acceptedFlags := acceptedMask.Ints() accepted := 0 for _, ok := range acceptedFlags { if ok == 0 { break } accepted++ } if accepted > draftCount { spec.Commit(0) return sampler.Result{}, 0, false, fmt.Errorf("dflash sample validation accepted %d tokens for %d draft tokens", accepted, draftCount) } commitIDs := make([]int32, 0, accepted+1) done := false for i, id := range draftIDs[:accepted] { commitIDs = append(commitIDs, int32(id)) if r.Tokenizer.IsEOS(int32(id)) { done = true accepted = i + 1 commitIDs = commitIDs[:accepted] break } } commitN := accepted + 1 spec.Commit(0) for _, id := range draftIDs[:accepted] { if *generated >= request.Options.NumPredict { done = true break } res := sampler.Result{Token: mlx.FromValues([]int32{int32(id)}, 1)} var err error done, err = r.emitMTPToken(ctx, request, session, dec, res, final) if err != nil { return sampler.Result{}, accepted, done, err } if !done { (*generated)++ } if done { break } } spec.Commit(commitN) *position = before + commitN draft.AppendContext(targetHiddenSeq.Slice(mlx.Slice(), mlx.Slice(0, commitN), mlx.Slice()), draftCaches) if done || *generated >= request.Options.NumPredict { r.Sampler.Commit(pipelineSlot, commitIDs) return sampler.Result{}, accepted, true, nil } var nextToken *mlx.Array if accepted == draftCount { nextToken = r.mtpSampleTokenAt(targetDist, draftCount) } else { nextToken = r.mtpSampleResidualToken(targetDist, draftDist, accepted) } mlx.Eval(nextToken) nextID := int32(tokenID(nextToken)) commitIDs = append(commitIDs, nextID) r.Sampler.Commit(pipelineSlot, commitIDs) return sampler.Result{Token: nextToken}, accepted, false, nil } func (r *Runner) acceptSampleDFlashDraftsSerial(ctx context.Context, request Request, session *cacheSession, dec *decoder, target base.DFlashTargetModel, draft base.DFlashDraftModel, targetCaches []cache.Cache, draftCaches []cache.Cache, position *int, current sampler.Result, candidates *dflashDraftCandidates, final *CompletionResponse, generated *int) (sampler.Result, int, bool, error) { targetForward := func(token *mlx.Array) *mlx.Array { hidden, targetHidden := target.ForwardDFlash(&batch.Batch{ InputIDs: token, SeqOffsets: []int32{int32(*position)}, SeqQueryLens: []int32{int32(token.Dim(1))}, }, targetCaches, draft.TargetLayerIDs()) *position += token.Dim(1) draft.AppendContext(targetHidden, draftCaches) return r.lastLogits(hidden) } mlx.Eval(candidates.tokens) draftIDs := candidates.tokens.Ints() logits := targetForward(mtpTokenInput(current.Token)) accepted := 0 for i, id := range draftIDs { targetDist := r.Sampler.Distribution(pipelineSlot, logits, nil) draftDist := candidates.dist.SliceRows(i, i+1) draftToken := mlx.FromValues([]int32{int32(id)}, 1) acceptedMask := r.mtpSampleAcceptedMask(targetDist, draftDist, draftToken) mlx.Eval(acceptedMask) if acceptedMask.Ints()[0] == 0 { nextToken := mtpTokenVector(r.Sampler.SampleDistribution(pipelineSlot, targetDist.ResidualAgainst(draftDist))) mlx.Eval(nextToken) r.Sampler.Commit(pipelineSlot, []int32{int32(tokenID(nextToken))}) return sampler.Result{Token: nextToken}, accepted, false, nil } accepted++ r.Sampler.Commit(pipelineSlot, []int32{int32(id)}) res := sampler.Result{Token: mlx.FromValues([]int32{int32(id)}, 1)} done, err := r.emitMTPToken(ctx, request, session, dec, res, final) if err != nil { return sampler.Result{}, accepted, done, err } if !done { (*generated)++ } if done || *generated >= request.Options.NumPredict { return sampler.Result{}, accepted, true, nil } logits = targetForward(mtpTokenInput(res.Token)) } targetDist := r.Sampler.Distribution(pipelineSlot, logits, nil) nextToken := mtpTokenVector(r.Sampler.SampleDistribution(pipelineSlot, targetDist)) mlx.Eval(nextToken) r.Sampler.Commit(pipelineSlot, []int32{int32(tokenID(nextToken))}) return sampler.Result{Token: nextToken}, accepted, false, nil }