ollama source for Momentry Core verification
This commit is contained in:
475
model/models/gemma4/model_text.go
Normal file
475
model/models/gemma4/model_text.go
Normal file
@@ -0,0 +1,475 @@
|
||||
package gemma4
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
const (
|
||||
cacheTypeSWA = iota
|
||||
cacheTypeCausal
|
||||
)
|
||||
|
||||
type TextOptions struct {
|
||||
hiddenSize int
|
||||
numHeads, numKVHeads int
|
||||
numGlobalKVHeads int
|
||||
headDim, globalHeadDim int
|
||||
hiddenLayers int
|
||||
hiddenSizePerLayerInput int
|
||||
|
||||
eps float32
|
||||
ropeBase float32
|
||||
ropeLocalBase float32
|
||||
partialRotaryDims int // RoPE dims for full-attention (global) layers
|
||||
|
||||
slidingWindowPattern []bool
|
||||
// kvDonorMap maps shared layer index -> donor layer index.
|
||||
// Donor is the last non-shared layer of the same type (sliding/full).
|
||||
kvDonorMap map[int]int
|
||||
|
||||
finalLogitSoftcap float32
|
||||
|
||||
numExperts int
|
||||
numExpertsUsed int
|
||||
}
|
||||
|
||||
func (o *TextOptions) isLocal(layer int) bool {
|
||||
if layer < len(o.slidingWindowPattern) {
|
||||
return o.slidingWindowPattern[layer]
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (o *TextOptions) ropeForLayer(layer int) (base float32, dims int) {
|
||||
if o.isLocal(layer) {
|
||||
return o.ropeLocalBase, o.headDim
|
||||
}
|
||||
return o.ropeBase, o.partialRotaryDims
|
||||
}
|
||||
|
||||
func (o *TextOptions) kvHeadsForLayer(layer int) int {
|
||||
if o.isLocal(layer) {
|
||||
return o.numKVHeads
|
||||
}
|
||||
if o.numGlobalKVHeads > 0 {
|
||||
return o.numGlobalKVHeads
|
||||
}
|
||||
return o.numKVHeads
|
||||
}
|
||||
|
||||
func (o *TextOptions) headDimForLayer(layer int) int {
|
||||
if o.isLocal(layer) {
|
||||
return o.headDim
|
||||
}
|
||||
return o.globalHeadDim
|
||||
}
|
||||
|
||||
type TextModel struct {
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
*PerLayerProjector
|
||||
Layers []TextLayer `gguf:"blk"`
|
||||
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
||||
TextOptions
|
||||
}
|
||||
|
||||
func newTextModel(c fs.Config) *TextModel {
|
||||
numLayers := int(c.Uint("block_count"))
|
||||
|
||||
// Head dimensions: key_length is global head dim, key_length_swa is local (SWA) head dim.
|
||||
globalHeadDim := int(c.Uint("attention.key_length", 512))
|
||||
headDim := int(c.Uint("attention.key_length_swa", 256))
|
||||
|
||||
// RoPE dimensions for global (full attention) layers with proportional RoPE.
|
||||
// The freq_factors tensor handles partial rotation (1.0 for rotated pairs,
|
||||
// 1e30 for non-rotated), so ropeDims equals the full global head dim.
|
||||
partialRotaryDims := int(c.Uint("rope.dimension_count", 0))
|
||||
if partialRotaryDims == 0 {
|
||||
partialFactor := c.Float("rope.partial_rotary_factor", 1.0)
|
||||
partialRotaryDims = int(float32(globalHeadDim) * partialFactor)
|
||||
}
|
||||
|
||||
ropeBase := c.Float("rope.freq_base", 1000000.0)
|
||||
ropeLocalBase := c.Float("rope.freq_base_swa", 0)
|
||||
if ropeLocalBase == 0 {
|
||||
ropeLocalBase = c.Float("rope.local.freq_base", 10000.0)
|
||||
}
|
||||
|
||||
numGlobalKVHeads := int(c.Uint("attention.global_head_count_kv", 0))
|
||||
slidingPattern := c.Bools("attention.sliding_window_pattern")
|
||||
|
||||
// KV heads: try per-layer array first (MoE models), then fall back to scalar
|
||||
numKVHeads := 0
|
||||
kvHeadsArray := c.Ints("attention.head_count_kv")
|
||||
if len(kvHeadsArray) > 0 {
|
||||
numKVHeads = int(kvHeadsArray[0])
|
||||
if numGlobalKVHeads == 0 && len(slidingPattern) > 0 {
|
||||
for i, isLocal := range slidingPattern {
|
||||
if !isLocal && i < len(kvHeadsArray) {
|
||||
numGlobalKVHeads = int(kvHeadsArray[i])
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if numKVHeads == 0 {
|
||||
numKVHeads = int(c.Uint("attention.head_count_kv", 0))
|
||||
}
|
||||
|
||||
// Compute KV sharing donor map (same logic as MLX)
|
||||
sharedLayers := int(c.Uint("attention.shared_kv_layers", 0))
|
||||
kvDonorMap := make(map[int]int)
|
||||
if sharedLayers > 0 && len(slidingPattern) > 0 {
|
||||
firstShared := numLayers - sharedLayers
|
||||
for i := firstShared; i < numLayers; i++ {
|
||||
isLocal := slidingPattern[i]
|
||||
// Find last non-shared layer of same type
|
||||
for j := firstShared - 1; j >= 0; j-- {
|
||||
if slidingPattern[j] == isLocal {
|
||||
kvDonorMap[i] = j
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &TextModel{
|
||||
Layers: make([]TextLayer, numLayers),
|
||||
TextOptions: TextOptions{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: numKVHeads,
|
||||
numGlobalKVHeads: numGlobalKVHeads,
|
||||
headDim: headDim,
|
||||
globalHeadDim: globalHeadDim,
|
||||
hiddenLayers: numLayers,
|
||||
hiddenSizePerLayerInput: int(c.Uint("embedding_length_per_layer_input", 0)),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
|
||||
ropeBase: ropeBase,
|
||||
ropeLocalBase: ropeLocalBase,
|
||||
partialRotaryDims: partialRotaryDims,
|
||||
slidingWindowPattern: slidingPattern,
|
||||
kvDonorMap: kvDonorMap,
|
||||
finalLogitSoftcap: c.Float("final_logit_softcapping", 0.0),
|
||||
numExperts: int(c.Uint("expert_count", 0)),
|
||||
numExpertsUsed: int(c.Uint("expert_used_count", 0)),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) ml.Tensor {
|
||||
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
|
||||
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.hiddenSize)))
|
||||
|
||||
// Inject vision embeddings into the hidden state
|
||||
var except []int
|
||||
for _, image := range batch.Multimodal {
|
||||
visionOutputs := image.Multimodal[0].Tensor
|
||||
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))
|
||||
|
||||
for i := range visionOutputs.Dim(1) {
|
||||
except = append(except, image.Index+i)
|
||||
}
|
||||
}
|
||||
|
||||
// PLE
|
||||
var perLayerInputs ml.Tensor
|
||||
if m.PerLayerProjector != nil {
|
||||
perLayerInputs = m.PerLayerProjector.Forward(ctx, batch, hiddenState, &m.TextOptions)
|
||||
}
|
||||
|
||||
for i := range len(m.Layers) {
|
||||
layer := m.Layers[i]
|
||||
if cache != nil {
|
||||
cache.SetLayer(i)
|
||||
cacheType := cacheTypeSWA
|
||||
if !m.isLocal(i) {
|
||||
cacheType = cacheTypeCausal
|
||||
}
|
||||
wc := cache.(*kvcache.WrapperCache)
|
||||
wc.SetLayerType(cacheType)
|
||||
|
||||
if causal, ok := wc.UnderlyingCache().(*kvcache.Causal); ok {
|
||||
causal.SetCausal(ctx, kvcache.CausalOptions{Except: except})
|
||||
}
|
||||
}
|
||||
|
||||
var lastLayerOutputs ml.Tensor
|
||||
if i == len(m.Layers)-1 {
|
||||
lastLayerOutputs = batch.Outputs
|
||||
}
|
||||
|
||||
var perLayerInput ml.Tensor
|
||||
if perLayerInputs != nil {
|
||||
perLayerInput = perLayerInputs.View(ctx, i*perLayerInputs.Stride(1), perLayerInputs.Dim(0), perLayerInputs.Stride(2), perLayerInputs.Dim(2))
|
||||
}
|
||||
|
||||
// KV sharing: layers >= firstShared reuse K/V from donor layers
|
||||
isShared := false
|
||||
if donorLayer, ok := m.kvDonorMap[i]; ok {
|
||||
// Set cache layer to donor so Get() reads donor's K/V
|
||||
cache.SetLayer(donorLayer)
|
||||
isShared = true
|
||||
}
|
||||
hiddenState = layer.Forward(ctx, i, hiddenState, positions, perLayerInput, lastLayerOutputs, cache, isShared, &m.TextOptions)
|
||||
}
|
||||
|
||||
return m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
||||
}
|
||||
|
||||
// PerLayerProjector implements PLE.
|
||||
type PerLayerProjector struct {
|
||||
TokenEmbedding *nn.Embedding `gguf:"per_layer_token_embd"`
|
||||
Projector *nn.Linear `gguf:"per_layer_model_proj"`
|
||||
Norm *nn.RMSNorm `gguf:"per_layer_proj_norm"`
|
||||
}
|
||||
|
||||
func (p *PerLayerProjector) Forward(ctx ml.Context, batch input.Batch, inputs ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||
inputsPerLayer := p.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
inputsPerLayer = inputsPerLayer.Scale(ctx, math.Sqrt(float64(opts.hiddenSizePerLayerInput)))
|
||||
// Reshape to [pleDim, numLayers, numTokens] — matching projection shape
|
||||
inputsPerLayer = inputsPerLayer.Reshape(ctx, opts.hiddenSizePerLayerInput, opts.hiddenLayers, inputs.Dim(1))
|
||||
|
||||
perLayerProjection := p.Projector.Forward(ctx, inputs)
|
||||
perLayerProjection = perLayerProjection.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize)))
|
||||
perLayerProjection = perLayerProjection.Reshape(ctx, opts.hiddenSizePerLayerInput, opts.hiddenLayers, inputs.Dim(1))
|
||||
perLayerProjection = p.Norm.Forward(ctx, perLayerProjection, opts.eps)
|
||||
|
||||
if inputsPerLayer != nil {
|
||||
perLayerProjection = perLayerProjection.Add(ctx, inputsPerLayer)
|
||||
perLayerProjection = perLayerProjection.Scale(ctx, 1/math.Sqrt(2))
|
||||
}
|
||||
|
||||
return perLayerProjection
|
||||
}
|
||||
|
||||
type TextSelfAttention struct {
|
||||
Query *nn.Linear `gguf:"attn_q"`
|
||||
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
|
||||
Value *nn.Linear `gguf:"attn_v"`
|
||||
Output *nn.Linear `gguf:"attn_output"`
|
||||
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"` // proportional RoPE freq_factors
|
||||
}
|
||||
|
||||
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positions ml.Tensor, cache kvcache.Cache, sharedKV bool, opts *TextOptions) ml.Tensor {
|
||||
batchSize := hiddenState.Dim(1)
|
||||
hd := opts.headDimForLayer(layer)
|
||||
kvHeads := opts.kvHeadsForLayer(layer)
|
||||
ropeBase, ropeDims := opts.ropeForLayer(layer)
|
||||
|
||||
q := sa.Query.Forward(ctx, hiddenState)
|
||||
q = q.Reshape(ctx, hd, opts.numHeads, batchSize)
|
||||
q = sa.QueryNorm.Forward(ctx, q, opts.eps)
|
||||
|
||||
var k, v ml.Tensor
|
||||
if !sharedKV {
|
||||
k = sa.Key.Forward(ctx, hiddenState)
|
||||
k = k.Reshape(ctx, hd, kvHeads, batchSize)
|
||||
|
||||
if sa.Value != nil {
|
||||
v = sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, hd, kvHeads, batchSize)
|
||||
} else {
|
||||
// K=V: use raw K projection (before K norm) as V
|
||||
v = k
|
||||
}
|
||||
|
||||
k = sa.KeyNorm.Forward(ctx, k, opts.eps)
|
||||
v = v.RMSNorm(ctx, nil, opts.eps) // V norm: unweighted RMSNorm
|
||||
}
|
||||
|
||||
// RoPE with proportional freq_factors on global layers
|
||||
ropeOpts := []func(*rope.Options){rope.WithTypeNeoX()}
|
||||
if sa.RopeFactors != nil && !opts.isLocal(layer) {
|
||||
ropeOpts = append(ropeOpts, rope.WithFactors(sa.RopeFactors))
|
||||
}
|
||||
q = nn.RoPE(ctx, q, positions, ropeDims, ropeBase, 1.0, ropeOpts...)
|
||||
if k != nil {
|
||||
k = nn.RoPE(ctx, k, positions, ropeDims, ropeBase, 1.0, ropeOpts...)
|
||||
}
|
||||
|
||||
attention := nn.Attention(ctx, q, k, v, 1.0, cache)
|
||||
|
||||
attention = attention.Reshape(ctx, hd*opts.numHeads, batchSize)
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
||||
type TextMLP struct {
|
||||
Gate *nn.Linear `gguf:"ffn_gate"`
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
}
|
||||
|
||||
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor) ml.Tensor {
|
||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||
return mlp.Down.Forward(ctx, hiddenState)
|
||||
}
|
||||
|
||||
// TextRouter implements the Gemma 4 MoE router.
|
||||
type TextRouter struct {
|
||||
Proj *nn.Linear `gguf:"ffn_gate_inp"`
|
||||
Scale ml.Tensor `gguf:"ffn_gate_inp.scale"`
|
||||
}
|
||||
|
||||
func (r *TextRouter) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) (routingWeights, selectedExperts ml.Tensor) {
|
||||
// RMSNorm without learned weight
|
||||
x := hiddenState.RMSNorm(ctx, nil, opts.eps)
|
||||
// Scale by 1/sqrt(hidden_size)
|
||||
x = x.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize)))
|
||||
// Multiply by learned scale parameter
|
||||
x = x.Mul(ctx, r.Scale)
|
||||
// Project to expert logits
|
||||
expertScores := r.Proj.Forward(ctx, x)
|
||||
// Softmax over experts
|
||||
routingWeights = expertScores.Softmax(ctx)
|
||||
// TopK expert selection
|
||||
selectedExperts = routingWeights.TopK(ctx, opts.numExpertsUsed)
|
||||
return routingWeights, selectedExperts
|
||||
}
|
||||
|
||||
// TextMoEBlock implements the Gemma 4 sparse MoE.
|
||||
type TextMoEBlock struct {
|
||||
GateUp *nn.LinearBatch `gguf:"ffn_gate_up_exps"`
|
||||
Gate *nn.LinearBatch `gguf:"ffn_gate_exps"`
|
||||
Up *nn.LinearBatch `gguf:"ffn_up_exps"`
|
||||
Down *nn.LinearBatch `gguf:"ffn_down_exps"`
|
||||
DownScale ml.Tensor `gguf:"ffn_down_exps.scale,alt:ffn_gate_inp.per_expert_scale"`
|
||||
}
|
||||
|
||||
func (moe *TextMoEBlock) Forward(ctx ml.Context, hiddenState, routingWeights, selectedExperts ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||
// Select routing weights for chosen experts and renormalize
|
||||
routingWeights = routingWeights.Reshape(ctx, 1, opts.numExperts, hiddenState.Dim(1)).Rows(ctx, selectedExperts)
|
||||
routingWeights = routingWeights.Reshape(ctx, opts.numExpertsUsed, hiddenState.Dim(1))
|
||||
routingWeights = routingWeights.Div(ctx, routingWeights.SumRows(ctx))
|
||||
routingWeights = routingWeights.Reshape(ctx, 1, opts.numExpertsUsed, hiddenState.Dim(1))
|
||||
|
||||
hiddenState = hiddenState.Reshape(ctx, hiddenState.Dim(0), 1, hiddenState.Dim(1))
|
||||
|
||||
// Expert computation using LinearBatch (MulmatID selecting experts by index)
|
||||
var gateOut, upOut ml.Tensor
|
||||
if moe.GateUp != nil && moe.GateUp.Weight != nil {
|
||||
gateUp := moe.GateUp.Forward(ctx, hiddenState, selectedExperts)
|
||||
nFF := gateUp.Dim(0) / 2
|
||||
gateOut = gateUp.Slice(ctx, 0, 0, nFF, 1)
|
||||
upOut = gateUp.Slice(ctx, 0, nFF, gateUp.Dim(0), 1)
|
||||
} else {
|
||||
gateOut = moe.Gate.Forward(ctx, hiddenState, selectedExperts)
|
||||
upOut = moe.Up.Forward(ctx, hiddenState, selectedExperts)
|
||||
}
|
||||
hiddenState = gateOut.GELU(ctx, upOut)
|
||||
experts := moe.Down.Forward(ctx, hiddenState, selectedExperts)
|
||||
|
||||
// Apply per-expert down projection scale when present.
|
||||
if moe.DownScale != nil {
|
||||
expertScales := moe.DownScale.Reshape(ctx, opts.numExperts, 1)
|
||||
expertScales = expertScales.Repeat(ctx, 1, hiddenState.Dim(2))
|
||||
expertScales = expertScales.Reshape(ctx, 1, opts.numExperts, hiddenState.Dim(2)).Rows(ctx, selectedExperts)
|
||||
expertScales = expertScales.Reshape(ctx, opts.numExpertsUsed, hiddenState.Dim(2))
|
||||
expertScales = expertScales.Reshape(ctx, 1, opts.numExpertsUsed, hiddenState.Dim(2))
|
||||
experts = experts.Mul(ctx, expertScales)
|
||||
}
|
||||
|
||||
// Apply routing weights
|
||||
experts = experts.Mul(ctx, routingWeights)
|
||||
|
||||
// Sum across experts
|
||||
nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
|
||||
for i := 1; i < opts.numExpertsUsed; i++ {
|
||||
nextStates = nextStates.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2)))
|
||||
}
|
||||
|
||||
return nextStates
|
||||
}
|
||||
|
||||
type TextLayer struct {
|
||||
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
|
||||
SelfAttention *TextSelfAttention
|
||||
PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_norm,alt:attn_post_norm"`
|
||||
MLPNorm *nn.RMSNorm `gguf:"ffn_norm,alt:ffn_pre_norm"`
|
||||
MLP *TextMLP
|
||||
PostMLPNorm *nn.RMSNorm `gguf:"post_ffw_norm,alt:ffn_post_norm"`
|
||||
|
||||
// MoE (present only for models with enable_moe_block=true)
|
||||
Router *TextRouter
|
||||
MoE *TextMoEBlock
|
||||
MoENorm *nn.RMSNorm `gguf:"pre_ffw_norm_2,alt:ffn_pre_norm_2"`
|
||||
PostMoENorm *nn.RMSNorm `gguf:"post_ffw_norm_2,alt:ffn_post_norm_2"`
|
||||
PostMLPNorm1 *nn.RMSNorm `gguf:"post_ffw_norm_1,alt:ffn_post_norm_1"` // used instead of PostMLPNorm when MoE is present
|
||||
|
||||
PerLayerInputGate *nn.Linear `gguf:"inp_gate"`
|
||||
PerLayerProjection *nn.Linear `gguf:"proj"`
|
||||
PostPerLayerNorm *nn.RMSNorm `gguf:"post_norm"`
|
||||
LayerScalar ml.Tensor `gguf:"layer_scalar,alt:layer_output_scale.weight"`
|
||||
}
|
||||
|
||||
func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positions, perLayerInput, outputs ml.Tensor, cache kvcache.Cache, sharedKV bool, opts *TextOptions) ml.Tensor {
|
||||
residual := hiddenState
|
||||
|
||||
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = l.SelfAttention.Forward(ctx, layer, hiddenState, positions, cache, sharedKV, opts)
|
||||
hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
|
||||
if outputs != nil {
|
||||
hiddenState = hiddenState.Rows(ctx, outputs)
|
||||
residual = residual.Rows(ctx, outputs)
|
||||
if perLayerInput != nil {
|
||||
perLayerInput = perLayerInput.Rows(ctx, outputs)
|
||||
}
|
||||
}
|
||||
|
||||
hiddenState = hiddenState.Add(ctx, residual)
|
||||
residual = hiddenState
|
||||
|
||||
// MLP (+ optional MoE in parallel)
|
||||
hasSplitExperts := l.MoE != nil && l.MoE.Gate != nil && l.MoE.Up != nil && l.MoE.Gate.Weight != nil && l.MoE.Up.Weight != nil
|
||||
hasFusedExperts := l.MoE != nil && l.MoE.GateUp != nil && l.MoE.GateUp.Weight != nil
|
||||
if l.Router != nil && l.MoE != nil && l.MoE.Down != nil && l.MoE.Down.Weight != nil && (hasSplitExperts || hasFusedExperts) {
|
||||
// MoE layers: run MLP and MoE in parallel, sum results
|
||||
mlpState := l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
mlpState = l.MLP.Forward(ctx, mlpState)
|
||||
mlpState = l.PostMLPNorm1.Forward(ctx, mlpState, opts.eps)
|
||||
|
||||
routingWeights, selectedExperts := l.Router.Forward(ctx, hiddenState, opts)
|
||||
moeState := l.MoENorm.Forward(ctx, hiddenState, opts.eps)
|
||||
moeState = l.MoE.Forward(ctx, moeState, routingWeights, selectedExperts, opts)
|
||||
moeState = l.PostMoENorm.Forward(ctx, moeState, opts.eps)
|
||||
|
||||
// Combine MLP + MoE, apply outer post-FFN norm, then add residual
|
||||
combined := mlpState.Add(ctx, moeState)
|
||||
combined = l.PostMLPNorm.Forward(ctx, combined, opts.eps)
|
||||
hiddenState = combined.Add(ctx, residual)
|
||||
} else {
|
||||
// Dense layers: MLP only
|
||||
hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = l.MLP.Forward(ctx, hiddenState)
|
||||
hiddenState = l.PostMLPNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = hiddenState.Add(ctx, residual)
|
||||
}
|
||||
|
||||
// PLE injection (after MLP residual)
|
||||
if perLayerInput != nil && l.PerLayerInputGate != nil {
|
||||
pleState := l.PerLayerInputGate.Forward(ctx, hiddenState)
|
||||
pleState = pleState.GELU(ctx, perLayerInput)
|
||||
pleState = l.PerLayerProjection.Forward(ctx, pleState)
|
||||
pleState = l.PostPerLayerNorm.Forward(ctx, pleState, opts.eps)
|
||||
hiddenState = hiddenState.Add(ctx, pleState)
|
||||
}
|
||||
|
||||
// Layer scalar applied at end of layer (full-attention layers only)
|
||||
if l.LayerScalar != nil {
|
||||
hiddenState = hiddenState.Mul(ctx, l.LayerScalar)
|
||||
}
|
||||
|
||||
return hiddenState
|
||||
}
|
||||
Reference in New Issue
Block a user