ollama source for Momentry Core verification
This commit is contained in:
55
model/models/nemotronh/cache.go
Normal file
55
model/models/nemotronh/cache.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package nemotronh
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
// ErrUnsupportedBatchLayout is returned when the batch layout is incompatible
|
||||
// with the layer requirements.
|
||||
var ErrUnsupportedBatchLayout = errors.New("nemotronh: unsupported batch layout")
|
||||
|
||||
var (
|
||||
_ kvcache.Cache = (*HybridCache)(nil)
|
||||
_ kvcache.CheckpointCache = (*HybridCache)(nil)
|
||||
)
|
||||
|
||||
// HybridCache adapts the shared recurrent cache base for Nemotron-H naming.
|
||||
type HybridCache struct {
|
||||
*kvcache.Recurrent
|
||||
}
|
||||
|
||||
func NewHybridCache(convDim, convChannels, ssmStateSize int) *HybridCache {
|
||||
base := kvcache.NewRecurrentCache(kvcache.RecurrentConfig{
|
||||
Shift: Shift,
|
||||
ConvDim: convDim,
|
||||
ConvChannels: convChannels,
|
||||
RecurrentStateSize: ssmStateSize,
|
||||
CheckpointLogPrefix: "nemotronh",
|
||||
})
|
||||
return &HybridCache{Recurrent: base}
|
||||
}
|
||||
|
||||
// SSMState returns the SSM state for current batch sequences.
|
||||
func (c *HybridCache) SSMState(ctx ml.Context, layer int, dState, headDim, nHead int) (ml.Tensor, error) {
|
||||
return c.RecurrentState4D(ctx, layer, dState, headDim, nHead)
|
||||
}
|
||||
|
||||
// UpdateSSMState writes a new SSM state for current batch sequences.
|
||||
func (c *HybridCache) UpdateSSMState(ctx ml.Context, layer int, newState ml.Tensor) {
|
||||
c.UpdateRecurrentState(ctx, layer, newState)
|
||||
}
|
||||
|
||||
func (c *HybridCache) slotsTensor() ml.Tensor {
|
||||
return c.SlotsTensor()
|
||||
}
|
||||
|
||||
func (c *HybridCache) seqTokens() int {
|
||||
return c.SeqTokens()
|
||||
}
|
||||
|
||||
func (c *HybridCache) numSeqs() int {
|
||||
return c.NumSeqs()
|
||||
}
|
||||
Reference in New Issue
Block a user