ollama source for Momentry Core verification
This commit is contained in:
265
model/models/gemma4/model.go
Normal file
265
model/models/gemma4/model.go
Normal file
@@ -0,0 +1,265 @@
|
||||
package gemma4
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"image"
|
||||
"log/slog"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
tokenizer.Tokenizer
|
||||
|
||||
*VisionModel `gguf:"v"`
|
||||
*TextModel
|
||||
*AudioModel `gguf:"a"`
|
||||
|
||||
*MultiModalProjector `gguf:"mm"`
|
||||
*AudioMultimodalProjector `gguf:"mm.a"`
|
||||
|
||||
ImageProcessor
|
||||
|
||||
imageTokenID int32
|
||||
imageEndTokenID int32
|
||||
audioTokenID int32
|
||||
audioEndTokenID int32
|
||||
|
||||
audioOpts *AudioModelOptions
|
||||
}
|
||||
|
||||
var _ model.MultimodalProcessor = (*Model)(nil)
|
||||
|
||||
type MultiModalProjector struct {
|
||||
Projection *ClippableLinear `gguf:"input_projection"`
|
||||
}
|
||||
|
||||
func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, eps float32) ml.Tensor {
|
||||
visionOutputs = p.Projection.Forward(ctx, visionOutputs)
|
||||
// Post-projection RMSNorm without learned weight
|
||||
visionOutputs = visionOutputs.RMSNorm(ctx, nil, eps)
|
||||
return visionOutputs
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
vocabulary := tokenizer.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: append(
|
||||
[]int32{
|
||||
int32(c.Uint("tokenizer.ggml.eos_token_id")),
|
||||
},
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
}
|
||||
|
||||
vocabulary.EOS = append(vocabulary.EOS, int32(c.Uint("tokenizer.ggml.eot_token_id", 106)))
|
||||
|
||||
// Gemma 4 uses BPE with SentencePiece-style ▁ space markers (not GPT-2 byte-level encoding).
|
||||
// The tokenizer.json has merges and a Replace normalizer (space → ▁), with no pre-tokenizer.
|
||||
t := tokenizer.NewBytePairEncodingWithOptions(&vocabulary, []string{},
|
||||
tokenizer.WithSentencePieceNormalizer())
|
||||
|
||||
// Look up special token IDs for vision and audio
|
||||
imageTokenID := int32(-1)
|
||||
imageEndTokenID := int32(-1)
|
||||
audioTokenID := int32(-1)
|
||||
audioEndTokenID := int32(-1)
|
||||
for i, tok := range vocabulary.Values {
|
||||
switch tok {
|
||||
case "<|image>":
|
||||
imageTokenID = int32(i)
|
||||
case "<image|>":
|
||||
imageEndTokenID = int32(i)
|
||||
case "<|audio>":
|
||||
audioTokenID = int32(i)
|
||||
case "<audio|>":
|
||||
audioEndTokenID = int32(i)
|
||||
}
|
||||
}
|
||||
|
||||
slog.Info("gemma4: token IDs", "image", imageTokenID, "image_end", imageEndTokenID, "audio", audioTokenID, "audio_end", audioEndTokenID)
|
||||
|
||||
m := Model{
|
||||
Tokenizer: t,
|
||||
TextModel: newTextModel(c),
|
||||
VisionModel: newVisionModel(c),
|
||||
AudioModel: newAudioModel(c),
|
||||
MultiModalProjector: &MultiModalProjector{},
|
||||
AudioMultimodalProjector: &AudioMultimodalProjector{},
|
||||
ImageProcessor: newImageProcessor(c),
|
||||
imageTokenID: imageTokenID,
|
||||
imageEndTokenID: imageEndTokenID,
|
||||
audioTokenID: audioTokenID,
|
||||
audioEndTokenID: audioEndTokenID,
|
||||
audioOpts: newAudioModelOptions(c),
|
||||
}
|
||||
|
||||
slidingWindowLen := int32(c.Uint("attention.sliding_window"))
|
||||
m.Cache = kvcache.NewWrapperCache(
|
||||
kvcache.NewSWAMemCache(slidingWindowLen, 4096, m.Shift),
|
||||
kvcache.NewCausalCache(m.Shift),
|
||||
)
|
||||
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) {
|
||||
// Audio input: detect WAV format and route to audio encoder.
|
||||
if isAudioData(multimodalData) {
|
||||
return m.encodeAudioMultimodal(ctx, multimodalData)
|
||||
}
|
||||
|
||||
if len(m.VisionModel.Layers) == 0 {
|
||||
return nil, model.ErrNoVisionModel
|
||||
}
|
||||
|
||||
t0 := time.Now()
|
||||
img, _, err := image.Decode(bytes.NewReader(multimodalData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
slog.Info("vision: decode", "elapsed", time.Since(t0), "bounds", img.Bounds())
|
||||
|
||||
t1 := time.Now()
|
||||
f32s, imgW, imgH, err := m.ImageProcessor.ProcessImage(img)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
slog.Info("vision: preprocess", "elapsed", time.Since(t1), "size", [2]int{imgW, imgH})
|
||||
|
||||
pixelValues := ctx.Input().FromFloats(f32s, imgW, imgH, m.ImageProcessor.numChannels)
|
||||
slog.Info("vision: pixelValues", "shape", pixelValues.Shape(), "dim0", pixelValues.Dim(0), "dim1", pixelValues.Dim(1), "dim2", pixelValues.Dim(2))
|
||||
|
||||
numPatchesX := imgW / m.ImageProcessor.patchSize
|
||||
numPatchesY := imgH / m.ImageProcessor.patchSize
|
||||
slog.Info("vision: patches", "patchesX", numPatchesX, "patchesY", numPatchesY, "total", numPatchesX*numPatchesY, "patchSize", m.ImageProcessor.patchSize)
|
||||
|
||||
visionOutputs := m.VisionModel.Forward(ctx, pixelValues, numPatchesX, numPatchesY)
|
||||
visionOutputs = visionPoolAndProject(ctx, visionOutputs, numPatchesX, numPatchesY, m.VisionModel.VisionModelOptions, m.MultiModalProjector, m.VisionModel.StdBias, m.VisionModel.StdScale)
|
||||
slog.Info("vision: encoded", "elapsed", time.Since(t0), "shape", visionOutputs.Shape())
|
||||
|
||||
return []input.Multimodal{{Tensor: visionOutputs}}, nil
|
||||
}
|
||||
|
||||
func (m *Model) PostLoad() error {
|
||||
m.VisionModel.InitClamp(m.MultiModalProjector)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Model) encodeAudioMultimodal(ctx ml.Context, data []byte) ([]input.Multimodal, error) {
|
||||
if m.AudioModel == nil || m.audioOpts == nil {
|
||||
return nil, model.ErrNoVisionModel
|
||||
}
|
||||
|
||||
t0 := time.Now()
|
||||
samples, err := decodeWAV(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
slog.Info("audio: decode", "elapsed", time.Since(t0), "samples", len(samples), "duration_s", float64(len(samples))/audioSampleRate)
|
||||
|
||||
// Pad waveform to next multiple of 128.
|
||||
if rem := len(samples) % 128; rem != 0 {
|
||||
samples = append(samples, make([]float32, 128-rem)...)
|
||||
}
|
||||
|
||||
// Compute mel spectrogram.
|
||||
melData, numFrames := computeMelSpectrogram(samples)
|
||||
if numFrames == 0 {
|
||||
return nil, fmt.Errorf("audio too short to encode")
|
||||
}
|
||||
slog.Info("audio: mel", "frames", numFrames, "elapsed", time.Since(t0))
|
||||
|
||||
// Create input tensor [melBins, numFrames] (GGML ne order). FromFloats creates F32.
|
||||
melTensor := ctx.Input().FromFloats(melData, melBins, numFrames)
|
||||
|
||||
// Run audio encoder.
|
||||
audioOutputs := m.AudioModel.ForwardAudio(ctx, melTensor, m.AudioMultimodalProjector, m.audioOpts)
|
||||
slog.Info("audio: encoded", "elapsed", time.Since(t0), "shape", audioOutputs.Shape())
|
||||
|
||||
return []input.Multimodal{{Tensor: audioOutputs, Data: audioTag{}}}, nil
|
||||
}
|
||||
|
||||
// audioTag marks multimodal data as audio (vs vision) for PostTokenize.
|
||||
type audioTag struct{}
|
||||
|
||||
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
var result []*input.Input
|
||||
|
||||
for _, inp := range inputs {
|
||||
if len(inp.Multimodal) == 0 {
|
||||
result = append(result, inp)
|
||||
continue
|
||||
}
|
||||
|
||||
inputMultimodal := inp.Multimodal[0].Tensor
|
||||
numTokens := inputMultimodal.Dim(1)
|
||||
|
||||
// Determine if this is audio or vision based on the tag.
|
||||
_, isAudio := inp.Multimodal[0].Data.(audioTag)
|
||||
|
||||
var beginToken, endToken int32
|
||||
if isAudio {
|
||||
beginToken = m.audioTokenID
|
||||
endToken = m.audioEndTokenID
|
||||
} else {
|
||||
beginToken = m.imageTokenID
|
||||
endToken = m.imageEndTokenID
|
||||
}
|
||||
|
||||
if beginToken >= 0 {
|
||||
result = append(result, &input.Input{Token: beginToken, SameBatch: numTokens + 2})
|
||||
}
|
||||
|
||||
result = append(result,
|
||||
&input.Input{Multimodal: []input.Multimodal{{Tensor: inputMultimodal}}, MultimodalHash: inp.MultimodalHash},
|
||||
)
|
||||
result = append(result, slices.Repeat([]*input.Input{{Token: 0}}, numTokens-1)...)
|
||||
|
||||
if endToken >= 0 {
|
||||
result = append(result, &input.Input{Token: endToken})
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
hiddenState := m.TextModel.Forward(ctx, batch, m.Cache)
|
||||
|
||||
hiddenState = m.TextModel.Output.Forward(ctx, hiddenState)
|
||||
|
||||
if m.TextModel.TextOptions.finalLogitSoftcap > 0.0 {
|
||||
hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.TextModel.TextOptions.finalLogitSoftcap))
|
||||
hiddenState = hiddenState.Tanh(ctx)
|
||||
hiddenState = hiddenState.Scale(ctx, float64(m.TextModel.TextOptions.finalLogitSoftcap))
|
||||
}
|
||||
|
||||
return hiddenState, nil
|
||||
}
|
||||
|
||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
ropeBase, ropeDims := m.TextModel.ropeForLayer(layer)
|
||||
return nn.RoPE(ctx, key, shift, ropeDims, ropeBase, 1.0, rope.WithTypeNeoX()), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
model.Register("gemma4", New)
|
||||
}
|
||||
611
model/models/gemma4/model_audio.go
Normal file
611
model/models/gemma4/model_audio.go
Normal file
@@ -0,0 +1,611 @@
|
||||
package gemma4
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
)
|
||||
|
||||
// AudioModel holds the audio encoder and configuration.
|
||||
type AudioModel struct {
|
||||
// SSCP: Sub-Sample Convolution Projection.
|
||||
SSCPConv0 *AudioConvBlock `gguf:"conv1d.0"`
|
||||
SSCPConv1 *AudioConvBlock `gguf:"conv1d.1"`
|
||||
|
||||
// SSCP output projection (linear).
|
||||
SSCPInputProj *nn.Linear `gguf:"pre_encode.out"`
|
||||
|
||||
// Conformer blocks.
|
||||
Layers []AudioConformerBlock `gguf:"blk"`
|
||||
|
||||
// Output projection to embedder dimension.
|
||||
OutputProj *AudioOutputProj `gguf:"output_proj"`
|
||||
|
||||
AudioModelOptions
|
||||
}
|
||||
|
||||
type AudioOutputProj struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
Bias ml.Tensor `gguf:"bias"`
|
||||
}
|
||||
|
||||
// AudioModelOptions holds audio model hyperparameters.
|
||||
type AudioModelOptions struct {
|
||||
hiddenSize int
|
||||
numHeads int
|
||||
headDim int
|
||||
ffnSize int
|
||||
numLayers int
|
||||
melBins int
|
||||
chunkSize int
|
||||
maxPast int
|
||||
maxFuture int
|
||||
contextSize int
|
||||
logitCap float32
|
||||
residualWeight float32
|
||||
gradClip float32
|
||||
convKernelSize int
|
||||
eps float32
|
||||
}
|
||||
|
||||
// AudioConvBlock is a single 2D convolution block for the SSCP.
|
||||
type AudioConvBlock struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
Norm *nn.LayerNorm `gguf:"norm"`
|
||||
}
|
||||
|
||||
// AudioConformerBlock is a single conformer layer.
|
||||
// All tensors are flat at the block level (a.blk.N.<name>) using underscore naming.
|
||||
type AudioConformerBlock struct {
|
||||
// Block-level norm
|
||||
Norm *nn.RMSNorm `gguf:"layer_pre_norm"`
|
||||
|
||||
// FFW start
|
||||
FFWNorm *nn.RMSNorm `gguf:"ffn_norm"`
|
||||
FFWUp *AudioClippableLinear `gguf:"ffn_up"`
|
||||
FFWDown *AudioClippableLinear `gguf:"ffn_down"`
|
||||
FFWPostNorm *nn.RMSNorm `gguf:"ffn_post_norm"`
|
||||
|
||||
// FFW end
|
||||
FFWNorm1 *nn.RMSNorm `gguf:"ffn_norm_1"`
|
||||
FFWUp1 *AudioClippableLinear `gguf:"ffn_up_1"`
|
||||
FFWDown1 *AudioClippableLinear `gguf:"ffn_down_1"`
|
||||
FFWPostNorm1 *nn.RMSNorm `gguf:"ffn_post_norm_1"`
|
||||
|
||||
// Attention
|
||||
AttnQ *AudioClippableLinear `gguf:"attn_q"`
|
||||
AttnK *AudioClippableLinear `gguf:"attn_k"`
|
||||
AttnV *AudioClippableLinear `gguf:"attn_v"`
|
||||
AttnOut *AudioClippableLinear `gguf:"attn_out"`
|
||||
AttnPreNorm *nn.RMSNorm `gguf:"ln1"`
|
||||
AttnPostNorm *nn.RMSNorm `gguf:"ln2"`
|
||||
LinearPos ml.Tensor `gguf:"linear_pos.weight"`
|
||||
PerDimScale ml.Tensor `gguf:"per_dim_scale.weight"`
|
||||
|
||||
// LightConv1d
|
||||
ConvPW1 *AudioClippableLinear `gguf:"conv_pw1"`
|
||||
ConvPW2 *AudioClippableLinear `gguf:"conv_pw2"`
|
||||
ConvDW ml.Tensor `gguf:"conv_dw.weight"`
|
||||
ConvNorm *nn.RMSNorm `gguf:"conv_norm"`
|
||||
NormConv *nn.RMSNorm `gguf:"norm_conv"`
|
||||
}
|
||||
|
||||
// AudioClippableLinear is a linear layer with optional input/output clamping.
|
||||
type AudioClippableLinear struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
Bias ml.Tensor `gguf:"bias"`
|
||||
InputMin ml.Tensor `gguf:"input_min"`
|
||||
InputMax ml.Tensor `gguf:"input_max"`
|
||||
OutputMin ml.Tensor `gguf:"output_min"`
|
||||
OutputMax ml.Tensor `gguf:"output_max"`
|
||||
|
||||
// Cached scalar clamp values (populated on first forward).
|
||||
inMin, inMax, outMin, outMax float32
|
||||
clampsLoaded bool
|
||||
}
|
||||
|
||||
func (l *AudioClippableLinear) loadClamps() {
|
||||
if l.clampsLoaded {
|
||||
return
|
||||
}
|
||||
l.clampsLoaded = true
|
||||
if l.InputMin != nil {
|
||||
vals := l.InputMin.BackendGet()
|
||||
if len(vals) > 0 {
|
||||
l.inMin = vals[0]
|
||||
}
|
||||
}
|
||||
if l.InputMax != nil {
|
||||
vals := l.InputMax.BackendGet()
|
||||
if len(vals) > 0 {
|
||||
l.inMax = vals[0]
|
||||
}
|
||||
}
|
||||
if l.OutputMin != nil {
|
||||
vals := l.OutputMin.BackendGet()
|
||||
if len(vals) > 0 {
|
||||
l.outMin = vals[0]
|
||||
}
|
||||
}
|
||||
if l.OutputMax != nil {
|
||||
vals := l.OutputMax.BackendGet()
|
||||
if len(vals) > 0 {
|
||||
l.outMax = vals[0]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *AudioClippableLinear) Forward(ctx ml.Context, x ml.Tensor) ml.Tensor {
|
||||
l.loadClamps()
|
||||
if l.inMax != 0 {
|
||||
x = x.Clamp(ctx, l.inMin, l.inMax)
|
||||
}
|
||||
out := l.Weight.Mulmat(ctx, x)
|
||||
if l.Bias != nil {
|
||||
out = out.Add(ctx, l.Bias)
|
||||
}
|
||||
if l.outMax != 0 {
|
||||
out = out.Clamp(ctx, l.outMin, l.outMax)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// AudioMultimodalProjector is the audio-to-text embedding projector.
|
||||
type AudioMultimodalProjector struct {
|
||||
Projection *AudioClippableLinear `gguf:"input_projection"`
|
||||
FC *AudioFC `gguf:"fc"`
|
||||
}
|
||||
|
||||
type AudioFC struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
Bias ml.Tensor `gguf:"bias"`
|
||||
}
|
||||
|
||||
func (p *AudioMultimodalProjector) Forward(ctx ml.Context, x ml.Tensor, eps float32) ml.Tensor {
|
||||
// FC: output projection from conformer to embedder dimension.
|
||||
x = p.FC.Weight.Mulmat(ctx, x)
|
||||
if p.FC.Bias != nil {
|
||||
x = x.Add(ctx, p.FC.Bias)
|
||||
}
|
||||
// Pre-projection RMSNorm (without learned weight) — matches Python's embedding_pre_projection_norm.
|
||||
x = x.RMSNorm(ctx, nil, eps)
|
||||
// Embedding projection to text hidden size.
|
||||
x = p.Projection.Forward(ctx, x)
|
||||
return x
|
||||
}
|
||||
|
||||
// ForwardAudio encodes mel spectrogram features into soft tokens.
|
||||
// melFeatures: float32 tensor with ne[0]=melBins, ne[1]=numFrames.
|
||||
// Returns: [hiddenSize, numTokens] tensor.
|
||||
func (m *AudioModel) ForwardAudio(ctx ml.Context, melFeatures ml.Tensor, proj *AudioMultimodalProjector, opts *AudioModelOptions) ml.Tensor {
|
||||
// SSCP Conv2D input: ne[0]=F (freq/width), ne[1]=T (time/height), ne[2]=C_in, ne[3]=B
|
||||
// melFeatures is [melBins, numFrames], add channel and batch dims.
|
||||
x := melFeatures.Reshape(ctx, melFeatures.Dim(0), melFeatures.Dim(1), 1, 1)
|
||||
|
||||
// SSCP Conv block 0: [F, T, 1, 1] → [F', T', C0, 1]
|
||||
x = forwardConvBlock(ctx, m.SSCPConv0, x, opts)
|
||||
|
||||
// SSCP Conv block 1: [F', T', C0, 1] → [F'', T'', C1, 1]
|
||||
x = forwardConvBlock(ctx, m.SSCPConv1, x, opts)
|
||||
|
||||
// After conv blocks, layout is [F'', T'', C_out, B].
|
||||
// Permute to [C_out*F'', T'', B] for linear projection (channels+freq in ne[0]).
|
||||
fOut := x.Dim(0)
|
||||
tOut := x.Dim(1)
|
||||
cOut := x.Dim(2)
|
||||
// Permute [F'', T'', C, B] → [C, F'', T'', B]
|
||||
// (1,2,0,3): old[0]→pos1, old[1]→pos2, old[2]→pos0
|
||||
x = x.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
x = x.Reshape(ctx, cOut*fOut, tOut)
|
||||
|
||||
// Linear projection to hidden size.
|
||||
x = m.SSCPInputProj.Forward(ctx, x)
|
||||
|
||||
// Build causal-valid mask for conformer attention.
|
||||
causalMask := buildCausalValidMaskF32(opts.chunkSize, opts.maxPast, opts.maxFuture)
|
||||
|
||||
// Run conformer blocks.
|
||||
for i := range m.Layers {
|
||||
x = m.Layers[i].Forward(ctx, x, causalMask, opts, i)
|
||||
}
|
||||
|
||||
// Output projection.
|
||||
if m.OutputProj != nil {
|
||||
x = m.OutputProj.Weight.Mulmat(ctx, x)
|
||||
if m.OutputProj.Bias != nil {
|
||||
x = x.Add(ctx, m.OutputProj.Bias)
|
||||
}
|
||||
}
|
||||
|
||||
// Audio embedder: project to text embedding space.
|
||||
if proj != nil {
|
||||
x = proj.Forward(ctx, x, opts.eps)
|
||||
}
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// forwardConvBlock runs a single SSCP Conv2D block.
|
||||
// Conv2D receiver is the kernel, argument is the input data.
|
||||
// Input: [F, T, C_in, B]. Output: [F', T', C_out, B].
|
||||
func forwardConvBlock(ctx ml.Context, block *AudioConvBlock, x ml.Tensor, opts *AudioModelOptions) ml.Tensor {
|
||||
// Conv2D: kernel.Conv2D(ctx, input, s0, s1, p0, p1, d0, d1)
|
||||
// Kernel is 3x3, stride 2x2, padding 1x1 (matching SSCP config).
|
||||
// Output layout: [F', T', C_out, B]
|
||||
// Make weight contiguous — the shape reversal in the converter creates
|
||||
// a tensor where the physical data order doesn't match ne[]/stride[].
|
||||
weight := block.Weight.Contiguous(ctx)
|
||||
x = weight.Conv2D(ctx, x, 2, 2, 1, 1, 1, 1)
|
||||
|
||||
// LayerNorm needs channels in ne[0]. Permute [F', T', C_out, B] → [C_out, F', T', B],
|
||||
// norm, then permute back.
|
||||
// GGML permute: axis i says where old axis i goes.
|
||||
// (1,2,0,3): old[0]→pos1, old[1]→pos2, old[2]→pos0 → [C_out, F', T', B]
|
||||
x = x.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
x = block.Norm.Forward(ctx, x, opts.eps)
|
||||
// (2,0,1,3): old[0]→pos2, old[1]→pos0, old[2]→pos1 → [F', T', C_out, B]
|
||||
x = x.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
|
||||
|
||||
x = x.RELU(ctx)
|
||||
return x
|
||||
}
|
||||
|
||||
// Forward runs a single conformer block.
|
||||
func (cb *AudioConformerBlock) Forward(ctx ml.Context, x ml.Tensor, causalMask []float32, opts *AudioModelOptions, blockIdx int) ml.Tensor {
|
||||
// FFW start (half-residual).
|
||||
x = cb.forwardFFW(ctx, cb.FFWNorm, cb.FFWUp, cb.FFWDown, cb.FFWPostNorm, x, opts)
|
||||
|
||||
// Self-attention.
|
||||
x = cb.forwardAttention(ctx, x, causalMask, opts, blockIdx)
|
||||
|
||||
// Lightweight Conv1d.
|
||||
x = cb.forwardLightConv(ctx, x, opts, blockIdx)
|
||||
|
||||
// FFW end (half-residual).
|
||||
x = cb.forwardFFW(ctx, cb.FFWNorm1, cb.FFWUp1, cb.FFWDown1, cb.FFWPostNorm1, x, opts)
|
||||
|
||||
// Gradient clipping + final norm.
|
||||
x = x.Clamp(ctx, -opts.gradClip, opts.gradClip)
|
||||
x = cb.Norm.Forward(ctx, x, opts.eps)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// forwardFFW runs a feedforward module with half-residual connection.
|
||||
func (cb *AudioConformerBlock) forwardFFW(ctx ml.Context, preNorm *nn.RMSNorm, up, down *AudioClippableLinear, postNorm *nn.RMSNorm, x ml.Tensor, opts *AudioModelOptions) ml.Tensor {
|
||||
residual := x
|
||||
x = x.Clamp(ctx, -opts.gradClip, opts.gradClip)
|
||||
x = preNorm.Forward(ctx, x, opts.eps)
|
||||
x = up.Forward(ctx, x)
|
||||
x = x.SILU(ctx)
|
||||
x = down.Forward(ctx, x)
|
||||
x = x.Clamp(ctx, -opts.gradClip, opts.gradClip)
|
||||
x = postNorm.Forward(ctx, x, opts.eps)
|
||||
x = x.Scale(ctx, float64(opts.residualWeight))
|
||||
return residual.Add(ctx, x)
|
||||
}
|
||||
|
||||
// forwardAttention runs the conformer block-local attention with relative position embeddings.
|
||||
func (cb *AudioConformerBlock) forwardAttention(ctx ml.Context, x ml.Tensor, causalMask []float32, opts *AudioModelOptions, blockIdx int) ml.Tensor {
|
||||
residual := x
|
||||
x = x.Clamp(ctx, -opts.gradClip, opts.gradClip)
|
||||
x = cb.AttnPreNorm.Forward(ctx, x, opts.eps)
|
||||
|
||||
hiddenSize := x.Dim(0)
|
||||
seqLen := x.Dim(1)
|
||||
|
||||
// QKV projections: [hiddenSize, seqLen] → [headDim, numHeads, seqLen]
|
||||
q := cb.AttnQ.Forward(ctx, x).Reshape(ctx, opts.headDim, opts.numHeads, seqLen)
|
||||
k := cb.AttnK.Forward(ctx, x).Reshape(ctx, opts.headDim, opts.numHeads, seqLen)
|
||||
v := cb.AttnV.Forward(ctx, x).Reshape(ctx, opts.headDim, opts.numHeads, seqLen)
|
||||
|
||||
// Per-dim scaling for queries: (headDim^-0.5 / log(2)) * softplus(per_dim_scale)
|
||||
// per_dim_scale is already softplus'd from the converter.
|
||||
qScale := float64(math.Pow(float64(opts.headDim), -0.5)) / math.Log(2)
|
||||
q = q.Scale(ctx, qScale)
|
||||
if cb.PerDimScale != nil {
|
||||
q = q.Mul(ctx, cb.PerDimScale)
|
||||
}
|
||||
|
||||
// Key scaling: softplus(1) / log(2) — matches the query base scaling convention.
|
||||
kScale := math.Log(1+math.E) / math.Log(2)
|
||||
k = k.Scale(ctx, kScale)
|
||||
|
||||
// Build sinusoidal position embeddings for the block-local context.
|
||||
maxSpan := opts.maxPast + opts.maxFuture + 1 // 13 unique relative positions
|
||||
posEmb := cb.buildPositionEmbeddings(ctx, maxSpan, opts)
|
||||
// posEmb: [headDim, numHeads, maxSpan]
|
||||
|
||||
// Block-local attention: process chunks of size chunkSize.
|
||||
chunkSize := opts.chunkSize
|
||||
numChunks := (seqLen + chunkSize - 1) / chunkSize
|
||||
contextSize := opts.contextSize
|
||||
|
||||
// Pad q/k/v to multiple of chunkSize on the time dimension (dim 2).
|
||||
padT := numChunks*chunkSize - seqLen
|
||||
if padT > 0 {
|
||||
q = q.Pad(ctx, 0, 0, padT, 0)
|
||||
k = k.Pad(ctx, 0, 0, padT, 0)
|
||||
v = v.Pad(ctx, 0, 0, padT, 0)
|
||||
}
|
||||
paddedLen := numChunks * chunkSize
|
||||
|
||||
// Pad k/v for context extraction: add maxPast on left, (maxFuture+chunkSize-1) on right.
|
||||
// Use Pad (right) + PadExt (left) workaround since PadExt+Slice has issues.
|
||||
// Actually use Concat with zero tensors for reliable left-padding.
|
||||
padLeft := opts.maxPast
|
||||
padRight := opts.maxFuture + chunkSize - 1
|
||||
zeroLeft := ctx.Input().FromFloats(make([]float32, opts.headDim*opts.numHeads*padLeft), opts.headDim, opts.numHeads, padLeft)
|
||||
zeroRight := ctx.Input().FromFloats(make([]float32, opts.headDim*opts.numHeads*padRight), opts.headDim, opts.numHeads, padRight)
|
||||
kPadded := zeroLeft.Concat(ctx, k, 2).Concat(ctx, zeroRight, 2)
|
||||
vPadded := zeroLeft.Concat(ctx, v, 2).Concat(ctx, zeroRight, 2)
|
||||
|
||||
// Reshape q into chunks: [headDim, numHeads, numChunks, chunkSize]
|
||||
qChunked := q.Reshape(ctx, opts.headDim, opts.numHeads, numChunks, chunkSize)
|
||||
|
||||
// Process each chunk and collect results.
|
||||
chunkOutputs := make([]ml.Tensor, numChunks)
|
||||
for u := range numChunks {
|
||||
// Extract query block: [headDim, numHeads, 1, chunkSize] → [headDim, numHeads, chunkSize]
|
||||
qBlock := qChunked.Slice(ctx, 2, u, u+1, 1).Reshape(ctx, opts.headDim, opts.numHeads, chunkSize)
|
||||
|
||||
// Extract key/value context: [headDim, numHeads, contextSize]
|
||||
cStart := u * chunkSize // offset in kPadded (padLeft already accounts for left context)
|
||||
kCtx := kPadded.Slice(ctx, 2, cStart, cStart+contextSize, 1).Contiguous(ctx)
|
||||
vCtx := vPadded.Slice(ctx, 2, cStart, cStart+contextSize, 1).Contiguous(ctx)
|
||||
|
||||
// Content-content logits: qBlock^T @ kCtx → [chunkSize, contextSize] per head.
|
||||
// Mulmat(a, b) = a^T @ b. We want Q^T K, so: kCtx.Mulmat(qBlock) but that gives
|
||||
// [numHeads, chunkSize, contextSize] with wrong batching.
|
||||
// Instead: permute to [headDim, chunkSize, numHeads] and [headDim, contextSize, numHeads]
|
||||
// then Mulmat batches over numHeads.
|
||||
// GGML permute(0,2,1,3): old[0]→0, old[1]→2, old[2]→1
|
||||
qP := qBlock.Permute(ctx, 0, 2, 1, 3) // [headDim, chunkSize, numHeads]
|
||||
kP := kCtx.Permute(ctx, 0, 2, 1, 3) // [headDim, contextSize, numHeads]
|
||||
|
||||
termAC := kP.MulmatFullPrec(ctx, qP) // [contextSize, chunkSize, numHeads]
|
||||
|
||||
// Content-position logits: qBlock^T @ posEmb → [chunkSize, maxSpan] per head.
|
||||
pP := posEmb.Permute(ctx, 0, 2, 1, 3) // [headDim, maxSpan, numHeads]
|
||||
termBDRaw := pP.MulmatFullPrec(ctx, qP) // [maxSpan, chunkSize, numHeads]
|
||||
|
||||
// Relative shift: [maxSpan, chunkSize, numHeads] → [contextSize, chunkSize, numHeads]
|
||||
termBD := cb.relativeShiftGGML(ctx, termBDRaw, maxSpan, chunkSize, contextSize, opts.numHeads)
|
||||
|
||||
// Combined logits.
|
||||
logits := termAC.Add(ctx, termBD)
|
||||
|
||||
// Logit softcap: tanh(logits / cap) * cap
|
||||
logits = logits.Scale(ctx, 1.0/float64(opts.logitCap))
|
||||
logits = logits.Tanh(ctx)
|
||||
logits = logits.Scale(ctx, float64(opts.logitCap))
|
||||
|
||||
// Apply combined causal + validity mask.
|
||||
// causalMask [chunkSize * contextSize]: 1=causal-allowed, 0=masked.
|
||||
// Validity: context positions before the actual sequence start are invalid.
|
||||
// For chunk u, context position c corresponds to actual time: u*chunkSize + c - padLeft.
|
||||
// Valid if 0 <= actual_time < seqLen.
|
||||
// Mask tensor layout: [contextSize, chunkSize, 1] with ne[0]=contextSize contiguous.
|
||||
// Element at (context=j, chunk=i) is at flat index: i*contextSize + j.
|
||||
maskData := make([]float32, contextSize*chunkSize)
|
||||
for i := range chunkSize {
|
||||
for j := range contextSize {
|
||||
actualTime := u*chunkSize + j - padLeft
|
||||
causalOK := causalMask[i*contextSize+j] > 0
|
||||
validOK := actualTime >= 0 && actualTime < seqLen
|
||||
if causalOK && validOK {
|
||||
maskData[i*contextSize+j] = 0
|
||||
} else {
|
||||
maskData[i*contextSize+j] = -1e9
|
||||
}
|
||||
}
|
||||
}
|
||||
mask := ctx.Input().FromFloats(maskData, contextSize, chunkSize, 1) // 3D for broadcasting over numHeads
|
||||
logits = logits.Add(ctx, mask)
|
||||
|
||||
// Softmax over context dimension (dim 0 = contextSize).
|
||||
logits = logits.Softmax(ctx) // softmax over ne[0]=contextSize
|
||||
|
||||
// Weighted sum: logits^T @ vCtx.
|
||||
// logits: [contextSize, chunkSize, numHeads], vCtx: [headDim, numHeads, contextSize]
|
||||
// vCtx permuted: [headDim, contextSize, numHeads]
|
||||
vP := vCtx.Permute(ctx, 0, 2, 1, 3) // [headDim, contextSize, numHeads]
|
||||
// Weighted sum: for each head, value[headDim, contextSize] @ weights[contextSize, chunkSize]
|
||||
// = [headDim, chunkSize].
|
||||
// Mulmat(a, b) = a^T @ b. Need a=[contextSize, headDim, numHeads], b=[contextSize, chunkSize, numHeads].
|
||||
vPT := vP.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) // [contextSize, headDim, numHeads]
|
||||
chunkOut := vPT.Mulmat(ctx, logits) // [headDim, chunkSize, numHeads]
|
||||
|
||||
// Permute back to [headDim, numHeads, chunkSize]
|
||||
chunkOut = chunkOut.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
chunkOutputs[u] = chunkOut
|
||||
}
|
||||
|
||||
// Concatenate chunk outputs along time dimension.
|
||||
var attnOut ml.Tensor
|
||||
if numChunks == 1 {
|
||||
attnOut = chunkOutputs[0]
|
||||
} else {
|
||||
attnOut = chunkOutputs[0]
|
||||
for _, co := range chunkOutputs[1:] {
|
||||
attnOut = attnOut.Concat(ctx, co, 2)
|
||||
}
|
||||
}
|
||||
|
||||
// Trim to original sequence length if we padded.
|
||||
if paddedLen > seqLen {
|
||||
attnOut = attnOut.Slice(ctx, 2, 0, seqLen, 1).Contiguous(ctx)
|
||||
}
|
||||
|
||||
// Reshape to [hiddenSize, seqLen] and project.
|
||||
attnOut = attnOut.Reshape(ctx, hiddenSize, seqLen)
|
||||
x = cb.AttnOut.Forward(ctx, attnOut)
|
||||
x = x.Clamp(ctx, -opts.gradClip, opts.gradClip)
|
||||
x = cb.AttnPostNorm.Forward(ctx, x, opts.eps)
|
||||
|
||||
return residual.Add(ctx, x)
|
||||
}
|
||||
|
||||
// buildPositionEmbeddings builds sinusoidal position embeddings and projects through linear_pos.
|
||||
// Returns [headDim, numHeads, maxSpan] tensor.
|
||||
func (cb *AudioConformerBlock) buildPositionEmbeddings(ctx ml.Context, maxSpan int, opts *AudioModelOptions) ml.Tensor {
|
||||
halfDim := opts.hiddenSize / 2
|
||||
hiddenSize := opts.hiddenSize
|
||||
|
||||
// inv_timescales: exp(-i * log(10000) / max(D/2-1, 1))
|
||||
logInc := math.Log(10000.0) / math.Max(float64(halfDim-1), 1)
|
||||
|
||||
// Sinusoidal embeddings for relative positions [maxPast, maxPast-1, ..., -maxFuture].
|
||||
posData := make([]float32, hiddenSize*maxSpan)
|
||||
for p := range maxSpan {
|
||||
relPos := float64(opts.maxPast - p)
|
||||
for d := range halfDim {
|
||||
angle := relPos * math.Exp(float64(-d)*logInc)
|
||||
posData[p*hiddenSize+d] = float32(math.Sin(angle))
|
||||
posData[p*hiddenSize+halfDim+d] = float32(math.Cos(angle))
|
||||
}
|
||||
}
|
||||
|
||||
// Create [hiddenSize, maxSpan] input tensor.
|
||||
posEmb := ctx.Input().FromFloats(posData, hiddenSize, maxSpan)
|
||||
|
||||
// Project through linear_pos: [hiddenSize, maxSpan] → Mulmat → [numHeads*headDim, maxSpan]
|
||||
projPos := cb.LinearPos.Mulmat(ctx, posEmb)
|
||||
|
||||
// Reshape to [headDim, numHeads, maxSpan].
|
||||
return projPos.Reshape(ctx, opts.headDim, opts.numHeads, maxSpan)
|
||||
}
|
||||
|
||||
// relativeShiftGGML performs the relative shift to extract correct position logits.
|
||||
// Input: [maxSpan, chunkSize, numHeads]. Output: [contextSize, chunkSize, numHeads].
|
||||
func (cb *AudioConformerBlock) relativeShiftGGML(ctx ml.Context, x ml.Tensor, maxSpan, chunkSize, contextSize, numHeads int) ml.Tensor {
|
||||
// The shift trick: pad ne[0] to contextSize+1, reshape to flatten first two dims,
|
||||
// skip first (contextSize+1-maxSpan) elements, take contextSize*chunkSize elements, reshape back.
|
||||
padAmt := contextSize + 1 - maxSpan
|
||||
if padAmt > 0 {
|
||||
x = x.Pad(ctx, padAmt, 0, 0, 0) // [maxSpan+padAmt, chunkSize, numHeads] = [contextSize+1, chunkSize, numHeads]
|
||||
}
|
||||
// Reshape to [(contextSize+1)*chunkSize, numHeads]
|
||||
x = x.Reshape(ctx, (contextSize+1)*chunkSize, numHeads)
|
||||
// Take the first contextSize*chunkSize elements (the standard relative shift trick).
|
||||
x = x.Slice(ctx, 0, 0, contextSize*chunkSize, 1).Contiguous(ctx)
|
||||
// Reshape to [contextSize, chunkSize, numHeads]
|
||||
return x.Reshape(ctx, contextSize, chunkSize, numHeads)
|
||||
}
|
||||
|
||||
// forwardLightConv runs the lightweight depthwise convolution module.
|
||||
func (cb *AudioConformerBlock) forwardLightConv(ctx ml.Context, x ml.Tensor, opts *AudioModelOptions, blockIdx int) ml.Tensor {
|
||||
residual := x
|
||||
|
||||
x = cb.ConvNorm.Forward(ctx, x, opts.eps)
|
||||
x = cb.ConvPW1.Forward(ctx, x) // [2*D, T, B]
|
||||
|
||||
// GLU: split in half along dim 0, sigmoid gate, multiply.
|
||||
d := x.Dim(0) / 2
|
||||
data := x.Slice(ctx, 0, 0, d, 1).Contiguous(ctx)
|
||||
gate := x.Slice(ctx, 0, d, d*2, 1).Contiguous(ctx).Sigmoid(ctx)
|
||||
x = data.Mul(ctx, gate) // [D, T, B]
|
||||
|
||||
// Depthwise Conv1d: manual implementation using model weight tensor slices.
|
||||
// Kernel cb.ConvDW shape: [K=5, D=1024] (ne[0]=K, ne[1]=D) after shape reversal.
|
||||
// Actually in GGML, ne[0]=K=5 contiguous, ne[1]=D=1024.
|
||||
// We need per-tap weights [D] and shifted input copies.
|
||||
kernelSize := cb.ConvDW.Dim(0) // K=5
|
||||
seqLen := x.Dim(1)
|
||||
|
||||
// Transpose kernel to [D, K] for per-tap slicing.
|
||||
// GGML permute(1,0,2,3): old[0]→pos1, old[1]→pos0 → swap ne[0] and ne[1]
|
||||
kernelT := cb.ConvDW.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) // [D, K]
|
||||
|
||||
var convOut ml.Tensor
|
||||
for k := range kernelSize {
|
||||
shift := kernelSize - 1 - k
|
||||
var shifted ml.Tensor
|
||||
if shift == 0 {
|
||||
shifted = x
|
||||
} else {
|
||||
trimmed := x.Slice(ctx, 1, 0, seqLen-shift, 1).Contiguous(ctx)
|
||||
shifted = trimmed.PadExt(ctx, 0, 0, shift, 0, 0, 0, 0, 0)
|
||||
}
|
||||
|
||||
wk := kernelT.Slice(ctx, 1, k, k+1, 1).Contiguous(ctx) // [D, 1]
|
||||
term := shifted.Mul(ctx, wk)
|
||||
if convOut == nil {
|
||||
convOut = term
|
||||
} else {
|
||||
convOut = convOut.Add(ctx, term)
|
||||
}
|
||||
}
|
||||
x = convOut
|
||||
|
||||
x = x.Clamp(ctx, -opts.gradClip, opts.gradClip)
|
||||
x = cb.NormConv.Forward(ctx, x, opts.eps)
|
||||
x = x.SILU(ctx)
|
||||
x = cb.ConvPW2.Forward(ctx, x)
|
||||
|
||||
return x.Add(ctx, residual)
|
||||
}
|
||||
|
||||
func newAudioModel(c fs.Config) *AudioModel {
|
||||
numLayers := int(c.Uint("audio.block_count", 0))
|
||||
if numLayers == 0 {
|
||||
return nil
|
||||
}
|
||||
return &AudioModel{
|
||||
Layers: make([]AudioConformerBlock, numLayers),
|
||||
}
|
||||
}
|
||||
|
||||
func newAudioModelOptions(c fs.Config) *AudioModelOptions {
|
||||
hiddenSize := int(c.Uint("audio.embedding_length", 0))
|
||||
if hiddenSize == 0 {
|
||||
return nil
|
||||
}
|
||||
numHeads := int(c.Uint("audio.attention.head_count", 8))
|
||||
headDim := hiddenSize / numHeads
|
||||
chunkSize := 12 // default conformer chunk size
|
||||
maxPast := 12 // conf_attention_context_left - 1
|
||||
maxFuture := 0 // conf_attention_context_right
|
||||
convKernel := int(c.Uint("audio.conv_kernel_size", 5))
|
||||
|
||||
eps := c.Float("audio.attention.layer_norm_epsilon", 1e-6)
|
||||
|
||||
return &AudioModelOptions{
|
||||
hiddenSize: hiddenSize,
|
||||
numHeads: numHeads,
|
||||
headDim: headDim,
|
||||
ffnSize: int(c.Uint("audio.feed_forward_length", uint32(hiddenSize*4))),
|
||||
numLayers: int(c.Uint("audio.block_count", 12)),
|
||||
melBins: int(c.Uint("audio.num_mel_bins", 128)),
|
||||
chunkSize: chunkSize,
|
||||
maxPast: maxPast,
|
||||
maxFuture: maxFuture,
|
||||
contextSize: chunkSize + maxPast + maxFuture,
|
||||
logitCap: 50.0,
|
||||
residualWeight: 0.5,
|
||||
gradClip: 1e10,
|
||||
convKernelSize: convKernel,
|
||||
eps: float32(eps),
|
||||
}
|
||||
}
|
||||
|
||||
// buildCausalValidMaskF32 creates the causal-valid mask for block-local attention.
|
||||
// Returns flat [chunkSize * contextSize] float32 data (1.0 = allowed, 0.0 = masked).
|
||||
func buildCausalValidMaskF32(chunkSize, maxPast, maxFuture int) []float32 {
|
||||
contextSize := chunkSize + maxPast + maxFuture
|
||||
upperDiag := maxPast + maxFuture
|
||||
|
||||
result := make([]float32, chunkSize*contextSize)
|
||||
for r := range chunkSize {
|
||||
for c := range contextSize {
|
||||
lower := (r <= c) // tril(contextSize, chunkSize) transposed
|
||||
upper := (c <= r+upperDiag) // tril(chunkSize, contextSize, diag=upperDiag)
|
||||
if lower && upper {
|
||||
result[r*contextSize+c] = 1.0
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
475
model/models/gemma4/model_text.go
Normal file
475
model/models/gemma4/model_text.go
Normal file
@@ -0,0 +1,475 @@
|
||||
package gemma4
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
const (
|
||||
cacheTypeSWA = iota
|
||||
cacheTypeCausal
|
||||
)
|
||||
|
||||
type TextOptions struct {
|
||||
hiddenSize int
|
||||
numHeads, numKVHeads int
|
||||
numGlobalKVHeads int
|
||||
headDim, globalHeadDim int
|
||||
hiddenLayers int
|
||||
hiddenSizePerLayerInput int
|
||||
|
||||
eps float32
|
||||
ropeBase float32
|
||||
ropeLocalBase float32
|
||||
partialRotaryDims int // RoPE dims for full-attention (global) layers
|
||||
|
||||
slidingWindowPattern []bool
|
||||
// kvDonorMap maps shared layer index -> donor layer index.
|
||||
// Donor is the last non-shared layer of the same type (sliding/full).
|
||||
kvDonorMap map[int]int
|
||||
|
||||
finalLogitSoftcap float32
|
||||
|
||||
numExperts int
|
||||
numExpertsUsed int
|
||||
}
|
||||
|
||||
func (o *TextOptions) isLocal(layer int) bool {
|
||||
if layer < len(o.slidingWindowPattern) {
|
||||
return o.slidingWindowPattern[layer]
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (o *TextOptions) ropeForLayer(layer int) (base float32, dims int) {
|
||||
if o.isLocal(layer) {
|
||||
return o.ropeLocalBase, o.headDim
|
||||
}
|
||||
return o.ropeBase, o.partialRotaryDims
|
||||
}
|
||||
|
||||
func (o *TextOptions) kvHeadsForLayer(layer int) int {
|
||||
if o.isLocal(layer) {
|
||||
return o.numKVHeads
|
||||
}
|
||||
if o.numGlobalKVHeads > 0 {
|
||||
return o.numGlobalKVHeads
|
||||
}
|
||||
return o.numKVHeads
|
||||
}
|
||||
|
||||
func (o *TextOptions) headDimForLayer(layer int) int {
|
||||
if o.isLocal(layer) {
|
||||
return o.headDim
|
||||
}
|
||||
return o.globalHeadDim
|
||||
}
|
||||
|
||||
type TextModel struct {
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
*PerLayerProjector
|
||||
Layers []TextLayer `gguf:"blk"`
|
||||
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
||||
TextOptions
|
||||
}
|
||||
|
||||
func newTextModel(c fs.Config) *TextModel {
|
||||
numLayers := int(c.Uint("block_count"))
|
||||
|
||||
// Head dimensions: key_length is global head dim, key_length_swa is local (SWA) head dim.
|
||||
globalHeadDim := int(c.Uint("attention.key_length", 512))
|
||||
headDim := int(c.Uint("attention.key_length_swa", 256))
|
||||
|
||||
// RoPE dimensions for global (full attention) layers with proportional RoPE.
|
||||
// The freq_factors tensor handles partial rotation (1.0 for rotated pairs,
|
||||
// 1e30 for non-rotated), so ropeDims equals the full global head dim.
|
||||
partialRotaryDims := int(c.Uint("rope.dimension_count", 0))
|
||||
if partialRotaryDims == 0 {
|
||||
partialFactor := c.Float("rope.partial_rotary_factor", 1.0)
|
||||
partialRotaryDims = int(float32(globalHeadDim) * partialFactor)
|
||||
}
|
||||
|
||||
ropeBase := c.Float("rope.freq_base", 1000000.0)
|
||||
ropeLocalBase := c.Float("rope.freq_base_swa", 0)
|
||||
if ropeLocalBase == 0 {
|
||||
ropeLocalBase = c.Float("rope.local.freq_base", 10000.0)
|
||||
}
|
||||
|
||||
numGlobalKVHeads := int(c.Uint("attention.global_head_count_kv", 0))
|
||||
slidingPattern := c.Bools("attention.sliding_window_pattern")
|
||||
|
||||
// KV heads: try per-layer array first (MoE models), then fall back to scalar
|
||||
numKVHeads := 0
|
||||
kvHeadsArray := c.Ints("attention.head_count_kv")
|
||||
if len(kvHeadsArray) > 0 {
|
||||
numKVHeads = int(kvHeadsArray[0])
|
||||
if numGlobalKVHeads == 0 && len(slidingPattern) > 0 {
|
||||
for i, isLocal := range slidingPattern {
|
||||
if !isLocal && i < len(kvHeadsArray) {
|
||||
numGlobalKVHeads = int(kvHeadsArray[i])
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if numKVHeads == 0 {
|
||||
numKVHeads = int(c.Uint("attention.head_count_kv", 0))
|
||||
}
|
||||
|
||||
// Compute KV sharing donor map (same logic as MLX)
|
||||
sharedLayers := int(c.Uint("attention.shared_kv_layers", 0))
|
||||
kvDonorMap := make(map[int]int)
|
||||
if sharedLayers > 0 && len(slidingPattern) > 0 {
|
||||
firstShared := numLayers - sharedLayers
|
||||
for i := firstShared; i < numLayers; i++ {
|
||||
isLocal := slidingPattern[i]
|
||||
// Find last non-shared layer of same type
|
||||
for j := firstShared - 1; j >= 0; j-- {
|
||||
if slidingPattern[j] == isLocal {
|
||||
kvDonorMap[i] = j
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &TextModel{
|
||||
Layers: make([]TextLayer, numLayers),
|
||||
TextOptions: TextOptions{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: numKVHeads,
|
||||
numGlobalKVHeads: numGlobalKVHeads,
|
||||
headDim: headDim,
|
||||
globalHeadDim: globalHeadDim,
|
||||
hiddenLayers: numLayers,
|
||||
hiddenSizePerLayerInput: int(c.Uint("embedding_length_per_layer_input", 0)),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
|
||||
ropeBase: ropeBase,
|
||||
ropeLocalBase: ropeLocalBase,
|
||||
partialRotaryDims: partialRotaryDims,
|
||||
slidingWindowPattern: slidingPattern,
|
||||
kvDonorMap: kvDonorMap,
|
||||
finalLogitSoftcap: c.Float("final_logit_softcapping", 0.0),
|
||||
numExperts: int(c.Uint("expert_count", 0)),
|
||||
numExpertsUsed: int(c.Uint("expert_used_count", 0)),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) ml.Tensor {
|
||||
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
|
||||
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.hiddenSize)))
|
||||
|
||||
// Inject vision embeddings into the hidden state
|
||||
var except []int
|
||||
for _, image := range batch.Multimodal {
|
||||
visionOutputs := image.Multimodal[0].Tensor
|
||||
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))
|
||||
|
||||
for i := range visionOutputs.Dim(1) {
|
||||
except = append(except, image.Index+i)
|
||||
}
|
||||
}
|
||||
|
||||
// PLE
|
||||
var perLayerInputs ml.Tensor
|
||||
if m.PerLayerProjector != nil {
|
||||
perLayerInputs = m.PerLayerProjector.Forward(ctx, batch, hiddenState, &m.TextOptions)
|
||||
}
|
||||
|
||||
for i := range len(m.Layers) {
|
||||
layer := m.Layers[i]
|
||||
if cache != nil {
|
||||
cache.SetLayer(i)
|
||||
cacheType := cacheTypeSWA
|
||||
if !m.isLocal(i) {
|
||||
cacheType = cacheTypeCausal
|
||||
}
|
||||
wc := cache.(*kvcache.WrapperCache)
|
||||
wc.SetLayerType(cacheType)
|
||||
|
||||
if causal, ok := wc.UnderlyingCache().(*kvcache.Causal); ok {
|
||||
causal.SetCausal(ctx, kvcache.CausalOptions{Except: except})
|
||||
}
|
||||
}
|
||||
|
||||
var lastLayerOutputs ml.Tensor
|
||||
if i == len(m.Layers)-1 {
|
||||
lastLayerOutputs = batch.Outputs
|
||||
}
|
||||
|
||||
var perLayerInput ml.Tensor
|
||||
if perLayerInputs != nil {
|
||||
perLayerInput = perLayerInputs.View(ctx, i*perLayerInputs.Stride(1), perLayerInputs.Dim(0), perLayerInputs.Stride(2), perLayerInputs.Dim(2))
|
||||
}
|
||||
|
||||
// KV sharing: layers >= firstShared reuse K/V from donor layers
|
||||
isShared := false
|
||||
if donorLayer, ok := m.kvDonorMap[i]; ok {
|
||||
// Set cache layer to donor so Get() reads donor's K/V
|
||||
cache.SetLayer(donorLayer)
|
||||
isShared = true
|
||||
}
|
||||
hiddenState = layer.Forward(ctx, i, hiddenState, positions, perLayerInput, lastLayerOutputs, cache, isShared, &m.TextOptions)
|
||||
}
|
||||
|
||||
return m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
||||
}
|
||||
|
||||
// PerLayerProjector implements PLE.
|
||||
type PerLayerProjector struct {
|
||||
TokenEmbedding *nn.Embedding `gguf:"per_layer_token_embd"`
|
||||
Projector *nn.Linear `gguf:"per_layer_model_proj"`
|
||||
Norm *nn.RMSNorm `gguf:"per_layer_proj_norm"`
|
||||
}
|
||||
|
||||
func (p *PerLayerProjector) Forward(ctx ml.Context, batch input.Batch, inputs ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||
inputsPerLayer := p.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
inputsPerLayer = inputsPerLayer.Scale(ctx, math.Sqrt(float64(opts.hiddenSizePerLayerInput)))
|
||||
// Reshape to [pleDim, numLayers, numTokens] — matching projection shape
|
||||
inputsPerLayer = inputsPerLayer.Reshape(ctx, opts.hiddenSizePerLayerInput, opts.hiddenLayers, inputs.Dim(1))
|
||||
|
||||
perLayerProjection := p.Projector.Forward(ctx, inputs)
|
||||
perLayerProjection = perLayerProjection.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize)))
|
||||
perLayerProjection = perLayerProjection.Reshape(ctx, opts.hiddenSizePerLayerInput, opts.hiddenLayers, inputs.Dim(1))
|
||||
perLayerProjection = p.Norm.Forward(ctx, perLayerProjection, opts.eps)
|
||||
|
||||
if inputsPerLayer != nil {
|
||||
perLayerProjection = perLayerProjection.Add(ctx, inputsPerLayer)
|
||||
perLayerProjection = perLayerProjection.Scale(ctx, 1/math.Sqrt(2))
|
||||
}
|
||||
|
||||
return perLayerProjection
|
||||
}
|
||||
|
||||
type TextSelfAttention struct {
|
||||
Query *nn.Linear `gguf:"attn_q"`
|
||||
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
|
||||
Value *nn.Linear `gguf:"attn_v"`
|
||||
Output *nn.Linear `gguf:"attn_output"`
|
||||
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"` // proportional RoPE freq_factors
|
||||
}
|
||||
|
||||
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positions ml.Tensor, cache kvcache.Cache, sharedKV bool, opts *TextOptions) ml.Tensor {
|
||||
batchSize := hiddenState.Dim(1)
|
||||
hd := opts.headDimForLayer(layer)
|
||||
kvHeads := opts.kvHeadsForLayer(layer)
|
||||
ropeBase, ropeDims := opts.ropeForLayer(layer)
|
||||
|
||||
q := sa.Query.Forward(ctx, hiddenState)
|
||||
q = q.Reshape(ctx, hd, opts.numHeads, batchSize)
|
||||
q = sa.QueryNorm.Forward(ctx, q, opts.eps)
|
||||
|
||||
var k, v ml.Tensor
|
||||
if !sharedKV {
|
||||
k = sa.Key.Forward(ctx, hiddenState)
|
||||
k = k.Reshape(ctx, hd, kvHeads, batchSize)
|
||||
|
||||
if sa.Value != nil {
|
||||
v = sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, hd, kvHeads, batchSize)
|
||||
} else {
|
||||
// K=V: use raw K projection (before K norm) as V
|
||||
v = k
|
||||
}
|
||||
|
||||
k = sa.KeyNorm.Forward(ctx, k, opts.eps)
|
||||
v = v.RMSNorm(ctx, nil, opts.eps) // V norm: unweighted RMSNorm
|
||||
}
|
||||
|
||||
// RoPE with proportional freq_factors on global layers
|
||||
ropeOpts := []func(*rope.Options){rope.WithTypeNeoX()}
|
||||
if sa.RopeFactors != nil && !opts.isLocal(layer) {
|
||||
ropeOpts = append(ropeOpts, rope.WithFactors(sa.RopeFactors))
|
||||
}
|
||||
q = nn.RoPE(ctx, q, positions, ropeDims, ropeBase, 1.0, ropeOpts...)
|
||||
if k != nil {
|
||||
k = nn.RoPE(ctx, k, positions, ropeDims, ropeBase, 1.0, ropeOpts...)
|
||||
}
|
||||
|
||||
attention := nn.Attention(ctx, q, k, v, 1.0, cache)
|
||||
|
||||
attention = attention.Reshape(ctx, hd*opts.numHeads, batchSize)
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
||||
type TextMLP struct {
|
||||
Gate *nn.Linear `gguf:"ffn_gate"`
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
}
|
||||
|
||||
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor) ml.Tensor {
|
||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||
return mlp.Down.Forward(ctx, hiddenState)
|
||||
}
|
||||
|
||||
// TextRouter implements the Gemma 4 MoE router.
|
||||
type TextRouter struct {
|
||||
Proj *nn.Linear `gguf:"ffn_gate_inp"`
|
||||
Scale ml.Tensor `gguf:"ffn_gate_inp.scale"`
|
||||
}
|
||||
|
||||
func (r *TextRouter) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) (routingWeights, selectedExperts ml.Tensor) {
|
||||
// RMSNorm without learned weight
|
||||
x := hiddenState.RMSNorm(ctx, nil, opts.eps)
|
||||
// Scale by 1/sqrt(hidden_size)
|
||||
x = x.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize)))
|
||||
// Multiply by learned scale parameter
|
||||
x = x.Mul(ctx, r.Scale)
|
||||
// Project to expert logits
|
||||
expertScores := r.Proj.Forward(ctx, x)
|
||||
// Softmax over experts
|
||||
routingWeights = expertScores.Softmax(ctx)
|
||||
// TopK expert selection
|
||||
selectedExperts = routingWeights.TopK(ctx, opts.numExpertsUsed)
|
||||
return routingWeights, selectedExperts
|
||||
}
|
||||
|
||||
// TextMoEBlock implements the Gemma 4 sparse MoE.
|
||||
type TextMoEBlock struct {
|
||||
GateUp *nn.LinearBatch `gguf:"ffn_gate_up_exps"`
|
||||
Gate *nn.LinearBatch `gguf:"ffn_gate_exps"`
|
||||
Up *nn.LinearBatch `gguf:"ffn_up_exps"`
|
||||
Down *nn.LinearBatch `gguf:"ffn_down_exps"`
|
||||
DownScale ml.Tensor `gguf:"ffn_down_exps.scale,alt:ffn_gate_inp.per_expert_scale"`
|
||||
}
|
||||
|
||||
func (moe *TextMoEBlock) Forward(ctx ml.Context, hiddenState, routingWeights, selectedExperts ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||
// Select routing weights for chosen experts and renormalize
|
||||
routingWeights = routingWeights.Reshape(ctx, 1, opts.numExperts, hiddenState.Dim(1)).Rows(ctx, selectedExperts)
|
||||
routingWeights = routingWeights.Reshape(ctx, opts.numExpertsUsed, hiddenState.Dim(1))
|
||||
routingWeights = routingWeights.Div(ctx, routingWeights.SumRows(ctx))
|
||||
routingWeights = routingWeights.Reshape(ctx, 1, opts.numExpertsUsed, hiddenState.Dim(1))
|
||||
|
||||
hiddenState = hiddenState.Reshape(ctx, hiddenState.Dim(0), 1, hiddenState.Dim(1))
|
||||
|
||||
// Expert computation using LinearBatch (MulmatID selecting experts by index)
|
||||
var gateOut, upOut ml.Tensor
|
||||
if moe.GateUp != nil && moe.GateUp.Weight != nil {
|
||||
gateUp := moe.GateUp.Forward(ctx, hiddenState, selectedExperts)
|
||||
nFF := gateUp.Dim(0) / 2
|
||||
gateOut = gateUp.Slice(ctx, 0, 0, nFF, 1)
|
||||
upOut = gateUp.Slice(ctx, 0, nFF, gateUp.Dim(0), 1)
|
||||
} else {
|
||||
gateOut = moe.Gate.Forward(ctx, hiddenState, selectedExperts)
|
||||
upOut = moe.Up.Forward(ctx, hiddenState, selectedExperts)
|
||||
}
|
||||
hiddenState = gateOut.GELU(ctx, upOut)
|
||||
experts := moe.Down.Forward(ctx, hiddenState, selectedExperts)
|
||||
|
||||
// Apply per-expert down projection scale when present.
|
||||
if moe.DownScale != nil {
|
||||
expertScales := moe.DownScale.Reshape(ctx, opts.numExperts, 1)
|
||||
expertScales = expertScales.Repeat(ctx, 1, hiddenState.Dim(2))
|
||||
expertScales = expertScales.Reshape(ctx, 1, opts.numExperts, hiddenState.Dim(2)).Rows(ctx, selectedExperts)
|
||||
expertScales = expertScales.Reshape(ctx, opts.numExpertsUsed, hiddenState.Dim(2))
|
||||
expertScales = expertScales.Reshape(ctx, 1, opts.numExpertsUsed, hiddenState.Dim(2))
|
||||
experts = experts.Mul(ctx, expertScales)
|
||||
}
|
||||
|
||||
// Apply routing weights
|
||||
experts = experts.Mul(ctx, routingWeights)
|
||||
|
||||
// Sum across experts
|
||||
nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
|
||||
for i := 1; i < opts.numExpertsUsed; i++ {
|
||||
nextStates = nextStates.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2)))
|
||||
}
|
||||
|
||||
return nextStates
|
||||
}
|
||||
|
||||
type TextLayer struct {
|
||||
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
|
||||
SelfAttention *TextSelfAttention
|
||||
PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_norm,alt:attn_post_norm"`
|
||||
MLPNorm *nn.RMSNorm `gguf:"ffn_norm,alt:ffn_pre_norm"`
|
||||
MLP *TextMLP
|
||||
PostMLPNorm *nn.RMSNorm `gguf:"post_ffw_norm,alt:ffn_post_norm"`
|
||||
|
||||
// MoE (present only for models with enable_moe_block=true)
|
||||
Router *TextRouter
|
||||
MoE *TextMoEBlock
|
||||
MoENorm *nn.RMSNorm `gguf:"pre_ffw_norm_2,alt:ffn_pre_norm_2"`
|
||||
PostMoENorm *nn.RMSNorm `gguf:"post_ffw_norm_2,alt:ffn_post_norm_2"`
|
||||
PostMLPNorm1 *nn.RMSNorm `gguf:"post_ffw_norm_1,alt:ffn_post_norm_1"` // used instead of PostMLPNorm when MoE is present
|
||||
|
||||
PerLayerInputGate *nn.Linear `gguf:"inp_gate"`
|
||||
PerLayerProjection *nn.Linear `gguf:"proj"`
|
||||
PostPerLayerNorm *nn.RMSNorm `gguf:"post_norm"`
|
||||
LayerScalar ml.Tensor `gguf:"layer_scalar,alt:layer_output_scale.weight"`
|
||||
}
|
||||
|
||||
func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positions, perLayerInput, outputs ml.Tensor, cache kvcache.Cache, sharedKV bool, opts *TextOptions) ml.Tensor {
|
||||
residual := hiddenState
|
||||
|
||||
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = l.SelfAttention.Forward(ctx, layer, hiddenState, positions, cache, sharedKV, opts)
|
||||
hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
|
||||
if outputs != nil {
|
||||
hiddenState = hiddenState.Rows(ctx, outputs)
|
||||
residual = residual.Rows(ctx, outputs)
|
||||
if perLayerInput != nil {
|
||||
perLayerInput = perLayerInput.Rows(ctx, outputs)
|
||||
}
|
||||
}
|
||||
|
||||
hiddenState = hiddenState.Add(ctx, residual)
|
||||
residual = hiddenState
|
||||
|
||||
// MLP (+ optional MoE in parallel)
|
||||
hasSplitExperts := l.MoE != nil && l.MoE.Gate != nil && l.MoE.Up != nil && l.MoE.Gate.Weight != nil && l.MoE.Up.Weight != nil
|
||||
hasFusedExperts := l.MoE != nil && l.MoE.GateUp != nil && l.MoE.GateUp.Weight != nil
|
||||
if l.Router != nil && l.MoE != nil && l.MoE.Down != nil && l.MoE.Down.Weight != nil && (hasSplitExperts || hasFusedExperts) {
|
||||
// MoE layers: run MLP and MoE in parallel, sum results
|
||||
mlpState := l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
mlpState = l.MLP.Forward(ctx, mlpState)
|
||||
mlpState = l.PostMLPNorm1.Forward(ctx, mlpState, opts.eps)
|
||||
|
||||
routingWeights, selectedExperts := l.Router.Forward(ctx, hiddenState, opts)
|
||||
moeState := l.MoENorm.Forward(ctx, hiddenState, opts.eps)
|
||||
moeState = l.MoE.Forward(ctx, moeState, routingWeights, selectedExperts, opts)
|
||||
moeState = l.PostMoENorm.Forward(ctx, moeState, opts.eps)
|
||||
|
||||
// Combine MLP + MoE, apply outer post-FFN norm, then add residual
|
||||
combined := mlpState.Add(ctx, moeState)
|
||||
combined = l.PostMLPNorm.Forward(ctx, combined, opts.eps)
|
||||
hiddenState = combined.Add(ctx, residual)
|
||||
} else {
|
||||
// Dense layers: MLP only
|
||||
hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = l.MLP.Forward(ctx, hiddenState)
|
||||
hiddenState = l.PostMLPNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = hiddenState.Add(ctx, residual)
|
||||
}
|
||||
|
||||
// PLE injection (after MLP residual)
|
||||
if perLayerInput != nil && l.PerLayerInputGate != nil {
|
||||
pleState := l.PerLayerInputGate.Forward(ctx, hiddenState)
|
||||
pleState = pleState.GELU(ctx, perLayerInput)
|
||||
pleState = l.PerLayerProjection.Forward(ctx, pleState)
|
||||
pleState = l.PostPerLayerNorm.Forward(ctx, pleState, opts.eps)
|
||||
hiddenState = hiddenState.Add(ctx, pleState)
|
||||
}
|
||||
|
||||
// Layer scalar applied at end of layer (full-attention layers only)
|
||||
if l.LayerScalar != nil {
|
||||
hiddenState = hiddenState.Mul(ctx, l.LayerScalar)
|
||||
}
|
||||
|
||||
return hiddenState
|
||||
}
|
||||
384
model/models/gemma4/model_vision.go
Normal file
384
model/models/gemma4/model_vision.go
Normal file
@@ -0,0 +1,384 @@
|
||||
package gemma4
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
)
|
||||
|
||||
const batchSize = 1
|
||||
|
||||
// ClippableLinear is a linear layer with optional input/output clamping.
|
||||
// Required by Gemma4 vision encoder for numerical stability with F16 weights.
|
||||
type ClippableLinear struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
|
||||
InputMin ml.Tensor `gguf:"input_min"`
|
||||
InputMax ml.Tensor `gguf:"input_max"`
|
||||
OutputMin ml.Tensor `gguf:"output_min"`
|
||||
OutputMax ml.Tensor `gguf:"output_max"`
|
||||
|
||||
inMin, inMax, outMin, outMax float32
|
||||
hasClamp bool
|
||||
clampsLoaded bool
|
||||
}
|
||||
|
||||
func scalarValue(t ml.Tensor) (float32, bool) {
|
||||
if t == nil {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
data := t.BackendGet()
|
||||
if len(data) == 0 {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
return data[0], true
|
||||
}
|
||||
|
||||
func (l *ClippableLinear) loadClampFromScalars() {
|
||||
if l.clampsLoaded {
|
||||
return
|
||||
}
|
||||
l.clampsLoaded = true
|
||||
|
||||
const (
|
||||
defaultMin = -math.MaxFloat32
|
||||
defaultMax = math.MaxFloat32
|
||||
)
|
||||
|
||||
inMin, hasInMin := scalarValue(l.InputMin)
|
||||
inMax, hasInMax := scalarValue(l.InputMax)
|
||||
outMin, hasOutMin := scalarValue(l.OutputMin)
|
||||
outMax, hasOutMax := scalarValue(l.OutputMax)
|
||||
|
||||
if !(hasInMin || hasInMax || hasOutMin || hasOutMax) {
|
||||
return
|
||||
}
|
||||
|
||||
l.hasClamp = true
|
||||
l.inMin = defaultMin
|
||||
l.inMax = defaultMax
|
||||
l.outMin = defaultMin
|
||||
l.outMax = defaultMax
|
||||
|
||||
if hasInMin {
|
||||
l.inMin = inMin
|
||||
}
|
||||
if hasInMax {
|
||||
l.inMax = inMax
|
||||
}
|
||||
if hasOutMin {
|
||||
l.outMin = outMin
|
||||
}
|
||||
if hasOutMax {
|
||||
l.outMax = outMax
|
||||
}
|
||||
}
|
||||
|
||||
func (l *ClippableLinear) Forward(ctx ml.Context, x ml.Tensor) ml.Tensor {
|
||||
if l.hasClamp {
|
||||
x = x.Clamp(ctx, l.inMin, l.inMax)
|
||||
}
|
||||
out := l.Weight.Mulmat(ctx, x)
|
||||
if l.hasClamp {
|
||||
out = out.Clamp(ctx, l.outMin, l.outMax)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// InitClamp distributes packed clamp values from v.clamp_data to ClippableLinear structs.
|
||||
// If scalar clamp tensors (input_min/max, output_min/max) are present, they are used too.
|
||||
// Layout: numLayers × 7 linears (q,k,v,out,gate,up,down) × 4 floats (inMin,inMax,outMin,outMax)
|
||||
// then 4 floats for the projector.
|
||||
func (m *VisionModel) InitClamp(proj *MultiModalProjector) {
|
||||
if m.clampInitDone {
|
||||
return
|
||||
}
|
||||
m.clampInitDone = true
|
||||
|
||||
linears := func(l *VisionEncoderLayer) []*ClippableLinear {
|
||||
return []*ClippableLinear{
|
||||
l.SelfAttention.Query, l.SelfAttention.Key, l.SelfAttention.Value,
|
||||
l.SelfAttention.Output, l.MLP.Gate, l.MLP.Up, l.MLP.Down,
|
||||
}
|
||||
}
|
||||
|
||||
for i := range m.Layers {
|
||||
for _, cl := range linears(&m.Layers[i]) {
|
||||
if cl != nil {
|
||||
cl.loadClampFromScalars()
|
||||
}
|
||||
}
|
||||
}
|
||||
if proj != nil && proj.Projection != nil {
|
||||
proj.Projection.loadClampFromScalars()
|
||||
}
|
||||
|
||||
// Load packed clamp data when present (legacy Ollama format).
|
||||
if m.ClampData == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Read all clamp values from packed F32 tensor
|
||||
data := m.ClampData.BackendGet()
|
||||
if len(data) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Distribute to layer linears: 7 per layer × 4 values each
|
||||
for i := range m.Layers {
|
||||
for li, cl := range linears(&m.Layers[i]) {
|
||||
if cl == nil {
|
||||
continue
|
||||
}
|
||||
idx := (i*7 + li) * 4
|
||||
if idx+3 < len(data) {
|
||||
cl.inMin = data[idx]
|
||||
cl.inMax = data[idx+1]
|
||||
cl.outMin = data[idx+2]
|
||||
cl.outMax = data[idx+3]
|
||||
cl.hasClamp = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Projector clamp values (last 4 floats)
|
||||
if proj != nil && proj.Projection != nil {
|
||||
projIdx := len(m.Layers) * 7 * 4
|
||||
if projIdx+3 < len(data) {
|
||||
proj.Projection.inMin = data[projIdx]
|
||||
proj.Projection.inMax = data[projIdx+1]
|
||||
proj.Projection.outMin = data[projIdx+2]
|
||||
proj.Projection.outMax = data[projIdx+3]
|
||||
proj.Projection.hasClamp = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type VisionSelfAttention struct {
|
||||
Query *ClippableLinear `gguf:"attn_q"`
|
||||
Key *ClippableLinear `gguf:"attn_k"`
|
||||
Value *ClippableLinear `gguf:"attn_v"`
|
||||
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
|
||||
KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
|
||||
Output *ClippableLinear `gguf:"attn_out"`
|
||||
}
|
||||
|
||||
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState, posX, posY, attnMask ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
numPatches := hiddenState.Dim(1)
|
||||
headDim := opts.hiddenSize / opts.numHeads
|
||||
|
||||
query := sa.Query.Forward(ctx, hiddenState)
|
||||
key := sa.Key.Forward(ctx, hiddenState)
|
||||
value := sa.Value.Forward(ctx, hiddenState)
|
||||
|
||||
query = query.Reshape(ctx, headDim, opts.numHeads, numPatches, batchSize)
|
||||
key = key.Reshape(ctx, headDim, opts.numHeads, numPatches, batchSize)
|
||||
value = value.Reshape(ctx, headDim, opts.numHeads, numPatches, batchSize)
|
||||
|
||||
// Q/K norms (Gemma-style: x * (1 + weight) / rms(x))
|
||||
query = sa.QueryNorm.Forward(ctx, query, opts.eps)
|
||||
key = sa.KeyNorm.Forward(ctx, key, opts.eps)
|
||||
|
||||
// V norm (RMSNorm without learned weights)
|
||||
value = value.RMSNorm(ctx, nil, opts.eps)
|
||||
|
||||
// 2D RoPE: split head dim in half, apply NeoX RoPE with x positions to first half,
|
||||
// y positions to second half, then concatenate.
|
||||
halfDim := headDim / 2
|
||||
ropeOpts := rope.WithTypeNeoX()
|
||||
|
||||
qFirst := query.View(ctx, 0, halfDim, query.Stride(1), opts.numHeads, query.Stride(2), numPatches)
|
||||
qFirst = nn.RoPE(ctx, qFirst, posX, halfDim, opts.ropeTheta, 1.0, ropeOpts)
|
||||
|
||||
kFirst := key.View(ctx, 0, halfDim, key.Stride(1), opts.numHeads, key.Stride(2), numPatches)
|
||||
kFirst = nn.RoPE(ctx, kFirst, posX, halfDim, opts.ropeTheta, 1.0, ropeOpts)
|
||||
|
||||
halfOffset := halfDim * query.Stride(0)
|
||||
qSecond := query.View(ctx, halfOffset, halfDim, query.Stride(1), opts.numHeads, query.Stride(2), numPatches)
|
||||
qSecond = nn.RoPE(ctx, qSecond, posY, halfDim, opts.ropeTheta, 1.0, ropeOpts)
|
||||
|
||||
halfOffsetK := halfDim * key.Stride(0)
|
||||
kSecond := key.View(ctx, halfOffsetK, halfDim, key.Stride(1), opts.numHeads, key.Stride(2), numPatches)
|
||||
kSecond = nn.RoPE(ctx, kSecond, posY, halfDim, opts.ropeTheta, 1.0, ropeOpts)
|
||||
|
||||
query = qFirst.Concat(ctx, qSecond, 0)
|
||||
key = kFirst.Concat(ctx, kSecond, 0)
|
||||
|
||||
// Use flash attention for numerical stability (handles large attention scores
|
||||
// from unclamped RMSNorm weights, e.g. 26B has addOne weights up to 19.5)
|
||||
attention := nn.Attention(ctx, query, key, value, 1.0, nil)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
|
||||
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
||||
type VisionMLP struct {
|
||||
Gate *ClippableLinear `gguf:"ffn_gate"`
|
||||
Up *ClippableLinear `gguf:"ffn_up"`
|
||||
Down *ClippableLinear `gguf:"ffn_down"`
|
||||
}
|
||||
|
||||
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenState ml.Tensor) ml.Tensor {
|
||||
gate := mlp.Gate.Forward(ctx, hiddenState)
|
||||
up := mlp.Up.Forward(ctx, hiddenState)
|
||||
hiddenState = gate.QuickGELU(ctx, up)
|
||||
return mlp.Down.Forward(ctx, hiddenState)
|
||||
}
|
||||
|
||||
type VisionEncoderLayer struct {
|
||||
AttentionNorm *nn.RMSNorm `gguf:"ln1"`
|
||||
SelfAttention *VisionSelfAttention
|
||||
PostAttentionNorm *nn.RMSNorm `gguf:"attn_post_norm"`
|
||||
|
||||
FFNNorm *nn.RMSNorm `gguf:"ln2"`
|
||||
MLP *VisionMLP
|
||||
PostFFNNorm *nn.RMSNorm `gguf:"ffn_post_norm"`
|
||||
|
||||
LayerOutputScale ml.Tensor `gguf:"out_scale.weight"`
|
||||
}
|
||||
|
||||
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenState, posX, posY, attnMask ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
residual := hiddenState
|
||||
|
||||
// Pre-attention norm -> self attention -> post-attention norm
|
||||
hiddenState = e.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = e.SelfAttention.Forward(ctx, hiddenState, posX, posY, attnMask, opts)
|
||||
hiddenState = e.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
|
||||
// Residual connection
|
||||
hiddenState = hiddenState.Add(ctx, residual)
|
||||
residual = hiddenState
|
||||
|
||||
// Pre-FFN norm -> FFN -> post-FFN norm
|
||||
hiddenState = e.FFNNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = e.MLP.Forward(ctx, hiddenState)
|
||||
hiddenState = e.PostFFNNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
|
||||
// Residual connection
|
||||
hiddenState = hiddenState.Add(ctx, residual)
|
||||
|
||||
// Per-layer output scale
|
||||
if e.LayerOutputScale != nil {
|
||||
hiddenState = hiddenState.Mul(ctx, e.LayerOutputScale)
|
||||
}
|
||||
|
||||
return hiddenState
|
||||
}
|
||||
|
||||
type VisionModelOptions struct {
|
||||
hiddenSize int
|
||||
numHeads int
|
||||
patchSize int
|
||||
nMerge int
|
||||
eps float32
|
||||
ropeTheta float32
|
||||
}
|
||||
|
||||
type VisionModel struct {
|
||||
PatchEmbedding *nn.Conv2D `gguf:"patch_embd"`
|
||||
PositionEmbedding ml.Tensor `gguf:"position_embd.weight"`
|
||||
ClampData ml.Tensor `gguf:"clamp_data"`
|
||||
StdBias ml.Tensor `gguf:"std_bias"`
|
||||
StdScale ml.Tensor `gguf:"std_scale"`
|
||||
|
||||
Layers []VisionEncoderLayer `gguf:"blk"`
|
||||
|
||||
*VisionModelOptions
|
||||
clampInitDone bool
|
||||
}
|
||||
|
||||
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, numPatchesX, numPatchesY int) ml.Tensor {
|
||||
numPatches := numPatchesX * numPatchesY
|
||||
|
||||
// Patch embedding via Conv2D
|
||||
hiddenState := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1)
|
||||
hiddenState = hiddenState.Reshape(ctx, numPatches, m.hiddenSize)
|
||||
hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
|
||||
// Conv2D with F16 weights produces F16 output via im2col; cast to F32 for encoder precision
|
||||
hiddenState = hiddenState.Cast(ctx, ml.DTypeF32)
|
||||
|
||||
// 2D positional embeddings from 3D tensor [nEmbd, maxPos, 2]
|
||||
posSize := m.PositionEmbedding.Dim(1)
|
||||
nb1 := m.PositionEmbedding.Stride(1)
|
||||
tblX := m.PositionEmbedding.View(ctx, 0, m.hiddenSize, nb1, posSize)
|
||||
tblY := m.PositionEmbedding.View(ctx, posSize*nb1, m.hiddenSize, nb1, posSize)
|
||||
|
||||
// Position indices for patches
|
||||
posXData := make([]int32, numPatches)
|
||||
posYData := make([]int32, numPatches)
|
||||
for i := range numPatches {
|
||||
posXData[i] = int32(i % numPatchesX)
|
||||
posYData[i] = int32(i / numPatchesX)
|
||||
}
|
||||
|
||||
posXEmb := ctx.Input().FromInts(posXData, numPatches)
|
||||
posYEmb := ctx.Input().FromInts(posYData, numPatches)
|
||||
|
||||
hiddenState = hiddenState.Add(ctx, tblX.Rows(ctx, posXEmb))
|
||||
hiddenState = hiddenState.Add(ctx, tblY.Rows(ctx, posYEmb))
|
||||
|
||||
// No attention mask — all positions are real patches
|
||||
var attnMask ml.Tensor
|
||||
|
||||
// RoPE positions
|
||||
posXRope := ctx.Input().FromInts(posXData, numPatches)
|
||||
posYRope := ctx.Input().FromInts(posYData, numPatches)
|
||||
|
||||
// Vision transformer layers
|
||||
for i := range m.Layers {
|
||||
hiddenState = m.Layers[i].Forward(ctx, hiddenState, posXRope, posYRope, attnMask, m.VisionModelOptions)
|
||||
}
|
||||
|
||||
return hiddenState
|
||||
}
|
||||
|
||||
func newVisionModel(c fs.Config) *VisionModel {
|
||||
return &VisionModel{
|
||||
Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count")),
|
||||
VisionModelOptions: &VisionModelOptions{
|
||||
hiddenSize: int(c.Uint("vision.embedding_length")),
|
||||
numHeads: int(c.Uint("vision.attention.head_count")),
|
||||
patchSize: int(c.Uint("vision.patch_size", 16)),
|
||||
nMerge: int(c.Uint("vision.projector.scale_factor", 3)),
|
||||
eps: c.Float("vision.attention.layer_norm_epsilon", 1e-6),
|
||||
ropeTheta: 100.0,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func visionPoolAndProject(ctx ml.Context, hiddenState ml.Tensor, numPatchesX, numPatchesY int, opts *VisionModelOptions, proj *MultiModalProjector, stdBias, stdScale ml.Tensor) ml.Tensor {
|
||||
hiddenSize := opts.hiddenSize
|
||||
|
||||
// Reshape from [hiddenSize, numPatches] to spatial layout for pooling
|
||||
hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
hiddenState = hiddenState.Reshape(ctx, numPatchesX, numPatchesY, hiddenSize)
|
||||
|
||||
// AvgPool2D with kernel=stride=nMerge
|
||||
hiddenState = hiddenState.AvgPool2D(ctx, opts.nMerge, opts.nMerge, 0)
|
||||
|
||||
// Reshape back to [hiddenSize, numMergedPatches]
|
||||
mergedX := numPatchesX / opts.nMerge
|
||||
mergedY := numPatchesY / opts.nMerge
|
||||
hiddenState = hiddenState.Reshape(ctx, mergedX*mergedY, hiddenSize)
|
||||
hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
|
||||
hiddenState = hiddenState.Cast(ctx, ml.DTypeF32)
|
||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(hiddenSize)))
|
||||
|
||||
// Optional vision standardization before projection.
|
||||
if stdBias != nil && stdScale != nil {
|
||||
hiddenState = hiddenState.Sub(ctx, stdBias)
|
||||
hiddenState = hiddenState.Mul(ctx, stdScale)
|
||||
}
|
||||
|
||||
// Project to text embedding dimension
|
||||
hiddenState = proj.Forward(ctx, hiddenState, opts.eps)
|
||||
|
||||
return hiddenState
|
||||
}
|
||||
280
model/models/gemma4/process_audio.go
Normal file
280
model/models/gemma4/process_audio.go
Normal file
@@ -0,0 +1,280 @@
|
||||
package gemma4
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"math"
|
||||
"math/cmplx"
|
||||
)
|
||||
|
||||
// Audio preprocessing constants.
|
||||
const (
|
||||
audioSampleRate = 16000
|
||||
melBins = 128
|
||||
frameLengthMs = 20.0
|
||||
hopLengthMs = 10.0
|
||||
minFrequency = 0.0
|
||||
maxFrequency = 8000.0
|
||||
melFloor = 1e-3
|
||||
maxAudioSoftTokens = 750
|
||||
)
|
||||
|
||||
// Computed from the above constants.
|
||||
var (
|
||||
frameLength = int(math.Round(audioSampleRate * frameLengthMs / 1000.0)) // 320
|
||||
hopLength = int(math.Round(audioSampleRate * hopLengthMs / 1000.0)) // 160
|
||||
)
|
||||
|
||||
// decodeWAV extracts mono float32 PCM samples from a WAV file, resampled to 16kHz.
|
||||
func decodeWAV(data []byte) ([]float32, error) {
|
||||
if len(data) < 12 {
|
||||
return nil, fmt.Errorf("WAV file too short")
|
||||
}
|
||||
if string(data[0:4]) != "RIFF" || string(data[8:12]) != "WAVE" {
|
||||
return nil, fmt.Errorf("not a WAV file")
|
||||
}
|
||||
|
||||
var audioFormat uint16
|
||||
var numChannels, sampleRate, bitsPerSample int
|
||||
var audioData []byte
|
||||
foundFmt := false
|
||||
|
||||
offset := 12
|
||||
for offset+8 <= len(data) {
|
||||
chunkID := string(data[offset : offset+4])
|
||||
chunkSize := int(binary.LittleEndian.Uint32(data[offset+4 : offset+8]))
|
||||
chunkData := data[offset+8 : min(offset+8+chunkSize, len(data))]
|
||||
|
||||
switch chunkID {
|
||||
case "fmt ":
|
||||
if len(chunkData) < 16 {
|
||||
return nil, fmt.Errorf("fmt chunk too short")
|
||||
}
|
||||
audioFormat = binary.LittleEndian.Uint16(chunkData[0:2])
|
||||
numChannels = int(binary.LittleEndian.Uint16(chunkData[2:4]))
|
||||
sampleRate = int(binary.LittleEndian.Uint32(chunkData[4:8]))
|
||||
bitsPerSample = int(binary.LittleEndian.Uint16(chunkData[14:16]))
|
||||
if audioFormat == 0xFFFE && len(chunkData) >= 26 {
|
||||
audioFormat = binary.LittleEndian.Uint16(chunkData[24:26])
|
||||
}
|
||||
foundFmt = true
|
||||
case "data":
|
||||
audioData = chunkData
|
||||
}
|
||||
|
||||
offset += 8 + chunkSize
|
||||
if chunkSize%2 != 0 {
|
||||
offset++
|
||||
}
|
||||
}
|
||||
|
||||
if !foundFmt {
|
||||
return nil, fmt.Errorf("no fmt chunk found in WAV file")
|
||||
}
|
||||
if audioFormat != 1 && audioFormat != 3 {
|
||||
return nil, fmt.Errorf("unsupported WAV format: %d (need PCM=1 or float=3)", audioFormat)
|
||||
}
|
||||
if audioData == nil {
|
||||
return nil, fmt.Errorf("no data chunk found in WAV file")
|
||||
}
|
||||
|
||||
samples := decodeWAVSamples(audioData, audioFormat, bitsPerSample, numChannels)
|
||||
if sampleRate != audioSampleRate {
|
||||
samples = resampleLinear(samples, sampleRate, audioSampleRate)
|
||||
}
|
||||
return samples, nil
|
||||
}
|
||||
|
||||
func decodeWAVSamples(data []byte, format uint16, bits, channels int) []float32 {
|
||||
bytesPerSample := bits / 8
|
||||
totalSamples := len(data) / (bytesPerSample * channels)
|
||||
mono := make([]float32, totalSamples)
|
||||
|
||||
for i := range totalSamples {
|
||||
var sum float64
|
||||
for ch := range channels {
|
||||
off := (i*channels + ch) * bytesPerSample
|
||||
if off+bytesPerSample > len(data) {
|
||||
break
|
||||
}
|
||||
switch {
|
||||
case format == 1 && bits == 16:
|
||||
v := int16(binary.LittleEndian.Uint16(data[off : off+2]))
|
||||
sum += float64(v) / 32768.0
|
||||
case format == 1 && bits == 32:
|
||||
v := int32(binary.LittleEndian.Uint32(data[off : off+4]))
|
||||
sum += float64(v) / 2147483648.0
|
||||
case format == 1 && bits == 24:
|
||||
v := int32(data[off]) | int32(data[off+1])<<8 | int32(data[off+2])<<16
|
||||
if v&0x800000 != 0 {
|
||||
v |= ^0xFFFFFF
|
||||
}
|
||||
sum += float64(v) / 8388608.0
|
||||
case format == 3 && bits == 32:
|
||||
v := math.Float32frombits(binary.LittleEndian.Uint32(data[off : off+4]))
|
||||
sum += float64(v)
|
||||
case format == 1 && bits == 8:
|
||||
sum += (float64(data[off]) - 128.0) / 128.0
|
||||
}
|
||||
}
|
||||
mono[i] = float32(sum / float64(channels))
|
||||
}
|
||||
return mono
|
||||
}
|
||||
|
||||
func resampleLinear(samples []float32, fromRate, toRate int) []float32 {
|
||||
n := int(float64(len(samples)) / float64(fromRate) * float64(toRate))
|
||||
out := make([]float32, n)
|
||||
for i := range n {
|
||||
pos := float64(i) * float64(len(samples)-1) / float64(n-1)
|
||||
idx := int(pos)
|
||||
frac := float32(pos - float64(idx))
|
||||
if idx+1 < len(samples) {
|
||||
out[i] = samples[idx]*(1-frac) + samples[idx+1]*frac
|
||||
} else {
|
||||
out[i] = samples[idx]
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// computeMelSpectrogram computes the log mel spectrogram from PCM samples.
|
||||
// Returns shape [numFrames, melBins] as float32 slice, and numFrames.
|
||||
func computeMelSpectrogram(samples []float32) ([]float32, int) {
|
||||
fftLen := 1
|
||||
for fftLen < frameLength {
|
||||
fftLen <<= 1
|
||||
}
|
||||
fftLen *= 2 // fft_overdrive=True
|
||||
|
||||
// Hanning-nonzero window.
|
||||
window := make([]float64, frameLength)
|
||||
arg := math.Pi * 2.0 / float64(frameLength)
|
||||
for i := range frameLength {
|
||||
window[i] = 0.5 - 0.5*math.Cos(arg*(float64(i)+0.5))
|
||||
}
|
||||
|
||||
numFreqBins := fftLen/2 + 1
|
||||
melFilters := buildMelFilterBank(numFreqBins, melBins, minFrequency, maxFrequency, audioSampleRate)
|
||||
|
||||
frameSizeForUnfold := frameLength + 1
|
||||
numFrames := (len(samples) - frameSizeForUnfold) / hopLength
|
||||
if numFrames <= 0 {
|
||||
return nil, 0
|
||||
}
|
||||
|
||||
result := make([]float32, numFrames*melBins)
|
||||
fftInput := make([]complex128, fftLen)
|
||||
|
||||
for f := range numFrames {
|
||||
start := f * hopLength
|
||||
for i := range frameLength {
|
||||
fftInput[i] = complex(float64(samples[start+i])*window[i], 0)
|
||||
}
|
||||
for i := frameLength; i < fftLen; i++ {
|
||||
fftInput[i] = 0
|
||||
}
|
||||
|
||||
fft(fftInput)
|
||||
|
||||
for m := range melBins {
|
||||
var melVal float64
|
||||
for k := range numFreqBins {
|
||||
mag := cmplx.Abs(fftInput[k])
|
||||
melVal += mag * float64(melFilters[k*melBins+m])
|
||||
}
|
||||
if melVal < melFloor {
|
||||
melVal = melFloor
|
||||
}
|
||||
result[f*melBins+m] = float32(math.Log(melVal))
|
||||
}
|
||||
}
|
||||
|
||||
return result, numFrames
|
||||
}
|
||||
|
||||
func buildMelFilterBank(numFreqBins, numMels int, fMin, fMax float64, sr int) []float32 {
|
||||
hzToMel := func(f float64) float64 {
|
||||
return 2595.0 * math.Log10(1.0+f/700.0)
|
||||
}
|
||||
melToHz := func(m float64) float64 {
|
||||
return 700.0 * (math.Pow(10.0, m/2595.0) - 1.0)
|
||||
}
|
||||
|
||||
melMin := hzToMel(fMin)
|
||||
melMax := hzToMel(fMax)
|
||||
|
||||
melPts := make([]float64, numMels+2)
|
||||
for i := range melPts {
|
||||
melPts[i] = melMin + float64(i)*(melMax-melMin)/float64(numMels+1)
|
||||
}
|
||||
filterFreqs := make([]float64, numMels+2)
|
||||
for i, m := range melPts {
|
||||
filterFreqs[i] = melToHz(m)
|
||||
}
|
||||
|
||||
fftFreqs := make([]float64, numFreqBins)
|
||||
for i := range fftFreqs {
|
||||
fftFreqs[i] = float64(i) * float64(sr) / float64(2*(numFreqBins-1))
|
||||
}
|
||||
|
||||
filters := make([]float32, numFreqBins*numMels)
|
||||
for m := range numMels {
|
||||
fLeft := filterFreqs[m]
|
||||
fCenter := filterFreqs[m+1]
|
||||
fRight := filterFreqs[m+2]
|
||||
for k := range numFreqBins {
|
||||
f := fftFreqs[k]
|
||||
var v float64
|
||||
if f >= fLeft && f <= fCenter && fCenter > fLeft {
|
||||
v = (f - fLeft) / (fCenter - fLeft)
|
||||
} else if f > fCenter && f <= fRight && fRight > fCenter {
|
||||
v = (fRight - f) / (fRight - fCenter)
|
||||
}
|
||||
if v > 0 {
|
||||
filters[k*numMels+m] = float32(v)
|
||||
}
|
||||
}
|
||||
}
|
||||
return filters
|
||||
}
|
||||
|
||||
// fft performs an in-place Cooley-Tukey radix-2 FFT.
|
||||
func fft(x []complex128) {
|
||||
n := len(x)
|
||||
if n <= 1 {
|
||||
return
|
||||
}
|
||||
|
||||
j := 0
|
||||
for i := 1; i < n; i++ {
|
||||
bit := n >> 1
|
||||
for j&bit != 0 {
|
||||
j ^= bit
|
||||
bit >>= 1
|
||||
}
|
||||
j ^= bit
|
||||
if i < j {
|
||||
x[i], x[j] = x[j], x[i]
|
||||
}
|
||||
}
|
||||
|
||||
for size := 2; size <= n; size <<= 1 {
|
||||
halfSize := size / 2
|
||||
w := complex(math.Cos(2*math.Pi/float64(size)), -math.Sin(2*math.Pi/float64(size)))
|
||||
for start := 0; start < n; start += size {
|
||||
wn := complex(1, 0)
|
||||
for k := range halfSize {
|
||||
t := wn * x[start+k+halfSize]
|
||||
x[start+k+halfSize] = x[start+k] - t
|
||||
x[start+k] = x[start+k] + t
|
||||
wn *= w
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// isAudioData checks if the data starts with WAV magic bytes.
|
||||
func isAudioData(data []byte) bool {
|
||||
return len(data) >= 12 && string(data[0:4]) == "RIFF" && string(data[8:12]) == "WAVE"
|
||||
}
|
||||
103
model/models/gemma4/process_image.go
Normal file
103
model/models/gemma4/process_image.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package gemma4
|
||||
|
||||
import (
|
||||
"image"
|
||||
"math"
|
||||
|
||||
"golang.org/x/image/draw"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
)
|
||||
|
||||
type ImageProcessor struct {
|
||||
patchSize int
|
||||
numChannels int
|
||||
nMerge int
|
||||
minPixels int
|
||||
maxPixels int
|
||||
}
|
||||
|
||||
func newImageProcessor(c fs.Config) ImageProcessor {
|
||||
patchSize := int(c.Uint("vision.patch_size", 16))
|
||||
nMerge := int(c.Uint("vision.projector.scale_factor", 3))
|
||||
numChannels := int(c.Uint("vision.num_channels", 3))
|
||||
|
||||
// Token limits from reference: min=40, max=280 output tokens after pooling.
|
||||
// Convert to pixel counts: tokens * nMerge^2 * patchSize^2
|
||||
minTokens := 40
|
||||
maxTokens := 280
|
||||
patchArea := patchSize * patchSize * nMerge * nMerge
|
||||
minPixels := minTokens * patchArea
|
||||
maxPixels := maxTokens * patchArea
|
||||
|
||||
return ImageProcessor{
|
||||
patchSize: patchSize,
|
||||
numChannels: numChannels,
|
||||
nMerge: nMerge,
|
||||
minPixels: minPixels,
|
||||
maxPixels: maxPixels,
|
||||
}
|
||||
}
|
||||
|
||||
// ProcessImage resizes an image preserving aspect ratio, aligning dimensions
|
||||
// to (patchSize * nMerge) boundaries, and normalizes pixels to [-1, 1].
|
||||
// Returns the float32 pixel data and the actual output dimensions.
|
||||
func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, int, int, error) {
|
||||
// Compute target size preserving aspect ratio
|
||||
alignSize := p.patchSize * p.nMerge
|
||||
targetW, targetH := p.smartResize(img.Bounds().Dx(), img.Bounds().Dy(), alignSize)
|
||||
|
||||
// Resize directly without alpha compositing, matching MLX reference.
|
||||
dst := image.NewRGBA(image.Rect(0, 0, targetW, targetH))
|
||||
draw.BiLinear.Scale(dst, dst.Bounds(), img, img.Bounds(), draw.Over, nil)
|
||||
|
||||
// Normalize to [-1, 1] using mean=0.5, std=0.5: (pixel/255 - 0.5) / 0.5 = 2*pixel/255 - 1
|
||||
data := p.pack(dst)
|
||||
return data, targetW, targetH, nil
|
||||
}
|
||||
|
||||
// smartResize computes target dimensions that preserve aspect ratio and
|
||||
// align to alignSize boundaries. It scales the image to fill the maximum
|
||||
// patch budget (maxPixels), matching the MLX reference.
|
||||
func (p *ImageProcessor) smartResize(origW, origH, alignSize int) (int, int) {
|
||||
totalPx := origW * origH
|
||||
|
||||
var targetW, targetH int
|
||||
if p.maxPixels > 0 && totalPx > 0 {
|
||||
factor := math.Sqrt(float64(p.maxPixels) / float64(totalPx))
|
||||
targetH = max(alignSize, int(math.Floor(factor*float64(origH)/float64(alignSize)))*alignSize)
|
||||
targetW = max(alignSize, int(math.Floor(factor*float64(origW)/float64(alignSize)))*alignSize)
|
||||
} else {
|
||||
targetH = max(alignSize, (origH/alignSize)*alignSize)
|
||||
targetW = max(alignSize, (origW/alignSize)*alignSize)
|
||||
}
|
||||
|
||||
return targetW, targetH
|
||||
}
|
||||
|
||||
// pack extracts RGB values from an image and normalizes to [-1, 1].
|
||||
// Returns channel-first layout: [R..., G..., B...].
|
||||
func (p *ImageProcessor) pack(img image.Image) []float32 {
|
||||
bounds := img.Bounds()
|
||||
w := bounds.Dx()
|
||||
h := bounds.Dy()
|
||||
size := w * h
|
||||
|
||||
pixelVals := make([]float32, 3*size)
|
||||
rOff, gOff, bOff := 0, size, 2*size
|
||||
|
||||
for y := bounds.Min.Y; y < bounds.Max.Y; y++ {
|
||||
for x := bounds.Min.X; x < bounds.Max.X; x++ {
|
||||
c := img.At(x, y)
|
||||
r, g, b, _ := c.RGBA()
|
||||
idx := (y-bounds.Min.Y)*w + (x - bounds.Min.X)
|
||||
|
||||
// Normalize [0, 255] -> [-1, 1]: 2 * (val/255) - 1
|
||||
pixelVals[rOff+idx] = float32(r>>8)/255.0*2.0 - 1.0
|
||||
pixelVals[gOff+idx] = float32(g>>8)/255.0*2.0 - 1.0
|
||||
pixelVals[bOff+idx] = float32(b>>8)/255.0*2.0 - 1.0
|
||||
}
|
||||
}
|
||||
|
||||
return pixelVals
|
||||
}
|
||||
102
model/models/gemma4/tokenizer_compare_test.go
Normal file
102
model/models/gemma4/tokenizer_compare_test.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package gemma4
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
// TestTokenizerMatchesHF compares our tokenizer output against HuggingFace reference tokens.
|
||||
func TestTokenizerMatchesHF(t *testing.T) {
|
||||
modelPath := os.Getenv("GEMMA4_MODEL_PATH")
|
||||
if modelPath == "" {
|
||||
t.Skip("set GEMMA4_MODEL_PATH to a gemma4 GGUF file")
|
||||
}
|
||||
|
||||
m, err := model.New(modelPath, ml.BackendParams{AllocMemory: true})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load model: %v", err)
|
||||
}
|
||||
defer m.Backend().Close()
|
||||
|
||||
tok := m.(tokenizer.Tokenizer)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected []int32
|
||||
}{
|
||||
{
|
||||
name: "simple",
|
||||
input: "Hello, world!",
|
||||
expected: []int32{9259, 236764, 1902, 236888},
|
||||
},
|
||||
{
|
||||
name: "special_tokens",
|
||||
input: "<|turn>user\nWhat is 2+2?<turn|>\n<|turn>model\n",
|
||||
expected: []int32{105, 2364, 107, 3689, 563, 236743, 236778, 236862, 236778, 236881, 106, 107, 105, 4368, 107},
|
||||
},
|
||||
{
|
||||
name: "tool_declaration",
|
||||
input: "<|tool>declaration:bash{description:<|\"|>Run a command<|\"|>}<tool|>",
|
||||
expected: []int32{46, 163688, 236787, 42422, 236782, 7777, 236787, 52, 7306, 496, 4991, 52, 236783, 47},
|
||||
},
|
||||
{
|
||||
name: "tool_call",
|
||||
input: "<|tool_call>call:bash{command:<|\"|>ls -la<|\"|>}<tool_call|>",
|
||||
expected: []int32{48, 6639, 236787, 42422, 236782, 7674, 236787, 52, 5629, 753, 2149, 52, 236783, 49},
|
||||
},
|
||||
{
|
||||
name: "thinking",
|
||||
input: "<|channel>thought\nLet me think about this...<channel|>The answer is 42.",
|
||||
expected: []int32{100, 45518, 107, 6481, 786, 1751, 1003, 672, 1390, 101, 818, 3890, 563, 236743, 236812, 236778, 236761},
|
||||
},
|
||||
{
|
||||
name: "code",
|
||||
input: "func main() { fmt.Println(\"hello\") }",
|
||||
expected: []int32{6823, 1689, 825, 642, 22766, 236761, 29006, 885, 23391, 1373, 682},
|
||||
},
|
||||
{
|
||||
name: "numbers",
|
||||
input: "The answer is 42, not 43.5 or -1",
|
||||
expected: []int32{818, 3890, 563, 236743, 236812, 236778, 236764, 711, 236743, 236812, 236800, 236761, 236810, 653, 753, 236770},
|
||||
},
|
||||
{
|
||||
name: "mixed_chat_with_tools",
|
||||
input: "<|turn>system\nYou are a helpful assistant.\n<|tool>declaration:get_weather{description:<|\"|>Get weather<|\"|>,parameters:{properties:{city:{type:<|\"|>STRING<|\"|>}},type:<|\"|>OBJECT<|\"|>}}<tool|><turn|>\n<|turn>user\nWhat's the weather in Paris?<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
|
||||
expected: []int32{105, 9731, 107, 3048, 659, 496, 11045, 16326, 236761, 107, 46, 163688, 236787, 828, 236779, 19323, 236782, 7777, 236787, 52, 3407, 7606, 52, 236764, 19031, 29616, 15921, 29616, 13319, 29616, 2084, 236787, 52, 35410, 52, 5237, 2084, 236787, 52, 60688, 52, 1807, 47, 106, 107, 105, 2364, 107, 3689, 236789, 236751, 506, 7606, 528, 9079, 236881, 106, 107, 105, 4368, 107, 100, 45518, 107, 101},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tokens, err := tok.Encode(tt.input, false) // no BOS
|
||||
if err != nil {
|
||||
t.Fatalf("encode error: %v", err)
|
||||
}
|
||||
|
||||
if len(tokens) != len(tt.expected) {
|
||||
t.Errorf("token count mismatch: got %d, want %d", len(tokens), len(tt.expected))
|
||||
t.Logf("got: %v", tokens)
|
||||
t.Logf("want: %v", tt.expected)
|
||||
return
|
||||
}
|
||||
|
||||
mismatches := 0
|
||||
for i := range tokens {
|
||||
if tokens[i] != tt.expected[i] {
|
||||
mismatches++
|
||||
if mismatches <= 5 {
|
||||
t.Errorf("mismatch at [%d]: got %d, want %d", i, tokens[i], tt.expected[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
if mismatches > 5 {
|
||||
t.Errorf("... and %d more mismatches", mismatches-5)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
341
model/models/gemma4/tokenizer_reference_test.go
Normal file
341
model/models/gemma4/tokenizer_reference_test.go
Normal file
@@ -0,0 +1,341 @@
|
||||
package gemma4
|
||||
|
||||
// TestGemma4TokenizerMatchesReference verifies our BPE tokenizer matches
|
||||
// the Rust tokenizers library (the reference implementation) for Gemma 4.
|
||||
//
|
||||
// The test loads vocabulary from any local ollama gemma4 GGUF model.
|
||||
// Skips if no gemma4 model is installed.
|
||||
//
|
||||
// Set VERIFY_HF_TOKENIZER=1 to verify against the Rust tokenizers library
|
||||
// via Python. Requires python3 with tokenizers>=0.21 on PATH:
|
||||
//
|
||||
// VERIFY_HF_TOKENIZER=1 go test ./model/models/gemma4/ -run TestGemma4Tokenizer -v
|
||||
//
|
||||
// Workflow for adding a new test case:
|
||||
// 1. Add {name: "...", input: "..."} to the test list (no want field)
|
||||
// 2. Run with VERIFY_HF_TOKENIZER=1 — it prints the reference IDs
|
||||
// 3. Paste those IDs into the want field
|
||||
// 4. Run without VERIFY_HF_TOKENIZER — our tokenizer must match
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/fs/gguf"
|
||||
"github.com/ollama/ollama/tokenizer"
|
||||
)
|
||||
|
||||
type tokenizerRefCase struct {
|
||||
name string
|
||||
input string
|
||||
want []int32
|
||||
}
|
||||
|
||||
// Reference token IDs generated by the Rust tokenizers library using
|
||||
// vocab/merges from a gemma4 GGUF with add_special_tokens=False.
|
||||
var gemma4TokenizerRefCases = []tokenizerRefCase{
|
||||
// Basic ASCII
|
||||
{name: "basic word", input: "hello", want: []int32{23391}},
|
||||
{name: "two words", input: "hello world", want: []int32{23391, 1902}},
|
||||
{name: "punctuation", input: "Hello, World!", want: []int32{9259, 236764, 4109, 236888}},
|
||||
|
||||
// Space handling (pretokenizer bug: GPT-2 splitter mangled leading/multiple spaces)
|
||||
{name: "leading space", input: " hello", want: []int32{29104}},
|
||||
{name: "double leading space", input: " hello", want: []int32{138, 23391}},
|
||||
{name: "double space between words", input: "hello world", want: []int32{23391, 138, 12392}},
|
||||
{name: "only spaces", input: " ", want: []int32{139}},
|
||||
{name: "repeated spaces", input: " ", want: []int32{142}},
|
||||
{name: "leading spaces phrase", input: " leading spaces", want: []int32{5830, 9952}},
|
||||
{name: "multiple interior spaces", input: "multiple spaces", want: []int32{43819, 140, 35220}},
|
||||
|
||||
// Polish diacritics (issue #15231 — Decode mangled U+0105-U+0142)
|
||||
{name: "polish diacritics", input: "ąęśćżźółń", want: []int32{237198, 237202, 14732, 237277, 238992, 24875, 238041}},
|
||||
{name: "polish sentence", input: "Zażółć gęślą jaźń", want: []int32{236953, 40512, 24875, 237289, 549, 237202, 62081, 237198, 4828, 238992, 238041}},
|
||||
|
||||
// French accents (issue #15229 — Decode mangled U+00E0-U+00FF)
|
||||
{name: "french accents", input: "café résumé naïve", want: []int32{123125, 236859, 118515, 120362}},
|
||||
{name: "french with apostrophe", input: "L'élève a mangé", want: []int32{236798, 236789, 161654, 496, 14695, 236859}},
|
||||
|
||||
// German umlauts
|
||||
{name: "german umlauts", input: "über Straße Größe", want: []int32{28223, 80176, 112880}},
|
||||
|
||||
// Codepoints in GPT-2 byte reversal range (U+0100-U+0142)
|
||||
{name: "codepoints in gpt2 byte range", input: "ąęćł", want: []int32{237198, 226110, 237114}},
|
||||
{name: "latin extended A", input: "ĀāĂ㥹", want: []int32{241920, 237448, 241645, 237106, 243514, 237198}},
|
||||
|
||||
// CJK & Japanese
|
||||
{name: "chinese", input: "你好世界", want: []int32{144626, 12811}},
|
||||
{name: "japanese hiragana", input: "こんにちは", want: []int32{85141}},
|
||||
|
||||
// Mixed scripts
|
||||
{name: "mixed scripts", input: "hello ąęść world café 你好", want: []int32{23391, 236743, 237198, 237202, 14732, 1902, 33443, 43758, 237389}},
|
||||
|
||||
// Whitespace
|
||||
{name: "empty string", input: "", want: []int32{}},
|
||||
{name: "newlines", input: "\n\n", want: []int32{108}},
|
||||
{name: "tabs", input: "\t\t", want: []int32{255969}},
|
||||
|
||||
// Code-like content
|
||||
{name: "python code", input: "def foo(x): return x + 1", want: []int32{2063, 46293, 236769, 236781, 1473, 994, 1123, 900, 236743, 236770}},
|
||||
{name: "json", input: `{"key": "value"}`, want: []int32{14937, 2478, 1083, 623, 2394, 25938}},
|
||||
|
||||
// Misc
|
||||
{name: "repeated char", input: "aaaaaa", want: []int32{50354, 9236}},
|
||||
{name: "emoji", input: "hello 👋 world", want: []int32{23391, 155818, 1902}},
|
||||
{name: "digits", input: "12345", want: []int32{236770, 236778, 236800, 236812, 236810}},
|
||||
{name: "float", input: "3.14159", want: []int32{236800, 236761, 236770, 236812, 236770, 236810, 236819}},
|
||||
}
|
||||
|
||||
// findGemma4GGUF looks for any gemma4 model GGUF in the local ollama store.
|
||||
func findGemma4GGUF() (string, error) {
|
||||
modelsDir := envconfig.Models()
|
||||
manifestDir := filepath.Join(modelsDir, "manifests", "registry.ollama.ai", "library", "gemma4")
|
||||
entries, err := os.ReadDir(manifestDir)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("no gemma4 manifests in %s: %w", manifestDir, err)
|
||||
}
|
||||
|
||||
blobDir := filepath.Join(modelsDir, "blobs")
|
||||
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(filepath.Join(manifestDir, entry.Name()))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var manifest struct {
|
||||
Layers []struct {
|
||||
MediaType string `json:"mediaType"`
|
||||
Digest string `json:"digest"`
|
||||
} `json:"layers"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &manifest); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, layer := range manifest.Layers {
|
||||
if layer.MediaType == "application/vnd.ollama.image.model" {
|
||||
blobPath := filepath.Join(blobDir, strings.Replace(layer.Digest, ":", "-", 1))
|
||||
if _, err := os.Stat(blobPath); err == nil {
|
||||
return blobPath, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("no gemma4 model blob found in %s", modelsDir)
|
||||
}
|
||||
|
||||
// loadGemma4Tokenizer opens a GGUF and builds a BPE tokenizer from its
|
||||
// tokenizer metadata — the same configuration used at inference time.
|
||||
func loadGemma4Tokenizer(t *testing.T, ggufPath string) tokenizer.BytePairEncoding {
|
||||
t.Helper()
|
||||
|
||||
f, err := gguf.Open(ggufPath)
|
||||
if err != nil {
|
||||
t.Fatalf("gguf.Open: %v", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
tokens := f.KeyValue("tokenizer.ggml.tokens").Strings()
|
||||
if len(tokens) == 0 {
|
||||
t.Fatal("no tokenizer.ggml.tokens in GGUF")
|
||||
}
|
||||
|
||||
scores64 := f.KeyValue("tokenizer.ggml.scores").Floats()
|
||||
scores := make([]float32, len(scores64))
|
||||
for i, s := range scores64 {
|
||||
scores[i] = float32(s)
|
||||
}
|
||||
|
||||
types64 := f.KeyValue("tokenizer.ggml.token_type").Ints()
|
||||
types := make([]int32, len(types64))
|
||||
for i, tt := range types64 {
|
||||
types[i] = int32(tt)
|
||||
}
|
||||
|
||||
merges := f.KeyValue("tokenizer.ggml.merges").Strings()
|
||||
|
||||
vocab := &tokenizer.Vocabulary{
|
||||
Values: tokens,
|
||||
Types: types,
|
||||
Scores: scores,
|
||||
Merges: merges,
|
||||
BOS: []int32{2},
|
||||
EOS: []int32{1},
|
||||
AddBOS: false,
|
||||
}
|
||||
|
||||
return tokenizer.NewBytePairEncodingWithOptions(vocab, []string{},
|
||||
tokenizer.WithSentencePieceNormalizer())
|
||||
}
|
||||
|
||||
// writeTokenizerJSON reconstructs a tokenizer.json from GGUF metadata
|
||||
// for the Rust tokenizers library to load as an independent reference.
|
||||
func writeTokenizerJSON(t *testing.T, ggufPath string) string {
|
||||
t.Helper()
|
||||
|
||||
f, err := gguf.Open(ggufPath)
|
||||
if err != nil {
|
||||
t.Fatalf("gguf.Open: %v", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
tokens := f.KeyValue("tokenizer.ggml.tokens").Strings()
|
||||
mergeStrs := f.KeyValue("tokenizer.ggml.merges").Strings()
|
||||
|
||||
vocab := make(map[string]int, len(tokens))
|
||||
for i, tok := range tokens {
|
||||
vocab[tok] = i
|
||||
}
|
||||
|
||||
merges := make([][2]string, len(mergeStrs))
|
||||
for i, m := range mergeStrs {
|
||||
parts := strings.SplitN(m, " ", 2)
|
||||
if len(parts) == 2 {
|
||||
merges[i] = [2]string{parts[0], parts[1]}
|
||||
}
|
||||
}
|
||||
|
||||
tj := map[string]any{
|
||||
"version": "1.0",
|
||||
"model": map[string]any{
|
||||
"type": "BPE",
|
||||
"vocab": vocab,
|
||||
"merges": merges,
|
||||
},
|
||||
"normalizer": map[string]any{
|
||||
"type": "Replace",
|
||||
"pattern": map[string]string{"String": " "},
|
||||
"content": "\u2581",
|
||||
},
|
||||
}
|
||||
|
||||
tmpFile, err := os.CreateTemp(t.TempDir(), "gemma4_tokenizer_*.json")
|
||||
if err != nil {
|
||||
t.Fatalf("create temp file: %v", err)
|
||||
}
|
||||
|
||||
if err := json.NewEncoder(tmpFile).Encode(tj); err != nil {
|
||||
tmpFile.Close()
|
||||
t.Fatalf("encode tokenizer.json: %v", err)
|
||||
}
|
||||
tmpFile.Close()
|
||||
|
||||
return tmpFile.Name()
|
||||
}
|
||||
|
||||
func TestGemma4TokenizerMatchesReference(t *testing.T) {
|
||||
ggufPath, err := findGemma4GGUF()
|
||||
if err != nil {
|
||||
t.Skipf("skipping: %v", err)
|
||||
}
|
||||
t.Logf("using GGUF: %s", ggufPath)
|
||||
|
||||
tok := loadGemma4Tokenizer(t, ggufPath)
|
||||
|
||||
verify := os.Getenv("VERIFY_HF_TOKENIZER") != ""
|
||||
var tokenizerJSONPath string
|
||||
if verify {
|
||||
if err := exec.Command("python3", "-c", "from tokenizers import Tokenizer").Run(); err != nil {
|
||||
t.Fatal("VERIFY_HF_TOKENIZER=1 requires python3 with tokenizers>=0.21 on PATH")
|
||||
}
|
||||
tokenizerJSONPath = writeTokenizerJSON(t, ggufPath)
|
||||
defer os.Remove(tokenizerJSONPath)
|
||||
t.Log("VERIFY_HF_TOKENIZER=1: verifying against Rust tokenizers library")
|
||||
}
|
||||
|
||||
for _, tc := range gemma4TokenizerRefCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ids, err := tok.Encode(tc.input, false)
|
||||
if err != nil {
|
||||
t.Fatalf("Encode(%q): %v", tc.input, err)
|
||||
}
|
||||
|
||||
if tc.want != nil {
|
||||
if fmt.Sprint(ids) != fmt.Sprint(tc.want) {
|
||||
t.Errorf("Encode(%q):\n got: %v\n want: %v", tc.input, ids, tc.want)
|
||||
}
|
||||
} else {
|
||||
t.Errorf("no expected IDs for %q; our tokenizer produced: %v", tc.input, ids)
|
||||
}
|
||||
|
||||
if len(ids) > 0 {
|
||||
decoded, err := tok.Decode(ids)
|
||||
if err != nil {
|
||||
t.Fatalf("Decode: %v", err)
|
||||
}
|
||||
if decoded != tc.input {
|
||||
t.Errorf("roundtrip %q: Decode(Encode) = %q", tc.input, decoded)
|
||||
}
|
||||
}
|
||||
|
||||
if verify {
|
||||
refIDs := encodeWithRustTokenizer(t, tokenizerJSONPath, tc.input)
|
||||
|
||||
if fmt.Sprint(refIDs) != fmt.Sprint(ids) {
|
||||
fmt.Fprintf(os.Stderr, "\nREFERENCE OUTPUT for %s (copy-paste as want):\nwant: []int32{%s},\n\n",
|
||||
tc.name, int32SliceStr(refIDs))
|
||||
}
|
||||
|
||||
if tc.want != nil && fmt.Sprint(refIDs) != fmt.Sprint(tc.want) {
|
||||
t.Errorf("hardcoded expected IDs don't match reference for %q:\n ref: %v\n hardcoded: %v",
|
||||
tc.input, refIDs, tc.want)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func encodeWithRustTokenizer(t *testing.T, tokenizerPath, text string) []int32 {
|
||||
t.Helper()
|
||||
|
||||
if text == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
script := fmt.Sprintf(`
|
||||
from tokenizers import Tokenizer
|
||||
t = Tokenizer.from_file(%q)
|
||||
ids = t.encode(%q, add_special_tokens=False).ids
|
||||
print(",".join(str(i) for i in ids))
|
||||
`, tokenizerPath, text)
|
||||
|
||||
cmd := exec.Command("python3", "-c", script)
|
||||
var stdout, stderr strings.Builder
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
if err := cmd.Run(); err != nil {
|
||||
t.Fatalf("python3 failed: %v\nstderr: %s", err, stderr.String())
|
||||
}
|
||||
|
||||
parts := strings.Split(strings.TrimSpace(stdout.String()), ",")
|
||||
var ids []int32
|
||||
for _, p := range parts {
|
||||
if p == "" {
|
||||
continue
|
||||
}
|
||||
var id int32
|
||||
fmt.Sscanf(p, "%d", &id)
|
||||
ids = append(ids, id)
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
func int32SliceStr(ids []int32) string {
|
||||
parts := make([]string, len(ids))
|
||||
for i, id := range ids {
|
||||
parts[i] = fmt.Sprintf("%d", id)
|
||||
}
|
||||
return strings.Join(parts, ", ")
|
||||
}
|
||||
Reference in New Issue
Block a user