476 lines
17 KiB
Go
476 lines
17 KiB
Go
package gemma4
|
|
|
|
import (
|
|
"math"
|
|
|
|
"github.com/ollama/ollama/fs"
|
|
"github.com/ollama/ollama/kvcache"
|
|
"github.com/ollama/ollama/ml"
|
|
"github.com/ollama/ollama/ml/nn"
|
|
"github.com/ollama/ollama/ml/nn/rope"
|
|
"github.com/ollama/ollama/model/input"
|
|
)
|
|
|
|
const (
|
|
cacheTypeSWA = iota
|
|
cacheTypeCausal
|
|
)
|
|
|
|
type TextOptions struct {
|
|
hiddenSize int
|
|
numHeads, numKVHeads int
|
|
numGlobalKVHeads int
|
|
headDim, globalHeadDim int
|
|
hiddenLayers int
|
|
hiddenSizePerLayerInput int
|
|
|
|
eps float32
|
|
ropeBase float32
|
|
ropeLocalBase float32
|
|
partialRotaryDims int // RoPE dims for full-attention (global) layers
|
|
|
|
slidingWindowPattern []bool
|
|
// kvDonorMap maps shared layer index -> donor layer index.
|
|
// Donor is the last non-shared layer of the same type (sliding/full).
|
|
kvDonorMap map[int]int
|
|
|
|
finalLogitSoftcap float32
|
|
|
|
numExperts int
|
|
numExpertsUsed int
|
|
}
|
|
|
|
func (o *TextOptions) isLocal(layer int) bool {
|
|
if layer < len(o.slidingWindowPattern) {
|
|
return o.slidingWindowPattern[layer]
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (o *TextOptions) ropeForLayer(layer int) (base float32, dims int) {
|
|
if o.isLocal(layer) {
|
|
return o.ropeLocalBase, o.headDim
|
|
}
|
|
return o.ropeBase, o.partialRotaryDims
|
|
}
|
|
|
|
func (o *TextOptions) kvHeadsForLayer(layer int) int {
|
|
if o.isLocal(layer) {
|
|
return o.numKVHeads
|
|
}
|
|
if o.numGlobalKVHeads > 0 {
|
|
return o.numGlobalKVHeads
|
|
}
|
|
return o.numKVHeads
|
|
}
|
|
|
|
func (o *TextOptions) headDimForLayer(layer int) int {
|
|
if o.isLocal(layer) {
|
|
return o.headDim
|
|
}
|
|
return o.globalHeadDim
|
|
}
|
|
|
|
type TextModel struct {
|
|
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
|
*PerLayerProjector
|
|
Layers []TextLayer `gguf:"blk"`
|
|
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
|
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
|
TextOptions
|
|
}
|
|
|
|
func newTextModel(c fs.Config) *TextModel {
|
|
numLayers := int(c.Uint("block_count"))
|
|
|
|
// Head dimensions: key_length is global head dim, key_length_swa is local (SWA) head dim.
|
|
globalHeadDim := int(c.Uint("attention.key_length", 512))
|
|
headDim := int(c.Uint("attention.key_length_swa", 256))
|
|
|
|
// RoPE dimensions for global (full attention) layers with proportional RoPE.
|
|
// The freq_factors tensor handles partial rotation (1.0 for rotated pairs,
|
|
// 1e30 for non-rotated), so ropeDims equals the full global head dim.
|
|
partialRotaryDims := int(c.Uint("rope.dimension_count", 0))
|
|
if partialRotaryDims == 0 {
|
|
partialFactor := c.Float("rope.partial_rotary_factor", 1.0)
|
|
partialRotaryDims = int(float32(globalHeadDim) * partialFactor)
|
|
}
|
|
|
|
ropeBase := c.Float("rope.freq_base", 1000000.0)
|
|
ropeLocalBase := c.Float("rope.freq_base_swa", 0)
|
|
if ropeLocalBase == 0 {
|
|
ropeLocalBase = c.Float("rope.local.freq_base", 10000.0)
|
|
}
|
|
|
|
numGlobalKVHeads := int(c.Uint("attention.global_head_count_kv", 0))
|
|
slidingPattern := c.Bools("attention.sliding_window_pattern")
|
|
|
|
// KV heads: try per-layer array first (MoE models), then fall back to scalar
|
|
numKVHeads := 0
|
|
kvHeadsArray := c.Ints("attention.head_count_kv")
|
|
if len(kvHeadsArray) > 0 {
|
|
numKVHeads = int(kvHeadsArray[0])
|
|
if numGlobalKVHeads == 0 && len(slidingPattern) > 0 {
|
|
for i, isLocal := range slidingPattern {
|
|
if !isLocal && i < len(kvHeadsArray) {
|
|
numGlobalKVHeads = int(kvHeadsArray[i])
|
|
break
|
|
}
|
|
}
|
|
}
|
|
}
|
|
if numKVHeads == 0 {
|
|
numKVHeads = int(c.Uint("attention.head_count_kv", 0))
|
|
}
|
|
|
|
// Compute KV sharing donor map (same logic as MLX)
|
|
sharedLayers := int(c.Uint("attention.shared_kv_layers", 0))
|
|
kvDonorMap := make(map[int]int)
|
|
if sharedLayers > 0 && len(slidingPattern) > 0 {
|
|
firstShared := numLayers - sharedLayers
|
|
for i := firstShared; i < numLayers; i++ {
|
|
isLocal := slidingPattern[i]
|
|
// Find last non-shared layer of same type
|
|
for j := firstShared - 1; j >= 0; j-- {
|
|
if slidingPattern[j] == isLocal {
|
|
kvDonorMap[i] = j
|
|
break
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return &TextModel{
|
|
Layers: make([]TextLayer, numLayers),
|
|
TextOptions: TextOptions{
|
|
hiddenSize: int(c.Uint("embedding_length")),
|
|
numHeads: int(c.Uint("attention.head_count")),
|
|
numKVHeads: numKVHeads,
|
|
numGlobalKVHeads: numGlobalKVHeads,
|
|
headDim: headDim,
|
|
globalHeadDim: globalHeadDim,
|
|
hiddenLayers: numLayers,
|
|
hiddenSizePerLayerInput: int(c.Uint("embedding_length_per_layer_input", 0)),
|
|
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
|
|
ropeBase: ropeBase,
|
|
ropeLocalBase: ropeLocalBase,
|
|
partialRotaryDims: partialRotaryDims,
|
|
slidingWindowPattern: slidingPattern,
|
|
kvDonorMap: kvDonorMap,
|
|
finalLogitSoftcap: c.Float("final_logit_softcapping", 0.0),
|
|
numExperts: int(c.Uint("expert_count", 0)),
|
|
numExpertsUsed: int(c.Uint("expert_used_count", 0)),
|
|
},
|
|
}
|
|
}
|
|
|
|
func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) ml.Tensor {
|
|
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
|
|
|
|
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
|
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.hiddenSize)))
|
|
|
|
// Inject vision embeddings into the hidden state
|
|
var except []int
|
|
for _, image := range batch.Multimodal {
|
|
visionOutputs := image.Multimodal[0].Tensor
|
|
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))
|
|
|
|
for i := range visionOutputs.Dim(1) {
|
|
except = append(except, image.Index+i)
|
|
}
|
|
}
|
|
|
|
// PLE
|
|
var perLayerInputs ml.Tensor
|
|
if m.PerLayerProjector != nil {
|
|
perLayerInputs = m.PerLayerProjector.Forward(ctx, batch, hiddenState, &m.TextOptions)
|
|
}
|
|
|
|
for i := range len(m.Layers) {
|
|
layer := m.Layers[i]
|
|
if cache != nil {
|
|
cache.SetLayer(i)
|
|
cacheType := cacheTypeSWA
|
|
if !m.isLocal(i) {
|
|
cacheType = cacheTypeCausal
|
|
}
|
|
wc := cache.(*kvcache.WrapperCache)
|
|
wc.SetLayerType(cacheType)
|
|
|
|
if causal, ok := wc.UnderlyingCache().(*kvcache.Causal); ok {
|
|
causal.SetCausal(ctx, kvcache.CausalOptions{Except: except})
|
|
}
|
|
}
|
|
|
|
var lastLayerOutputs ml.Tensor
|
|
if i == len(m.Layers)-1 {
|
|
lastLayerOutputs = batch.Outputs
|
|
}
|
|
|
|
var perLayerInput ml.Tensor
|
|
if perLayerInputs != nil {
|
|
perLayerInput = perLayerInputs.View(ctx, i*perLayerInputs.Stride(1), perLayerInputs.Dim(0), perLayerInputs.Stride(2), perLayerInputs.Dim(2))
|
|
}
|
|
|
|
// KV sharing: layers >= firstShared reuse K/V from donor layers
|
|
isShared := false
|
|
if donorLayer, ok := m.kvDonorMap[i]; ok {
|
|
// Set cache layer to donor so Get() reads donor's K/V
|
|
cache.SetLayer(donorLayer)
|
|
isShared = true
|
|
}
|
|
hiddenState = layer.Forward(ctx, i, hiddenState, positions, perLayerInput, lastLayerOutputs, cache, isShared, &m.TextOptions)
|
|
}
|
|
|
|
return m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
|
}
|
|
|
|
// PerLayerProjector implements PLE.
|
|
type PerLayerProjector struct {
|
|
TokenEmbedding *nn.Embedding `gguf:"per_layer_token_embd"`
|
|
Projector *nn.Linear `gguf:"per_layer_model_proj"`
|
|
Norm *nn.RMSNorm `gguf:"per_layer_proj_norm"`
|
|
}
|
|
|
|
func (p *PerLayerProjector) Forward(ctx ml.Context, batch input.Batch, inputs ml.Tensor, opts *TextOptions) ml.Tensor {
|
|
inputsPerLayer := p.TokenEmbedding.Forward(ctx, batch.Inputs)
|
|
inputsPerLayer = inputsPerLayer.Scale(ctx, math.Sqrt(float64(opts.hiddenSizePerLayerInput)))
|
|
// Reshape to [pleDim, numLayers, numTokens] — matching projection shape
|
|
inputsPerLayer = inputsPerLayer.Reshape(ctx, opts.hiddenSizePerLayerInput, opts.hiddenLayers, inputs.Dim(1))
|
|
|
|
perLayerProjection := p.Projector.Forward(ctx, inputs)
|
|
perLayerProjection = perLayerProjection.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize)))
|
|
perLayerProjection = perLayerProjection.Reshape(ctx, opts.hiddenSizePerLayerInput, opts.hiddenLayers, inputs.Dim(1))
|
|
perLayerProjection = p.Norm.Forward(ctx, perLayerProjection, opts.eps)
|
|
|
|
if inputsPerLayer != nil {
|
|
perLayerProjection = perLayerProjection.Add(ctx, inputsPerLayer)
|
|
perLayerProjection = perLayerProjection.Scale(ctx, 1/math.Sqrt(2))
|
|
}
|
|
|
|
return perLayerProjection
|
|
}
|
|
|
|
type TextSelfAttention struct {
|
|
Query *nn.Linear `gguf:"attn_q"`
|
|
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
|
|
Key *nn.Linear `gguf:"attn_k"`
|
|
KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
|
|
Value *nn.Linear `gguf:"attn_v"`
|
|
Output *nn.Linear `gguf:"attn_output"`
|
|
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"` // proportional RoPE freq_factors
|
|
}
|
|
|
|
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positions ml.Tensor, cache kvcache.Cache, sharedKV bool, opts *TextOptions) ml.Tensor {
|
|
batchSize := hiddenState.Dim(1)
|
|
hd := opts.headDimForLayer(layer)
|
|
kvHeads := opts.kvHeadsForLayer(layer)
|
|
ropeBase, ropeDims := opts.ropeForLayer(layer)
|
|
|
|
q := sa.Query.Forward(ctx, hiddenState)
|
|
q = q.Reshape(ctx, hd, opts.numHeads, batchSize)
|
|
q = sa.QueryNorm.Forward(ctx, q, opts.eps)
|
|
|
|
var k, v ml.Tensor
|
|
if !sharedKV {
|
|
k = sa.Key.Forward(ctx, hiddenState)
|
|
k = k.Reshape(ctx, hd, kvHeads, batchSize)
|
|
|
|
if sa.Value != nil {
|
|
v = sa.Value.Forward(ctx, hiddenState)
|
|
v = v.Reshape(ctx, hd, kvHeads, batchSize)
|
|
} else {
|
|
// K=V: use raw K projection (before K norm) as V
|
|
v = k
|
|
}
|
|
|
|
k = sa.KeyNorm.Forward(ctx, k, opts.eps)
|
|
v = v.RMSNorm(ctx, nil, opts.eps) // V norm: unweighted RMSNorm
|
|
}
|
|
|
|
// RoPE with proportional freq_factors on global layers
|
|
ropeOpts := []func(*rope.Options){rope.WithTypeNeoX()}
|
|
if sa.RopeFactors != nil && !opts.isLocal(layer) {
|
|
ropeOpts = append(ropeOpts, rope.WithFactors(sa.RopeFactors))
|
|
}
|
|
q = nn.RoPE(ctx, q, positions, ropeDims, ropeBase, 1.0, ropeOpts...)
|
|
if k != nil {
|
|
k = nn.RoPE(ctx, k, positions, ropeDims, ropeBase, 1.0, ropeOpts...)
|
|
}
|
|
|
|
attention := nn.Attention(ctx, q, k, v, 1.0, cache)
|
|
|
|
attention = attention.Reshape(ctx, hd*opts.numHeads, batchSize)
|
|
return sa.Output.Forward(ctx, attention)
|
|
}
|
|
|
|
type TextMLP struct {
|
|
Gate *nn.Linear `gguf:"ffn_gate"`
|
|
Up *nn.Linear `gguf:"ffn_up"`
|
|
Down *nn.Linear `gguf:"ffn_down"`
|
|
}
|
|
|
|
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor) ml.Tensor {
|
|
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx, mlp.Up.Forward(ctx, hiddenState))
|
|
return mlp.Down.Forward(ctx, hiddenState)
|
|
}
|
|
|
|
// TextRouter implements the Gemma 4 MoE router.
|
|
type TextRouter struct {
|
|
Proj *nn.Linear `gguf:"ffn_gate_inp"`
|
|
Scale ml.Tensor `gguf:"ffn_gate_inp.scale"`
|
|
}
|
|
|
|
func (r *TextRouter) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) (routingWeights, selectedExperts ml.Tensor) {
|
|
// RMSNorm without learned weight
|
|
x := hiddenState.RMSNorm(ctx, nil, opts.eps)
|
|
// Scale by 1/sqrt(hidden_size)
|
|
x = x.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize)))
|
|
// Multiply by learned scale parameter
|
|
x = x.Mul(ctx, r.Scale)
|
|
// Project to expert logits
|
|
expertScores := r.Proj.Forward(ctx, x)
|
|
// Softmax over experts
|
|
routingWeights = expertScores.Softmax(ctx)
|
|
// TopK expert selection
|
|
selectedExperts = routingWeights.TopK(ctx, opts.numExpertsUsed)
|
|
return routingWeights, selectedExperts
|
|
}
|
|
|
|
// TextMoEBlock implements the Gemma 4 sparse MoE.
|
|
type TextMoEBlock struct {
|
|
GateUp *nn.LinearBatch `gguf:"ffn_gate_up_exps"`
|
|
Gate *nn.LinearBatch `gguf:"ffn_gate_exps"`
|
|
Up *nn.LinearBatch `gguf:"ffn_up_exps"`
|
|
Down *nn.LinearBatch `gguf:"ffn_down_exps"`
|
|
DownScale ml.Tensor `gguf:"ffn_down_exps.scale,alt:ffn_gate_inp.per_expert_scale"`
|
|
}
|
|
|
|
func (moe *TextMoEBlock) Forward(ctx ml.Context, hiddenState, routingWeights, selectedExperts ml.Tensor, opts *TextOptions) ml.Tensor {
|
|
// Select routing weights for chosen experts and renormalize
|
|
routingWeights = routingWeights.Reshape(ctx, 1, opts.numExperts, hiddenState.Dim(1)).Rows(ctx, selectedExperts)
|
|
routingWeights = routingWeights.Reshape(ctx, opts.numExpertsUsed, hiddenState.Dim(1))
|
|
routingWeights = routingWeights.Div(ctx, routingWeights.SumRows(ctx))
|
|
routingWeights = routingWeights.Reshape(ctx, 1, opts.numExpertsUsed, hiddenState.Dim(1))
|
|
|
|
hiddenState = hiddenState.Reshape(ctx, hiddenState.Dim(0), 1, hiddenState.Dim(1))
|
|
|
|
// Expert computation using LinearBatch (MulmatID selecting experts by index)
|
|
var gateOut, upOut ml.Tensor
|
|
if moe.GateUp != nil && moe.GateUp.Weight != nil {
|
|
gateUp := moe.GateUp.Forward(ctx, hiddenState, selectedExperts)
|
|
nFF := gateUp.Dim(0) / 2
|
|
gateOut = gateUp.Slice(ctx, 0, 0, nFF, 1)
|
|
upOut = gateUp.Slice(ctx, 0, nFF, gateUp.Dim(0), 1)
|
|
} else {
|
|
gateOut = moe.Gate.Forward(ctx, hiddenState, selectedExperts)
|
|
upOut = moe.Up.Forward(ctx, hiddenState, selectedExperts)
|
|
}
|
|
hiddenState = gateOut.GELU(ctx, upOut)
|
|
experts := moe.Down.Forward(ctx, hiddenState, selectedExperts)
|
|
|
|
// Apply per-expert down projection scale when present.
|
|
if moe.DownScale != nil {
|
|
expertScales := moe.DownScale.Reshape(ctx, opts.numExperts, 1)
|
|
expertScales = expertScales.Repeat(ctx, 1, hiddenState.Dim(2))
|
|
expertScales = expertScales.Reshape(ctx, 1, opts.numExperts, hiddenState.Dim(2)).Rows(ctx, selectedExperts)
|
|
expertScales = expertScales.Reshape(ctx, opts.numExpertsUsed, hiddenState.Dim(2))
|
|
expertScales = expertScales.Reshape(ctx, 1, opts.numExpertsUsed, hiddenState.Dim(2))
|
|
experts = experts.Mul(ctx, expertScales)
|
|
}
|
|
|
|
// Apply routing weights
|
|
experts = experts.Mul(ctx, routingWeights)
|
|
|
|
// Sum across experts
|
|
nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
|
|
for i := 1; i < opts.numExpertsUsed; i++ {
|
|
nextStates = nextStates.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2)))
|
|
}
|
|
|
|
return nextStates
|
|
}
|
|
|
|
type TextLayer struct {
|
|
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
|
|
SelfAttention *TextSelfAttention
|
|
PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_norm,alt:attn_post_norm"`
|
|
MLPNorm *nn.RMSNorm `gguf:"ffn_norm,alt:ffn_pre_norm"`
|
|
MLP *TextMLP
|
|
PostMLPNorm *nn.RMSNorm `gguf:"post_ffw_norm,alt:ffn_post_norm"`
|
|
|
|
// MoE (present only for models with enable_moe_block=true)
|
|
Router *TextRouter
|
|
MoE *TextMoEBlock
|
|
MoENorm *nn.RMSNorm `gguf:"pre_ffw_norm_2,alt:ffn_pre_norm_2"`
|
|
PostMoENorm *nn.RMSNorm `gguf:"post_ffw_norm_2,alt:ffn_post_norm_2"`
|
|
PostMLPNorm1 *nn.RMSNorm `gguf:"post_ffw_norm_1,alt:ffn_post_norm_1"` // used instead of PostMLPNorm when MoE is present
|
|
|
|
PerLayerInputGate *nn.Linear `gguf:"inp_gate"`
|
|
PerLayerProjection *nn.Linear `gguf:"proj"`
|
|
PostPerLayerNorm *nn.RMSNorm `gguf:"post_norm"`
|
|
LayerScalar ml.Tensor `gguf:"layer_scalar,alt:layer_output_scale.weight"`
|
|
}
|
|
|
|
func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positions, perLayerInput, outputs ml.Tensor, cache kvcache.Cache, sharedKV bool, opts *TextOptions) ml.Tensor {
|
|
residual := hiddenState
|
|
|
|
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
|
hiddenState = l.SelfAttention.Forward(ctx, layer, hiddenState, positions, cache, sharedKV, opts)
|
|
hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
|
|
|
if outputs != nil {
|
|
hiddenState = hiddenState.Rows(ctx, outputs)
|
|
residual = residual.Rows(ctx, outputs)
|
|
if perLayerInput != nil {
|
|
perLayerInput = perLayerInput.Rows(ctx, outputs)
|
|
}
|
|
}
|
|
|
|
hiddenState = hiddenState.Add(ctx, residual)
|
|
residual = hiddenState
|
|
|
|
// MLP (+ optional MoE in parallel)
|
|
hasSplitExperts := l.MoE != nil && l.MoE.Gate != nil && l.MoE.Up != nil && l.MoE.Gate.Weight != nil && l.MoE.Up.Weight != nil
|
|
hasFusedExperts := l.MoE != nil && l.MoE.GateUp != nil && l.MoE.GateUp.Weight != nil
|
|
if l.Router != nil && l.MoE != nil && l.MoE.Down != nil && l.MoE.Down.Weight != nil && (hasSplitExperts || hasFusedExperts) {
|
|
// MoE layers: run MLP and MoE in parallel, sum results
|
|
mlpState := l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
|
|
mlpState = l.MLP.Forward(ctx, mlpState)
|
|
mlpState = l.PostMLPNorm1.Forward(ctx, mlpState, opts.eps)
|
|
|
|
routingWeights, selectedExperts := l.Router.Forward(ctx, hiddenState, opts)
|
|
moeState := l.MoENorm.Forward(ctx, hiddenState, opts.eps)
|
|
moeState = l.MoE.Forward(ctx, moeState, routingWeights, selectedExperts, opts)
|
|
moeState = l.PostMoENorm.Forward(ctx, moeState, opts.eps)
|
|
|
|
// Combine MLP + MoE, apply outer post-FFN norm, then add residual
|
|
combined := mlpState.Add(ctx, moeState)
|
|
combined = l.PostMLPNorm.Forward(ctx, combined, opts.eps)
|
|
hiddenState = combined.Add(ctx, residual)
|
|
} else {
|
|
// Dense layers: MLP only
|
|
hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
|
|
hiddenState = l.MLP.Forward(ctx, hiddenState)
|
|
hiddenState = l.PostMLPNorm.Forward(ctx, hiddenState, opts.eps)
|
|
hiddenState = hiddenState.Add(ctx, residual)
|
|
}
|
|
|
|
// PLE injection (after MLP residual)
|
|
if perLayerInput != nil && l.PerLayerInputGate != nil {
|
|
pleState := l.PerLayerInputGate.Forward(ctx, hiddenState)
|
|
pleState = pleState.GELU(ctx, perLayerInput)
|
|
pleState = l.PerLayerProjection.Forward(ctx, pleState)
|
|
pleState = l.PostPerLayerNorm.Forward(ctx, pleState, opts.eps)
|
|
hiddenState = hiddenState.Add(ctx, pleState)
|
|
}
|
|
|
|
// Layer scalar applied at end of layer (full-attention layers only)
|
|
if l.LayerScalar != nil {
|
|
hiddenState = hiddenState.Mul(ctx, l.LayerScalar)
|
|
}
|
|
|
|
return hiddenState
|
|
}
|