ollama source for Momentry Core verification

This commit is contained in:
Accusys
2026-05-22 17:19:10 +08:00
commit 0b31ff9135
2020 changed files with 1413145 additions and 0 deletions

View File

@@ -0,0 +1,265 @@
package gemma4
import (
"bytes"
"fmt"
"image"
"log/slog"
"slices"
"time"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
tokenizer.Tokenizer
*VisionModel `gguf:"v"`
*TextModel
*AudioModel `gguf:"a"`
*MultiModalProjector `gguf:"mm"`
*AudioMultimodalProjector `gguf:"mm.a"`
ImageProcessor
imageTokenID int32
imageEndTokenID int32
audioTokenID int32
audioEndTokenID int32
audioOpts *AudioModelOptions
}
var _ model.MultimodalProcessor = (*Model)(nil)
type MultiModalProjector struct {
Projection *ClippableLinear `gguf:"input_projection"`
}
func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, eps float32) ml.Tensor {
visionOutputs = p.Projection.Forward(ctx, visionOutputs)
// Post-projection RMSNorm without learned weight
visionOutputs = visionOutputs.RMSNorm(ctx, nil, eps)
return visionOutputs
}
func New(c fs.Config) (model.Model, error) {
vocabulary := tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{
int32(c.Uint("tokenizer.ggml.eos_token_id")),
},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
}
vocabulary.EOS = append(vocabulary.EOS, int32(c.Uint("tokenizer.ggml.eot_token_id", 106)))
// Gemma 4 uses BPE with SentencePiece-style ▁ space markers (not GPT-2 byte-level encoding).
// The tokenizer.json has merges and a Replace normalizer (space → ▁), with no pre-tokenizer.
t := tokenizer.NewBytePairEncodingWithOptions(&vocabulary, []string{},
tokenizer.WithSentencePieceNormalizer())
// Look up special token IDs for vision and audio
imageTokenID := int32(-1)
imageEndTokenID := int32(-1)
audioTokenID := int32(-1)
audioEndTokenID := int32(-1)
for i, tok := range vocabulary.Values {
switch tok {
case "<|image>":
imageTokenID = int32(i)
case "<image|>":
imageEndTokenID = int32(i)
case "<|audio>":
audioTokenID = int32(i)
case "<audio|>":
audioEndTokenID = int32(i)
}
}
slog.Info("gemma4: token IDs", "image", imageTokenID, "image_end", imageEndTokenID, "audio", audioTokenID, "audio_end", audioEndTokenID)
m := Model{
Tokenizer: t,
TextModel: newTextModel(c),
VisionModel: newVisionModel(c),
AudioModel: newAudioModel(c),
MultiModalProjector: &MultiModalProjector{},
AudioMultimodalProjector: &AudioMultimodalProjector{},
ImageProcessor: newImageProcessor(c),
imageTokenID: imageTokenID,
imageEndTokenID: imageEndTokenID,
audioTokenID: audioTokenID,
audioEndTokenID: audioEndTokenID,
audioOpts: newAudioModelOptions(c),
}
slidingWindowLen := int32(c.Uint("attention.sliding_window"))
m.Cache = kvcache.NewWrapperCache(
kvcache.NewSWAMemCache(slidingWindowLen, 4096, m.Shift),
kvcache.NewCausalCache(m.Shift),
)
return &m, nil
}
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) {
// Audio input: detect WAV format and route to audio encoder.
if isAudioData(multimodalData) {
return m.encodeAudioMultimodal(ctx, multimodalData)
}
if len(m.VisionModel.Layers) == 0 {
return nil, model.ErrNoVisionModel
}
t0 := time.Now()
img, _, err := image.Decode(bytes.NewReader(multimodalData))
if err != nil {
return nil, err
}
slog.Info("vision: decode", "elapsed", time.Since(t0), "bounds", img.Bounds())
t1 := time.Now()
f32s, imgW, imgH, err := m.ImageProcessor.ProcessImage(img)
if err != nil {
return nil, err
}
slog.Info("vision: preprocess", "elapsed", time.Since(t1), "size", [2]int{imgW, imgH})
pixelValues := ctx.Input().FromFloats(f32s, imgW, imgH, m.ImageProcessor.numChannels)
slog.Info("vision: pixelValues", "shape", pixelValues.Shape(), "dim0", pixelValues.Dim(0), "dim1", pixelValues.Dim(1), "dim2", pixelValues.Dim(2))
numPatchesX := imgW / m.ImageProcessor.patchSize
numPatchesY := imgH / m.ImageProcessor.patchSize
slog.Info("vision: patches", "patchesX", numPatchesX, "patchesY", numPatchesY, "total", numPatchesX*numPatchesY, "patchSize", m.ImageProcessor.patchSize)
visionOutputs := m.VisionModel.Forward(ctx, pixelValues, numPatchesX, numPatchesY)
visionOutputs = visionPoolAndProject(ctx, visionOutputs, numPatchesX, numPatchesY, m.VisionModel.VisionModelOptions, m.MultiModalProjector, m.VisionModel.StdBias, m.VisionModel.StdScale)
slog.Info("vision: encoded", "elapsed", time.Since(t0), "shape", visionOutputs.Shape())
return []input.Multimodal{{Tensor: visionOutputs}}, nil
}
func (m *Model) PostLoad() error {
m.VisionModel.InitClamp(m.MultiModalProjector)
return nil
}
func (m *Model) encodeAudioMultimodal(ctx ml.Context, data []byte) ([]input.Multimodal, error) {
if m.AudioModel == nil || m.audioOpts == nil {
return nil, model.ErrNoVisionModel
}
t0 := time.Now()
samples, err := decodeWAV(data)
if err != nil {
return nil, err
}
slog.Info("audio: decode", "elapsed", time.Since(t0), "samples", len(samples), "duration_s", float64(len(samples))/audioSampleRate)
// Pad waveform to next multiple of 128.
if rem := len(samples) % 128; rem != 0 {
samples = append(samples, make([]float32, 128-rem)...)
}
// Compute mel spectrogram.
melData, numFrames := computeMelSpectrogram(samples)
if numFrames == 0 {
return nil, fmt.Errorf("audio too short to encode")
}
slog.Info("audio: mel", "frames", numFrames, "elapsed", time.Since(t0))
// Create input tensor [melBins, numFrames] (GGML ne order). FromFloats creates F32.
melTensor := ctx.Input().FromFloats(melData, melBins, numFrames)
// Run audio encoder.
audioOutputs := m.AudioModel.ForwardAudio(ctx, melTensor, m.AudioMultimodalProjector, m.audioOpts)
slog.Info("audio: encoded", "elapsed", time.Since(t0), "shape", audioOutputs.Shape())
return []input.Multimodal{{Tensor: audioOutputs, Data: audioTag{}}}, nil
}
// audioTag marks multimodal data as audio (vs vision) for PostTokenize.
type audioTag struct{}
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
var result []*input.Input
for _, inp := range inputs {
if len(inp.Multimodal) == 0 {
result = append(result, inp)
continue
}
inputMultimodal := inp.Multimodal[0].Tensor
numTokens := inputMultimodal.Dim(1)
// Determine if this is audio or vision based on the tag.
_, isAudio := inp.Multimodal[0].Data.(audioTag)
var beginToken, endToken int32
if isAudio {
beginToken = m.audioTokenID
endToken = m.audioEndTokenID
} else {
beginToken = m.imageTokenID
endToken = m.imageEndTokenID
}
if beginToken >= 0 {
result = append(result, &input.Input{Token: beginToken, SameBatch: numTokens + 2})
}
result = append(result,
&input.Input{Multimodal: []input.Multimodal{{Tensor: inputMultimodal}}, MultimodalHash: inp.MultimodalHash},
)
result = append(result, slices.Repeat([]*input.Input{{Token: 0}}, numTokens-1)...)
if endToken >= 0 {
result = append(result, &input.Input{Token: endToken})
}
}
return result, nil
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
hiddenState := m.TextModel.Forward(ctx, batch, m.Cache)
hiddenState = m.TextModel.Output.Forward(ctx, hiddenState)
if m.TextModel.TextOptions.finalLogitSoftcap > 0.0 {
hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.TextModel.TextOptions.finalLogitSoftcap))
hiddenState = hiddenState.Tanh(ctx)
hiddenState = hiddenState.Scale(ctx, float64(m.TextModel.TextOptions.finalLogitSoftcap))
}
return hiddenState, nil
}
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
ropeBase, ropeDims := m.TextModel.ropeForLayer(layer)
return nn.RoPE(ctx, key, shift, ropeDims, ropeBase, 1.0, rope.WithTypeNeoX()), nil
}
func init() {
model.Register("gemma4", New)
}

View File

