// Package dflash implements DFlash block-diffusion draft models for MLX. package dflash import ( "encoding/json" "fmt" "math" "sort" "strings" "github.com/ollama/ollama/x/mlxrunner/batch" "github.com/ollama/ollama/x/mlxrunner/cache" "github.com/ollama/ollama/x/mlxrunner/mlx" "github.com/ollama/ollama/x/mlxrunner/model" "github.com/ollama/ollama/x/mlxrunner/model/base" "github.com/ollama/ollama/x/models/nn" ) func init() { base.RegisterDraft("DFlashDraftModel", newModel) base.RegisterDraft("dflash", newModel) } var _ base.DFlashDraftModel = (*Model)(nil) type Config struct { HiddenSize int32 `json:"hidden_size"` NumHiddenLayers int32 `json:"num_hidden_layers"` NumAttentionHeads int32 `json:"num_attention_heads"` NumKeyValueHeads int32 `json:"num_key_value_heads"` HeadDim int32 `json:"head_dim"` IntermediateSize int32 `json:"intermediate_size"` VocabSize int32 `json:"vocab_size"` RMSNormEps float32 `json:"rms_norm_eps"` RopeTheta float32 `json:"rope_theta"` RopeScaling *nn.RopeParameters `json:"rope_scaling"` RopeParameters *nn.RopeParameters `json:"rope_parameters"` MaxPositionEmbeddings int32 `json:"max_position_embeddings"` BlockSizeValue int32 `json:"block_size"` NumTargetLayers int32 `json:"num_target_layers"` LayerTypes []string `json:"layer_types"` SlidingWindow int32 `json:"sliding_window"` FinalLogitSoftcapping *float32 `json:"final_logit_softcapping"` DFlash struct { TargetLayerIDs []int `json:"target_layer_ids"` MaskTokenID int32 `json:"mask_token_id"` } `json:"dflash_config"` QuantGroupSize int `json:"-"` QuantBits int `json:"-"` QuantMode string `json:"-"` TensorQuant map[string]*model.TensorQuantInfo `json:"-"` Scale float32 `json:"-"` RopeFreqs *mlx.Array `json:"-"` RopeScale float32 `json:"-"` } type Model struct { FC nn.LinearLayer HiddenNorm *nn.RMSNorm Layers []*Layer Norm *nn.RMSNorm target base.Model targetEmbeddings base.MTPEmbeddingModel tensorPrefix string *Config } type Layer struct { Attention *Attention MLP *MLP InputNorm *nn.RMSNorm PostAttentionNorm *nn.RMSNorm } type Attention struct { QProj nn.LinearLayer KProj nn.LinearLayer VProj nn.LinearLayer OProj nn.LinearLayer QNorm *nn.RMSNorm KNorm *nn.RMSNorm Sliding bool } type MLP struct { GateProj nn.LinearLayer UpProj nn.LinearLayer DownProj nn.LinearLayer } func parseConfig(data []byte) (Config, error) { var cfg Config if err := json.Unmarshal(data, &cfg); err != nil { return Config{}, fmt.Errorf("parse dflash config: %w", err) } if cfg.HiddenSize <= 0 { return Config{}, fmt.Errorf("invalid hidden_size: %d", cfg.HiddenSize) } if cfg.NumHiddenLayers <= 0 { return Config{}, fmt.Errorf("invalid num_hidden_layers: %d", cfg.NumHiddenLayers) } if cfg.NumAttentionHeads <= 0 { return Config{}, fmt.Errorf("invalid num_attention_heads: %d", cfg.NumAttentionHeads) } if cfg.NumKeyValueHeads <= 0 { cfg.NumKeyValueHeads = cfg.NumAttentionHeads } if cfg.HeadDim <= 0 { if cfg.HiddenSize%cfg.NumAttentionHeads != 0 { return Config{}, fmt.Errorf("hidden_size (%d) must be divisible by num_attention_heads (%d)", cfg.HiddenSize, cfg.NumAttentionHeads) } cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads } if cfg.RMSNormEps == 0 { cfg.RMSNormEps = 1e-6 } if cfg.RopeTheta == 0 { ropeParams := cfg.RopeParameters if ropeParams == nil { ropeParams = cfg.RopeScaling } if ropeParams != nil && ropeParams.RopeTheta > 0 { cfg.RopeTheta = ropeParams.RopeTheta } } if cfg.RopeTheta == 0 { cfg.RopeTheta = 1000000 } if cfg.BlockSizeValue <= 0 { return Config{}, fmt.Errorf("invalid block_size: %d", cfg.BlockSizeValue) } if len(cfg.DFlash.TargetLayerIDs) == 0 { return Config{}, fmt.Errorf("dflash_config.target_layer_ids is required") } if !sort.IntsAreSorted(cfg.DFlash.TargetLayerIDs) { return Config{}, fmt.Errorf("dflash_config.target_layer_ids must be sorted") } if len(cfg.LayerTypes) == 0 { cfg.LayerTypes = make([]string, cfg.NumHiddenLayers) for i := range cfg.LayerTypes { cfg.LayerTypes[i] = "full_attention" } } if len(cfg.LayerTypes) != int(cfg.NumHiddenLayers) { return Config{}, fmt.Errorf("layer_types length %d does not match num_hidden_layers %d", len(cfg.LayerTypes), cfg.NumHiddenLayers) } for i, typ := range cfg.LayerTypes { switch strings.ToLower(typ) { case "full_attention": case "sliding_attention": if cfg.SlidingWindow <= 0 { return Config{}, fmt.Errorf("layer %d uses sliding_attention but sliding_window is not set", i) } default: return Config{}, fmt.Errorf("unsupported layer type %q", typ) } } cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim))) cfg.RopeScale = 1 ropeParams := cfg.RopeParameters if ropeParams == nil { ropeParams = cfg.RopeScaling } if ropeParams != nil && strings.EqualFold(ropeParams.TypeName(), "yarn") { cfg.RopeFreqs, cfg.RopeScale = nn.BuildYarnRopeFreqs(int(cfg.HeadDim), cfg.RopeTheta, ropeParams) } return cfg, nil } func newModel(root *model.Root, target base.Model) (base.DraftModel, error) { if root == nil || root.Draft == nil { return nil, fmt.Errorf("draft metadata missing") } configPath := root.Draft.Config if configPath == "" { configPath = "draft/config.json" } configData, err := root.Manifest.ReadConfig(configPath) if err != nil { return nil, fmt.Errorf("load dflash config: %w", err) } cfg, err := parseConfig(configData) if err != nil { return nil, err } if target.NumLayers() < int(cfg.NumTargetLayers) { return nil, fmt.Errorf("dflash target expects %d layers, target has %d", cfg.NumTargetLayers, target.NumLayers()) } for _, layerID := range cfg.DFlash.TargetLayerIDs { if layerID < 0 || layerID >= target.NumLayers() { return nil, fmt.Errorf("dflash target layer id %d out of range for %d-layer target", layerID, target.NumLayers()) } } targetEmbeddings, ok := target.(base.MTPEmbeddingModel) if !ok { return nil, fmt.Errorf("dflash draft requires target token embeddings, got %T", target) } if qt := root.QuantType(); qt != "" { cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams(qt) if gs := root.GroupSize(); gs > 0 { cfg.QuantGroupSize = gs } } else { cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams("") } cfg.TensorQuant = root.AllTensorQuant() prefix := root.Draft.TensorPrefix if prefix == "" { prefix = "draft." } m := &Model{ Config: &cfg, Layers: make([]*Layer, cfg.NumHiddenLayers), target: target, targetEmbeddings: targetEmbeddings, tensorPrefix: prefix, } for i := range m.Layers { m.Layers[i] = &Layer{Attention: &Attention{}, MLP: &MLP{}} } return m, nil } func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error { prefix := m.tensorPrefix linears := model.NewLinearFactory(tensors, m.QuantGroupSize, m.QuantBits, m.QuantMode, m.TensorQuant) m.FC = linears.Make(prefix + "fc") if m.FC == nil { return fmt.Errorf("missing dflash fc weight") } if w := tensors[prefix+"hidden_norm.weight"]; w != nil { m.HiddenNorm = nn.NewRMSNorm(w, m.RMSNormEps) } if w := tensors[prefix+"norm.weight"]; w != nil { m.Norm = nn.NewRMSNorm(w, m.RMSNormEps) } if m.HiddenNorm == nil || m.Norm == nil { return fmt.Errorf("missing dflash norm weights") } for i := range m.NumHiddenLayers { layerPrefix := fmt.Sprintf("%slayers.%d", prefix, i) layer := &Layer{ Attention: &Attention{Sliding: strings.ToLower(m.LayerTypes[i]) == "sliding_attention"}, MLP: &MLP{ GateProj: linears.Make(layerPrefix + ".mlp.gate_proj"), UpProj: linears.Make(layerPrefix + ".mlp.up_proj"), DownProj: linears.Make(layerPrefix + ".mlp.down_proj"), }, } if w := tensors[layerPrefix+".input_layernorm.weight"]; w != nil { layer.InputNorm = nn.NewRMSNorm(w, m.RMSNormEps) } if w := tensors[layerPrefix+".post_attention_layernorm.weight"]; w != nil { layer.PostAttentionNorm = nn.NewRMSNorm(w, m.RMSNormEps) } layer.Attention.QProj = linears.Make(layerPrefix + ".self_attn.q_proj") layer.Attention.KProj = linears.Make(layerPrefix + ".self_attn.k_proj") layer.Attention.VProj = linears.Make(layerPrefix + ".self_attn.v_proj") layer.Attention.OProj = linears.Make(layerPrefix + ".self_attn.o_proj") if w := tensors[layerPrefix+".self_attn.q_norm.weight"]; w != nil { layer.Attention.QNorm = nn.NewRMSNorm(w, m.RMSNormEps) } if w := tensors[layerPrefix+".self_attn.k_norm.weight"]; w != nil { layer.Attention.KNorm = nn.NewRMSNorm(w, m.RMSNormEps) } if layer.InputNorm == nil || layer.PostAttentionNorm == nil { return fmt.Errorf("dflash layer %d: missing layer norms", i) } if layer.Attention.QProj == nil || layer.Attention.KProj == nil || layer.Attention.VProj == nil || layer.Attention.OProj == nil { return fmt.Errorf("dflash layer %d: missing attention projections", i) } if layer.Attention.QNorm == nil || layer.Attention.KNorm == nil { return fmt.Errorf("dflash layer %d: missing attention q/k norms", i) } if layer.MLP.GateProj == nil || layer.MLP.UpProj == nil || layer.MLP.DownProj == nil { return fmt.Errorf("dflash layer %d: missing mlp projections", i) } m.Layers[i] = layer } return nil } func (m *Model) TargetLayerIDs() []int { return append([]int(nil), m.DFlash.TargetLayerIDs...) } func (m *Model) BlockSize() int { return int(m.BlockSizeValue) } func (m *Model) MaskTokenID() int32 { return m.DFlash.MaskTokenID } func (m *Model) NewCaches() []cache.Cache { caches := make([]cache.Cache, len(m.Layers)) for i, typ := range m.LayerTypes { if strings.ToLower(typ) == "sliding_attention" { // RotatingKVCache.View returns maxSize-1 tokens so assistant // paths can append the current query. DFlash uses that same // view for target context, so allocate one extra slot to expose // the draft model's sliding_window-1 context tokens. caches[i] = cache.NewRotatingKVCache(int(m.SlidingWindow)) } else { caches[i] = cache.NewKVCache() } } return caches } func (m *Model) AppendContext(targetHidden *mlx.Array, caches []cache.Cache) { if targetHidden == nil || targetHidden.Dim(1) == 0 { return } hCtx := m.HiddenNorm.Forward(m.FC.Forward(targetHidden), m.RMSNormEps) offset := int32(0) if len(caches) > 0 && caches[0] != nil { offset = int32(caches[0].Offset()) } b := &batch.Batch{ InputIDs: mlx.Zeros(mlx.DTypeInt32, targetHidden.Dim(0), targetHidden.Dim(1)), SeqOffsets: []int32{offset}, SeqQueryLens: []int32{int32(targetHidden.Dim(1))}, } positions := mlx.FromValues(b.SeqOffsets, len(b.SeqOffsets)) for i, layer := range m.Layers { if i >= len(caches) || caches[i] == nil { continue } layer.Attention.AppendContext(hCtx, b, positions, caches[i], m.Config) } } func (m *Model) Draft(inputIDs *mlx.Array, caches []cache.Cache) *mlx.Array { dims := inputIDs.Dims() B, L := int32(dims[0]), int32(dims[1]) offset := int32(0) if len(caches) > 0 && caches[0] != nil { offset = int32(caches[0].Offset()) } b := &batch.Batch{ InputIDs: inputIDs, SeqOffsets: []int32{offset}, SeqQueryLens: []int32{L}, } positions := mlx.FromValues(b.SeqOffsets, len(b.SeqOffsets)) h := m.targetEmbeddings.TokenEmbeddings(inputIDs) for i, layer := range m.Layers { var c cache.Cache if i < len(caches) { c = caches[i] } h = layer.Forward(h, b, c, positions, B, L, m.Config) } logits := m.target.Unembed(m.Norm.Forward(h, m.RMSNormEps)) if m.FinalLogitSoftcapping != nil { cap := mlx.FromValue(*m.FinalLogitSoftcapping).AsType(logits.DType()) logits = mlx.LogitSoftcap(logits, cap) } return logits } func (l *Layer) Forward(x *mlx.Array, b *batch.Batch, c cache.Cache, positions *mlx.Array, B, L int32, cfg *Config) *mlx.Array { h := mlx.Add(x, l.Attention.Forward(l.InputNorm.Forward(x, cfg.RMSNormEps), b, c, positions, B, L, cfg)) return mlx.Add(h, l.MLP.Forward(l.PostAttentionNorm.Forward(h, cfg.RMSNormEps))) } func (a *Attention) AppendContext(xCtx *mlx.Array, b *batch.Batch, positions *mlx.Array, c cache.Cache, cfg *Config) { B, L := int32(xCtx.Dim(0)), int32(xCtx.Dim(1)) k := a.KProj.Forward(xCtx) v := a.VProj.Forward(xCtx) k = mlx.Reshape(k, B, L, cfg.NumKeyValueHeads, cfg.HeadDim) v = mlx.Reshape(v, B, L, cfg.NumKeyValueHeads, cfg.HeadDim) k = a.KNorm.Forward(k, cfg.RMSNormEps) k = mlx.Transpose(k, 0, 2, 1, 3) v = mlx.Transpose(v, 0, 2, 1, 3) k = nn.ScaleRotaryPart(mlx.RoPEWithFreqs(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, positions, cfg.RopeFreqs), int(cfg.HeadDim), cfg.RopeScale) c.(cache.Attention).Update(b, k, v) } func (a *Attention) Forward(x *mlx.Array, b *batch.Batch, c cache.Cache, positions *mlx.Array, B, L int32, cfg *Config) *mlx.Array { q := a.QProj.Forward(x) propK := a.KProj.Forward(x) propV := a.VProj.Forward(x) q = mlx.Reshape(q, B, L, cfg.NumAttentionHeads, cfg.HeadDim) propK = mlx.Reshape(propK, B, L, cfg.NumKeyValueHeads, cfg.HeadDim) propV = mlx.Reshape(propV, B, L, cfg.NumKeyValueHeads, cfg.HeadDim) q = a.QNorm.Forward(q, cfg.RMSNormEps) propK = a.KNorm.Forward(propK, cfg.RMSNormEps) q = mlx.Transpose(q, 0, 2, 1, 3) propK = mlx.Transpose(propK, 0, 2, 1, 3) propV = mlx.Transpose(propV, 0, 2, 1, 3) q = nn.ScaleRotaryPart(mlx.RoPEWithFreqs(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, positions, cfg.RopeFreqs), int(cfg.HeadDim), cfg.RopeScale) propK = nn.ScaleRotaryPart(mlx.RoPEWithFreqs(propK, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, positions, cfg.RopeFreqs), int(cfg.HeadDim), cfg.RopeScale) k, v := propK, propV if viewer, ok := c.(cache.Viewer); ok { if history := viewer.View(b); history != nil { k = history.K().Concatenate(2, propK) v = history.V().Concatenate(2, propV) } } mask := nn.AttentionMask{} if a.Sliding { mask = nn.CausalMask() if int(cfg.SlidingWindow) > 0 && k.Dim(2) > int(cfg.SlidingWindow) { mask = mask.Intersect(nn.SlidingWindowMask(b, k.Dim(2), int(cfg.SlidingWindow), q.DType())) } } out := nn.ScaledDotProductAttention(b, q, cfg.Scale, nn.WithKV(k, v, []int32{int32(k.Dim(2))}), nn.WithMask(mask)) out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim) return a.OProj.Forward(out) } func (m *MLP) Forward(x *mlx.Array) *mlx.Array { return m.DownProj.Forward(mlx.SwiGLU(m.GateProj.Forward(x), m.UpProj.Forward(x))) }