ollama source for Momentry Core verification
This commit is contained in:
752
kvcache/recurrent.go
Normal file
752
kvcache/recurrent.go
Normal file
@@ -0,0 +1,752 @@
|
||||
package kvcache
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultCheckpointCount = 24
|
||||
DefaultCheckpointMinPos = int32(16)
|
||||
DefaultCheckpointInterval = int32(1664)
|
||||
)
|
||||
|
||||
var ErrInvalidRecurrentShape = errors.New("kvcache: invalid recurrent state shape")
|
||||
|
||||
// Config configures a shared hybrid recurrent cache.
|
||||
type RecurrentConfig struct {
|
||||
Shift func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error)
|
||||
ConvDim int
|
||||
ConvChannels int
|
||||
RecurrentStateSize int
|
||||
CheckpointLogPrefix string
|
||||
}
|
||||
|
||||
var (
|
||||
_ Cache = (*Recurrent)(nil)
|
||||
_ CheckpointCache = (*Recurrent)(nil)
|
||||
)
|
||||
|
||||
// Cache stores:
|
||||
// - a standard causal KV cache
|
||||
// - per-sequence conv state for recurrent operators
|
||||
// - per-sequence recurrent state for recurrent operators
|
||||
//
|
||||
// Conv state shape (per layer, per sequence): [convDim, convChannels]
|
||||
// Recurrent state shape (per layer, per sequence): [recurrentStateSize]
|
||||
type Recurrent struct {
|
||||
kv *Causal
|
||||
|
||||
backend ml.Backend
|
||||
dtype ml.DType
|
||||
maxSequences int
|
||||
|
||||
// Conv state dimensions
|
||||
convDim int
|
||||
convChannels int
|
||||
|
||||
// Recurrent state dimensions
|
||||
recurrentStateSize int
|
||||
|
||||
logPrefix string
|
||||
|
||||
// slot mapping for recurrent state (copy-on-write)
|
||||
slotForSeq map[int]int
|
||||
refCount []int
|
||||
freeSlots []int
|
||||
seqCounts map[int]int
|
||||
slotScratch [1]int32
|
||||
|
||||
// per-layer conv state buffers (allocated lazily)
|
||||
convCtxs map[int]ml.Context
|
||||
convStates map[int]ml.Tensor // [convDim*convChannels, maxSlots]
|
||||
|
||||
// per-layer recurrent state buffers (allocated lazily)
|
||||
recurrentCtxs map[int]ml.Context
|
||||
recurrentStates map[int]ml.Tensor // [recurrentStateSize, maxSlots]
|
||||
|
||||
// recurrent checkpoints (per slot)
|
||||
checkpointCount int
|
||||
checkpointMinPos int32
|
||||
checkpointInterval int32
|
||||
checkpointCtxSize int
|
||||
checkpoints map[int]*slotCheckpointStore
|
||||
pendingRestore map[int]checkpointRestore
|
||||
curCheckpointPos []int32
|
||||
curCheckpointSlots map[int]int
|
||||
reserveCheckpoints bool
|
||||
checkpointConvCtxs map[int]ml.Context
|
||||
checkpointRecurCtxs map[int]ml.Context
|
||||
checkpointReserved map[int]struct{}
|
||||
|
||||
// current forward batch (derived in StartForward)
|
||||
curSeqs []int
|
||||
curSlots []int
|
||||
curSlotsInput ml.Tensor
|
||||
curSeqTokens int
|
||||
|
||||
// track if EnsureWritable has been called for this forward pass
|
||||
writableEnsured bool
|
||||
writableError error
|
||||
}
|
||||
|
||||
func NewRecurrentCache(config RecurrentConfig) *Recurrent {
|
||||
return &Recurrent{
|
||||
kv: NewCausalCache(config.Shift),
|
||||
convDim: config.ConvDim,
|
||||
convChannels: config.ConvChannels,
|
||||
recurrentStateSize: config.RecurrentStateSize,
|
||||
logPrefix: config.CheckpointLogPrefix,
|
||||
slotForSeq: make(map[int]int),
|
||||
seqCounts: make(map[int]int),
|
||||
convCtxs: make(map[int]ml.Context),
|
||||
convStates: make(map[int]ml.Tensor),
|
||||
recurrentCtxs: make(map[int]ml.Context),
|
||||
recurrentStates: make(map[int]ml.Tensor),
|
||||
checkpointCount: DefaultCheckpointCount,
|
||||
checkpointMinPos: DefaultCheckpointMinPos,
|
||||
checkpointInterval: DefaultCheckpointInterval,
|
||||
checkpoints: make(map[int]*slotCheckpointStore),
|
||||
pendingRestore: make(map[int]checkpointRestore),
|
||||
curCheckpointSlots: make(map[int]int),
|
||||
checkpointConvCtxs: make(map[int]ml.Context),
|
||||
checkpointRecurCtxs: make(map[int]ml.Context),
|
||||
checkpointReserved: make(map[int]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Recurrent) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
|
||||
c.backend = backend
|
||||
c.dtype = dtype
|
||||
c.maxSequences = maxSequences
|
||||
c.checkpoints = make(map[int]*slotCheckpointStore)
|
||||
c.pendingRestore = make(map[int]checkpointRestore)
|
||||
c.curCheckpointPos = c.curCheckpointPos[:0]
|
||||
c.curCheckpointSlots = make(map[int]int)
|
||||
c.checkpointReserved = make(map[int]struct{})
|
||||
c.checkpointCtxSize = c.checkpointCount * c.maxSequences
|
||||
if c.checkpointCtxSize < 8 {
|
||||
c.checkpointCtxSize = 8
|
||||
}
|
||||
|
||||
// initialize slot allocator
|
||||
c.refCount = make([]int, maxSequences)
|
||||
c.freeSlots = c.freeSlots[:0]
|
||||
for i := maxSequences - 1; i >= 0; i-- {
|
||||
c.freeSlots = append(c.freeSlots, i)
|
||||
}
|
||||
|
||||
c.kv.Init(backend, dtype, maxSequences, capacity, maxBatch)
|
||||
}
|
||||
|
||||
func (c *Recurrent) Close() {
|
||||
for _, ctx := range c.convCtxs {
|
||||
ctx.Close()
|
||||
}
|
||||
for _, ctx := range c.recurrentCtxs {
|
||||
ctx.Close()
|
||||
}
|
||||
for _, ctx := range c.checkpointConvCtxs {
|
||||
ctx.Close()
|
||||
}
|
||||
for _, ctx := range c.checkpointRecurCtxs {
|
||||
ctx.Close()
|
||||
}
|
||||
c.kv.Close()
|
||||
}
|
||||
|
||||
func (c *Recurrent) SetConfig(config ml.CacheConfig) {
|
||||
c.kv.SetConfig(config)
|
||||
}
|
||||
|
||||
func (c *Recurrent) SetLayer(layer int) {
|
||||
c.kv.SetLayer(layer)
|
||||
}
|
||||
|
||||
func (c *Recurrent) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||
return c.kv.Get(ctx)
|
||||
}
|
||||
|
||||
func (c *Recurrent) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
c.kv.Put(ctx, key, value)
|
||||
}
|
||||
|
||||
func (c *Recurrent) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
||||
if err := c.kv.StartForward(ctx, batch, reserve); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
nTokens := len(batch.Sequences)
|
||||
if nTokens == 0 {
|
||||
c.curSeqs = c.curSeqs[:0]
|
||||
c.curSlots = c.curSlots[:0]
|
||||
c.curSlotsInput = nil
|
||||
c.curSeqTokens = 0
|
||||
c.reserveCheckpoints = false
|
||||
c.writableEnsured = false
|
||||
c.writableError = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// Fast path for single-sequence batches (common during decode and prefill).
|
||||
firstSeq := batch.Sequences[0]
|
||||
singleSeq := true
|
||||
for _, s := range batch.Sequences[1:] {
|
||||
if s != firstSeq {
|
||||
singleSeq = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if singleSeq {
|
||||
return c.startForwardSingleSeq(ctx, firstSeq, nTokens, batch, reserve)
|
||||
}
|
||||
|
||||
// Derive equal-length sequence layout for recurrent layers.
|
||||
seqCounts := c.seqCounts
|
||||
for s := range seqCounts {
|
||||
delete(seqCounts, s)
|
||||
}
|
||||
|
||||
c.curSeqs = c.curSeqs[:0]
|
||||
for _, s := range batch.Sequences {
|
||||
if seqCounts[s] == 0 {
|
||||
c.curSeqs = append(c.curSeqs, s)
|
||||
}
|
||||
seqCounts[s]++
|
||||
}
|
||||
|
||||
nSeqs := len(c.curSeqs)
|
||||
want := nTokens / nSeqs
|
||||
for _, s := range c.curSeqs {
|
||||
if seqCounts[s] != want {
|
||||
return ErrNotSupported
|
||||
}
|
||||
}
|
||||
|
||||
c.curSeqTokens = want
|
||||
|
||||
if reserve {
|
||||
c.curSlots = c.curSlots[:0]
|
||||
for i := range nSeqs {
|
||||
c.curSlots = append(c.curSlots, i)
|
||||
}
|
||||
c.finalizeStartForward(ctx, batch, true)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ensure slots exist for sequences in this batch.
|
||||
c.curSlots = c.curSlots[:0]
|
||||
var newSlots []int
|
||||
for _, s := range c.curSeqs {
|
||||
slot, ok := c.slotForSeq[s]
|
||||
if !ok {
|
||||
var err error
|
||||
slot, err = c.allocSlot()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.slotForSeq[s] = slot
|
||||
c.refCount[slot] = 1
|
||||
newSlots = append(newSlots, slot)
|
||||
}
|
||||
c.curSlots = append(c.curSlots, slot)
|
||||
}
|
||||
|
||||
if len(newSlots) > 0 {
|
||||
c.zeroSlots(ctx, newSlots)
|
||||
}
|
||||
|
||||
c.finalizeStartForward(ctx, batch, false)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Recurrent) startForwardSingleSeq(ctx ml.Context, seq, seqTokens int, batch input.Batch, reserve bool) error {
|
||||
c.curSeqs = append(c.curSeqs[:0], seq)
|
||||
c.curSeqTokens = seqTokens
|
||||
|
||||
if reserve {
|
||||
c.curSlots = append(c.curSlots[:0], 0)
|
||||
c.finalizeStartForward(ctx, batch, true)
|
||||
return nil
|
||||
}
|
||||
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok {
|
||||
var err error
|
||||
slot, err = c.allocSlot()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.slotForSeq[seq] = slot
|
||||
c.refCount[slot] = 1
|
||||
slotList := [1]int{slot}
|
||||
c.zeroSlots(ctx, slotList[:])
|
||||
}
|
||||
|
||||
c.curSlots = append(c.curSlots[:0], slot)
|
||||
c.finalizeStartForward(ctx, batch, false)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Recurrent) finalizeStartForward(ctx ml.Context, batch input.Batch, reserve bool) {
|
||||
c.setCurSlotsInput(ctx)
|
||||
c.writableEnsured = false
|
||||
c.writableError = nil
|
||||
c.reserveCheckpoints = reserve
|
||||
c.planCheckpoints(batch)
|
||||
}
|
||||
|
||||
func (c *Recurrent) setCurSlotsInput(ctx ml.Context) {
|
||||
c.curSlotsInput = c.slotsInput(ctx, c.curSlots)
|
||||
}
|
||||
|
||||
func (c *Recurrent) slotsInput(ctx ml.Context, slots []int) ml.Tensor {
|
||||
switch len(slots) {
|
||||
case 0:
|
||||
return nil
|
||||
case 1:
|
||||
c.slotScratch[0] = int32(slots[0])
|
||||
return ctx.Input().FromInts(c.slotScratch[:], 1)
|
||||
default:
|
||||
slotIndices := make([]int32, len(slots))
|
||||
for i, v := range slots {
|
||||
slotIndices[i] = int32(v)
|
||||
}
|
||||
return ctx.Input().FromInts(slotIndices, len(slotIndices))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Recurrent) allocSlot() (int, error) {
|
||||
if len(c.freeSlots) == 0 {
|
||||
return 0, ErrKvCacheFull
|
||||
}
|
||||
slot := c.freeSlots[len(c.freeSlots)-1]
|
||||
c.freeSlots = c.freeSlots[:len(c.freeSlots)-1]
|
||||
return slot, nil
|
||||
}
|
||||
|
||||
func (c *Recurrent) freeSlot(slot int) {
|
||||
if slot >= 0 && slot < c.maxSequences {
|
||||
c.freeSlots = append(c.freeSlots, slot)
|
||||
}
|
||||
}
|
||||
|
||||
// zeroSlots zeros recurrent state for the given slots across all cached layers.
|
||||
func (c *Recurrent) zeroSlots(ctx ml.Context, slots []int) {
|
||||
if len(slots) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
inputCtx := ctx.Input()
|
||||
slotsTensor := c.slotsInput(ctx, slots)
|
||||
|
||||
if len(c.convStates) > 0 {
|
||||
zeros := inputCtx.Zeros(ml.DTypeF32, c.convDim*c.convChannels, len(slots))
|
||||
for _, buf := range c.convStates {
|
||||
ctx.Forward(buf.SetRows(ctx, zeros, slotsTensor))
|
||||
}
|
||||
}
|
||||
|
||||
if len(c.recurrentStates) > 0 {
|
||||
zeros := inputCtx.Zeros(ml.DTypeF32, c.recurrentStateSize, len(slots))
|
||||
for _, buf := range c.recurrentStates {
|
||||
ctx.Forward(buf.SetRows(ctx, zeros, slotsTensor))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// EnsureWritable ensures sequences have private slots (copy-on-write).
|
||||
func (c *Recurrent) EnsureWritable(ctx ml.Context) error {
|
||||
for i, seq := range c.curSeqs {
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if slot < 0 || slot >= len(c.refCount) {
|
||||
continue
|
||||
}
|
||||
|
||||
if c.refCount[slot] <= 1 {
|
||||
continue
|
||||
}
|
||||
|
||||
newSlot, err := c.allocSlot()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.refCount[slot]--
|
||||
c.refCount[newSlot] = 1
|
||||
c.slotForSeq[seq] = newSlot
|
||||
c.curSlots[i] = newSlot
|
||||
|
||||
c.copyRecurrentState(ctx, slot, newSlot)
|
||||
c.copyCheckpoints(ctx, slot, newSlot)
|
||||
}
|
||||
|
||||
c.setCurSlotsInput(ctx)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Recurrent) copyRecurrentState(ctx ml.Context, srcSlot, dstSlot int) {
|
||||
src := ctx.Input().FromInts([]int32{int32(srcSlot)}, 1)
|
||||
dst := ctx.Input().FromInts([]int32{int32(dstSlot)}, 1)
|
||||
|
||||
for _, buf := range c.convStates {
|
||||
rows := buf.Rows(ctx, src)
|
||||
if rows.DType() != ml.DTypeF32 {
|
||||
rows = rows.Cast(ctx, ml.DTypeF32)
|
||||
}
|
||||
ctx.Forward(buf.SetRows(ctx, rows, dst))
|
||||
}
|
||||
|
||||
for _, buf := range c.recurrentStates {
|
||||
rows := buf.Rows(ctx, src)
|
||||
if rows.DType() != ml.DTypeF32 {
|
||||
rows = rows.Cast(ctx, ml.DTypeF32)
|
||||
}
|
||||
ctx.Forward(buf.SetRows(ctx, rows, dst))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Recurrent) CopyPrefix(srcSeq, dstSeq int, prefixLen int32) {
|
||||
c.kv.CopyPrefix(srcSeq, dstSeq, prefixLen)
|
||||
|
||||
if dstSlot, ok := c.slotForSeq[dstSeq]; ok {
|
||||
if c.validSlot(dstSlot) {
|
||||
c.refCount[dstSlot]--
|
||||
if c.refCount[dstSlot] <= 0 {
|
||||
c.refCount[dstSlot] = 0
|
||||
c.freeSlot(dstSlot)
|
||||
}
|
||||
}
|
||||
delete(c.slotForSeq, dstSeq)
|
||||
}
|
||||
|
||||
srcSlot, ok := c.slotForSeq[srcSeq]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if c.validSlot(srcSlot) {
|
||||
c.slotForSeq[dstSeq] = srcSlot
|
||||
c.refCount[srcSlot]++
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Recurrent) CanResume(seq int, pos int32) bool {
|
||||
if !c.kv.CanResume(seq, pos) {
|
||||
return false
|
||||
}
|
||||
if pos == 0 {
|
||||
return true
|
||||
}
|
||||
return c.hasCheckpoint(seq, pos)
|
||||
}
|
||||
|
||||
func (c *Recurrent) Remove(seq int, beginIndex, endIndex int32) error {
|
||||
if beginIndex > 0 && endIndex != math.MaxInt32 {
|
||||
if err := c.kv.Remove(seq, beginIndex, endIndex); err != nil {
|
||||
return err
|
||||
}
|
||||
delete(c.pendingRestore, seq)
|
||||
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok || !c.validSlot(slot) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Detach shared recurrent state/checkpoints before mutating checkpoint positions.
|
||||
if c.refCount[slot] > 1 {
|
||||
newSlot, err := c.allocSlot()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ctx := c.backend.NewContext()
|
||||
c.copyRecurrentState(ctx, slot, newSlot)
|
||||
c.copyCheckpoints(ctx, slot, newSlot)
|
||||
if len(c.convStates) > 0 || len(c.recurrentStates) > 0 {
|
||||
ctx.Compute()
|
||||
}
|
||||
ctx.Close()
|
||||
|
||||
c.refCount[slot]--
|
||||
c.refCount[newSlot] = 1
|
||||
c.slotForSeq[seq] = newSlot
|
||||
slot = newSlot
|
||||
}
|
||||
|
||||
c.shiftCheckpoints(slot, beginIndex, endIndex)
|
||||
return nil
|
||||
}
|
||||
|
||||
if beginIndex > 0 {
|
||||
restore, ok := c.pendingRestore[seq]
|
||||
if !ok || restore.pos+1 != beginIndex {
|
||||
return ErrNotSupported
|
||||
}
|
||||
if !c.restoreComplete(restore) {
|
||||
return ErrNotSupported
|
||||
}
|
||||
if slot, ok := c.slotForSeq[seq]; ok && c.validSlot(slot) && c.refCount[slot] > 1 {
|
||||
newSlot, err := c.allocSlot()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ctx := c.backend.NewContext()
|
||||
c.copyRecurrentState(ctx, slot, newSlot)
|
||||
c.copyCheckpoints(ctx, slot, newSlot)
|
||||
if len(c.convStates) > 0 || len(c.recurrentStates) > 0 {
|
||||
ctx.Compute()
|
||||
}
|
||||
ctx.Close()
|
||||
|
||||
c.refCount[slot]--
|
||||
c.refCount[newSlot] = 1
|
||||
c.slotForSeq[seq] = newSlot
|
||||
|
||||
restore.slot = newSlot
|
||||
c.pendingRestore[seq] = restore
|
||||
}
|
||||
}
|
||||
|
||||
if err := c.kv.Remove(seq, beginIndex, endIndex); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if beginIndex > 0 {
|
||||
restore := c.pendingRestore[seq]
|
||||
delete(c.pendingRestore, seq)
|
||||
return c.applyCheckpointRestore(restore)
|
||||
}
|
||||
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
delete(c.pendingRestore, seq)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !c.validSlot(slot) {
|
||||
delete(c.slotForSeq, seq)
|
||||
return nil
|
||||
}
|
||||
|
||||
c.refCount[slot]--
|
||||
if c.refCount[slot] <= 0 {
|
||||
c.refCount[slot] = 0
|
||||
c.clearCheckpoints(slot)
|
||||
c.freeSlot(slot)
|
||||
}
|
||||
delete(c.slotForSeq, seq)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Recurrent) validSlot(slot int) bool {
|
||||
return slot >= 0 && slot < len(c.refCount)
|
||||
}
|
||||
|
||||
func (c *Recurrent) SlotsTensor() ml.Tensor {
|
||||
return c.curSlotsInput
|
||||
}
|
||||
|
||||
// contiguousSlots returns the starting slot if current slots are contiguous and ordered.
|
||||
func (c *Recurrent) contiguousSlots() (int, bool) {
|
||||
if len(c.curSlots) == 0 {
|
||||
return 0, false
|
||||
}
|
||||
start := c.curSlots[0]
|
||||
for i, s := range c.curSlots {
|
||||
if s != start+i {
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
return start, true
|
||||
}
|
||||
|
||||
func (c *Recurrent) SeqTokens() int {
|
||||
return c.curSeqTokens
|
||||
}
|
||||
|
||||
func (c *Recurrent) NumSeqs() int {
|
||||
return len(c.curSeqs)
|
||||
}
|
||||
|
||||
func (c *Recurrent) convBuffer(layer int) ml.Tensor {
|
||||
if buf, ok := c.convStates[layer]; ok {
|
||||
return buf
|
||||
}
|
||||
|
||||
if _, ok := c.convCtxs[layer]; !ok {
|
||||
c.convCtxs[layer] = c.backend.NewContextSize(1).Layer(layer)
|
||||
}
|
||||
|
||||
buf := c.convCtxs[layer].Zeros(ml.DTypeF32, c.convDim*c.convChannels, c.maxSequences)
|
||||
c.convStates[layer] = buf
|
||||
return buf
|
||||
}
|
||||
|
||||
func (c *Recurrent) recurrentBuffer(layer int) ml.Tensor {
|
||||
if buf, ok := c.recurrentStates[layer]; ok {
|
||||
return buf
|
||||
}
|
||||
|
||||
if _, ok := c.recurrentCtxs[layer]; !ok {
|
||||
c.recurrentCtxs[layer] = c.backend.NewContextSize(1).Layer(layer)
|
||||
}
|
||||
|
||||
buf := c.recurrentCtxs[layer].Zeros(ml.DTypeF32, c.recurrentStateSize, c.maxSequences)
|
||||
c.recurrentStates[layer] = buf
|
||||
return buf
|
||||
}
|
||||
|
||||
func (c *Recurrent) ensureWritable(ctx ml.Context) error {
|
||||
c.ensureWritableOnce(ctx)
|
||||
return c.writableError
|
||||
}
|
||||
|
||||
func (c *Recurrent) currentSlotRows(ctx ml.Context, buf ml.Tensor, rowSize int) ml.Tensor {
|
||||
if start, ok := c.contiguousSlots(); ok {
|
||||
offset := start * buf.Stride(1)
|
||||
return buf.View(ctx, offset, rowSize, buf.Stride(1), c.NumSeqs())
|
||||
}
|
||||
|
||||
return buf.Rows(ctx, c.SlotsTensor())
|
||||
}
|
||||
|
||||
func (c *Recurrent) writeCurrentSlotRows(ctx ml.Context, buf ml.Tensor, rowSize int, src ml.Tensor) {
|
||||
if start, ok := c.contiguousSlots(); ok {
|
||||
offset := start * buf.Stride(1)
|
||||
view := buf.View(ctx, offset, rowSize, buf.Stride(1), c.NumSeqs())
|
||||
ctx.Forward(src.Copy(ctx, view))
|
||||
return
|
||||
}
|
||||
|
||||
ctx.Forward(buf.SetRows(ctx, src, c.SlotsTensor()))
|
||||
}
|
||||
|
||||
func (c *Recurrent) ensureWritableOnce(ctx ml.Context) {
|
||||
if !c.writableEnsured {
|
||||
needsWritable := false
|
||||
for _, seq := range c.curSeqs {
|
||||
slot, ok := c.slotForSeq[seq]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if slot >= 0 && slot < len(c.refCount) && c.refCount[slot] > 1 {
|
||||
needsWritable = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if needsWritable {
|
||||
if err := c.EnsureWritable(ctx); err != nil {
|
||||
c.writableError = err
|
||||
}
|
||||
}
|
||||
c.writableEnsured = true
|
||||
}
|
||||
}
|
||||
|
||||
// ConvState returns conv state for current batch sequences as [convDim, convChannels, nSeqs].
|
||||
func (c *Recurrent) ConvState(ctx ml.Context, layer int) (ml.Tensor, error) {
|
||||
if err := c.ensureWritable(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
buf := c.convBuffer(layer)
|
||||
cur := c.currentSlotRows(ctx, buf, c.convDim*c.convChannels)
|
||||
return cur.Reshape(ctx, c.convDim, c.convChannels, c.NumSeqs()), nil
|
||||
}
|
||||
|
||||
// UpdateConvState writes new conv state for current batch sequences.
|
||||
func (c *Recurrent) UpdateConvState(ctx ml.Context, layer int, newState ml.Tensor) {
|
||||
buf := c.convBuffer(layer)
|
||||
src := newState.Reshape(ctx, c.convDim*c.convChannels, c.NumSeqs())
|
||||
srcF32 := src
|
||||
if src.DType() != ml.DTypeF32 {
|
||||
srcF32 = src.Cast(ctx, ml.DTypeF32)
|
||||
}
|
||||
c.writeCurrentSlotRows(ctx, buf, c.convDim*c.convChannels, srcF32)
|
||||
|
||||
c.captureConvCheckpoint(ctx, layer, srcF32)
|
||||
}
|
||||
|
||||
// RecurrentState returns recurrent state for current batch sequences with shape [dims..., nSeqs].
|
||||
func (c *Recurrent) RecurrentState(ctx ml.Context, layer int, dims ...int) (ml.Tensor, error) {
|
||||
if err := c.ensureWritable(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(dims) == 0 {
|
||||
return nil, ErrInvalidRecurrentShape
|
||||
}
|
||||
|
||||
size := 1
|
||||
for _, d := range dims {
|
||||
if d <= 0 {
|
||||
return nil, ErrInvalidRecurrentShape
|
||||
}
|
||||
size *= d
|
||||
}
|
||||
if size != c.recurrentStateSize {
|
||||
return nil, fmt.Errorf("%w: got %v (size %d), want size %d", ErrInvalidRecurrentShape, dims, size, c.recurrentStateSize)
|
||||
}
|
||||
|
||||
buf := c.recurrentBuffer(layer)
|
||||
cur := c.currentSlotRows(ctx, buf, c.recurrentStateSize)
|
||||
shape := make([]int, 0, len(dims)+1)
|
||||
shape = append(shape, dims...)
|
||||
shape = append(shape, c.NumSeqs())
|
||||
return cur.Reshape(ctx, shape...), nil
|
||||
}
|
||||
|
||||
// RecurrentState4D returns recurrent state as [dim0, dim1, dim2, nSeqs].
|
||||
func (c *Recurrent) RecurrentState4D(ctx ml.Context, layer int, dim0, dim1, dim2 int) (ml.Tensor, error) {
|
||||
if err := c.ensureWritable(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if dim0 <= 0 || dim1 <= 0 || dim2 <= 0 {
|
||||
return nil, ErrInvalidRecurrentShape
|
||||
}
|
||||
|
||||
size := dim0 * dim1 * dim2
|
||||
if size != c.recurrentStateSize {
|
||||
return nil, fmt.Errorf("%w: got [%d %d %d] (size %d), want size %d", ErrInvalidRecurrentShape, dim0, dim1, dim2, size, c.recurrentStateSize)
|
||||
}
|
||||
|
||||
buf := c.recurrentBuffer(layer)
|
||||
cur := c.currentSlotRows(ctx, buf, c.recurrentStateSize)
|
||||
return cur.Reshape(ctx, dim0, dim1, dim2, c.NumSeqs()), nil
|
||||
}
|
||||
|
||||
// UpdateRecurrentState writes new recurrent state for current batch sequences.
|
||||
func (c *Recurrent) UpdateRecurrentState(ctx ml.Context, layer int, newState ml.Tensor) {
|
||||
buf := c.recurrentBuffer(layer)
|
||||
src := newState.Reshape(ctx, c.recurrentStateSize, c.NumSeqs())
|
||||
srcF32 := src
|
||||
if src.DType() != ml.DTypeF32 {
|
||||
srcF32 = src.Cast(ctx, ml.DTypeF32)
|
||||
}
|
||||
c.writeCurrentSlotRows(ctx, buf, c.recurrentStateSize, srcF32)
|
||||
|
||||
c.captureRecurrentCheckpoint(ctx, layer, srcF32)
|
||||
}
|
||||
|
||||
// IsSupportedForBatch returns true if the current batch layout supports recurrent layers.
|
||||
func (c *Recurrent) IsSupportedForBatch() bool {
|
||||
return c.curSeqTokens > 0 && len(c.curSeqs) > 0
|
||||
}
|
||||
|
||||
// Seqs returns the ordered unique sequences for the current forward pass.
|
||||
func (c *Recurrent) Seqs() []int {
|
||||
return slices.Clone(c.curSeqs)
|
||||
}
|
||||
Reference in New Issue
Block a user