@@ -0,0 +1,611 @@
package gemma4
import (
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
)
// AudioModel holds the audio encoder and configuration.
type AudioModel struct {
// SSCP: Sub-Sample Convolution Projection.
SSCPConv0 *AudioConvBlock `gguf:"conv1d.0"`
SSCPConv1 *AudioConvBlock `gguf:"conv1d.1"`
// SSCP output projection (linear).
SSCPInputProj *nn.Linear `gguf:"pre_encode.out"`
// Conformer blocks.
Layers []AudioConformerBlock `gguf:"blk"`
// Output projection to embedder dimension.
OutputProj *AudioOutputProj `gguf:"output_proj"`
AudioModelOptions
}
type AudioOutputProj struct {
Weight ml.Tensor `gguf:"weight"`
Bias ml.Tensor `gguf:"bias"`
}
// AudioModelOptions holds audio model hyperparameters.
type AudioModelOptions struct {
hiddenSize int
numHeads int
headDim int
ffnSize int
numLayers int
melBins int
chunkSize int
maxPast int
maxFuture int
contextSize int
logitCap float32
residualWeight float32
gradClip float32
convKernelSize int
eps float32
}
// AudioConvBlock is a single 2D convolution block for the SSCP.
type AudioConvBlock struct {
Weight ml.Tensor `gguf:"weight"`
Norm *nn.LayerNorm `gguf:"norm"`
}
// AudioConformerBlock is a single conformer layer.
// All tensors are flat at the block level (a.blk.N.<name>) using underscore naming.
type AudioConformerBlock struct {
// Block-level norm
Norm *nn.RMSNorm `gguf:"layer_pre_norm"`
// FFW start
FFWNorm *nn.RMSNorm `gguf:"ffn_norm"`
FFWUp *AudioClippableLinear `gguf:"ffn_up"`
FFWDown *AudioClippableLinear `gguf:"ffn_down"`
FFWPostNorm *nn.RMSNorm `gguf:"ffn_post_norm"`
// FFW end
FFWNorm1 *nn.RMSNorm `gguf:"ffn_norm_1"`
FFWUp1 *AudioClippableLinear `gguf:"ffn_up_1"`
FFWDown1 *AudioClippableLinear `gguf:"ffn_down_1"`
FFWPostNorm1 *nn.RMSNorm `gguf:"ffn_post_norm_1"`
// Attention
AttnQ *AudioClippableLinear `gguf:"attn_q"`
AttnK *AudioClippableLinear `gguf:"attn_k"`
AttnV *AudioClippableLinear `gguf:"attn_v"`
AttnOut *AudioClippableLinear `gguf:"attn_out"`
AttnPreNorm *nn.RMSNorm `gguf:"ln1"`
AttnPostNorm *nn.RMSNorm `gguf:"ln2"`
LinearPos ml.Tensor `gguf:"linear_pos.weight"`
PerDimScale ml.Tensor `gguf:"per_dim_scale.weight"`
// LightConv1d
ConvPW1 *AudioClippableLinear `gguf:"conv_pw1"`
ConvPW2 *AudioClippableLinear `gguf:"conv_pw2"`
ConvDW ml.Tensor `gguf:"conv_dw.weight"`
ConvNorm *nn.RMSNorm `gguf:"conv_norm"`
NormConv *nn.RMSNorm `gguf:"norm_conv"`
}
// AudioClippableLinear is a linear layer with optional input/output clamping.
type AudioClippableLinear struct {
Weight ml.Tensor `gguf:"weight"`
Bias ml.Tensor `gguf:"bias"`
InputMin ml.Tensor `gguf:"input_min"`
InputMax ml.Tensor `gguf:"input_max"`
OutputMin ml.Tensor `gguf:"output_min"`
OutputMax ml.Tensor `gguf:"output_max"`
// Cached scalar clamp values (populated on first forward).
inMin, inMax, outMin, outMax float32
clampsLoaded bool
}
func (l *AudioClippableLinear) loadClamps() {
if l.clampsLoaded {
return
}
l.clampsLoaded = true
if l.InputMin != nil {
vals := l.InputMin.BackendGet()
if len(vals) > 0 {
l.inMin = vals[0]
}
}
if l.InputMax != nil {
vals := l.InputMax.BackendGet()
if len(vals) > 0 {
l.inMax = vals[0]
}
}
if l.OutputMin != nil {
vals := l.OutputMin.BackendGet()
if len(vals) > 0 {
l.outMin = vals[0]
}
}
if l.OutputMax != nil {
vals := l.OutputMax.BackendGet()
if len(vals) > 0 {
l.outMax = vals[0]
}
}
}
func (l *AudioClippableLinear) Forward(ctx ml.Context, x ml.Tensor) ml.Tensor {
l.loadClamps()
if l.inMax != 0 {
x = x.Clamp(ctx, l.inMin, l.inMax)
}
out := l.Weight.Mulmat(ctx, x)
if l.Bias != nil {
out = out.Add(ctx, l.Bias)
}
if l.outMax != 0 {
out = out.Clamp(ctx, l.outMin, l.outMax)
}
return out
}
// AudioMultimodalProjector is the audio-to-text embedding projector.
type AudioMultimodalProjector struct {
Projection *AudioClippableLinear `gguf:"input_projection"`
FC *AudioFC `gguf:"fc"`
}
type AudioFC struct {
Weight ml.Tensor `gguf:"weight"`
Bias ml.Tensor `gguf:"bias"`
}
func (p *AudioMultimodalProjector) Forward(ctx ml.Context, x ml.Tensor, eps float32) ml.Tensor {
// FC: output projection from conformer to embedder dimension.
x = p.FC.Weight.Mulmat(ctx, x)
if p.FC.Bias != nil {
x = x.Add(ctx, p.FC.Bias)
}
// Pre-projection RMSNorm (without learned weight) — matches Python's embedding_pre_projection_norm.
x = x.RMSNorm(ctx, nil, eps)
// Embedding projection to text hidden size.
x = p.Projection.Forward(ctx, x)
return x
}
// ForwardAudio encodes mel spectrogram features into soft tokens.
// melFeatures: float32 tensor with ne[0]=melBins, ne[1]=numFrames.
// Returns: [hiddenSize, numTokens] tensor.
func (m *AudioModel) ForwardAudio(ctx ml.Context, melFeatures ml.Tensor, proj *AudioMultimodalProjector, opts *AudioModelOptions) ml.Tensor {
// SSCP Conv2D input: ne[0]=F (freq/width), ne[1]=T (time/height), ne[2]=C_in, ne[3]=B
// melFeatures is [melBins, numFrames], add channel and batch dims.
x := melFeatures.Reshape(ctx, melFeatures.Dim(0), melFeatures.Dim(1), 1, 1)
// SSCP Conv block 0: [F, T, 1, 1] → [F', T', C0, 1]
x = forwardConvBlock(ctx, m.SSCPConv0, x, opts)
// SSCP Conv block 1: [F', T', C0, 1] → [F'', T'', C1, 1]
x = forwardConvBlock(ctx, m.SSCPConv1, x, opts)
// After conv blocks, layout is [F'', T'', C_out, B].
// Permute to [C_out*F'', T'', B] for linear projection (channels+freq in ne[0]).
fOut := x.Dim(0)
tOut := x.Dim(1)
cOut := x.Dim(2)
// Permute [F'', T'', C, B] → [C, F'', T'', B]
// (1,2,0,3): old[0]→pos1, old[1]→pos2, old[2]→pos0
x = x.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
x = x.Reshape(ctx, cOut*fOut, tOut)
// Linear projection to hidden size.
x = m.SSCPInputProj.Forward(ctx, x)
// Build causal-valid mask for conformer attention.
causalMask := buildCausalValidMaskF32(opts.chunkSize, opts.maxPast, opts.maxFuture)
// Run conformer blocks.
for i := range m.Layers {
x = m.Layers[i].Forward(ctx, x, causalMask, opts, i)
}
// Output projection.
if m.OutputProj != nil {
x = m.OutputProj.Weight.Mulmat(ctx, x)
if m.OutputProj.Bias != nil {
x = x.Add(ctx, m.OutputProj.Bias)
}
}
// Audio embedder: project to text embedding space.
if proj != nil {
x = proj.Forward(ctx, x, opts.eps)
}
return x
}
// forwardConvBlock runs a single SSCP Conv2D block.
// Conv2D receiver is the kernel, argument is the input data.
// Input: [F, T, C_in, B]. Output: [F', T', C_out, B].
func forwardConvBlock(ctx ml.Context, block *AudioConvBlock, x ml.Tensor, opts *AudioModelOptions) ml.Tensor {
// Conv2D: kernel.Conv2D(ctx, input, s0, s1, p0, p1, d0, d1)
// Kernel is 3x3, stride 2x2, padding 1x1 (matching SSCP config).
// Output layout: [F', T', C_out, B]
// Make weight contiguous — the shape reversal in the converter creates
// a tensor where the physical data order doesn't match ne[]/stride[].
weight := block.Weight.Contiguous(ctx)
x = weight.Conv2D(ctx, x, 2, 2, 1, 1, 1, 1)
// LayerNorm needs channels in ne[0]. Permute [F', T', C_out, B] → [C_out, F', T', B],
// norm, then permute back.
// GGML permute: axis i says where old axis i goes.
// (1,2,0,3): old[0]→pos1, old[1]→pos2, old[2]→pos0 → [C_out, F', T', B]
x = x.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
x = block.Norm.Forward(ctx, x, opts.eps)
// (2,0,1,3): old[0]→pos2, old[1]→pos0, old[2]→pos1 → [F', T', C_out, B]
x = x.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
x = x.RELU(ctx)
return x
}
// Forward runs a single conformer block.
func (cb *AudioConformerBlock) Forward(ctx ml.Context, x ml.Tensor, causalMask []float32, opts *AudioModelOptions, blockIdx int) ml.Tensor {
// FFW start (half-residual).
x = cb.forwardFFW(ctx, cb.FFWNorm, cb.FFWUp, cb.FFWDown, cb.FFWPostNorm, x, opts)
// Self-attention.
x = cb.forwardAttention(ctx, x, causalMask, opts, blockIdx)
// Lightweight Conv1d.
x = cb.forwardLightConv(ctx, x, opts, blockIdx)
// FFW end (half-residual).
x = cb.forwardFFW(ctx, cb.FFWNorm1, cb.FFWUp1, cb.FFWDown1, cb.FFWPostNorm1, x, opts)
// Gradient clipping + final norm.
x = x.Clamp(ctx, -opts.gradClip, opts.gradClip)
x = cb.Norm.Forward(ctx, x, opts.eps)
return x
}
// forwardFFW runs a feedforward module with half-residual connection.
func (cb *AudioConformerBlock) forwardFFW(ctx ml.Context, preNorm *nn.RMSNorm, up, down *AudioClippableLinear, postNorm *nn.RMSNorm, x ml.Tensor, opts *AudioModelOptions) ml.Tensor {
residual := x
x = x.Clamp(ctx, -opts.gradClip, opts.gradClip)
x = preNorm.Forward(ctx, x, opts.eps)
x = up.Forward(ctx, x)
x = x.SILU(ctx)
x = down.Forward(ctx, x)
x = x.Clamp(ctx, -opts.gradClip, opts.gradClip)
x = postNorm.Forward(ctx, x, opts.eps)
x = x.Scale(ctx, float64(opts.residualWeight))
return residual.Add(ctx, x)
}
// forwardAttention runs the conformer block-local attention with relative position embeddings.
func (cb *AudioConformerBlock) forwardAttention(ctx ml.Context, x ml.Tensor, causalMask []float32, opts *AudioModelOptions, blockIdx int) ml.Tensor {
residual := x
x = x.Clamp(ctx, -opts.gradClip, opts.gradClip)
x = cb.AttnPreNorm.Forward(ctx, x, opts.eps)
hiddenSize := x.Dim(0)
seqLen := x.Dim(1)
// QKV projections: [hiddenSize, seqLen] → [headDim, numHeads, seqLen]
q := cb.AttnQ.Forward(ctx, x).Reshape(ctx, opts.headDim, opts.numHeads, seqLen)
k := cb.AttnK.Forward(ctx, x).Reshape(ctx, opts.headDim, opts.numHeads, seqLen)
v := cb.AttnV.Forward(ctx, x).Reshape(ctx, opts.headDim, opts.numHeads, seqLen)
// Per-dim scaling for queries: (headDim^-0.5 / log(2)) * softplus(per_dim_scale)
// per_dim_scale is already softplus'd from the converter.
qScale := float64(math.Pow(float64(opts.headDim), -0.5)) / math.Log(2)
q = q.Scale(ctx, qScale)
if cb.PerDimScale != nil {
q = q.Mul(ctx, cb.PerDimScale)
}
// Key scaling: softplus(1) / log(2) — matches the query base scaling convention.
kScale := math.Log(1+math.E) / math.Log(2)
k = k.Scale(ctx, kScale)
// Build sinusoidal position embeddings for the block-local context.
maxSpan := opts.maxPast + opts.maxFuture + 1 // 13 unique relative positions
posEmb := cb.buildPositionEmbeddings(ctx, maxSpan, opts)
// posEmb: [headDim, numHeads, maxSpan]
// Block-local attention: process chunks of size chunkSize.
chunkSize := opts.chunkSize
numChunks := (seqLen + chunkSize - 1) / chunkSize
contextSize := opts.contextSize
// Pad q/k/v to multiple of chunkSize on the time dimension (dim 2).
padT := numChunks*chunkSize - seqLen
if padT > 0 {
q = q.Pad(ctx, 0, 0, padT, 0)
k = k.Pad(ctx, 0, 0, padT, 0)
v = v.Pad(ctx, 0, 0, padT, 0)
}
paddedLen := numChunks * chunkSize
// Pad k/v for context extraction: add maxPast on left, (maxFuture+chunkSize-1) on right.
// Use Pad (right) + PadExt (left) workaround since PadExt+Slice has issues.
// Actually use Concat with zero tensors for reliable left-padding.
padLeft := opts.maxPast
padRight := opts.maxFuture + chunkSize - 1
zeroLeft := ctx.Input().FromFloats(make([]float32, opts.headDim*opts.numHeads*padLeft), opts.headDim, opts.numHeads, padLeft)
zeroRight := ctx.Input().FromFloats(make([]float32, opts.headDim*opts.numHeads*padRight), opts.headDim, opts.numHeads, padRight)
kPadded := zeroLeft.Concat(ctx, k, 2).Concat(ctx, zeroRight, 2)
vPadded := zeroLeft.Concat(ctx, v, 2).Concat(ctx, zeroRight, 2)
// Reshape q into chunks: [headDim, numHeads, numChunks, chunkSize]
qChunked := q.Reshape(ctx, opts.headDim, opts.numHeads, numChunks, chunkSize)
// Process each chunk and collect results.
chunkOutputs := make([]ml.Tensor, numChunks)
for u := range numChunks {
// Extract query block: [headDim, numHeads, 1, chunkSize] → [headDim, numHeads, chunkSize]
qBlock := qChunked.Slice(ctx, 2, u, u+1, 1).Reshape(ctx, opts.headDim, opts.numHeads, chunkSize)
// Extract key/value context: [headDim, numHeads, contextSize]
cStart := u * chunkSize // offset in kPadded (padLeft already accounts for left context)
kCtx := kPadded.Slice(ctx, 2, cStart, cStart+contextSize, 1).Contiguous(ctx)
vCtx := vPadded.Slice(ctx, 2, cStart, cStart+contextSize, 1).Contiguous(ctx)
// Content-content logits: qBlock^T @ kCtx → [chunkSize, contextSize] per head.
// Mulmat(a, b) = a^T @ b. We want Q^T K, so: kCtx.Mulmat(qBlock) but that gives
// [numHeads, chunkSize, contextSize] with wrong batching.
// Instead: permute to [headDim, chunkSize, numHeads] and [headDim, contextSize, numHeads]
// then Mulmat batches over numHeads.
// GGML permute(0,2,1,3): old[0]→0, old[1]→2, old[2]→1
qP := qBlock.Permute(ctx, 0, 2, 1, 3) // [headDim, chunkSize, numHeads]
kP := kCtx.Permute(ctx, 0, 2, 1, 3) // [headDim, contextSize, numHeads]
termAC := kP.MulmatFullPrec(ctx, qP) // [contextSize, chunkSize, numHeads]
// Content-position logits: qBlock^T @ posEmb → [chunkSize, maxSpan] per head.
pP := posEmb.Permute(ctx, 0, 2, 1, 3) // [headDim, maxSpan, numHeads]
termBDRaw := pP.MulmatFullPrec(ctx, qP) // [maxSpan, chunkSize, numHeads]
// Relative shift: [maxSpan, chunkSize, numHeads] → [contextSize, chunkSize, numHeads]
termBD := cb.relativeShiftGGML(ctx, termBDRaw, maxSpan, chunkSize, contextSize, opts.numHeads)
// Combined logits.
logits := termAC.Add(ctx, termBD)
// Logit softcap: tanh(logits / cap) * cap
logits = logits.Scale(ctx, 1.0/float64(opts.logitCap))
logits = logits.Tanh(ctx)
logits = logits.Scale(ctx, float64(opts.logitCap))
// Apply combined causal + validity mask.
// causalMask [chunkSize * contextSize]: 1=causal-allowed, 0=masked.
// Validity: context positions before the actual sequence start are invalid.
// For chunk u, context position c corresponds to actual time: u*chunkSize + c - padLeft.
// Valid if 0 <= actual_time < seqLen.
// Mask tensor layout: [contextSize, chunkSize, 1] with ne[0]=contextSize contiguous.
// Element at (context=j, chunk=i) is at flat index: i*contextSize + j.
maskData := make([]float32, contextSize*chunkSize)
for i := range chunkSize {
for j := range contextSize {
actualTime := u*chunkSize + j - padLeft
causalOK := causalMask[i*contextSize+j] > 0
validOK := actualTime >= 0 && actualTime < seqLen
if causalOK && validOK {
maskData[i*contextSize+j] = 0
} else {
maskData[i*contextSize+j] = -1e9
}
}
}
mask := ctx.Input().FromFloats(maskData, contextSize, chunkSize, 1) // 3D for broadcasting over numHeads
logits = logits.Add(ctx, mask)
// Softmax over context dimension (dim 0 = contextSize).
logits = logits.Softmax(ctx) // softmax over ne[0]=contextSize
// Weighted sum: logits^T @ vCtx.
// logits: [contextSize, chunkSize, numHeads], vCtx: [headDim, numHeads, contextSize]
// vCtx permuted: [headDim, contextSize, numHeads]
vP := vCtx.Permute(ctx, 0, 2, 1, 3) // [headDim, contextSize, numHeads]
// Weighted sum: for each head, value[headDim, contextSize] @ weights[contextSize, chunkSize]
// = [headDim, chunkSize].
// Mulmat(a, b) = a^T @ b. Need a=[contextSize, headDim, numHeads], b=[contextSize, chunkSize, numHeads].
vPT := vP.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) // [contextSize, headDim, numHeads]
chunkOut := vPT.Mulmat(ctx, logits) // [headDim, chunkSize, numHeads]
// Permute back to [headDim, numHeads, chunkSize]
chunkOut = chunkOut.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
chunkOutputs[u] = chunkOut
}
// Concatenate chunk outputs along time dimension.
var attnOut ml.Tensor
if numChunks == 1 {
attnOut = chunkOutputs[0]
} else {
attnOut = chunkOutputs[0]
for _, co := range chunkOutputs[1:] {
attnOut = attnOut.Concat(ctx, co, 2)
}
}
// Trim to original sequence length if we padded.
if paddedLen > seqLen {
attnOut = attnOut.Slice(ctx, 2, 0, seqLen, 1).Contiguous(ctx)
}
// Reshape to [hiddenSize, seqLen] and project.
attnOut = attnOut.Reshape(ctx, hiddenSize, seqLen)
x = cb.AttnOut.Forward(ctx, attnOut)
x = x.Clamp(ctx, -opts.gradClip, opts.gradClip)
x = cb.AttnPostNorm.Forward(ctx, x, opts.eps)
return residual.Add(ctx, x)
}
// buildPositionEmbeddings builds sinusoidal position embeddings and projects through linear_pos.
// Returns [headDim, numHeads, maxSpan] tensor.
func (cb *AudioConformerBlock) buildPositionEmbeddings(ctx ml.Context, maxSpan int, opts *AudioModelOptions) ml.Tensor {
halfDim := opts.hiddenSize / 2
hiddenSize := opts.hiddenSize
// inv_timescales: exp(-i * log(10000) / max(D/2-1, 1))
logInc := math.Log(10000.0) / math.Max(float64(halfDim-1), 1)
// Sinusoidal embeddings for relative positions [maxPast, maxPast-1, ..., -maxFuture].
posData := make([]float32, hiddenSize*maxSpan)
for p := range maxSpan {
relPos := float64(opts.maxPast - p)
for d := range halfDim {
angle := relPos * math.Exp(float64(-d)*logInc)
posData[p*hiddenSize+d] = float32(math.Sin(angle))
posData[p*hiddenSize+halfDim+d] = float32(math.Cos(angle))
}
}
// Create [hiddenSize, maxSpan] input tensor.
posEmb := ctx.Input().FromFloats(posData, hiddenSize, maxSpan)
// Project through linear_pos: [hiddenSize, maxSpan] → Mulmat → [numHeads*headDim, maxSpan]
projPos := cb.LinearPos.Mulmat(ctx, posEmb)
// Reshape to [headDim, numHeads, maxSpan].
return projPos.Reshape(ctx, opts.headDim, opts.numHeads, maxSpan)
}
// relativeShiftGGML performs the relative shift to extract correct position logits.
// Input: [maxSpan, chunkSize, numHeads]. Output: [contextSize, chunkSize, numHeads].
func (cb *AudioConformerBlock) relativeShiftGGML(ctx ml.Context, x ml.Tensor, maxSpan, chunkSize, contextSize, numHeads int) ml.Tensor {
// The shift trick: pad ne[0] to contextSize+1, reshape to flatten first two dims,
// skip first (contextSize+1-maxSpan) elements, take contextSize*chunkSize elements, reshape back.
padAmt := contextSize + 1 - maxSpan
if padAmt > 0 {
x = x.Pad(ctx, padAmt, 0, 0, 0) // [maxSpan+padAmt, chunkSize, numHeads] = [contextSize+1, chunkSize, numHeads]
}
// Reshape to [(contextSize+1)*chunkSize, numHeads]
x = x.Reshape(ctx, (contextSize+1)*chunkSize, numHeads)
// Take the first contextSize*chunkSize elements (the standard relative shift trick).
x = x.Slice(ctx, 0, 0, contextSize*chunkSize, 1).Contiguous(ctx)
// Reshape to [contextSize, chunkSize, numHeads]
return x.Reshape(ctx, contextSize, chunkSize, numHeads)
}
// forwardLightConv runs the lightweight depthwise convolution module.
func (cb *AudioConformerBlock) forwardLightConv(ctx ml.Context, x ml.Tensor, opts *AudioModelOptions, blockIdx int) ml.Tensor {
residual := x
x = cb.ConvNorm.Forward(ctx, x, opts.eps)
x = cb.ConvPW1.Forward(ctx, x) // [2*D, T, B]
// GLU: split in half along dim 0, sigmoid gate, multiply.
d := x.Dim(0) / 2
data := x.Slice(ctx, 0, 0, d, 1).Contiguous(ctx)
gate := x.Slice(ctx, 0, d, d*2, 1).Contiguous(ctx).Sigmoid(ctx)
x = data.Mul(ctx, gate) // [D, T, B]
// Depthwise Conv1d: manual implementation using model weight tensor slices.
// Kernel cb.ConvDW shape: [K=5, D=1024] (ne[0]=K, ne[1]=D) after shape reversal.
// Actually in GGML, ne[0]=K=5 contiguous, ne[1]=D=1024.
// We need per-tap weights [D] and shifted input copies.
kernelSize := cb.ConvDW.Dim(0) // K=5
seqLen := x.Dim(1)
// Transpose kernel to [D, K] for per-tap slicing.
// GGML permute(1,0,2,3): old[0]→pos1, old[1]→pos0 → swap ne[0] and ne[1]
kernelT := cb.ConvDW.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) // [D, K]
var convOut ml.Tensor
for k := range kernelSize {
shift := kernelSize - 1 - k
var shifted ml.Tensor
if shift == 0 {
shifted = x
} else {
trimmed := x.Slice(ctx, 1, 0, seqLen-shift, 1).Contiguous(ctx)
shifted = trimmed.PadExt(ctx, 0, 0, shift, 0, 0, 0, 0, 0)
}
wk := kernelT.Slice(ctx, 1, k, k+1, 1).Contiguous(ctx) // [D, 1]
term := shifted.Mul(ctx, wk)
if convOut == nil {
convOut = term
} else {
convOut = convOut.Add(ctx, term)
}
}
x = convOut
x = x.Clamp(ctx, -opts.gradClip, opts.gradClip)
x = cb.NormConv.Forward(ctx, x, opts.eps)
x = x.SILU(ctx)
x = cb.ConvPW2.Forward(ctx, x)
return x.Add(ctx, residual)
}
func newAudioModel(c fs.Config) *AudioModel {
numLayers := int(c.Uint("audio.block_count", 0))
if numLayers == 0 {
return nil
}
return &AudioModel{
Layers: make([]AudioConformerBlock, numLayers),
}
}
func newAudioModelOptions(c fs.Config) *AudioModelOptions {
hiddenSize := int(c.Uint("audio.embedding_length", 0))
if hiddenSize == 0 {
return nil
}
numHeads := int(c.Uint("audio.attention.head_count", 8))
headDim := hiddenSize / numHeads
chunkSize := 12 // default conformer chunk size
maxPast := 12 // conf_attention_context_left - 1
maxFuture := 0 // conf_attention_context_right
convKernel := int(c.Uint("audio.conv_kernel_size", 5))
eps := c.Float("audio.attention.layer_norm_epsilon", 1e-6)
return &AudioModelOptions{
hiddenSize: hiddenSize,
numHeads: numHeads,
headDim: headDim,
ffnSize: int(c.Uint("audio.feed_forward_length", uint32(hiddenSize*4))),
numLayers: int(c.Uint("audio.block_count", 12)),
melBins: int(c.Uint("audio.num_mel_bins", 128)),
chunkSize: chunkSize,
maxPast: maxPast,
maxFuture: maxFuture,
contextSize: chunkSize + maxPast + maxFuture,
logitCap: 50.0,
residualWeight: 0.5,
gradClip: 1e10,
convKernelSize: convKernel,
eps: float32(eps),
}
}
// buildCausalValidMaskF32 creates the causal-valid mask for block-local attention.
// Returns flat [chunkSize * contextSize] float32 data (1.0 = allowed, 0.0 = masked).
func buildCausalValidMaskF32(chunkSize, maxPast, maxFuture int) []float32 {
contextSize := chunkSize + maxPast + maxFuture
upperDiag := maxPast + maxFuture
result := make([]float32, chunkSize*contextSize)
for r := range chunkSize {
for c := range contextSize {
lower := (r <= c) // tril(contextSize, chunkSize) transposed
upper := (c <= r+upperDiag) // tril(chunkSize, contextSize, diag=upperDiag)
if lower && upper {
result[r*contextSize+c] = 1.0
}
}
}
return result
}

