Files
ollama/x/models/dflash/dflash.go
2026-05-22 17:19:10 +08:00

440 lines
15 KiB
Go

// Package dflash implements DFlash block-diffusion draft models for MLX.
package dflash
import (
"encoding/json"
"fmt"
"math"
"sort"
"strings"
"github.com/ollama/ollama/x/mlxrunner/batch"
"github.com/ollama/ollama/x/mlxrunner/cache"
"github.com/ollama/ollama/x/mlxrunner/mlx"
"github.com/ollama/ollama/x/mlxrunner/model"
"github.com/ollama/ollama/x/mlxrunner/model/base"
"github.com/ollama/ollama/x/models/nn"
)
func init() {
base.RegisterDraft("DFlashDraftModel", newModel)
base.RegisterDraft("dflash", newModel)
}
var _ base.DFlashDraftModel = (*Model)(nil)
type Config struct {
HiddenSize int32 `json:"hidden_size"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumKeyValueHeads int32 `json:"num_key_value_heads"`
HeadDim int32 `json:"head_dim"`
IntermediateSize int32 `json:"intermediate_size"`
VocabSize int32 `json:"vocab_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
RopeScaling *nn.RopeParameters `json:"rope_scaling"`
RopeParameters *nn.RopeParameters `json:"rope_parameters"`
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
BlockSizeValue int32 `json:"block_size"`
NumTargetLayers int32 `json:"num_target_layers"`
LayerTypes []string `json:"layer_types"`
SlidingWindow int32 `json:"sliding_window"`
FinalLogitSoftcapping *float32 `json:"final_logit_softcapping"`
DFlash struct {
TargetLayerIDs []int `json:"target_layer_ids"`
MaskTokenID int32 `json:"mask_token_id"`
} `json:"dflash_config"`
QuantGroupSize int `json:"-"`
QuantBits int `json:"-"`
QuantMode string `json:"-"`
TensorQuant map[string]*model.TensorQuantInfo `json:"-"`
Scale float32 `json:"-"`
RopeFreqs *mlx.Array `json:"-"`
RopeScale float32 `json:"-"`
}
type Model struct {
FC nn.LinearLayer
HiddenNorm *nn.RMSNorm
Layers []*Layer
Norm *nn.RMSNorm
target base.Model
targetEmbeddings base.MTPEmbeddingModel
tensorPrefix string
*Config
}
type Layer struct {
Attention *Attention
MLP *MLP
InputNorm *nn.RMSNorm
PostAttentionNorm *nn.RMSNorm
}
type Attention struct {
QProj nn.LinearLayer
KProj nn.LinearLayer
VProj nn.LinearLayer
OProj nn.LinearLayer
QNorm *nn.RMSNorm
KNorm *nn.RMSNorm
Sliding bool
}
type MLP struct {
GateProj nn.LinearLayer
UpProj nn.LinearLayer
DownProj nn.LinearLayer
}
func parseConfig(data []byte) (Config, error) {
var cfg Config
if err := json.Unmarshal(data, &cfg); err != nil {
return Config{}, fmt.Errorf("parse dflash config: %w", err)
}
if cfg.HiddenSize <= 0 {
return Config{}, fmt.Errorf("invalid hidden_size: %d", cfg.HiddenSize)
}
if cfg.NumHiddenLayers <= 0 {
return Config{}, fmt.Errorf("invalid num_hidden_layers: %d", cfg.NumHiddenLayers)
}
if cfg.NumAttentionHeads <= 0 {
return Config{}, fmt.Errorf("invalid num_attention_heads: %d", cfg.NumAttentionHeads)
}
if cfg.NumKeyValueHeads <= 0 {
cfg.NumKeyValueHeads = cfg.NumAttentionHeads
}
if cfg.HeadDim <= 0 {
if cfg.HiddenSize%cfg.NumAttentionHeads != 0 {
return Config{}, fmt.Errorf("hidden_size (%d) must be divisible by num_attention_heads (%d)", cfg.HiddenSize, cfg.NumAttentionHeads)
}
cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
}
if cfg.RMSNormEps == 0 {
cfg.RMSNormEps = 1e-6
}
if cfg.RopeTheta == 0 {
ropeParams := cfg.RopeParameters
if ropeParams == nil {
ropeParams = cfg.RopeScaling
}
if ropeParams != nil && ropeParams.RopeTheta > 0 {
cfg.RopeTheta = ropeParams.RopeTheta
}
}
if cfg.RopeTheta == 0 {
cfg.RopeTheta = 1000000
}
if cfg.BlockSizeValue <= 0 {
return Config{}, fmt.Errorf("invalid block_size: %d", cfg.BlockSizeValue)
}
if len(cfg.DFlash.TargetLayerIDs) == 0 {
return Config{}, fmt.Errorf("dflash_config.target_layer_ids is required")
}
if !sort.IntsAreSorted(cfg.DFlash.TargetLayerIDs) {
return Config{}, fmt.Errorf("dflash_config.target_layer_ids must be sorted")
}
if len(cfg.LayerTypes) == 0 {
cfg.LayerTypes = make([]string, cfg.NumHiddenLayers)
for i := range cfg.LayerTypes {
cfg.LayerTypes[i] = "full_attention"
}
}
if len(cfg.LayerTypes) != int(cfg.NumHiddenLayers) {
return Config{}, fmt.Errorf("layer_types length %d does not match num_hidden_layers %d", len(cfg.LayerTypes), cfg.NumHiddenLayers)
}
for i, typ := range cfg.LayerTypes {
switch strings.ToLower(typ) {
case "full_attention":
case "sliding_attention":
if cfg.SlidingWindow <= 0 {
return Config{}, fmt.Errorf("layer %d uses sliding_attention but sliding_window is not set", i)
}
default:
return Config{}, fmt.Errorf("unsupported layer type %q", typ)
}
}
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
cfg.RopeScale = 1
ropeParams := cfg.RopeParameters
if ropeParams == nil {
ropeParams = cfg.RopeScaling
}
if ropeParams != nil && strings.EqualFold(ropeParams.TypeName(), "yarn") {
cfg.RopeFreqs, cfg.RopeScale = nn.BuildYarnRopeFreqs(int(cfg.HeadDim), cfg.RopeTheta, ropeParams)
}
return cfg, nil
}
func newModel(root *model.Root, target base.Model) (base.DraftModel, error) {
if root == nil || root.Draft == nil {
return nil, fmt.Errorf("draft metadata missing")
}
configPath := root.Draft.Config
if configPath == "" {
configPath = "draft/config.json"
}
configData, err := root.Manifest.ReadConfig(configPath)
if err != nil {
return nil, fmt.Errorf("load dflash config: %w", err)
}
cfg, err := parseConfig(configData)
if err != nil {
return nil, err
}
if target.NumLayers() < int(cfg.NumTargetLayers) {
return nil, fmt.Errorf("dflash target expects %d layers, target has %d", cfg.NumTargetLayers, target.NumLayers())
}
for _, layerID := range cfg.DFlash.TargetLayerIDs {
if layerID < 0 || layerID >= target.NumLayers() {
return nil, fmt.Errorf("dflash target layer id %d out of range for %d-layer target", layerID, target.NumLayers())
}
}
targetEmbeddings, ok := target.(base.MTPEmbeddingModel)
if !ok {
return nil, fmt.Errorf("dflash draft requires target token embeddings, got %T", target)
}
if qt := root.QuantType(); qt != "" {
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams(qt)
if gs := root.GroupSize(); gs > 0 {
cfg.QuantGroupSize = gs
}
} else {
cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams("")
}
cfg.TensorQuant = root.AllTensorQuant()
prefix := root.Draft.TensorPrefix
if prefix == "" {
prefix = "draft."
}
m := &Model{
Config: &cfg,
Layers: make([]*Layer, cfg.NumHiddenLayers),
target: target,
targetEmbeddings: targetEmbeddings,
tensorPrefix: prefix,
}
for i := range m.Layers {
m.Layers[i] = &Layer{Attention: &Attention{}, MLP: &MLP{}}
}
return m, nil
}
func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
prefix := m.tensorPrefix
linears := model.NewLinearFactory(tensors, m.QuantGroupSize, m.QuantBits, m.QuantMode, m.TensorQuant)
m.FC = linears.Make(prefix + "fc")
if m.FC == nil {
return fmt.Errorf("missing dflash fc weight")
}
if w := tensors[prefix+"hidden_norm.weight"]; w != nil {
m.HiddenNorm = nn.NewRMSNorm(w, m.RMSNormEps)
}
if w := tensors[prefix+"norm.weight"]; w != nil {
m.Norm = nn.NewRMSNorm(w, m.RMSNormEps)
}
if m.HiddenNorm == nil || m.Norm == nil {
return fmt.Errorf("missing dflash norm weights")
}
for i := range m.NumHiddenLayers {
layerPrefix := fmt.Sprintf("%slayers.%d", prefix, i)
layer := &Layer{
Attention: &Attention{Sliding: strings.ToLower(m.LayerTypes[i]) == "sliding_attention"},
MLP: &MLP{
GateProj: linears.Make(layerPrefix + ".mlp.gate_proj"),
UpProj: linears.Make(layerPrefix + ".mlp.up_proj"),
DownProj: linears.Make(layerPrefix + ".mlp.down_proj"),
},
}
if w := tensors[layerPrefix+".input_layernorm.weight"]; w != nil {
layer.InputNorm = nn.NewRMSNorm(w, m.RMSNormEps)
}
if w := tensors[layerPrefix+".post_attention_layernorm.weight"]; w != nil {
layer.PostAttentionNorm = nn.NewRMSNorm(w, m.RMSNormEps)
}
layer.Attention.QProj = linears.Make(layerPrefix + ".self_attn.q_proj")
layer.Attention.KProj = linears.Make(layerPrefix + ".self_attn.k_proj")
layer.Attention.VProj = linears.Make(layerPrefix + ".self_attn.v_proj")
layer.Attention.OProj = linears.Make(layerPrefix + ".self_attn.o_proj")
if w := tensors[layerPrefix+".self_attn.q_norm.weight"]; w != nil {
layer.Attention.QNorm = nn.NewRMSNorm(w, m.RMSNormEps)
}
if w := tensors[layerPrefix+".self_attn.k_norm.weight"]; w != nil {
layer.Attention.KNorm = nn.NewRMSNorm(w, m.RMSNormEps)
}
if layer.InputNorm == nil || layer.PostAttentionNorm == nil {
return fmt.Errorf("dflash layer %d: missing layer norms", i)
}
if layer.Attention.QProj == nil || layer.Attention.KProj == nil || layer.Attention.VProj == nil || layer.Attention.OProj == nil {
return fmt.Errorf("dflash layer %d: missing attention projections", i)
}
if layer.Attention.QNorm == nil || layer.Attention.KNorm == nil {
return fmt.Errorf("dflash layer %d: missing attention q/k norms", i)
}
if layer.MLP.GateProj == nil || layer.MLP.UpProj == nil || layer.MLP.DownProj == nil {
return fmt.Errorf("dflash layer %d: missing mlp projections", i)
}
m.Layers[i] = layer
}
return nil
}
func (m *Model) TargetLayerIDs() []int {
return append([]int(nil), m.DFlash.TargetLayerIDs...)
}
func (m *Model) BlockSize() int {
return int(m.BlockSizeValue)
}
func (m *Model) MaskTokenID() int32 {
return m.DFlash.MaskTokenID
}
func (m *Model) NewCaches() []cache.Cache {
caches := make([]cache.Cache, len(m.Layers))
for i, typ := range m.LayerTypes {
if strings.ToLower(typ) == "sliding_attention" {
// RotatingKVCache.View returns maxSize-1 tokens so assistant
// paths can append the current query. DFlash uses that same
// view for target context, so allocate one extra slot to expose
// the draft model's sliding_window-1 context tokens.
caches[i] = cache.NewRotatingKVCache(int(m.SlidingWindow))
} else {
caches[i] = cache.NewKVCache()
}
}
return caches
}
func (m *Model) AppendContext(targetHidden *mlx.Array, caches []cache.Cache) {
if targetHidden == nil || targetHidden.Dim(1) == 0 {
return
}
hCtx := m.HiddenNorm.Forward(m.FC.Forward(targetHidden), m.RMSNormEps)
offset := int32(0)
if len(caches) > 0 && caches[0] != nil {
offset = int32(caches[0].Offset())
}
b := &batch.Batch{
InputIDs: mlx.Zeros(mlx.DTypeInt32, targetHidden.Dim(0), targetHidden.Dim(1)),
SeqOffsets: []int32{offset},
SeqQueryLens: []int32{int32(targetHidden.Dim(1))},
}
positions := mlx.FromValues(b.SeqOffsets, len(b.SeqOffsets))
for i, layer := range m.Layers {
if i >= len(caches) || caches[i] == nil {
continue
}
layer.Attention.AppendContext(hCtx, b, positions, caches[i], m.Config)
}
}
func (m *Model) Draft(inputIDs *mlx.Array, caches []cache.Cache) *mlx.Array {
dims := inputIDs.Dims()
B, L := int32(dims[0]), int32(dims[1])
offset := int32(0)
if len(caches) > 0 && caches[0] != nil {
offset = int32(caches[0].Offset())
}
b := &batch.Batch{
InputIDs: inputIDs,
SeqOffsets: []int32{offset},
SeqQueryLens: []int32{L},
}
positions := mlx.FromValues(b.SeqOffsets, len(b.SeqOffsets))
h := m.targetEmbeddings.TokenEmbeddings(inputIDs)
for i, layer := range m.Layers {
var c cache.Cache
if i < len(caches) {
c = caches[i]
}
h = layer.Forward(h, b, c, positions, B, L, m.Config)
}
logits := m.target.Unembed(m.Norm.Forward(h, m.RMSNormEps))
if m.FinalLogitSoftcapping != nil {
cap := mlx.FromValue(*m.FinalLogitSoftcapping).AsType(logits.DType())
logits = mlx.LogitSoftcap(logits, cap)
}
return logits
}
func (l *Layer) Forward(x *mlx.Array, b *batch.Batch, c cache.Cache, positions *mlx.Array, B, L int32, cfg *Config) *mlx.Array {
h := mlx.Add(x, l.Attention.Forward(l.InputNorm.Forward(x, cfg.RMSNormEps), b, c, positions, B, L, cfg))
return mlx.Add(h, l.MLP.Forward(l.PostAttentionNorm.Forward(h, cfg.RMSNormEps)))
}
func (a *Attention) AppendContext(xCtx *mlx.Array, b *batch.Batch, positions *mlx.Array, c cache.Cache, cfg *Config) {
B, L := int32(xCtx.Dim(0)), int32(xCtx.Dim(1))
k := a.KProj.Forward(xCtx)
v := a.VProj.Forward(xCtx)
k = mlx.Reshape(k, B, L, cfg.NumKeyValueHeads, cfg.HeadDim)
v = mlx.Reshape(v, B, L, cfg.NumKeyValueHeads, cfg.HeadDim)
k = a.KNorm.Forward(k, cfg.RMSNormEps)
k = mlx.Transpose(k, 0, 2, 1, 3)
v = mlx.Transpose(v, 0, 2, 1, 3)
k = nn.ScaleRotaryPart(mlx.RoPEWithFreqs(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, positions, cfg.RopeFreqs), int(cfg.HeadDim), cfg.RopeScale)
c.(cache.Attention).Update(b, k, v)
}
func (a *Attention) Forward(x *mlx.Array, b *batch.Batch, c cache.Cache, positions *mlx.Array, B, L int32, cfg *Config) *mlx.Array {
q := a.QProj.Forward(x)
propK := a.KProj.Forward(x)
propV := a.VProj.Forward(x)
q = mlx.Reshape(q, B, L, cfg.NumAttentionHeads, cfg.HeadDim)
propK = mlx.Reshape(propK, B, L, cfg.NumKeyValueHeads, cfg.HeadDim)
propV = mlx.Reshape(propV, B, L, cfg.NumKeyValueHeads, cfg.HeadDim)
q = a.QNorm.Forward(q, cfg.RMSNormEps)
propK = a.KNorm.Forward(propK, cfg.RMSNormEps)
q = mlx.Transpose(q, 0, 2, 1, 3)
propK = mlx.Transpose(propK, 0, 2, 1, 3)
propV = mlx.Transpose(propV, 0, 2, 1, 3)
q = nn.ScaleRotaryPart(mlx.RoPEWithFreqs(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, positions, cfg.RopeFreqs), int(cfg.HeadDim), cfg.RopeScale)
propK = nn.ScaleRotaryPart(mlx.RoPEWithFreqs(propK, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, positions, cfg.RopeFreqs), int(cfg.HeadDim), cfg.RopeScale)
k, v := propK, propV
if viewer, ok := c.(cache.Viewer); ok {
if history := viewer.View(b); history != nil {
k = history.K().Concatenate(2, propK)
v = history.V().Concatenate(2, propV)
}
}
mask := nn.AttentionMask{}
if a.Sliding {
mask = nn.CausalMask()
if int(cfg.SlidingWindow) > 0 && k.Dim(2) > int(cfg.SlidingWindow) {
mask = mask.Intersect(nn.SlidingWindowMask(b, k.Dim(2), int(cfg.SlidingWindow), q.DType()))
}
}
out := nn.ScaledDotProductAttention(b, q, cfg.Scale, nn.WithKV(k, v, []int32{int32(k.Dim(2))}), nn.WithMask(mask))
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
return a.OProj.Forward(out)
}
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
return m.DownProj.Forward(mlx.SwiGLU(m.GateProj.Forward(x), m.UpProj.Forward(x)))
}