260 lines
8.1 KiB
Go
260 lines
8.1 KiB
Go
package nn
|
|
|
|
import (
|
|
"github.com/ollama/ollama/x/mlxrunner/batch"
|
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
|
)
|
|
|
|
// RecurrentOption configures a call to CausalConv1D or GatedDelta.
|
|
type RecurrentOption func(*recurrentConfig)
|
|
|
|
// recurrentConfig is the resolved set of inputs supplied via
|
|
// RecurrentOption. Exactly one of history or (convState/deltaState)
|
|
// must be supplied per call.
|
|
type recurrentConfig struct {
|
|
history *RecurrentHistory
|
|
convState *mlx.Array
|
|
deltaState *mlx.Array
|
|
}
|
|
|
|
// WithRecurrentHistory supplies a cache's per-layer view of conv and
|
|
// delta state. The cache hides any storage layout (per-row, paged,
|
|
// gather/scatter) behind the history.
|
|
func WithRecurrentHistory(h *RecurrentHistory) RecurrentOption {
|
|
return func(c *recurrentConfig) { c.history = h }
|
|
}
|
|
|
|
// WithRecurrentState supplies explicit conv and delta state tensors
|
|
// for the no-cache path. Each wrapper consumes one of the two — pass
|
|
// nil for the unused slot when calling only one wrapper.
|
|
func WithRecurrentState(convState, deltaState *mlx.Array) RecurrentOption {
|
|
return func(c *recurrentConfig) {
|
|
c.convState = convState
|
|
c.deltaState = deltaState
|
|
}
|
|
}
|
|
|
|
// resolve applies opts and panics if WithRecurrentHistory and
|
|
// WithRecurrentState were combined or neither was supplied.
|
|
func resolveRecurrentConfig(opts []RecurrentOption) recurrentConfig {
|
|
var cfg recurrentConfig
|
|
for _, opt := range opts {
|
|
opt(&cfg)
|
|
}
|
|
|
|
haveHistory := cfg.history != nil
|
|
haveState := cfg.convState != nil || cfg.deltaState != nil
|
|
if haveHistory && haveState {
|
|
panic("WithRecurrentHistory and WithRecurrentState are mutually exclusive")
|
|
}
|
|
if !haveHistory && !haveState {
|
|
panic("no recurrent state supplied (use WithRecurrentHistory or WithRecurrentState)")
|
|
}
|
|
|
|
return cfg
|
|
}
|
|
|
|
// CausalConv1D runs a depthwise causal 1D convolution with recurrent
|
|
// state management. Prepends the prior conv state along axis 1, runs
|
|
// the conv, and returns (output, nextConv). nextConv is the trailing
|
|
// convTail positions of the concat — write it back to the cache via
|
|
// Put alongside the scan's new delta state.
|
|
//
|
|
// Conv selection: when conv is non-nil (a full nn.Conv1d layer), it
|
|
// runs through conv.Forward. Otherwise weight is treated as the bare
|
|
// depthwise kernel [C, K] and the fallback manual implementation runs.
|
|
// Exactly one of conv or weight should be non-nil.
|
|
//
|
|
// Shapes: input [B, L, D]; prior state [B, convTail, D]; output
|
|
// [B, L, D] (the causal conv strips the prepended state).
|
|
//
|
|
// Prior state comes from exactly one of WithRecurrentHistory (cache
|
|
// path) or WithRecurrentState (no-cache path).
|
|
func CausalConv1D(b *batch.Batch, input *mlx.Array, conv *Conv1d, weight *mlx.Array, convTail int, opts ...RecurrentOption) (out, nextConv *mlx.Array) {
|
|
cfg := resolveRecurrentConfig(opts)
|
|
var prior *mlx.Array
|
|
if cfg.history != nil {
|
|
prior = cfg.history.ConvState()
|
|
} else {
|
|
prior = cfg.convState
|
|
}
|
|
|
|
mask := paddingMask(b, int32(input.Dim(1)))
|
|
if mask != nil {
|
|
zero := mlx.FromValue(float32(0)).AsType(input.DType())
|
|
input = mlx.Where(mlx.ExpandDims(mask, 2), input, zero)
|
|
}
|
|
|
|
concat := mlx.Concatenate([]*mlx.Array{prior, input}, 1)
|
|
if conv != nil {
|
|
out = conv.Forward(concat)
|
|
} else {
|
|
out = depthwiseCausalConv1d(concat, weight, int32(input.Dim(1)))
|
|
}
|
|
|
|
B := int32(concat.Dim(0))
|
|
total := int32(concat.Dim(1))
|
|
D := int32(concat.Dim(2))
|
|
|
|
// Gather the tail from each of the non-padded sequence ends
|
|
if mask != nil && convTail > 0 {
|
|
offsets := make([]int32, int(B)*convTail)
|
|
|
|
for i := range int(B) {
|
|
end := b.SeqQueryLens[i]
|
|
|
|
for k := range convTail {
|
|
offsets[i*convTail+k] = end + int32(k)
|
|
}
|
|
}
|
|
|
|
positions := mlx.NewArrayInt32(offsets, []int32{B, int32(convTail), 1})
|
|
nextConv = mlx.TakeAlongAxis(concat, positions, 1)
|
|
} else {
|
|
nextConv = mlx.SliceStartStop(concat,
|
|
[]int32{0, total - int32(convTail), 0},
|
|
[]int32{B, total, D})
|
|
}
|
|
|
|
return out, nextConv
|
|
}
|
|
|
|
// depthwiseCausalConv1d implements a depthwise 1D causal convolution
|
|
// manually as a sum of kernel-offset multiplies. x has shape
|
|
// [B, inLen, C], weight has shape [C, K]; output has shape [B, outLen, C]
|
|
// where outLen = inLen - K + 1 (the caller passes outLen to avoid the
|
|
// subtraction). Used as the fallback path in CausalConv1D when no
|
|
// full Conv1d layer is configured.
|
|
func depthwiseCausalConv1d(x, w *mlx.Array, outLen int32) *mlx.Array {
|
|
if x == nil || w == nil {
|
|
return nil
|
|
}
|
|
if w.NumDims() != 2 {
|
|
return nil
|
|
}
|
|
B := int32(x.Dim(0))
|
|
C := int32(w.Dim(0))
|
|
K := int32(w.Dim(1))
|
|
var out *mlx.Array
|
|
for i := range K {
|
|
seg := mlx.SliceStartStop(x, []int32{0, i, 0}, []int32{B, i + outLen, C})
|
|
wi := mlx.SliceStartStop(w, []int32{0, i}, []int32{C, i + 1})
|
|
wi = mlx.Reshape(wi, 1, 1, C)
|
|
term := mlx.Mul(seg, wi)
|
|
if out == nil {
|
|
out = term
|
|
} else {
|
|
out = mlx.Add(out, term)
|
|
}
|
|
}
|
|
return out
|
|
}
|
|
|
|
// GatedDelta wraps mlx.FastGatedDelta with recurrent state management.
|
|
// Reads prior delta state from the supplied option and returns
|
|
// (output, newDelta). Write newDelta back via the cache's Put
|
|
// alongside the conv wrapper's nextConv.
|
|
//
|
|
// Shape conventions:
|
|
//
|
|
// q: [B, L, numKeyHeads, headKDim]
|
|
// k: [B, L, numKeyHeads, headKDim]
|
|
// v: [B, L, numValueHeads, headVDim]
|
|
// state: [B, numValueHeads, headVDim, headKDim]
|
|
//
|
|
// Prior state comes from exactly one of WithRecurrentHistory (cache
|
|
// path) or WithRecurrentState (no-cache path).
|
|
func GatedDelta(b *batch.Batch, q, k, v, gDecay, beta *mlx.Array, opts ...RecurrentOption) (out, newDelta *mlx.Array) {
|
|
cfg := resolveRecurrentConfig(opts)
|
|
var state *mlx.Array
|
|
if cfg.history != nil {
|
|
state = cfg.history.DeltaState()
|
|
} else {
|
|
state = cfg.deltaState
|
|
}
|
|
|
|
return mlx.FastGatedDelta(q, k, v, gDecay, beta, state, paddingMask(b, int32(q.Dim(1))))
|
|
}
|
|
|
|
// RecurrentHistory is an opaque per-forward view a recurrent cache
|
|
// hands to the SSM kernel wrappers — prior conv and delta state
|
|
// tensors. Models do not construct this directly; pass it through
|
|
// via WithRecurrentHistory, or use WithRecurrentState on the no-cache
|
|
// path.
|
|
//
|
|
// Opaque structure to model code; accessors ConvState/DeltaState
|
|
// provide the escape hatch for custom SSM paths.
|
|
type RecurrentHistory struct {
|
|
convState, deltaState *mlx.Array
|
|
}
|
|
|
|
// NewRecurrentHistory constructs a RecurrentHistory. Intended for
|
|
// cache implementations across packages; model code uses
|
|
// WithRecurrentHistory / WithRecurrentState instead.
|
|
func NewRecurrentHistory(convState, deltaState *mlx.Array) *RecurrentHistory {
|
|
return &RecurrentHistory{convState: convState, deltaState: deltaState}
|
|
}
|
|
|
|
// ConvState returns the current convolution state tensor.
|
|
//
|
|
// Last-resort escape hatch for custom SSM paths — may force a slow
|
|
// materialization to canonical form depending on the cache's
|
|
// internal storage. Prefer CausalConv1D via WithRecurrentHistory.
|
|
func (h *RecurrentHistory) ConvState() *mlx.Array { return h.convState }
|
|
|
|
// DeltaState returns the current delta state tensor.
|
|
//
|
|
// Last-resort escape hatch for custom SSM paths — may force a slow
|
|
// materialization to canonical form depending on the cache's
|
|
// internal storage. Prefer GatedDelta via WithRecurrentHistory.
|
|
func (h *RecurrentHistory) DeltaState() *mlx.Array { return h.deltaState }
|
|
|
|
type paddingMaskInputs struct {
|
|
batch *batch.Batch
|
|
L int32
|
|
}
|
|
|
|
func (in paddingMaskInputs) build() *mlx.Array {
|
|
B := len(in.batch.SeqQueryLens)
|
|
|
|
needed := false
|
|
for i := range B {
|
|
if in.batch.SeqQueryLens[i] < in.L {
|
|
needed = true
|
|
break
|
|
}
|
|
}
|
|
if !needed {
|
|
return nil
|
|
}
|
|
|
|
L := int(in.L)
|
|
vals := make([]bool, B*L)
|
|
for i := range B {
|
|
n := int(in.batch.SeqQueryLens[i])
|
|
|
|
base := i * L
|
|
for j := range n {
|
|
vals[base+j] = true
|
|
}
|
|
}
|
|
|
|
return mlx.FromValues(vals, B, L)
|
|
}
|
|
|
|
// paddingMask derives a [B, L] bool mask from b.SeqQueryLens for
|
|
// right-padded inputs (real tokens at [0, len_i), padding at
|
|
// [len_i, L)). Returns nil when b has no rows or every row is full —
|
|
// the no-padding fast path that costs nothing extra.
|
|
func paddingMask(b *batch.Batch, L int32) *mlx.Array {
|
|
inputs := paddingMaskInputs{batch: b, L: L}
|
|
if cached, ok := b.Memo.Get(inputs); ok {
|
|
return cached.(*mlx.Array)
|
|
}
|
|
|
|
mask := inputs.build()
|
|
b.Memo.Put(inputs, mask)
|
|
|
|
return mask
|
|
}
|