View File

@@ -0,0 +1,475 @@
package gemma4
import (
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model/input"
)
const (
cacheTypeSWA = iota
cacheTypeCausal
)
type TextOptions struct {
hiddenSize int
numHeads, numKVHeads int
numGlobalKVHeads int
headDim, globalHeadDim int
hiddenLayers int
hiddenSizePerLayerInput int
eps float32
ropeBase float32
ropeLocalBase float32
partialRotaryDims int // RoPE dims for full-attention (global) layers
slidingWindowPattern []bool
// kvDonorMap maps shared layer index -> donor layer index.
// Donor is the last non-shared layer of the same type (sliding/full).
kvDonorMap map[int]int
finalLogitSoftcap float32
numExperts int
numExpertsUsed int
}
func (o *TextOptions) isLocal(layer int) bool {
if layer < len(o.slidingWindowPattern) {
return o.slidingWindowPattern[layer]
}
return false
}
func (o *TextOptions) ropeForLayer(layer int) (base float32, dims int) {
if o.isLocal(layer) {
return o.ropeLocalBase, o.headDim
}
return o.ropeBase, o.partialRotaryDims
}
func (o *TextOptions) kvHeadsForLayer(layer int) int {
if o.isLocal(layer) {
return o.numKVHeads
}
if o.numGlobalKVHeads > 0 {
return o.numGlobalKVHeads
}
return o.numKVHeads
}
func (o *TextOptions) headDimForLayer(layer int) int {
if o.isLocal(layer) {
return o.headDim
}
return o.globalHeadDim
}
type TextModel struct {
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
*PerLayerProjector
Layers []TextLayer `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
TextOptions
}
func newTextModel(c fs.Config) *TextModel {
numLayers := int(c.Uint("block_count"))
// Head dimensions: key_length is global head dim, key_length_swa is local (SWA) head dim.
globalHeadDim := int(c.Uint("attention.key_length", 512))
headDim := int(c.Uint("attention.key_length_swa", 256))
// RoPE dimensions for global (full attention) layers with proportional RoPE.
// The freq_factors tensor handles partial rotation (1.0 for rotated pairs,
// 1e30 for non-rotated), so ropeDims equals the full global head dim.
partialRotaryDims := int(c.Uint("rope.dimension_count", 0))
if partialRotaryDims == 0 {
partialFactor := c.Float("rope.partial_rotary_factor", 1.0)
partialRotaryDims = int(float32(globalHeadDim) * partialFactor)
}
ropeBase := c.Float("rope.freq_base", 1000000.0)
ropeLocalBase := c.Float("rope.freq_base_swa", 0)
if ropeLocalBase == 0 {
ropeLocalBase = c.Float("rope.local.freq_base", 10000.0)
}
numGlobalKVHeads := int(c.Uint("attention.global_head_count_kv", 0))
slidingPattern := c.Bools("attention.sliding_window_pattern")
// KV heads: try per-layer array first (MoE models), then fall back to scalar
numKVHeads := 0
kvHeadsArray := c.Ints("attention.head_count_kv")
if len(kvHeadsArray) > 0 {
numKVHeads = int(kvHeadsArray[0])
if numGlobalKVHeads == 0 && len(slidingPattern) > 0 {
for i, isLocal := range slidingPattern {
if !isLocal && i < len(kvHeadsArray) {
numGlobalKVHeads = int(kvHeadsArray[i])
break
}
}
}
}
if numKVHeads == 0 {
numKVHeads = int(c.Uint("attention.head_count_kv", 0))
}
// Compute KV sharing donor map (same logic as MLX)
sharedLayers := int(c.Uint("attention.shared_kv_layers", 0))
kvDonorMap := make(map[int]int)
if sharedLayers > 0 && len(slidingPattern) > 0 {
firstShared := numLayers - sharedLayers
for i := firstShared; i < numLayers; i++ {
isLocal := slidingPattern[i]
// Find last non-shared layer of same type
for j := firstShared - 1; j >= 0; j-- {
if slidingPattern[j] == isLocal {
kvDonorMap[i] = j
break
}
}
}
}
return &TextModel{
Layers: make([]TextLayer, numLayers),
TextOptions: TextOptions{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: numKVHeads,
numGlobalKVHeads: numGlobalKVHeads,
headDim: headDim,
globalHeadDim: globalHeadDim,
hiddenLayers: numLayers,
hiddenSizePerLayerInput: int(c.Uint("embedding_length_per_layer_input", 0)),
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
ropeBase: ropeBase,
ropeLocalBase: ropeLocalBase,
partialRotaryDims: partialRotaryDims,
slidingWindowPattern: slidingPattern,
kvDonorMap: kvDonorMap,
finalLogitSoftcap: c.Float("final_logit_softcapping", 0.0),
numExperts: int(c.Uint("expert_count", 0)),
numExpertsUsed: int(c.Uint("expert_used_count", 0)),
},
}
}
func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) ml.Tensor {
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.hiddenSize)))
// Inject vision embeddings into the hidden state
var except []int
for _, image := range batch.Multimodal {
visionOutputs := image.Multimodal[0].Tensor
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))
for i := range visionOutputs.Dim(1) {
except = append(except, image.Index+i)
}
}
// PLE
var perLayerInputs ml.Tensor
if m.PerLayerProjector != nil {
perLayerInputs = m.PerLayerProjector.Forward(ctx, batch, hiddenState, &m.TextOptions)
}
for i := range len(m.Layers) {
layer := m.Layers[i]
if cache != nil {
cache.SetLayer(i)
cacheType := cacheTypeSWA
if !m.isLocal(i) {
cacheType = cacheTypeCausal
}
wc := cache.(*kvcache.WrapperCache)
wc.SetLayerType(cacheType)
if causal, ok := wc.UnderlyingCache().(*kvcache.Causal); ok {
causal.SetCausal(ctx, kvcache.CausalOptions{Except: except})
}
}
var lastLayerOutputs ml.Tensor
if i == len(m.Layers)-1 {
lastLayerOutputs = batch.Outputs
}
var perLayerInput ml.Tensor
if perLayerInputs != nil {
perLayerInput = perLayerInputs.View(ctx, i*perLayerInputs.Stride(1), perLayerInputs.Dim(0), perLayerInputs.Stride(2), perLayerInputs.Dim(2))
}
// KV sharing: layers >= firstShared reuse K/V from donor layers
isShared := false
if donorLayer, ok := m.kvDonorMap[i]; ok {
// Set cache layer to donor so Get() reads donor's K/V
cache.SetLayer(donorLayer)
isShared = true
}
hiddenState = layer.Forward(ctx, i, hiddenState, positions, perLayerInput, lastLayerOutputs, cache, isShared, &m.TextOptions)
}
return m.OutputNorm.Forward(ctx, hiddenState, m.eps)
}
// PerLayerProjector implements PLE.
type PerLayerProjector struct {
TokenEmbedding *nn.Embedding `gguf:"per_layer_token_embd"`
Projector *nn.Linear `gguf:"per_layer_model_proj"`
Norm *nn.RMSNorm `gguf:"per_layer_proj_norm"`
}
func (p *PerLayerProjector) Forward(ctx ml.Context, batch input.Batch, inputs ml.Tensor, opts *TextOptions) ml.Tensor {
inputsPerLayer := p.TokenEmbedding.Forward(ctx, batch.Inputs)
inputsPerLayer = inputsPerLayer.Scale(ctx, math.Sqrt(float64(opts.hiddenSizePerLayerInput)))
// Reshape to [pleDim, numLayers, numTokens] — matching projection shape
inputsPerLayer = inputsPerLayer.Reshape(ctx, opts.hiddenSizePerLayerInput, opts.hiddenLayers, inputs.Dim(1))
perLayerProjection := p.Projector.Forward(ctx, inputs)
perLayerProjection = perLayerProjection.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize)))
perLayerProjection = perLayerProjection.Reshape(ctx, opts.hiddenSizePerLayerInput, opts.hiddenLayers, inputs.Dim(1))
perLayerProjection = p.Norm.Forward(ctx, perLayerProjection, opts.eps)
if inputsPerLayer != nil {
perLayerProjection = perLayerProjection.Add(ctx, inputsPerLayer)
perLayerProjection = perLayerProjection.Scale(ctx, 1/math.Sqrt(2))
}
return perLayerProjection
}
type TextSelfAttention struct {
Query *nn.Linear `gguf:"attn_q"`
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
Key *nn.Linear `gguf:"attn_k"`
KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"` // proportional RoPE freq_factors
}
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positions ml.Tensor, cache kvcache.Cache, sharedKV bool, opts *TextOptions) ml.Tensor {
batchSize := hiddenState.Dim(1)
hd := opts.headDimForLayer(layer)
kvHeads := opts.kvHeadsForLayer(layer)
ropeBase, ropeDims := opts.ropeForLayer(layer)
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, hd, opts.numHeads, batchSize)
q = sa.QueryNorm.Forward(ctx, q, opts.eps)
var k, v ml.Tensor
if !sharedKV {
k = sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, hd, kvHeads, batchSize)
if sa.Value != nil {
v = sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, hd, kvHeads, batchSize)
} else {
// K=V: use raw K projection (before K norm) as V
v = k
}
k = sa.KeyNorm.Forward(ctx, k, opts.eps)
v = v.RMSNorm(ctx, nil, opts.eps) // V norm: unweighted RMSNorm
}
// RoPE with proportional freq_factors on global layers
ropeOpts := []func(*rope.Options){rope.WithTypeNeoX()}
if sa.RopeFactors != nil && !opts.isLocal(layer) {
ropeOpts = append(ropeOpts, rope.WithFactors(sa.RopeFactors))
}
q = nn.RoPE(ctx, q, positions, ropeDims, ropeBase, 1.0, ropeOpts...)
if k != nil {
k = nn.RoPE(ctx, k, positions, ropeDims, ropeBase, 1.0, ropeOpts...)
}
attention := nn.Attention(ctx, q, k, v, 1.0, cache)
attention = attention.Reshape(ctx, hd*opts.numHeads, batchSize)
return sa.Output.Forward(ctx, attention)
}
type TextMLP struct {
Gate *nn.Linear `gguf:"ffn_gate"`
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
}
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor) ml.Tensor {
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState)
}
// TextRouter implements the Gemma 4 MoE router.
type TextRouter struct {
Proj *nn.Linear `gguf:"ffn_gate_inp"`
Scale ml.Tensor `gguf:"ffn_gate_inp.scale"`
}
func (r *TextRouter) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) (routingWeights, selectedExperts ml.Tensor) {
// RMSNorm without learned weight
x := hiddenState.RMSNorm(ctx, nil, opts.eps)
// Scale by 1/sqrt(hidden_size)
x = x.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize)))
// Multiply by learned scale parameter
x = x.Mul(ctx, r.Scale)
// Project to expert logits
expertScores := r.Proj.Forward(ctx, x)
// Softmax over experts
routingWeights = expertScores.Softmax(ctx)
// TopK expert selection
selectedExperts = routingWeights.TopK(ctx, opts.numExpertsUsed)
return routingWeights, selectedExperts
}
// TextMoEBlock implements the Gemma 4 sparse MoE.
type TextMoEBlock struct {
GateUp *nn.LinearBatch `gguf:"ffn_gate_up_exps"`
Gate *nn.LinearBatch `gguf:"ffn_gate_exps"`
Up *nn.LinearBatch `gguf:"ffn_up_exps"`
Down *nn.LinearBatch `gguf:"ffn_down_exps"`
DownScale ml.Tensor `gguf:"ffn_down_exps.scale,alt:ffn_gate_inp.per_expert_scale"`
}
func (moe *TextMoEBlock) Forward(ctx ml.Context, hiddenState, routingWeights, selectedExperts ml.Tensor, opts *TextOptions) ml.Tensor {
// Select routing weights for chosen experts and renormalize
routingWeights = routingWeights.Reshape(ctx, 1, opts.numExperts, hiddenState.Dim(1)).Rows(ctx, selectedExperts)
routingWeights = routingWeights.Reshape(ctx, opts.numExpertsUsed, hiddenState.Dim(1))
routingWeights = routingWeights.Div(ctx, routingWeights.SumRows(ctx))
routingWeights = routingWeights.Reshape(ctx, 1, opts.numExpertsUsed, hiddenState.Dim(1))
hiddenState = hiddenState.Reshape(ctx, hiddenState.Dim(0), 1, hiddenState.Dim(1))
// Expert computation using LinearBatch (MulmatID selecting experts by index)
var gateOut, upOut ml.Tensor
if moe.GateUp != nil && moe.GateUp.Weight != nil {
gateUp := moe.GateUp.Forward(ctx, hiddenState, selectedExperts)
nFF := gateUp.Dim(0) / 2
gateOut = gateUp.Slice(ctx, 0, 0, nFF, 1)
upOut = gateUp.Slice(ctx, 0, nFF, gateUp.Dim(0), 1)
} else {
gateOut = moe.Gate.Forward(ctx, hiddenState, selectedExperts)
upOut = moe.Up.Forward(ctx, hiddenState, selectedExperts)
}
hiddenState = gateOut.GELU(ctx, upOut)
experts := moe.Down.Forward(ctx, hiddenState, selectedExperts)
// Apply per-expert down projection scale when present.
if moe.DownScale != nil {
expertScales := moe.DownScale.Reshape(ctx, opts.numExperts, 1)
expertScales = expertScales.Repeat(ctx, 1, hiddenState.Dim(2))
expertScales = expertScales.Reshape(ctx, 1, opts.numExperts, hiddenState.Dim(2)).Rows(ctx, selectedExperts)
expertScales = expertScales.Reshape(ctx, opts.numExpertsUsed, hiddenState.Dim(2))
expertScales = expertScales.Reshape(ctx, 1, opts.numExpertsUsed, hiddenState.Dim(2))
experts = experts.Mul(ctx, expertScales)
}
// Apply routing weights
experts = experts.Mul(ctx, routingWeights)
// Sum across experts
nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
for i := 1; i < opts.numExpertsUsed; i++ {
nextStates = nextStates.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2)))
}
return nextStates
}
type TextLayer struct {
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
SelfAttention *TextSelfAttention
PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_norm,alt:attn_post_norm"`
MLPNorm *nn.RMSNorm `gguf:"ffn_norm,alt:ffn_pre_norm"`
MLP *TextMLP
PostMLPNorm *nn.RMSNorm `gguf:"post_ffw_norm,alt:ffn_post_norm"`
// MoE (present only for models with enable_moe_block=true)
Router *TextRouter
MoE *TextMoEBlock
MoENorm *nn.RMSNorm `gguf:"pre_ffw_norm_2,alt:ffn_pre_norm_2"`
PostMoENorm *nn.RMSNorm `gguf:"post_ffw_norm_2,alt:ffn_post_norm_2"`
PostMLPNorm1 *nn.RMSNorm `gguf:"post_ffw_norm_1,alt:ffn_post_norm_1"` // used instead of PostMLPNorm when MoE is present
PerLayerInputGate *nn.Linear `gguf:"inp_gate"`
PerLayerProjection *nn.Linear `gguf:"proj"`
PostPerLayerNorm *nn.RMSNorm `gguf:"post_norm"`
LayerScalar ml.Tensor `gguf:"layer_scalar,alt:layer_output_scale.weight"`
}
func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positions, perLayerInput, outputs ml.Tensor, cache kvcache.Cache, sharedKV bool, opts *TextOptions) ml.Tensor {
residual := hiddenState
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.SelfAttention.Forward(ctx, layer, hiddenState, positions, cache, sharedKV, opts)
hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)
if outputs != nil {
hiddenState = hiddenState.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
if perLayerInput != nil {
perLayerInput = perLayerInput.Rows(ctx, outputs)
}
}
hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState
// MLP (+ optional MoE in parallel)
hasSplitExperts := l.MoE != nil && l.MoE.Gate != nil && l.MoE.Up != nil && l.MoE.Gate.Weight != nil && l.MoE.Up.Weight != nil
hasFusedExperts := l.MoE != nil && l.MoE.GateUp != nil && l.MoE.GateUp.Weight != nil
if l.Router != nil && l.MoE != nil && l.MoE.Down != nil && l.MoE.Down.Weight != nil && (hasSplitExperts || hasFusedExperts) {
// MoE layers: run MLP and MoE in parallel, sum results
mlpState := l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
mlpState = l.MLP.Forward(ctx, mlpState)
mlpState = l.PostMLPNorm1.Forward(ctx, mlpState, opts.eps)
routingWeights, selectedExperts := l.Router.Forward(ctx, hiddenState, opts)
moeState := l.MoENorm.Forward(ctx, hiddenState, opts.eps)
moeState = l.MoE.Forward(ctx, moeState, routingWeights, selectedExperts, opts)
moeState = l.PostMoENorm.Forward(ctx, moeState, opts.eps)
// Combine MLP + MoE, apply outer post-FFN norm, then add residual
combined := mlpState.Add(ctx, moeState)
combined = l.PostMLPNorm.Forward(ctx, combined, opts.eps)
hiddenState = combined.Add(ctx, residual)
} else {
// Dense layers: MLP only
hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.MLP.Forward(ctx, hiddenState)
hiddenState = l.PostMLPNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = hiddenState.Add(ctx, residual)
}
// PLE injection (after MLP residual)
if perLayerInput != nil && l.PerLayerInputGate != nil {
pleState := l.PerLayerInputGate.Forward(ctx, hiddenState)
pleState = pleState.GELU(ctx, perLayerInput)
pleState = l.PerLayerProjection.Forward(ctx, pleState)
pleState = l.PostPerLayerNorm.Forward(ctx, pleState, opts.eps)
hiddenState = hiddenState.Add(ctx, pleState)
}
// Layer scalar applied at end of layer (full-attention layers only)
if l.LayerScalar != nil {
hiddenState = hiddenState.Mul(ctx, l.LayerScalar)
}
return hiddenState
}

