716 lines
24 KiB
Go
716 lines
24 KiB
Go
package mlxrunner
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"log/slog"
|
|
"time"
|
|
|
|
"github.com/ollama/ollama/x/mlxrunner/batch"
|
|
"github.com/ollama/ollama/x/mlxrunner/cache"
|
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
|
"github.com/ollama/ollama/x/mlxrunner/model/base"
|
|
sampler "github.com/ollama/ollama/x/mlxrunner/sample"
|
|
)
|
|
|
|
type dflashStats struct {
|
|
iterations int
|
|
drafted int
|
|
accepted int
|
|
mismatches int
|
|
allAccepted int
|
|
batched int
|
|
serial int
|
|
targetDuration time.Duration
|
|
draftDuration time.Duration
|
|
validateDuration time.Duration
|
|
}
|
|
|
|
type dflashDecodeMode string
|
|
|
|
const (
|
|
dflashDecodeDisabled dflashDecodeMode = ""
|
|
dflashDecodeGreedy dflashDecodeMode = "greedy"
|
|
dflashDecodeSample dflashDecodeMode = "sample"
|
|
)
|
|
|
|
func (m dflashDecodeMode) enabled() bool {
|
|
return m != dflashDecodeDisabled
|
|
}
|
|
|
|
func newDFlashTargetCaches(m base.Model) []cache.Cache {
|
|
if cacheFactory, ok := m.(interface{ NewCaches() []cache.Cache }); ok {
|
|
return cacheFactory.NewCaches()
|
|
}
|
|
caches := make([]cache.Cache, m.NumLayers())
|
|
for i := range caches {
|
|
caches[i] = cache.NewKVCache()
|
|
}
|
|
return caches
|
|
}
|
|
|
|
func freeCacheSet(caches []cache.Cache) {
|
|
for _, c := range caches {
|
|
if c != nil {
|
|
c.Free()
|
|
}
|
|
}
|
|
}
|
|
|
|
func (r *Runner) dflashGate(opts sampler.Options) (dflashDecodeMode, string) {
|
|
if r.Draft == nil {
|
|
return dflashDecodeDisabled, "no_draft"
|
|
}
|
|
if _, ok := r.Draft.(base.DFlashDraftModel); !ok {
|
|
return dflashDecodeDisabled, "draft_not_dflash"
|
|
}
|
|
if _, ok := r.Model.(base.DFlashTargetModel); !ok {
|
|
return dflashDecodeDisabled, "target_not_dflash"
|
|
}
|
|
if _, ok := r.Model.(base.MTPEmbeddingModel); !ok {
|
|
return dflashDecodeDisabled, "target_embeddings_missing"
|
|
}
|
|
if opts.Logprobs || opts.TopLogprobs > 0 {
|
|
return dflashDecodeDisabled, "logprobs_requested"
|
|
}
|
|
|
|
if opts.Temperature > 0 || dflashUsesSamplerHistory(opts) {
|
|
return dflashDecodeSample, ""
|
|
}
|
|
|
|
return dflashDecodeGreedy, ""
|
|
}
|
|
|
|
func dflashUsesSamplerHistory(opts sampler.Options) bool {
|
|
if opts.RepeatLastN == 0 {
|
|
return false
|
|
}
|
|
|
|
repeatPenalty := opts.RepeatPenalty
|
|
if repeatPenalty <= 0 {
|
|
repeatPenalty = 1
|
|
}
|
|
return repeatPenalty != 1 || opts.PresencePenalty != 0 || opts.FrequencyPenalty != 0
|
|
}
|
|
|
|
func (r *Runner) runGreedyDFlashDecode(ctx context.Context, request Request, session *cacheSession, targetCaches []cache.Cache, draftCaches []cache.Cache, seed []int32, position *int, started time.Time) error {
|
|
target := r.Model.(base.DFlashTargetModel)
|
|
draft := r.Draft.(base.DFlashDraftModel)
|
|
stats := dflashStats{}
|
|
slog.Info("DFlash greedy decode enabled", "block_size", draft.BlockSize(), "target_layers", draft.TargetLayerIDs())
|
|
|
|
targetForward := func(token *mlx.Array) (*mlx.Array, *mlx.Array) {
|
|
hidden, targetHidden := target.ForwardDFlash(&batch.Batch{
|
|
InputIDs: token,
|
|
SeqOffsets: []int32{int32(*position)},
|
|
SeqQueryLens: []int32{int32(token.Dim(1))},
|
|
}, targetCaches, draft.TargetLayerIDs())
|
|
*position += token.Dim(1)
|
|
return hidden, targetHidden
|
|
}
|
|
|
|
t0 := time.Now()
|
|
hidden, targetHidden := targetForward(mlx.FromValues(seed, 1, len(seed)))
|
|
draft.AppendContext(targetHidden, draftCaches)
|
|
current := sampler.Result{Token: greedyTokenFromLogits(r.lastLogits(hidden))}
|
|
mlx.Pin(current.Arrays()...)
|
|
mlx.Sweep()
|
|
mlx.AsyncEval(current.Arrays()...)
|
|
stats.targetDuration += time.Since(t0)
|
|
defer func() {
|
|
mlx.Unpin(current.Arrays()...)
|
|
}()
|
|
|
|
dec := decoder{tokenizer: r.Tokenizer}
|
|
final := CompletionResponse{Done: true, PromptEvalCount: len(request.Tokens), DoneReason: 1}
|
|
now := started
|
|
generated := 0
|
|
|
|
for generated < request.Options.NumPredict {
|
|
if err := ctx.Err(); err != nil {
|
|
return err
|
|
}
|
|
|
|
if generated == 0 {
|
|
mlx.Eval(current.Arrays()...)
|
|
final.PromptEvalDuration = time.Since(now)
|
|
now = time.Now()
|
|
}
|
|
|
|
done, err := r.emitMTPToken(ctx, request, session, &dec, current, &final)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if !done {
|
|
generated++
|
|
}
|
|
if done || generated >= request.Options.NumPredict {
|
|
break
|
|
}
|
|
|
|
draftCount := min(draft.BlockSize()-1, request.Options.NumPredict-generated)
|
|
if draftCount <= 0 {
|
|
t0 = time.Now()
|
|
hidden, targetHidden := targetForward(mtpTokenInput(current.Token))
|
|
draft.AppendContext(targetHidden, draftCaches)
|
|
stats.targetDuration += time.Since(t0)
|
|
next := sampler.Result{Token: greedyTokenFromLogits(r.lastLogits(hidden))}
|
|
mlx.Pin(next.Arrays()...)
|
|
old := current
|
|
current = next
|
|
mlx.Unpin(old.Arrays()...)
|
|
mlx.Sweep()
|
|
mlx.AsyncEval(current.Arrays()...)
|
|
continue
|
|
}
|
|
|
|
stats.iterations++
|
|
t0 = time.Now()
|
|
draftTokens := r.generateDFlashDrafts(draft, current.Token, draftCaches, draftCount)
|
|
mlx.Pin(draftTokens)
|
|
mlx.Eval(draftTokens)
|
|
stats.draftDuration += time.Since(t0)
|
|
stats.drafted += draftCount
|
|
|
|
t0 = time.Now()
|
|
next, accepted, done, err := r.acceptDFlashDrafts(ctx, request, session, &dec, target, draft, targetCaches, draftCaches, position, current, draftTokens, &final, &generated, &stats)
|
|
stats.validateDuration += time.Since(t0)
|
|
mlx.Unpin(draftTokens)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
stats.accepted += accepted
|
|
if accepted == draftCount {
|
|
stats.allAccepted++
|
|
} else {
|
|
stats.mismatches++
|
|
}
|
|
if done || generated >= request.Options.NumPredict {
|
|
break
|
|
}
|
|
|
|
mlx.Pin(next.Arrays()...)
|
|
old := current
|
|
current = next
|
|
mlx.Unpin(old.Arrays()...)
|
|
mlx.Sweep()
|
|
mlx.AsyncEval(current.Arrays()...)
|
|
|
|
if generated%256 == 0 {
|
|
mlx.ClearCache()
|
|
}
|
|
}
|
|
|
|
final.EvalCount = generated
|
|
final.EvalDuration = time.Since(now)
|
|
acceptance := 0.0
|
|
if stats.drafted > 0 {
|
|
acceptance = float64(stats.accepted) / float64(stats.drafted)
|
|
}
|
|
avgDraft := 0.0
|
|
avgAccepted := 0.0
|
|
if stats.iterations > 0 {
|
|
avgDraft = float64(stats.drafted) / float64(stats.iterations)
|
|
avgAccepted = float64(stats.accepted) / float64(stats.iterations)
|
|
}
|
|
slog.Info("DFlash decode stats", "mode", "greedy", "generated", generated, "drafted", stats.drafted, "accepted", stats.accepted, "acceptance", acceptance, "iterations", stats.iterations, "avg_draft", avgDraft, "avg_accepted", avgAccepted, "batched", stats.batched, "serial", stats.serial, "mismatches", stats.mismatches, "all_accepted", stats.allAccepted, "max_draft", draft.BlockSize()-1, "block_size", draft.BlockSize(), "target_layers", draft.TargetLayerIDs(), "target_duration", stats.targetDuration, "draft_duration", stats.draftDuration, "validate_duration", stats.validateDuration)
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
case request.Responses <- final:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func (r *Runner) runSampleDFlashDecode(ctx context.Context, request Request, session *cacheSession, targetCaches []cache.Cache, draftCaches []cache.Cache, seed []int32, position *int, started time.Time) error {
|
|
target := r.Model.(base.DFlashTargetModel)
|
|
draft := r.Draft.(base.DFlashDraftModel)
|
|
stats := dflashStats{}
|
|
slog.Info("DFlash sample decode enabled",
|
|
"block_size", draft.BlockSize(),
|
|
"target_layers", draft.TargetLayerIDs(),
|
|
"temperature", request.SamplerOpts.Temperature,
|
|
"top_p", request.SamplerOpts.TopP,
|
|
"top_k", request.SamplerOpts.TopK,
|
|
"min_p", request.SamplerOpts.MinP,
|
|
"repeat_penalty", request.SamplerOpts.RepeatPenalty,
|
|
"presence_penalty", request.SamplerOpts.PresencePenalty,
|
|
"frequency_penalty", request.SamplerOpts.FrequencyPenalty,
|
|
)
|
|
|
|
targetForward := func(token *mlx.Array) (*mlx.Array, *mlx.Array) {
|
|
hidden, targetHidden := target.ForwardDFlash(&batch.Batch{
|
|
InputIDs: token,
|
|
SeqOffsets: []int32{int32(*position)},
|
|
SeqQueryLens: []int32{int32(token.Dim(1))},
|
|
}, targetCaches, draft.TargetLayerIDs())
|
|
*position += token.Dim(1)
|
|
return hidden, targetHidden
|
|
}
|
|
|
|
t0 := time.Now()
|
|
hidden, targetHidden := targetForward(mlx.FromValues(seed, 1, len(seed)))
|
|
draft.AppendContext(targetHidden, draftCaches)
|
|
current := r.Sampler.Sample([]int{pipelineSlot}, r.lastLogits(hidden))
|
|
mlx.Pin(current.Arrays()...)
|
|
mlx.Sweep()
|
|
mlx.AsyncEval(current.Arrays()...)
|
|
stats.targetDuration += time.Since(t0)
|
|
defer func() {
|
|
mlx.Unpin(current.Arrays()...)
|
|
}()
|
|
|
|
dec := decoder{tokenizer: r.Tokenizer}
|
|
final := CompletionResponse{Done: true, PromptEvalCount: len(request.Tokens), DoneReason: 1}
|
|
now := started
|
|
generated := 0
|
|
|
|
for generated < request.Options.NumPredict {
|
|
if err := ctx.Err(); err != nil {
|
|
return err
|
|
}
|
|
|
|
if generated == 0 {
|
|
mlx.Eval(current.Arrays()...)
|
|
final.PromptEvalDuration = time.Since(now)
|
|
now = time.Now()
|
|
}
|
|
|
|
done, err := r.emitMTPToken(ctx, request, session, &dec, current, &final)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if !done {
|
|
generated++
|
|
}
|
|
if done || generated >= request.Options.NumPredict {
|
|
break
|
|
}
|
|
|
|
draftCount := min(draft.BlockSize()-1, request.Options.NumPredict-generated)
|
|
if draftCount <= 0 {
|
|
t0 = time.Now()
|
|
hidden, targetHidden := targetForward(mtpTokenInput(current.Token))
|
|
draft.AppendContext(targetHidden, draftCaches)
|
|
stats.targetDuration += time.Since(t0)
|
|
next := r.Sampler.Sample([]int{pipelineSlot}, r.lastLogits(hidden))
|
|
mlx.Pin(next.Arrays()...)
|
|
old := current
|
|
current = next
|
|
mlx.Unpin(old.Arrays()...)
|
|
mlx.Sweep()
|
|
mlx.AsyncEval(current.Arrays()...)
|
|
continue
|
|
}
|
|
|
|
stats.iterations++
|
|
t0 = time.Now()
|
|
candidates := r.generateDFlashDraftCandidates(draft, current.Token, draftCaches, draftCount)
|
|
var candidateArrays []*mlx.Array
|
|
if candidates != nil {
|
|
draftCount = candidates.tokens.Dim(1)
|
|
candidateArrays = candidates.Arrays()
|
|
mlx.Pin(candidateArrays...)
|
|
mlx.Sweep()
|
|
}
|
|
stats.draftDuration += time.Since(t0)
|
|
stats.drafted += draftCount
|
|
|
|
var next sampler.Result
|
|
if draftCount == 0 {
|
|
t0 = time.Now()
|
|
hidden, targetHidden := targetForward(mtpTokenInput(current.Token))
|
|
draft.AppendContext(targetHidden, draftCaches)
|
|
stats.targetDuration += time.Since(t0)
|
|
next = r.Sampler.Sample([]int{pipelineSlot}, r.lastLogits(hidden))
|
|
} else {
|
|
var accepted int
|
|
t0 = time.Now()
|
|
next, accepted, done, err = r.acceptSampleDFlashDrafts(ctx, request, session, &dec, target, draft, targetCaches, draftCaches, position, current, candidates, &final, &generated, &stats)
|
|
stats.validateDuration += time.Since(t0)
|
|
mlx.Unpin(candidateArrays...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
stats.accepted += accepted
|
|
if accepted == draftCount {
|
|
stats.allAccepted++
|
|
} else {
|
|
stats.mismatches++
|
|
}
|
|
if next.Token == nil {
|
|
mlx.Sweep()
|
|
}
|
|
if done || generated >= request.Options.NumPredict {
|
|
break
|
|
}
|
|
}
|
|
|
|
mlx.Pin(next.Arrays()...)
|
|
old := current
|
|
current = next
|
|
mlx.Unpin(old.Arrays()...)
|
|
mlx.Sweep()
|
|
mlx.AsyncEval(current.Arrays()...)
|
|
|
|
if generated%256 == 0 {
|
|
mlx.ClearCache()
|
|
}
|
|
}
|
|
|
|
final.EvalCount = generated
|
|
final.EvalDuration = time.Since(now)
|
|
acceptance := 0.0
|
|
if stats.drafted > 0 {
|
|
acceptance = float64(stats.accepted) / float64(stats.drafted)
|
|
}
|
|
avgDraft := 0.0
|
|
avgAccepted := 0.0
|
|
if stats.iterations > 0 {
|
|
avgDraft = float64(stats.drafted) / float64(stats.iterations)
|
|
avgAccepted = float64(stats.accepted) / float64(stats.iterations)
|
|
}
|
|
slog.Info("DFlash decode stats", "mode", "sample", "generated", generated, "drafted", stats.drafted, "accepted", stats.accepted, "acceptance", acceptance, "iterations", stats.iterations, "avg_draft", avgDraft, "avg_accepted", avgAccepted, "batched", stats.batched, "serial", stats.serial, "mismatches", stats.mismatches, "all_accepted", stats.allAccepted, "max_draft", draft.BlockSize()-1, "block_size", draft.BlockSize(), "target_layers", draft.TargetLayerIDs(), "target_duration", stats.targetDuration, "draft_duration", stats.draftDuration, "validate_duration", stats.validateDuration)
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
case request.Responses <- final:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func (r *Runner) dflashDraftLogits(draft base.DFlashDraftModel, current *mlx.Array, caches []cache.Cache, draftCount int) *mlx.Array {
|
|
blockLen := draftCount + 1
|
|
values := make([]int32, blockLen)
|
|
values[0] = int32(tokenID(current))
|
|
for i := 1; i < blockLen; i++ {
|
|
values[i] = draft.MaskTokenID()
|
|
}
|
|
block := mlx.FromValues(values, 1, blockLen)
|
|
logits := draft.Draft(block, caches)
|
|
return logits.Slice(mlx.Slice(), mlx.Slice(1, blockLen), mlx.Slice())
|
|
}
|
|
|
|
func (r *Runner) generateDFlashDrafts(draft base.DFlashDraftModel, current *mlx.Array, caches []cache.Cache, draftCount int) *mlx.Array {
|
|
logits := r.dflashDraftLogits(draft, current, caches, draftCount)
|
|
return logits.Argmax(-1, false).AsType(mlx.DTypeInt32)
|
|
}
|
|
|
|
type dflashDraftCandidates struct {
|
|
tokens *mlx.Array
|
|
dist sampler.Distribution
|
|
}
|
|
|
|
func (c *dflashDraftCandidates) Arrays() []*mlx.Array {
|
|
if c == nil {
|
|
return nil
|
|
}
|
|
return append([]*mlx.Array{c.tokens}, c.dist.Arrays()...)
|
|
}
|
|
|
|
func (r *Runner) generateDFlashDraftCandidates(draft base.DFlashDraftModel, current *mlx.Array, caches []cache.Cache, draftCount int) *dflashDraftCandidates {
|
|
if draftCount <= 0 {
|
|
return nil
|
|
}
|
|
|
|
logits := r.dflashDraftLogits(draft, current, caches, draftCount)
|
|
draftTokens := make([]*mlx.Array, 0, draftCount)
|
|
draftDists := make([]sampler.Distribution, 0, draftCount)
|
|
var prefix *mlx.Array
|
|
|
|
for i := range draftCount {
|
|
rows := logits.Slice(mlx.Slice(), mlx.Slice(0, i+1), mlx.Slice())
|
|
dist := r.Sampler.Distribution(pipelineSlot, rows, prefix).SliceRows(i, i+1)
|
|
nextToken := mtpTokenVector(r.Sampler.SampleDistribution(pipelineSlot, dist))
|
|
nextInput := mtpTokenInput(nextToken)
|
|
|
|
draftTokens = append(draftTokens, nextInput)
|
|
draftDists = append(draftDists, dist)
|
|
if prefix == nil {
|
|
prefix = nextInput
|
|
} else {
|
|
prefix = prefix.Concatenate(1, nextInput)
|
|
}
|
|
}
|
|
if len(draftTokens) == 0 {
|
|
return nil
|
|
}
|
|
return &dflashDraftCandidates{
|
|
tokens: mlx.Concatenate(draftTokens, 1),
|
|
dist: sampler.ConcatenateDistributions(draftDists),
|
|
}
|
|
}
|
|
|
|
func (r *Runner) acceptDFlashDrafts(ctx context.Context, request Request, session *cacheSession, dec *decoder, target base.DFlashTargetModel, draft base.DFlashDraftModel, targetCaches []cache.Cache, draftCaches []cache.Cache, position *int, current sampler.Result, draftTokens *mlx.Array, final *CompletionResponse, generated *int, stats *dflashStats) (sampler.Result, int, bool, error) {
|
|
specCaches, spec, ok := cache.BeginSpeculation(targetCaches)
|
|
if !ok {
|
|
stats.serial++
|
|
return r.acceptDFlashDraftsSerial(ctx, request, session, dec, target, draft, targetCaches, draftCaches, position, current, draftTokens, final, generated)
|
|
}
|
|
stats.batched++
|
|
return r.acceptDFlashDraftsBatched(ctx, request, session, dec, target, draft, specCaches, spec, draftCaches, position, current, draftTokens, final, generated)
|
|
}
|
|
|
|
func (r *Runner) acceptDFlashDraftsBatched(ctx context.Context, request Request, session *cacheSession, dec *decoder, target base.DFlashTargetModel, draft base.DFlashDraftModel, specCaches []cache.Cache, spec *cache.Speculation, draftCaches []cache.Cache, position *int, current sampler.Result, draftTokens *mlx.Array, final *CompletionResponse, generated *int) (sampler.Result, int, bool, error) {
|
|
before := *position
|
|
draftCount := draftTokens.Dim(1)
|
|
verifyInput := mtpTokenInput(current.Token).Concatenate(1, draftTokens)
|
|
hiddenSeq, targetHiddenSeq := target.ForwardDFlash(&batch.Batch{
|
|
InputIDs: verifyInput,
|
|
SeqOffsets: []int32{int32(before)},
|
|
SeqQueryLens: []int32{int32(verifyInput.Dim(1))},
|
|
}, specCaches, draft.TargetLayerIDs())
|
|
|
|
selectedTokens := r.Model.Unembed(hiddenSeq).Argmax(-1, false).AsType(mlx.DTypeInt32)
|
|
mlx.Eval(draftTokens, selectedTokens)
|
|
|
|
draftIDs := draftTokens.Ints()
|
|
selectedIDs := selectedTokens.Ints()
|
|
if len(selectedIDs) < draftCount+1 {
|
|
spec.Commit(0)
|
|
return sampler.Result{}, 0, false, fmt.Errorf("dflash validation produced %d tokens for %d draft tokens", len(selectedIDs), draftCount)
|
|
}
|
|
|
|
accepted := 0
|
|
for i, id := range draftIDs {
|
|
if selectedIDs[i] != id {
|
|
break
|
|
}
|
|
accepted++
|
|
if r.Tokenizer.IsEOS(int32(id)) {
|
|
break
|
|
}
|
|
}
|
|
|
|
commitN := accepted + 1
|
|
spec.Commit(0)
|
|
|
|
done := false
|
|
for _, id := range draftIDs[:accepted] {
|
|
if *generated >= request.Options.NumPredict {
|
|
done = true
|
|
break
|
|
}
|
|
res := sampler.Result{Token: mlx.FromValues([]int32{int32(id)}, 1)}
|
|
var err error
|
|
done, err = r.emitMTPToken(ctx, request, session, dec, res, final)
|
|
if err != nil {
|
|
return sampler.Result{}, accepted, done, err
|
|
}
|
|
if !done {
|
|
(*generated)++
|
|
}
|
|
if done {
|
|
break
|
|
}
|
|
}
|
|
|
|
spec.Commit(commitN)
|
|
*position = before + commitN
|
|
draft.AppendContext(targetHiddenSeq.Slice(mlx.Slice(), mlx.Slice(0, commitN), mlx.Slice()), draftCaches)
|
|
|
|
if done || *generated >= request.Options.NumPredict {
|
|
return sampler.Result{}, accepted, true, nil
|
|
}
|
|
|
|
nextIndex := accepted
|
|
if nextIndex >= len(selectedIDs) {
|
|
nextIndex = len(selectedIDs) - 1
|
|
}
|
|
return sampler.Result{Token: mlx.FromValues([]int32{int32(selectedIDs[nextIndex])}, 1)}, accepted, false, nil
|
|
}
|
|
|
|
func (r *Runner) acceptDFlashDraftsSerial(ctx context.Context, request Request, session *cacheSession, dec *decoder, target base.DFlashTargetModel, draft base.DFlashDraftModel, targetCaches []cache.Cache, draftCaches []cache.Cache, position *int, current sampler.Result, draftTokens *mlx.Array, final *CompletionResponse, generated *int) (sampler.Result, int, bool, error) {
|
|
targetForward := func(token *mlx.Array) *mlx.Array {
|
|
hidden, targetHidden := target.ForwardDFlash(&batch.Batch{
|
|
InputIDs: token,
|
|
SeqOffsets: []int32{int32(*position)},
|
|
SeqQueryLens: []int32{int32(token.Dim(1))},
|
|
}, targetCaches, draft.TargetLayerIDs())
|
|
*position += token.Dim(1)
|
|
draft.AppendContext(targetHidden, draftCaches)
|
|
return r.lastLogits(hidden)
|
|
}
|
|
|
|
logits := targetForward(mtpTokenInput(current.Token))
|
|
accepted := 0
|
|
for _, id := range draftTokens.Ints() {
|
|
selected := greedyTokenFromLogits(logits)
|
|
mlx.Eval(selected)
|
|
selectedID := tokenID(selected)
|
|
if selectedID != id {
|
|
return sampler.Result{Token: mlx.FromValues([]int32{int32(selectedID)}, 1)}, accepted, false, nil
|
|
}
|
|
|
|
res := sampler.Result{Token: mlx.FromValues([]int32{int32(id)}, 1)}
|
|
done, err := r.emitMTPToken(ctx, request, session, dec, res, final)
|
|
if err != nil {
|
|
return sampler.Result{}, accepted, done, err
|
|
}
|
|
accepted++
|
|
if !done {
|
|
(*generated)++
|
|
}
|
|
if done || *generated >= request.Options.NumPredict {
|
|
return sampler.Result{}, accepted, true, nil
|
|
}
|
|
|
|
logits = targetForward(mtpTokenInput(res.Token))
|
|
}
|
|
|
|
return sampler.Result{Token: greedyTokenFromLogits(logits)}, accepted, false, nil
|
|
}
|
|
|
|
func (r *Runner) acceptSampleDFlashDrafts(ctx context.Context, request Request, session *cacheSession, dec *decoder, target base.DFlashTargetModel, draft base.DFlashDraftModel, targetCaches []cache.Cache, draftCaches []cache.Cache, position *int, current sampler.Result, candidates *dflashDraftCandidates, final *CompletionResponse, generated *int, stats *dflashStats) (sampler.Result, int, bool, error) {
|
|
specCaches, spec, ok := cache.BeginSpeculation(targetCaches)
|
|
if !ok {
|
|
stats.serial++
|
|
return r.acceptSampleDFlashDraftsSerial(ctx, request, session, dec, target, draft, targetCaches, draftCaches, position, current, candidates, final, generated)
|
|
}
|
|
stats.batched++
|
|
return r.acceptSampleDFlashDraftsBatched(ctx, request, session, dec, target, draft, specCaches, spec, draftCaches, position, current, candidates, final, generated)
|
|
}
|
|
|
|
func (r *Runner) acceptSampleDFlashDraftsBatched(ctx context.Context, request Request, session *cacheSession, dec *decoder, target base.DFlashTargetModel, draft base.DFlashDraftModel, specCaches []cache.Cache, spec *cache.Speculation, draftCaches []cache.Cache, position *int, current sampler.Result, candidates *dflashDraftCandidates, final *CompletionResponse, generated *int) (sampler.Result, int, bool, error) {
|
|
before := *position
|
|
draftCount := candidates.tokens.Dim(1)
|
|
verifyInput := mtpTokenInput(current.Token).Concatenate(1, candidates.tokens)
|
|
hiddenSeq, targetHiddenSeq := target.ForwardDFlash(&batch.Batch{
|
|
InputIDs: verifyInput,
|
|
SeqOffsets: []int32{int32(before)},
|
|
SeqQueryLens: []int32{int32(verifyInput.Dim(1))},
|
|
}, specCaches, draft.TargetLayerIDs())
|
|
|
|
targetDist := r.Sampler.Distribution(pipelineSlot, r.Model.Unembed(hiddenSeq), candidates.tokens)
|
|
draftDist := candidates.dist
|
|
acceptedMask := r.mtpSampleAcceptedMask(targetDist.SliceRows(0, draftCount), draftDist, candidates.tokens)
|
|
mlx.Eval(candidates.tokens, acceptedMask)
|
|
|
|
draftIDs := candidates.tokens.Ints()
|
|
acceptedFlags := acceptedMask.Ints()
|
|
accepted := 0
|
|
for _, ok := range acceptedFlags {
|
|
if ok == 0 {
|
|
break
|
|
}
|
|
accepted++
|
|
}
|
|
if accepted > draftCount {
|
|
spec.Commit(0)
|
|
return sampler.Result{}, 0, false, fmt.Errorf("dflash sample validation accepted %d tokens for %d draft tokens", accepted, draftCount)
|
|
}
|
|
|
|
commitIDs := make([]int32, 0, accepted+1)
|
|
done := false
|
|
for i, id := range draftIDs[:accepted] {
|
|
commitIDs = append(commitIDs, int32(id))
|
|
if r.Tokenizer.IsEOS(int32(id)) {
|
|
done = true
|
|
accepted = i + 1
|
|
commitIDs = commitIDs[:accepted]
|
|
break
|
|
}
|
|
}
|
|
|
|
commitN := accepted + 1
|
|
spec.Commit(0)
|
|
|
|
for _, id := range draftIDs[:accepted] {
|
|
if *generated >= request.Options.NumPredict {
|
|
done = true
|
|
break
|
|
}
|
|
res := sampler.Result{Token: mlx.FromValues([]int32{int32(id)}, 1)}
|
|
var err error
|
|
done, err = r.emitMTPToken(ctx, request, session, dec, res, final)
|
|
if err != nil {
|
|
return sampler.Result{}, accepted, done, err
|
|
}
|
|
if !done {
|
|
(*generated)++
|
|
}
|
|
if done {
|
|
break
|
|
}
|
|
}
|
|
|
|
spec.Commit(commitN)
|
|
*position = before + commitN
|
|
draft.AppendContext(targetHiddenSeq.Slice(mlx.Slice(), mlx.Slice(0, commitN), mlx.Slice()), draftCaches)
|
|
|
|
if done || *generated >= request.Options.NumPredict {
|
|
r.Sampler.Commit(pipelineSlot, commitIDs)
|
|
return sampler.Result{}, accepted, true, nil
|
|
}
|
|
|
|
var nextToken *mlx.Array
|
|
if accepted == draftCount {
|
|
nextToken = r.mtpSampleTokenAt(targetDist, draftCount)
|
|
} else {
|
|
nextToken = r.mtpSampleResidualToken(targetDist, draftDist, accepted)
|
|
}
|
|
mlx.Eval(nextToken)
|
|
nextID := int32(tokenID(nextToken))
|
|
commitIDs = append(commitIDs, nextID)
|
|
r.Sampler.Commit(pipelineSlot, commitIDs)
|
|
|
|
return sampler.Result{Token: nextToken}, accepted, false, nil
|
|
}
|
|
|
|
func (r *Runner) acceptSampleDFlashDraftsSerial(ctx context.Context, request Request, session *cacheSession, dec *decoder, target base.DFlashTargetModel, draft base.DFlashDraftModel, targetCaches []cache.Cache, draftCaches []cache.Cache, position *int, current sampler.Result, candidates *dflashDraftCandidates, final *CompletionResponse, generated *int) (sampler.Result, int, bool, error) {
|
|
targetForward := func(token *mlx.Array) *mlx.Array {
|
|
hidden, targetHidden := target.ForwardDFlash(&batch.Batch{
|
|
InputIDs: token,
|
|
SeqOffsets: []int32{int32(*position)},
|
|
SeqQueryLens: []int32{int32(token.Dim(1))},
|
|
}, targetCaches, draft.TargetLayerIDs())
|
|
*position += token.Dim(1)
|
|
draft.AppendContext(targetHidden, draftCaches)
|
|
return r.lastLogits(hidden)
|
|
}
|
|
|
|
mlx.Eval(candidates.tokens)
|
|
draftIDs := candidates.tokens.Ints()
|
|
logits := targetForward(mtpTokenInput(current.Token))
|
|
accepted := 0
|
|
|
|
for i, id := range draftIDs {
|
|
targetDist := r.Sampler.Distribution(pipelineSlot, logits, nil)
|
|
draftDist := candidates.dist.SliceRows(i, i+1)
|
|
draftToken := mlx.FromValues([]int32{int32(id)}, 1)
|
|
acceptedMask := r.mtpSampleAcceptedMask(targetDist, draftDist, draftToken)
|
|
mlx.Eval(acceptedMask)
|
|
|
|
if acceptedMask.Ints()[0] == 0 {
|
|
nextToken := mtpTokenVector(r.Sampler.SampleDistribution(pipelineSlot, targetDist.ResidualAgainst(draftDist)))
|
|
mlx.Eval(nextToken)
|
|
r.Sampler.Commit(pipelineSlot, []int32{int32(tokenID(nextToken))})
|
|
return sampler.Result{Token: nextToken}, accepted, false, nil
|
|
}
|
|
|
|
accepted++
|
|
r.Sampler.Commit(pipelineSlot, []int32{int32(id)})
|
|
res := sampler.Result{Token: mlx.FromValues([]int32{int32(id)}, 1)}
|
|
done, err := r.emitMTPToken(ctx, request, session, dec, res, final)
|
|
if err != nil {
|
|
return sampler.Result{}, accepted, done, err
|
|
}
|
|
if !done {
|
|
(*generated)++
|
|
}
|
|
if done || *generated >= request.Options.NumPredict {
|
|
return sampler.Result{}, accepted, true, nil
|
|
}
|
|
|
|
logits = targetForward(mtpTokenInput(res.Token))
|
|
}
|
|
|
|
targetDist := r.Sampler.Distribution(pipelineSlot, logits, nil)
|
|
nextToken := mtpTokenVector(r.Sampler.SampleDistribution(pipelineSlot, targetDist))
|
|
mlx.Eval(nextToken)
|
|
r.Sampler.Commit(pipelineSlot, []int32{int32(tokenID(nextToken))})
|
|
return sampler.Result{Token: nextToken}, accepted, false, nil
|
|
}
|