View File

@@ -0,0 +1,384 @@
package gemma4
import (
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/rope"
)
const batchSize = 1
// ClippableLinear is a linear layer with optional input/output clamping.
// Required by Gemma4 vision encoder for numerical stability with F16 weights.
type ClippableLinear struct {
Weight ml.Tensor `gguf:"weight"`
InputMin ml.Tensor `gguf:"input_min"`
InputMax ml.Tensor `gguf:"input_max"`
OutputMin ml.Tensor `gguf:"output_min"`
OutputMax ml.Tensor `gguf:"output_max"`
inMin, inMax, outMin, outMax float32
hasClamp bool
clampsLoaded bool
}
func scalarValue(t ml.Tensor) (float32, bool) {
if t == nil {
return 0, false
}
data := t.BackendGet()
if len(data) == 0 {
return 0, false
}
return data[0], true
}
func (l *ClippableLinear) loadClampFromScalars() {
if l.clampsLoaded {
return
}
l.clampsLoaded = true
const (
defaultMin = -math.MaxFloat32
defaultMax = math.MaxFloat32
)
inMin, hasInMin := scalarValue(l.InputMin)
inMax, hasInMax := scalarValue(l.InputMax)
outMin, hasOutMin := scalarValue(l.OutputMin)
outMax, hasOutMax := scalarValue(l.OutputMax)
if !(hasInMin || hasInMax || hasOutMin || hasOutMax) {
return
}
l.hasClamp = true
l.inMin = defaultMin
l.inMax = defaultMax
l.outMin = defaultMin
l.outMax = defaultMax
if hasInMin {
l.inMin = inMin
}
if hasInMax {
l.inMax = inMax
}
if hasOutMin {
l.outMin = outMin
}
if hasOutMax {
l.outMax = outMax
}
}
func (l *ClippableLinear) Forward(ctx ml.Context, x ml.Tensor) ml.Tensor {
if l.hasClamp {
x = x.Clamp(ctx, l.inMin, l.inMax)
}
out := l.Weight.Mulmat(ctx, x)
if l.hasClamp {
out = out.Clamp(ctx, l.outMin, l.outMax)
}
return out
}
// InitClamp distributes packed clamp values from v.clamp_data to ClippableLinear structs.
// If scalar clamp tensors (input_min/max, output_min/max) are present, they are used too.
// Layout: numLayers × 7 linears (q,k,v,out,gate,up,down) × 4 floats (inMin,inMax,outMin,outMax)
// then 4 floats for the projector.
func (m *VisionModel) InitClamp(proj *MultiModalProjector) {
if m.clampInitDone {
return
}
m.clampInitDone = true
linears := func(l *VisionEncoderLayer) []*ClippableLinear {
return []*ClippableLinear{
l.SelfAttention.Query, l.SelfAttention.Key, l.SelfAttention.Value,
l.SelfAttention.Output, l.MLP.Gate, l.MLP.Up, l.MLP.Down,
}
}
for i := range m.Layers {
for _, cl := range linears(&m.Layers[i]) {
if cl != nil {
cl.loadClampFromScalars()
}
}
}
if proj != nil && proj.Projection != nil {
proj.Projection.loadClampFromScalars()
}
// Load packed clamp data when present (legacy Ollama format).
if m.ClampData == nil {
return
}
// Read all clamp values from packed F32 tensor
data := m.ClampData.BackendGet()
if len(data) == 0 {
return
}
// Distribute to layer linears: 7 per layer × 4 values each
for i := range m.Layers {
for li, cl := range linears(&m.Layers[i]) {
if cl == nil {
continue
}
idx := (i*7 + li) * 4
if idx+3 < len(data) {
cl.inMin = data[idx]
cl.inMax = data[idx+1]
cl.outMin = data[idx+2]
cl.outMax = data[idx+3]
cl.hasClamp = true
}
}
}
// Projector clamp values (last 4 floats)
if proj != nil && proj.Projection != nil {
projIdx := len(m.Layers) * 7 * 4
if projIdx+3 < len(data) {
proj.Projection.inMin = data[projIdx]
proj.Projection.inMax = data[projIdx+1]
proj.Projection.outMin = data[projIdx+2]
proj.Projection.outMax = data[projIdx+3]
proj.Projection.hasClamp = true
}
}
}
type VisionSelfAttention struct {
Query *ClippableLinear `gguf:"attn_q"`
Key *ClippableLinear `gguf:"attn_k"`
Value *ClippableLinear `gguf:"attn_v"`
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
Output *ClippableLinear `gguf:"attn_out"`
}
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState, posX, posY, attnMask ml.Tensor, opts *VisionModelOptions) ml.Tensor {
numPatches := hiddenState.Dim(1)
headDim := opts.hiddenSize / opts.numHeads
query := sa.Query.Forward(ctx, hiddenState)
key := sa.Key.Forward(ctx, hiddenState)
value := sa.Value.Forward(ctx, hiddenState)
query = query.Reshape(ctx, headDim, opts.numHeads, numPatches, batchSize)
key = key.Reshape(ctx, headDim, opts.numHeads, numPatches, batchSize)
value = value.Reshape(ctx, headDim, opts.numHeads, numPatches, batchSize)
// Q/K norms (Gemma-style: x * (1 + weight) / rms(x))
query = sa.QueryNorm.Forward(ctx, query, opts.eps)
key = sa.KeyNorm.Forward(ctx, key, opts.eps)
// V norm (RMSNorm without learned weights)
value = value.RMSNorm(ctx, nil, opts.eps)
// 2D RoPE: split head dim in half, apply NeoX RoPE with x positions to first half,
// y positions to second half, then concatenate.
halfDim := headDim / 2
ropeOpts := rope.WithTypeNeoX()
qFirst := query.View(ctx, 0, halfDim, query.Stride(1), opts.numHeads, query.Stride(2), numPatches)
qFirst = nn.RoPE(ctx, qFirst, posX, halfDim, opts.ropeTheta, 1.0, ropeOpts)
kFirst := key.View(ctx, 0, halfDim, key.Stride(1), opts.numHeads, key.Stride(2), numPatches)
kFirst = nn.RoPE(ctx, kFirst, posX, halfDim, opts.ropeTheta, 1.0, ropeOpts)
halfOffset := halfDim * query.Stride(0)
qSecond := query.View(ctx, halfOffset, halfDim, query.Stride(1), opts.numHeads, query.Stride(2), numPatches)
qSecond = nn.RoPE(ctx, qSecond, posY, halfDim, opts.ropeTheta, 1.0, ropeOpts)
halfOffsetK := halfDim * key.Stride(0)
kSecond := key.View(ctx, halfOffsetK, halfDim, key.Stride(1), opts.numHeads, key.Stride(2), numPatches)
kSecond = nn.RoPE(ctx, kSecond, posY, halfDim, opts.ropeTheta, 1.0, ropeOpts)
query = qFirst.Concat(ctx, qSecond, 0)
key = kFirst.Concat(ctx, kSecond, 0)
// Use flash attention for numerical stability (handles large attention scores
// from unclamped RMSNorm weights, e.g. 26B has addOne weights up to 19.5)
attention := nn.Attention(ctx, query, key, value, 1.0, nil)
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
return sa.Output.Forward(ctx, attention)
}
type VisionMLP struct {
Gate *ClippableLinear `gguf:"ffn_gate"`
Up *ClippableLinear `gguf:"ffn_up"`
Down *ClippableLinear `gguf:"ffn_down"`
}
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenState ml.Tensor) ml.Tensor {
gate := mlp.Gate.Forward(ctx, hiddenState)
up := mlp.Up.Forward(ctx, hiddenState)
hiddenState = gate.QuickGELU(ctx, up)
return mlp.Down.Forward(ctx, hiddenState)
}
type VisionEncoderLayer struct {
AttentionNorm *nn.RMSNorm `gguf:"ln1"`
SelfAttention *VisionSelfAttention
PostAttentionNorm *nn.RMSNorm `gguf:"attn_post_norm"`
FFNNorm *nn.RMSNorm `gguf:"ln2"`
MLP *VisionMLP
PostFFNNorm *nn.RMSNorm `gguf:"ffn_post_norm"`
LayerOutputScale ml.Tensor `gguf:"out_scale.weight"`
}
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenState, posX, posY, attnMask ml.Tensor, opts *VisionModelOptions) ml.Tensor {
residual := hiddenState
// Pre-attention norm -> self attention -> post-attention norm
hiddenState = e.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = e.SelfAttention.Forward(ctx, hiddenState, posX, posY, attnMask, opts)
hiddenState = e.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)
// Residual connection
hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState
// Pre-FFN norm -> FFN -> post-FFN norm
hiddenState = e.FFNNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = e.MLP.Forward(ctx, hiddenState)
hiddenState = e.PostFFNNorm.Forward(ctx, hiddenState, opts.eps)
// Residual connection
hiddenState = hiddenState.Add(ctx, residual)
// Per-layer output scale
if e.LayerOutputScale != nil {
hiddenState = hiddenState.Mul(ctx, e.LayerOutputScale)
}
return hiddenState
}
type VisionModelOptions struct {
hiddenSize int
numHeads int
patchSize int
nMerge int
eps float32
ropeTheta float32
}
type VisionModel struct {
PatchEmbedding *nn.Conv2D `gguf:"patch_embd"`
PositionEmbedding ml.Tensor `gguf:"position_embd.weight"`
ClampData ml.Tensor `gguf:"clamp_data"`
StdBias ml.Tensor `gguf:"std_bias"`
StdScale ml.Tensor `gguf:"std_scale"`
Layers []VisionEncoderLayer `gguf:"blk"`
*VisionModelOptions
clampInitDone bool
}
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, numPatchesX, numPatchesY int) ml.Tensor {
numPatches := numPatchesX * numPatchesY
// Patch embedding via Conv2D
hiddenState := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1)
hiddenState = hiddenState.Reshape(ctx, numPatches, m.hiddenSize)
hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
// Conv2D with F16 weights produces F16 output via im2col; cast to F32 for encoder precision
hiddenState = hiddenState.Cast(ctx, ml.DTypeF32)
// 2D positional embeddings from 3D tensor [nEmbd, maxPos, 2]
posSize := m.PositionEmbedding.Dim(1)
nb1 := m.PositionEmbedding.Stride(1)
tblX := m.PositionEmbedding.View(ctx, 0, m.hiddenSize, nb1, posSize)
tblY := m.PositionEmbedding.View(ctx, posSize*nb1, m.hiddenSize, nb1, posSize)
// Position indices for patches
posXData := make([]int32, numPatches)
posYData := make([]int32, numPatches)
for i := range numPatches {
posXData[i] = int32(i % numPatchesX)
posYData[i] = int32(i / numPatchesX)
}
posXEmb := ctx.Input().FromInts(posXData, numPatches)
posYEmb := ctx.Input().FromInts(posYData, numPatches)
hiddenState = hiddenState.Add(ctx, tblX.Rows(ctx, posXEmb))
hiddenState = hiddenState.Add(ctx, tblY.Rows(ctx, posYEmb))
// No attention mask — all positions are real patches
var attnMask ml.Tensor
// RoPE positions
posXRope := ctx.Input().FromInts(posXData, numPatches)
posYRope := ctx.Input().FromInts(posYData, numPatches)
// Vision transformer layers
for i := range m.Layers {
hiddenState = m.Layers[i].Forward(ctx, hiddenState, posXRope, posYRope, attnMask, m.VisionModelOptions)
}
return hiddenState
}
func newVisionModel(c fs.Config) *VisionModel {
return &VisionModel{
Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count")),
VisionModelOptions: &VisionModelOptions{
hiddenSize: int(c.Uint("vision.embedding_length")),
numHeads: int(c.Uint("vision.attention.head_count")),
patchSize: int(c.Uint("vision.patch_size", 16)),
nMerge: int(c.Uint("vision.projector.scale_factor", 3)),
eps: c.Float("vision.attention.layer_norm_epsilon", 1e-6),
ropeTheta: 100.0,
},
}
}
func visionPoolAndProject(ctx ml.Context, hiddenState ml.Tensor, numPatchesX, numPatchesY int, opts *VisionModelOptions, proj *MultiModalProjector, stdBias, stdScale ml.Tensor) ml.Tensor {
hiddenSize := opts.hiddenSize
// Reshape from [hiddenSize, numPatches] to spatial layout for pooling
hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
hiddenState = hiddenState.Reshape(ctx, numPatchesX, numPatchesY, hiddenSize)
// AvgPool2D with kernel=stride=nMerge
hiddenState = hiddenState.AvgPool2D(ctx, opts.nMerge, opts.nMerge, 0)
// Reshape back to [hiddenSize, numMergedPatches]
mergedX := numPatchesX / opts.nMerge
mergedY := numPatchesY / opts.nMerge
hiddenState = hiddenState.Reshape(ctx, mergedX*mergedY, hiddenSize)
hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
hiddenState = hiddenState.Cast(ctx, ml.DTypeF32)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(hiddenSize)))
// Optional vision standardization before projection.
if stdBias != nil && stdScale != nil {
hiddenState = hiddenState.Sub(ctx, stdBias)
hiddenState = hiddenState.Mul(ctx, stdScale)
}
// Project to text embedding dimension
hiddenState = proj.Forward(ctx, hiddenState, opts.eps)
return hiddenState
}

View File

@@ -0,0 +1,280 @@
package gemma4
import (
"encoding/binary"
"fmt"
"math"
"math/cmplx"
)
// Audio preprocessing constants.
const (
audioSampleRate = 16000
melBins = 128
frameLengthMs = 20.0
hopLengthMs = 10.0
minFrequency = 0.0
maxFrequency = 8000.0
melFloor = 1e-3
maxAudioSoftTokens = 750
)
// Computed from the above constants.
var (
frameLength = int(math.Round(audioSampleRate * frameLengthMs / 1000.0)) // 320
hopLength = int(math.Round(audioSampleRate * hopLengthMs / 1000.0)) // 160
)
// decodeWAV extracts mono float32 PCM samples from a WAV file, resampled to 16kHz.
func decodeWAV(data []byte) ([]float32, error) {
if len(data) < 12 {
return nil, fmt.Errorf("WAV file too short")
}
if string(data[0:4]) != "RIFF" || string(data[8:12]) != "WAVE" {
return nil, fmt.Errorf("not a WAV file")
}
var audioFormat uint16
var numChannels, sampleRate, bitsPerSample int
var audioData []byte
foundFmt := false
offset := 12
for offset+8 <= len(data) {
chunkID := string(data[offset : offset+4])
chunkSize := int(binary.LittleEndian.Uint32(data[offset+4 : offset+8]))
chunkData := data[offset+8 : min(offset+8+chunkSize, len(data))]
switch chunkID {
case "fmt ":
if len(chunkData) < 16 {
return nil, fmt.Errorf("fmt chunk too short")
}
audioFormat = binary.LittleEndian.Uint16(chunkData[0:2])
numChannels = int(binary.LittleEndian.Uint16(chunkData[2:4]))
sampleRate = int(binary.LittleEndian.Uint32(chunkData[4:8]))
bitsPerSample = int(binary.LittleEndian.Uint16(chunkData[14:16]))
if audioFormat == 0xFFFE && len(chunkData) >= 26 {
audioFormat = binary.LittleEndian.Uint16(chunkData[24:26])
}
foundFmt = true
case "data":
audioData = chunkData
}
offset += 8 + chunkSize
if chunkSize%2 != 0 {
offset++
}
}
if !foundFmt {
return nil, fmt.Errorf("no fmt chunk found in WAV file")
}
if audioFormat != 1 && audioFormat != 3 {
return nil, fmt.Errorf("unsupported WAV format: %d (need PCM=1 or float=3)", audioFormat)
}
if audioData == nil {
return nil, fmt.Errorf("no data chunk found in WAV file")
}
samples := decodeWAVSamples(audioData, audioFormat, bitsPerSample, numChannels)
if sampleRate != audioSampleRate {
samples = resampleLinear(samples, sampleRate, audioSampleRate)
}
return samples, nil
}
func decodeWAVSamples(data []byte, format uint16, bits, channels int) []float32 {
bytesPerSample := bits / 8
totalSamples := len(data) / (bytesPerSample * channels)
mono := make([]float32, totalSamples)
for i := range totalSamples {
var sum float64
for ch := range channels {
off := (i*channels + ch) * bytesPerSample
if off+bytesPerSample > len(data) {
break
}
switch {
case format == 1 && bits == 16:
v := int16(binary.LittleEndian.Uint16(data[off : off+2]))
sum += float64(v) / 32768.0
case format == 1 && bits == 32:
v := int32(binary.LittleEndian.Uint32(data[off : off+4]))
sum += float64(v) / 2147483648.0
case format == 1 && bits == 24:
v := int32(data[off]) | int32(data[off+1])<<8 | int32(data[off+2])<<16
if v&0x800000 != 0 {
v |= ^0xFFFFFF
}
sum += float64(v) / 8388608.0
case format == 3 && bits == 32:
v := math.Float32frombits(binary.LittleEndian.Uint32(data[off : off+4]))
sum += float64(v)
case format == 1 && bits == 8:
sum += (float64(data[off]) - 128.0) / 128.0
}
}
mono[i] = float32(sum / float64(channels))
}
return mono
}
func resampleLinear(samples []float32, fromRate, toRate int) []float32 {
n := int(float64(len(samples)) / float64(fromRate) * float64(toRate))
out := make([]float32, n)
for i := range n {
pos := float64(i) * float64(len(samples)-1) / float64(n-1)
idx := int(pos)
frac := float32(pos - float64(idx))
if idx+1 < len(samples) {
out[i] = samples[idx]*(1-frac) + samples[idx+1]*frac
} else {
out[i] = samples[idx]
}
}
return out
}
// computeMelSpectrogram computes the log mel spectrogram from PCM samples.
// Returns shape [numFrames, melBins] as float32 slice, and numFrames.
func computeMelSpectrogram(samples []float32) ([]float32, int) {
fftLen := 1
for fftLen < frameLength {
fftLen <<= 1
}
fftLen *= 2 // fft_overdrive=True
// Hanning-nonzero window.
window := make([]float64, frameLength)
arg := math.Pi * 2.0 / float64(frameLength)
for i := range frameLength {
window[i] = 0.5 - 0.5*math.Cos(arg*(float64(i)+0.5))
}
numFreqBins := fftLen/2 + 1
melFilters := buildMelFilterBank(numFreqBins, melBins, minFrequency, maxFrequency, audioSampleRate)
frameSizeForUnfold := frameLength + 1
numFrames := (len(samples) - frameSizeForUnfold) / hopLength
if numFrames <= 0 {
return nil, 0
}
result := make([]float32, numFrames*melBins)
fftInput := make([]complex128, fftLen)
for f := range numFrames {
start := f * hopLength
for i := range frameLength {
fftInput[i] = complex(float64(samples[start+i])*window[i], 0)
}
for i := frameLength; i < fftLen; i++ {
fftInput[i] = 0
}
fft(fftInput)
for m := range melBins {
var melVal float64
for k := range numFreqBins {
mag := cmplx.Abs(fftInput[k])
melVal += mag * float64(melFilters[k*melBins+m])
}
if melVal < melFloor {
melVal = melFloor
}
result[f*melBins+m] = float32(math.Log(melVal))
}
}
return result, numFrames
}
func buildMelFilterBank(numFreqBins, numMels int, fMin, fMax float64, sr int) []float32 {
hzToMel := func(f float64) float64 {
return 2595.0 * math.Log10(1.0+f/700.0)
}
melToHz := func(m float64) float64 {
return 700.0 * (math.Pow(10.0, m/2595.0) - 1.0)
}
melMin := hzToMel(fMin)
melMax := hzToMel(fMax)
melPts := make([]float64, numMels+2)
for i := range melPts {
melPts[i] = melMin + float64(i)*(melMax-melMin)/float64(numMels+1)
}
filterFreqs := make([]float64, numMels+2)
for i, m := range melPts {
filterFreqs[i] = melToHz(m)
}
fftFreqs := make([]float64, numFreqBins)
for i := range fftFreqs {
fftFreqs[i] = float64(i) * float64(sr) / float64(2*(numFreqBins-1))
}
filters := make([]float32, numFreqBins*numMels)
for m := range numMels {
fLeft := filterFreqs[m]
fCenter := filterFreqs[m+1]
fRight := filterFreqs[m+2]
for k := range numFreqBins {
f := fftFreqs[k]
var v float64
if f >= fLeft && f <= fCenter && fCenter > fLeft {
v = (f - fLeft) / (fCenter - fLeft)
} else if f > fCenter && f <= fRight && fRight > fCenter {
v = (fRight - f) / (fRight - fCenter)
}
if v > 0 {
filters[k*numMels+m] = float32(v)
}
}
}
return filters
}
// fft performs an in-place Cooley-Tukey radix-2 FFT.
func fft(x []complex128) {
n := len(x)
if n <= 1 {
return
}
j := 0
for i := 1; i < n; i++ {
bit := n >> 1
for j&bit != 0 {
j ^= bit
bit >>= 1
}
j ^= bit
if i < j {
x[i], x[j] = x[j], x[i]
}
}
for size := 2; size <= n; size <<= 1 {
halfSize := size / 2
w := complex(math.Cos(2*math.Pi/float64(size)), -math.Sin(2*math.Pi/float64(size)))
for start := 0; start < n; start += size {
wn := complex(1, 0)
for k := range halfSize {
t := wn * x[start+k+halfSize]
x[start+k+halfSize] = x[start+k] - t
x[start+k] = x[start+k] + t
wn *= w
}
}
}
}
// isAudioData checks if the data starts with WAV magic bytes.
func isAudioData(data []byte) bool {
return len(data) >= 12 && string(data[0:4]) == "RIFF" && string(data[8:12]) == "WAVE"
}

View File

@@ -0,0 +1,103 @@
package gemma4
import (
"image"
"math"
"golang.org/x/image/draw"
"github.com/ollama/ollama/fs"
)
type ImageProcessor struct {
patchSize int
numChannels int
nMerge int
minPixels int
maxPixels int
}
func newImageProcessor(c fs.Config) ImageProcessor {
patchSize := int(c.Uint("vision.patch_size", 16))
nMerge := int(c.Uint("vision.projector.scale_factor", 3))
numChannels := int(c.Uint("vision.num_channels", 3))
// Token limits from reference: min=40, max=280 output tokens after pooling.
// Convert to pixel counts: tokens * nMerge^2 * patchSize^2
minTokens := 40
maxTokens := 280
patchArea := patchSize * patchSize * nMerge * nMerge
minPixels := minTokens * patchArea
maxPixels := maxTokens * patchArea
return ImageProcessor{
patchSize: patchSize,
numChannels: numChannels,
nMerge: nMerge,
minPixels: minPixels,
maxPixels: maxPixels,
}
}
// ProcessImage resizes an image preserving aspect ratio, aligning dimensions
// to (patchSize * nMerge) boundaries, and normalizes pixels to [-1, 1].
// Returns the float32 pixel data and the actual output dimensions.
func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, int, int, error) {
// Compute target size preserving aspect ratio
alignSize := p.patchSize * p.nMerge
targetW, targetH := p.smartResize(img.Bounds().Dx(), img.Bounds().Dy(), alignSize)
// Resize directly without alpha compositing, matching MLX reference.
dst := image.NewRGBA(image.Rect(0, 0, targetW, targetH))
draw.BiLinear.Scale(dst, dst.Bounds(), img, img.Bounds(), draw.Over, nil)
// Normalize to [-1, 1] using mean=0.5, std=0.5: (pixel/255 - 0.5) / 0.5 = 2*pixel/255 - 1
data := p.pack(dst)
return data, targetW, targetH, nil
}
// smartResize computes target dimensions that preserve aspect ratio and
// align to alignSize boundaries. It scales the image to fill the maximum
// patch budget (maxPixels), matching the MLX reference.
func (p *ImageProcessor) smartResize(origW, origH, alignSize int) (int, int) {
totalPx := origW * origH
var targetW, targetH int
if p.maxPixels > 0 && totalPx > 0 {
factor := math.Sqrt(float64(p.maxPixels) / float64(totalPx))
targetH = max(alignSize, int(math.Floor(factor*float64(origH)/float64(alignSize)))*alignSize)
targetW = max(alignSize, int(math.Floor(factor*float64(origW)/float64(alignSize)))*alignSize)
} else {
targetH = max(alignSize, (origH/alignSize)*alignSize)
targetW = max(alignSize, (origW/alignSize)*alignSize)
}
return targetW, targetH
}
// pack extracts RGB values from an image and normalizes to [-1, 1].
// Returns channel-first layout: [R..., G..., B...].
func (p *ImageProcessor) pack(img image.Image) []float32 {
bounds := img.Bounds()
w := bounds.Dx()
h := bounds.Dy()
size := w * h
pixelVals := make([]float32, 3*size)
rOff, gOff, bOff := 0, size, 2*size
for y := bounds.Min.Y; y < bounds.Max.Y; y++ {
for x := bounds.Min.X; x < bounds.Max.X; x++ {
c := img.At(x, y)
r, g, b, _ := c.RGBA()
idx := (y-bounds.Min.Y)*w + (x - bounds.Min.X)
// Normalize [0, 255] -> [-1, 1]: 2 * (val/255) - 1
pixelVals[rOff+idx] = float32(r>>8)/255.0*2.0 - 1.0
pixelVals[gOff+idx] = float32(g>>8)/255.0*2.0 - 1.0
pixelVals[bOff+idx] = float32(b>>8)/255.0*2.0 - 1.0
}
}
return pixelVals
}

View File

@@ -0,0 +1,102 @@
package gemma4
import (
"os"
"testing"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/tokenizer"
)
// TestTokenizerMatchesHF compares our tokenizer output against HuggingFace reference tokens.
func TestTokenizerMatchesHF(t *testing.T) {
modelPath := os.Getenv("GEMMA4_MODEL_PATH")
if modelPath == "" {
t.Skip("set GEMMA4_MODEL_PATH to a gemma4 GGUF file")
}
m, err := model.New(modelPath, ml.BackendParams{AllocMemory: true})
if err != nil {
t.Fatalf("Failed to load model: %v", err)
}
defer m.Backend().Close()
tok := m.(tokenizer.Tokenizer)
tests := []struct {
name string
input string
expected []int32
}{
{
name: "simple",
input: "Hello, world!",
expected: []int32{9259, 236764, 1902, 236888},
},
{
name: "special_tokens",
input: "<|turn>user\nWhat is 2+2?<turn|>\n<|turn>model\n",
expected: []int32{105, 2364, 107, 3689, 563, 236743, 236778, 236862, 236778, 236881, 106, 107, 105, 4368, 107},
},
{
name: "tool_declaration",
input: "<|tool>declaration:bash{description:<|\"|>Run a command<|\"|>}<tool|>",
expected: []int32{46, 163688, 236787, 42422, 236782, 7777, 236787, 52, 7306, 496, 4991, 52, 236783, 47},
},
{
name: "tool_call",
input: "<|tool_call>call:bash{command:<|\"|>ls -la<|\"|>}<tool_call|>",
expected: []int32{48, 6639, 236787, 42422, 236782, 7674, 236787, 52, 5629, 753, 2149, 52, 236783, 49},
},
{
name: "thinking",
input: "<|channel>thought\nLet me think about this...<channel|>The answer is 42.",
expected: []int32{100, 45518, 107, 6481, 786, 1751, 1003, 672, 1390, 101, 818, 3890, 563, 236743, 236812, 236778, 236761},
},
{
name: "code",
input: "func main() { fmt.Println(\"hello\") }",
expected: []int32{6823, 1689, 825, 642, 22766, 236761, 29006, 885, 23391, 1373, 682},
},
{
name: "numbers",
input: "The answer is 42, not 43.5 or -1",
expected: []int32{818, 3890, 563, 236743, 236812, 236778, 236764, 711, 236743, 236812, 236800, 236761, 236810, 653, 753, 236770},
},
{
name: "mixed_chat_with_tools",
input: "<|turn>system\nYou are a helpful assistant.\n<|tool>declaration:get_weather{description:<|\"|>Get weather<|\"|>,parameters:{properties:{city:{type:<|\"|>STRING<|\"|>}},type:<|\"|>OBJECT<|\"|>}}<tool|><turn|>\n<|turn>user\nWhat's the weather in Paris?<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
expected: []int32{105, 9731, 107, 3048, 659, 496, 11045, 16326, 236761, 107, 46, 163688, 236787, 828, 236779, 19323, 236782, 7777, 236787, 52, 3407, 7606, 52, 236764, 19031, 29616, 15921, 29616, 13319, 29616, 2084, 236787, 52, 35410, 52, 5237, 2084, 236787, 52, 60688, 52, 1807, 47, 106, 107, 105, 2364, 107, 3689, 236789, 236751, 506, 7606, 528, 9079, 236881, 106, 107, 105, 4368, 107, 100, 45518, 107, 101},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tokens, err := tok.Encode(tt.input, false) // no BOS
if err != nil {
t.Fatalf("encode error: %v", err)
}
if len(tokens) != len(tt.expected) {
t.Errorf("token count mismatch: got %d, want %d", len(tokens), len(tt.expected))
t.Logf("got: %v", tokens)
t.Logf("want: %v", tt.expected)
return
}
mismatches := 0
for i := range tokens {
if tokens[i] != tt.expected[i] {
mismatches++
if mismatches <= 5 {
t.Errorf("mismatch at [%d]: got %d, want %d", i, tokens[i], tt.expected[i])
}
}
}
if mismatches > 5 {
t.Errorf("... and %d more mismatches", mismatches-5)
}
})
}
}

View File

@@ -0,0 +1,341 @@
package gemma4
// TestGemma4TokenizerMatchesReference verifies our BPE tokenizer matches
// the Rust tokenizers library (the reference implementation) for Gemma 4.
//
// The test loads vocabulary from any local ollama gemma4 GGUF model.
// Skips if no gemma4 model is installed.
//
// Set VERIFY_HF_TOKENIZER=1 to verify against the Rust tokenizers library
// via Python. Requires python3 with tokenizers>=0.21 on PATH:
//
// VERIFY_HF_TOKENIZER=1 go test ./model/models/gemma4/ -run TestGemma4Tokenizer -v
//
// Workflow for adding a new test case:
// 1. Add {name: "...", input: "..."} to the test list (no want field)
// 2. Run with VERIFY_HF_TOKENIZER=1 — it prints the reference IDs
// 3. Paste those IDs into the want field
// 4. Run without VERIFY_HF_TOKENIZER — our tokenizer must match
import (
"encoding/json"
"fmt"
"os"
"os/exec"
"path/filepath"
"strings"
"testing"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/fs/gguf"
"github.com/ollama/ollama/tokenizer"
)
type tokenizerRefCase struct {
name string
input string
want []int32
}
// Reference token IDs generated by the Rust tokenizers library using
// vocab/merges from a gemma4 GGUF with add_special_tokens=False.
var gemma4TokenizerRefCases = []tokenizerRefCase{
// Basic ASCII
{name: "basic word", input: "hello", want: []int32{23391}},
{name: "two words", input: "hello world", want: []int32{23391, 1902}},
{name: "punctuation", input: "Hello, World!", want: []int32{9259, 236764, 4109, 236888}},
// Space handling (pretokenizer bug: GPT-2 splitter mangled leading/multiple spaces)
{name: "leading space", input: " hello", want: []int32{29104}},
{name: "double leading space", input: " hello", want: []int32{138, 23391}},
{name: "double space between words", input: "hello world", want: []int32{23391, 138, 12392}},
{name: "only spaces", input: " ", want: []int32{139}},
{name: "repeated spaces", input: " ", want: []int32{142}},
{name: "leading spaces phrase", input: " leading spaces", want: []int32{5830, 9952}},
{name: "multiple interior spaces", input: "multiple spaces", want: []int32{43819, 140, 35220}},
// Polish diacritics (issue #15231 — Decode mangled U+0105-U+0142)
{name: "polish diacritics", input: "ąęśćżźółń", want: []int32{237198, 237202, 14732, 237277, 238992, 24875, 238041}},
{name: "polish sentence", input: "Zażółć gęślą jaźń", want: []int32{236953, 40512, 24875, 237289, 549, 237202, 62081, 237198, 4828, 238992, 238041}},
// French accents (issue #15229 — Decode mangled U+00E0-U+00FF)
{name: "french accents", input: "café résumé naïve", want: []int32{123125, 236859, 118515, 120362}},
{name: "french with apostrophe", input: "L'élève a mangé", want: []int32{236798, 236789, 161654, 496, 14695, 236859}},
// German umlauts
{name: "german umlauts", input: "über Straße Größe", want: []int32{28223, 80176, 112880}},
// Codepoints in GPT-2 byte reversal range (U+0100-U+0142)
{name: "codepoints in gpt2 byte range", input: "ąęćł", want: []int32{237198, 226110, 237114}},
{name: "latin extended A", input: "ĀāĂ㥹", want: []int32{241920, 237448, 241645, 237106, 243514, 237198}},
// CJK & Japanese
{name: "chinese", input: "你好世界", want: []int32{144626, 12811}},
{name: "japanese hiragana", input: "こんにちは", want: []int32{85141}},
// Mixed scripts
{name: "mixed scripts", input: "hello ąęść world café 你好", want: []int32{23391, 236743, 237198, 237202, 14732, 1902, 33443, 43758, 237389}},
// Whitespace
{name: "empty string", input: "", want: []int32{}},
{name: "newlines", input: "\n\n", want: []int32{108}},
{name: "tabs", input: "\t\t", want: []int32{255969}},
// Code-like content
{name: "python code", input: "def foo(x): return x + 1", want: []int32{2063, 46293, 236769, 236781, 1473, 994, 1123, 900, 236743, 236770}},
{name: "json", input: `{"key": "value"}`, want: []int32{14937, 2478, 1083, 623, 2394, 25938}},
// Misc
{name: "repeated char", input: "aaaaaa", want: []int32{50354, 9236}},
{name: "emoji", input: "hello 👋 world", want: []int32{23391, 155818, 1902}},
{name: "digits", input: "12345", want: []int32{236770, 236778, 236800, 236812, 236810}},
{name: "float", input: "3.14159", want: []int32{236800, 236761, 236770, 236812, 236770, 236810, 236819}},
}
// findGemma4GGUF looks for any gemma4 model GGUF in the local ollama store.
func findGemma4GGUF() (string, error) {
modelsDir := envconfig.Models()
manifestDir := filepath.Join(modelsDir, "manifests", "registry.ollama.ai", "library", "gemma4")
entries, err := os.ReadDir(manifestDir)
if err != nil {
return "", fmt.Errorf("no gemma4 manifests in %s: %w", manifestDir, err)
}
blobDir := filepath.Join(modelsDir, "blobs")
for _, entry := range entries {
if entry.IsDir() {
continue
}
data, err := os.ReadFile(filepath.Join(manifestDir, entry.Name()))
if err != nil {
continue
}
var manifest struct {
Layers []struct {
MediaType string `json:"mediaType"`
Digest string `json:"digest"`
} `json:"layers"`
}
if err := json.Unmarshal(data, &manifest); err != nil {
continue
}
for _, layer := range manifest.Layers {
if layer.MediaType == "application/vnd.ollama.image.model" {
blobPath := filepath.Join(blobDir, strings.Replace(layer.Digest, ":", "-", 1))
if _, err := os.Stat(blobPath); err == nil {
return blobPath, nil
}
}
}
}
return "", fmt.Errorf("no gemma4 model blob found in %s", modelsDir)
}
// loadGemma4Tokenizer opens a GGUF and builds a BPE tokenizer from its
// tokenizer metadata — the same configuration used at inference time.
func loadGemma4Tokenizer(t *testing.T, ggufPath string) tokenizer.BytePairEncoding {
t.Helper()
f, err := gguf.Open(ggufPath)
if err != nil {
t.Fatalf("gguf.Open: %v", err)
}
defer f.Close()
tokens := f.KeyValue("tokenizer.ggml.tokens").Strings()
if len(tokens) == 0 {
t.Fatal("no tokenizer.ggml.tokens in GGUF")
}
scores64 := f.KeyValue("tokenizer.ggml.scores").Floats()
scores := make([]float32, len(scores64))
for i, s := range scores64 {
scores[i] = float32(s)
}
types64 := f.KeyValue("tokenizer.ggml.token_type").Ints()
types := make([]int32, len(types64))
for i, tt := range types64 {
types[i] = int32(tt)
}
merges := f.KeyValue("tokenizer.ggml.merges").Strings()
vocab := &tokenizer.Vocabulary{
Values: tokens,
Types: types,
Scores: scores,
Merges: merges,
BOS: []int32{2},
EOS: []int32{1},
AddBOS: false,
}
return tokenizer.NewBytePairEncodingWithOptions(vocab, []string{},
tokenizer.WithSentencePieceNormalizer())
}
// writeTokenizerJSON reconstructs a tokenizer.json from GGUF metadata
// for the Rust tokenizers library to load as an independent reference.
func writeTokenizerJSON(t *testing.T, ggufPath string) string {
t.Helper()
f, err := gguf.Open(ggufPath)
if err != nil {
t.Fatalf("gguf.Open: %v", err)
}
defer f.Close()
tokens := f.KeyValue("tokenizer.ggml.tokens").Strings()
mergeStrs := f.KeyValue("tokenizer.ggml.merges").Strings()
vocab := make(map[string]int, len(tokens))
for i, tok := range tokens {
vocab[tok] = i
}
merges := make([][2]string, len(mergeStrs))
for i, m := range mergeStrs {
parts := strings.SplitN(m, " ", 2)
if len(parts) == 2 {
merges[i] = [2]string{parts[0], parts[1]}
}
}
tj := map[string]any{
"version": "1.0",
"model": map[string]any{
"type": "BPE",
"vocab": vocab,
"merges": merges,
},
"normalizer": map[string]any{
"type": "Replace",
"pattern": map[string]string{"String": " "},
"content": "\u2581",
},
}
tmpFile, err := os.CreateTemp(t.TempDir(), "gemma4_tokenizer_*.json")
if err != nil {
t.Fatalf("create temp file: %v", err)
}
if err := json.NewEncoder(tmpFile).Encode(tj); err != nil {
tmpFile.Close()
t.Fatalf("encode tokenizer.json: %v", err)
}
tmpFile.Close()
return tmpFile.Name()
}
func TestGemma4TokenizerMatchesReference(t *testing.T) {
ggufPath, err := findGemma4GGUF()
if err != nil {
t.Skipf("skipping: %v", err)
}
t.Logf("using GGUF: %s", ggufPath)
tok := loadGemma4Tokenizer(t, ggufPath)
verify := os.Getenv("VERIFY_HF_TOKENIZER") != ""
var tokenizerJSONPath string
if verify {
if err := exec.Command("python3", "-c", "from tokenizers import Tokenizer").Run(); err != nil {
t.Fatal("VERIFY_HF_TOKENIZER=1 requires python3 with tokenizers>=0.21 on PATH")
}
tokenizerJSONPath = writeTokenizerJSON(t, ggufPath)
defer os.Remove(tokenizerJSONPath)
t.Log("VERIFY_HF_TOKENIZER=1: verifying against Rust tokenizers library")
}
for _, tc := range gemma4TokenizerRefCases {
t.Run(tc.name, func(t *testing.T) {
ids, err := tok.Encode(tc.input, false)
if err != nil {
t.Fatalf("Encode(%q): %v", tc.input, err)
}
if tc.want != nil {
if fmt.Sprint(ids) != fmt.Sprint(tc.want) {
t.Errorf("Encode(%q):\n got: %v\n want: %v", tc.input, ids, tc.want)
}
} else {
t.Errorf("no expected IDs for %q; our tokenizer produced: %v", tc.input, ids)
}
if len(ids) > 0 {
decoded, err := tok.Decode(ids)
if err != nil {
t.Fatalf("Decode: %v", err)
}
if decoded != tc.input {
t.Errorf("roundtrip %q: Decode(Encode) = %q", tc.input, decoded)
}
}
if verify {
refIDs := encodeWithRustTokenizer(t, tokenizerJSONPath, tc.input)
if fmt.Sprint(refIDs) != fmt.Sprint(ids) {
fmt.Fprintf(os.Stderr, "\nREFERENCE OUTPUT for %s (copy-paste as want):\nwant: []int32{%s},\n\n",
tc.name, int32SliceStr(refIDs))
}
if tc.want != nil && fmt.Sprint(refIDs) != fmt.Sprint(tc.want) {
t.Errorf("hardcoded expected IDs don't match reference for %q:\n ref: %v\n hardcoded: %v",
tc.input, refIDs, tc.want)
}
}
})
}
}
func encodeWithRustTokenizer(t *testing.T, tokenizerPath, text string) []int32 {
t.Helper()
if text == "" {
return nil
}
script := fmt.Sprintf(`
from tokenizers import Tokenizer
t = Tokenizer.from_file(%q)
ids = t.encode(%q, add_special_tokens=False).ids
print(",".join(str(i) for i in ids))
`, tokenizerPath, text)
cmd := exec.Command("python3", "-c", script)
var stdout, stderr strings.Builder
cmd.Stdout = &stdout
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
t.Fatalf("python3 failed: %v\nstderr: %s", err, stderr.String())
}
parts := strings.Split(strings.TrimSpace(stdout.String()), ",")
var ids []int32
for _, p := range parts {
if p == "" {
continue
}
var id int32
fmt.Sscanf(p, "%d", &id)
ids = append(ids, id)
}
return ids
}
func int32SliceStr(ids []int32) string {
parts := make([]string, len(ids))
for i, id := range ids {
parts[i] = fmt.Sprintf("%d", id)
}
return strings.Join(parts, ", ")
}