ollama source for Momentry Core verification
This commit is contained in:
202
x/mlxrunner/model/base/base.go
Normal file
202
x/mlxrunner/model/base/base.go
Normal file
@@ -0,0 +1,202 @@
|
||||
package base
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/batch"
|
||||
"github.com/ollama/ollama/x/mlxrunner/cache"
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/mlxrunner/model"
|
||||
"github.com/ollama/ollama/x/tokenizer"
|
||||
)
|
||||
|
||||
// Model is the interface that model implementations must satisfy.
|
||||
type Model interface {
|
||||
Forward(b *batch.Batch, cache []cache.Cache) *mlx.Array
|
||||
Unembed(x *mlx.Array) *mlx.Array
|
||||
NumLayers() int
|
||||
Tokenizer() *tokenizer.Tokenizer
|
||||
MaxContextLength() int
|
||||
|
||||
// LoadWeights receives all tensors loaded from the manifest and assigns
|
||||
// them to model fields. Model-specific logic (MLA absorption, expert
|
||||
// stacking, quantized layer creation) happens here.
|
||||
LoadWeights(tensors map[string]*mlx.Array) error
|
||||
}
|
||||
|
||||
// DraftModel is an auxiliary model stored alongside a target model.
|
||||
type DraftModel interface {
|
||||
LoadWeights(tensors map[string]*mlx.Array) error
|
||||
}
|
||||
|
||||
// MTPDefaults holds model-provided draft-token defaults for speculative
|
||||
// decoding. Environment settings in the runner may override these values.
|
||||
type MTPDefaults struct {
|
||||
InitialDraftTokens int
|
||||
MaxDraftTokens int
|
||||
Enabled bool
|
||||
}
|
||||
|
||||
// MTPDefaultsProvider lets a model provide MTP policy defaults from its own
|
||||
// config without teaching the runner model-specific shape heuristics.
|
||||
type MTPDefaultsProvider interface {
|
||||
MTPDraftDefaults(sample bool) MTPDefaults
|
||||
}
|
||||
|
||||
// MTPDraftModel is a draft model capable of Gemma-style multi-token
|
||||
// prediction from target token embeddings, target hidden states, and target KV.
|
||||
type MTPDraftModel interface {
|
||||
Draft(inputEmbeds *mlx.Array, position int32, caches []cache.Cache) (logits, hidden *mlx.Array)
|
||||
}
|
||||
|
||||
// MTPEmbeddingModel exposes the target token embedding path used by MTP drafts.
|
||||
type MTPEmbeddingModel interface {
|
||||
TokenEmbeddings(inputIDs *mlx.Array) *mlx.Array
|
||||
}
|
||||
|
||||
// DFlashTargetModel exposes target-layer hidden states for DFlash drafts.
|
||||
type DFlashTargetModel interface {
|
||||
ForwardDFlash(b *batch.Batch, caches []cache.Cache, layerIDs []int) (hidden, targetHidden *mlx.Array)
|
||||
}
|
||||
|
||||
// DFlashDraftModel is a block-diffusion speculative draft model.
|
||||
type DFlashDraftModel interface {
|
||||
DraftModel
|
||||
|
||||
TargetLayerIDs() []int
|
||||
BlockSize() int
|
||||
MaskTokenID() int32
|
||||
NewCaches() []cache.Cache
|
||||
AppendContext(targetHidden *mlx.Array, caches []cache.Cache)
|
||||
Draft(inputIDs *mlx.Array, caches []cache.Cache) *mlx.Array
|
||||
}
|
||||
|
||||
var (
|
||||
mu sync.Mutex
|
||||
registry = make(map[string]func(root *model.Root) (Model, error))
|
||||
draftRegistry = make(map[string]func(root *model.Root, target Model) (DraftModel, error))
|
||||
)
|
||||
|
||||
// Register registers a model constructor by architecture name.
|
||||
// Called from init() in model packages. Panics on duplicate registration.
|
||||
func Register(arch string, fn func(root *model.Root) (Model, error)) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
if _, exists := registry[arch]; exists {
|
||||
panic(fmt.Sprintf("model architecture %q already registered", arch))
|
||||
}
|
||||
registry[arch] = fn
|
||||
}
|
||||
|
||||
// RegisterDraft registers a draft model constructor by architecture name.
|
||||
func RegisterDraft(arch string, fn func(root *model.Root, target Model) (DraftModel, error)) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
if _, exists := draftRegistry[arch]; exists {
|
||||
panic(fmt.Sprintf("draft model architecture %q already registered", arch))
|
||||
}
|
||||
draftRegistry[arch] = fn
|
||||
}
|
||||
|
||||
// New reads config.json from the manifest, detects the architecture, looks up
|
||||
// the registered constructor, and calls it to create the model (with config
|
||||
// parsed and struct created, but weights not yet loaded).
|
||||
func New(root *model.Root) (Model, error) {
|
||||
configData, err := root.Manifest.ReadConfig("config.json")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read config.json: %w", err)
|
||||
}
|
||||
|
||||
var archConfig struct {
|
||||
Architectures []string `json:"architectures"`
|
||||
}
|
||||
if err := json.Unmarshal(configData, &archConfig); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse config.json: %w", err)
|
||||
}
|
||||
|
||||
if len(archConfig.Architectures) == 0 {
|
||||
return nil, fmt.Errorf("no architectures found in config.json")
|
||||
}
|
||||
|
||||
arch := archConfig.Architectures[0]
|
||||
slog.Info("Model architecture", "arch", arch)
|
||||
|
||||
mu.Lock()
|
||||
fn, ok := registry[arch]
|
||||
mu.Unlock()
|
||||
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unsupported architecture: %s", arch)
|
||||
}
|
||||
|
||||
return fn(root)
|
||||
}
|
||||
|
||||
// NewDraft constructs the draft model described by the manifest config, if any.
|
||||
func NewDraft(root *model.Root, target Model) (DraftModel, error) {
|
||||
if root == nil || root.Draft == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
configPath := root.Draft.Config
|
||||
if configPath == "" {
|
||||
configPath = "draft/config.json"
|
||||
}
|
||||
configData, err := root.Manifest.ReadConfig(configPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read %s: %w", configPath, err)
|
||||
}
|
||||
|
||||
var archConfig struct {
|
||||
Architectures []string `json:"architectures"`
|
||||
ModelType string `json:"model_type"`
|
||||
}
|
||||
if err := json.Unmarshal(configData, &archConfig); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse %s: %w", configPath, err)
|
||||
}
|
||||
|
||||
arch := root.Draft.Architecture
|
||||
if arch == "" && len(archConfig.Architectures) > 0 {
|
||||
arch = archConfig.Architectures[0]
|
||||
}
|
||||
if arch == "" {
|
||||
arch = archConfig.ModelType
|
||||
}
|
||||
if arch == "" {
|
||||
return nil, fmt.Errorf("no draft architecture found in %s", configPath)
|
||||
}
|
||||
slog.Info("Draft model architecture", "arch", arch)
|
||||
|
||||
mu.Lock()
|
||||
fn, ok := draftRegistry[arch]
|
||||
mu.Unlock()
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unsupported draft architecture: %s", arch)
|
||||
}
|
||||
|
||||
return fn(root, target)
|
||||
}
|
||||
|
||||
// Weights returns a function that loads model weights, then pins all
|
||||
// arrays reachable from the model struct and sweeps everything else.
|
||||
func Weights(m Model) func(map[string]*mlx.Array) error {
|
||||
return func(tensors map[string]*mlx.Array) error {
|
||||
if err := m.LoadWeights(tensors); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
collected := mlx.Collect(m)
|
||||
for _, arr := range collected {
|
||||
mlx.Pin(arr)
|
||||
}
|
||||
mlx.Sweep()
|
||||
mlx.Eval(collected...)
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
42
x/mlxrunner/model/embedding.go
Normal file
42
x/mlxrunner/model/embedding.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/models/nn"
|
||||
)
|
||||
|
||||
// MakeEmbeddingLayer constructs an embedding layer from a tensor map.
|
||||
//
|
||||
// For quantized tensors (path.weight + path.weight_scale), it returns a
|
||||
// QuantizedEmbedding using the same quant metadata path that linear layers use.
|
||||
// For non-quantized tensors, it returns a standard dense embedding.
|
||||
func MakeEmbeddingLayer(
|
||||
tensors map[string]*mlx.Array,
|
||||
path string,
|
||||
defaultGroupSize, defaultBits int,
|
||||
defaultMode string,
|
||||
tensorQuant map[string]*TensorQuantInfo,
|
||||
) nn.EmbeddingLayer {
|
||||
w := tensors[path+".weight"]
|
||||
if w == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
scales := tensors[path+".weight_scale"]
|
||||
if scales != nil {
|
||||
qbiases := tensors[path+".weight_qbias"]
|
||||
groupSize, bits, mode := ResolveLinearQuantParams(
|
||||
defaultGroupSize,
|
||||
defaultBits,
|
||||
defaultMode,
|
||||
tensorQuant,
|
||||
path+".weight",
|
||||
w,
|
||||
scales,
|
||||
)
|
||||
|
||||
return nn.NewQuantizedEmbedding(w, scales, qbiases, groupSize, bits, mode)
|
||||
}
|
||||
|
||||
return nn.NewEmbedding(w)
|
||||
}
|
||||
78
x/mlxrunner/model/embedding_test.go
Normal file
78
x/mlxrunner/model/embedding_test.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/models/nn"
|
||||
)
|
||||
|
||||
func skipIfNoMLX(t *testing.T) {
|
||||
t.Helper()
|
||||
if err := mlx.CheckInit(); err != nil {
|
||||
t.Skipf("MLX not available: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMakeEmbeddingLayerDense(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
weight := mlx.FromValues([]float32{
|
||||
1, 2, 3, 4,
|
||||
5, 6, 7, 8,
|
||||
}, 2, 4).AsType(mlx.DTypeBFloat16)
|
||||
|
||||
emb := MakeEmbeddingLayer(map[string]*mlx.Array{
|
||||
"model.embed_tokens.weight": weight,
|
||||
}, "model.embed_tokens", 0, 0, "", nil)
|
||||
|
||||
dense, ok := emb.(*nn.Embedding)
|
||||
if !ok {
|
||||
t.Fatalf("embedding type = %T, want *nn.Embedding", emb)
|
||||
}
|
||||
if dense.Weight.DType() != mlx.DTypeBFloat16 {
|
||||
t.Fatalf("embedding dtype = %v, want %v", dense.Weight.DType(), mlx.DTypeBFloat16)
|
||||
}
|
||||
if _, ok := emb.AsLinear().(*nn.Linear); !ok {
|
||||
t.Fatalf("AsLinear type = %T, want *nn.Linear", emb.AsLinear())
|
||||
}
|
||||
}
|
||||
|
||||
func TestMakeEmbeddingLayerQuantized(t *testing.T) {
|
||||
skipIfNoMLX(t)
|
||||
|
||||
denseWeight := mlx.FromValues(func() []float32 {
|
||||
out := make([]float32, 2*64)
|
||||
for i := range out {
|
||||
out[i] = float32(i%17) / 8
|
||||
}
|
||||
return out
|
||||
}(), 2, 64).AsType(mlx.DTypeBFloat16)
|
||||
|
||||
qw, scales, qbiases := mlx.Quantize(denseWeight, 64, 4, "affine")
|
||||
mlx.Eval(qw, scales, qbiases)
|
||||
|
||||
emb := MakeEmbeddingLayer(map[string]*mlx.Array{
|
||||
"model.embed_tokens.weight": qw,
|
||||
"model.embed_tokens.weight_scale": scales,
|
||||
"model.embed_tokens.weight_qbias": qbiases,
|
||||
}, "model.embed_tokens", 64, 4, "affine", nil)
|
||||
|
||||
qemb, ok := emb.(*nn.QuantizedEmbedding)
|
||||
if !ok {
|
||||
t.Fatalf("embedding type = %T, want *nn.QuantizedEmbedding", emb)
|
||||
}
|
||||
if qemb.GroupSize != 64 || qemb.Bits != 4 || qemb.Mode != "affine" {
|
||||
t.Fatalf("quant params = (%d, %d, %q), want (64, 4, %q)", qemb.GroupSize, qemb.Bits, qemb.Mode, "affine")
|
||||
}
|
||||
|
||||
indices := mlx.FromValues([]int32{1, 0}, 2)
|
||||
out := emb.Forward(indices)
|
||||
mlx.Eval(out)
|
||||
if dims := out.Dims(); len(dims) != 2 || dims[0] != 2 || dims[1] != 64 {
|
||||
t.Fatalf("embedding output dims = %v, want [2 64]", dims)
|
||||
}
|
||||
if _, ok := emb.AsLinear().(*nn.QuantizedLinear); !ok {
|
||||
t.Fatalf("AsLinear type = %T, want *nn.QuantizedLinear", emb.AsLinear())
|
||||
}
|
||||
}
|
||||
99
x/mlxrunner/model/linear.go
Normal file
99
x/mlxrunner/model/linear.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
"github.com/ollama/ollama/x/models/nn"
|
||||
)
|
||||
|
||||
// LinearFactory builds linear layers using shared tensor maps and quant defaults.
|
||||
type LinearFactory struct {
|
||||
tensors map[string]*mlx.Array
|
||||
defaultGroupSize int
|
||||
defaultBits int
|
||||
defaultMode string
|
||||
tensorQuant map[string]*TensorQuantInfo
|
||||
}
|
||||
|
||||
// NewLinearFactory creates a reusable constructor for model linear layers.
|
||||
func NewLinearFactory(
|
||||
tensors map[string]*mlx.Array,
|
||||
defaultGroupSize, defaultBits int,
|
||||
defaultMode string,
|
||||
tensorQuant map[string]*TensorQuantInfo,
|
||||
) LinearFactory {
|
||||
return LinearFactory{
|
||||
tensors: tensors,
|
||||
defaultGroupSize: defaultGroupSize,
|
||||
defaultBits: defaultBits,
|
||||
defaultMode: defaultMode,
|
||||
tensorQuant: tensorQuant,
|
||||
}
|
||||
}
|
||||
|
||||
// Make constructs a linear layer at path.
|
||||
func (f LinearFactory) Make(path string) nn.LinearLayer {
|
||||
return MakeLinearLayer(
|
||||
f.tensors,
|
||||
path,
|
||||
f.defaultGroupSize,
|
||||
f.defaultBits,
|
||||
f.defaultMode,
|
||||
f.tensorQuant,
|
||||
)
|
||||
}
|
||||
|
||||
// MakeLinearLayer constructs a linear layer from a tensor map.
|
||||
//
|
||||
// For quantized tensors (path.weight + path.weight_scale), it resolves per-tensor
|
||||
// quant params via TensorQuant metadata (with shape-based affine fallback).
|
||||
// For non-quantized tensors, it returns a standard nn.Linear.
|
||||
func MakeLinearLayer(
|
||||
tensors map[string]*mlx.Array,
|
||||
path string,
|
||||
defaultGroupSize, defaultBits int,
|
||||
defaultMode string,
|
||||
tensorQuant map[string]*TensorQuantInfo,
|
||||
) nn.LinearLayer {
|
||||
w := tensors[path+".weight"]
|
||||
if w == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
scales := tensors[path+".weight_scale"]
|
||||
if scales != nil {
|
||||
qbiases := tensors[path+".weight_qbias"]
|
||||
bias := tensors[path+".bias"]
|
||||
|
||||
groupSize, bits, mode := ResolveLinearQuantParams(
|
||||
defaultGroupSize,
|
||||
defaultBits,
|
||||
defaultMode,
|
||||
tensorQuant,
|
||||
path+".weight",
|
||||
w,
|
||||
scales,
|
||||
)
|
||||
|
||||
// Check for per-tensor global scale (NVIDIA double-scale nvfp4).
|
||||
// NVIDIA ModelOpt stores this as "weight_scale_2"; our import
|
||||
// pipeline maps it to "weight.global_scale".
|
||||
globalScale := tensors[path+".weight.global_scale"]
|
||||
if globalScale == nil {
|
||||
globalScale = tensors[path+".weight_scale_2"]
|
||||
}
|
||||
|
||||
return &nn.QuantizedLinear{
|
||||
Weight: w,
|
||||
Scales: scales,
|
||||
QBiases: qbiases,
|
||||
Bias: bias,
|
||||
GlobalScale: globalScale,
|
||||
GroupSize: groupSize,
|
||||
Bits: bits,
|
||||
Mode: mode,
|
||||
}
|
||||
}
|
||||
|
||||
bias := tensors[path+".bias"]
|
||||
return nn.NewLinear(w, bias)
|
||||
}
|
||||
132
x/mlxrunner/model/quant.go
Normal file
132
x/mlxrunner/model/quant.go
Normal file
@@ -0,0 +1,132 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
||||
)
|
||||
|
||||
// QuantizationParams returns default groupSize, bits, and mode for a quantization type.
|
||||
func QuantizationParams(quantization string) (groupSize, bits int, mode string) {
|
||||
switch strings.ToUpper(quantization) {
|
||||
case "NVFP4":
|
||||
return 16, 4, "nvfp4"
|
||||
case "MXFP4":
|
||||
return 32, 4, "mxfp4"
|
||||
case "FP4", "Q4", "INT4":
|
||||
return 64, 4, "affine"
|
||||
case "MXFP8":
|
||||
return 32, 8, "mxfp8"
|
||||
case "FP8", "Q8", "INT8":
|
||||
return 64, 8, "affine"
|
||||
case "":
|
||||
return 0, 0, ""
|
||||
default:
|
||||
return 32, 8, "affine"
|
||||
}
|
||||
}
|
||||
|
||||
// TensorQuantParams resolves quant params for a tensor using per-tensor metadata
|
||||
// when available, otherwise falling back to the provided model defaults.
|
||||
func TensorQuantParams(
|
||||
defaultGroupSize, defaultBits int,
|
||||
defaultMode string,
|
||||
tensorQuant map[string]*TensorQuantInfo,
|
||||
tensorName string,
|
||||
) (groupSize, bits int, mode string, fromTensor bool) {
|
||||
if tensorQuant != nil {
|
||||
if tq := tensorQuant[tensorName]; tq != nil {
|
||||
groupSize, bits, mode = QuantizationParams(tq.QuantType)
|
||||
if tq.GroupSize > 0 {
|
||||
groupSize = tq.GroupSize
|
||||
}
|
||||
return groupSize, bits, mode, true
|
||||
}
|
||||
}
|
||||
return defaultGroupSize, defaultBits, defaultMode, false
|
||||
}
|
||||
|
||||
// ResolveLinearQuantParams resolves quantization params for a quantized linear
|
||||
// tensor, preferring per-tensor metadata and falling back to shape-based
|
||||
// inference for affine packed tensors.
|
||||
func ResolveLinearQuantParams(
|
||||
defaultGroupSize, defaultBits int,
|
||||
defaultMode string,
|
||||
tensorQuant map[string]*TensorQuantInfo,
|
||||
tensorName string,
|
||||
weight, scales *mlx.Array,
|
||||
) (groupSize, bits int, mode string) {
|
||||
groupSize, bits, mode, fromTensor := TensorQuantParams(
|
||||
defaultGroupSize,
|
||||
defaultBits,
|
||||
defaultMode,
|
||||
tensorQuant,
|
||||
tensorName,
|
||||
)
|
||||
|
||||
if mode == "affine" {
|
||||
if inferredGroupSize, inferredBits, ok := InferAffineQuantParamsFromShapes(weight, scales, bits); ok {
|
||||
if !fromTensor || groupSize == 0 || bits == 0 {
|
||||
groupSize = inferredGroupSize
|
||||
bits = inferredBits
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return groupSize, bits, mode
|
||||
}
|
||||
|
||||
// InferAffineQuantParamsFromShapes infers (groupSize,bits) for affine quantized
|
||||
// tensors from packed weight and scale shapes.
|
||||
func InferAffineQuantParamsFromShapes(weight, scales *mlx.Array, hintBits int) (groupSize, bits int, ok bool) {
|
||||
if weight == nil || scales == nil {
|
||||
return 0, 0, false
|
||||
}
|
||||
|
||||
weightShape := weight.Dims()
|
||||
scaleShape := scales.Dims()
|
||||
if len(weightShape) == 0 || len(scaleShape) == 0 {
|
||||
return 0, 0, false
|
||||
}
|
||||
|
||||
weightCols := weightShape[len(weightShape)-1]
|
||||
scalesCols := scaleShape[len(scaleShape)-1]
|
||||
if weightCols <= 0 || scalesCols <= 0 {
|
||||
return 0, 0, false
|
||||
}
|
||||
|
||||
groupSize4 := weightCols * 8 / scalesCols
|
||||
groupSize8 := weightCols * 4 / scalesCols
|
||||
|
||||
switch {
|
||||
case groupSize4 == 32:
|
||||
return 32, 4, true
|
||||
case groupSize8 == 64:
|
||||
return 64, 8, true
|
||||
case groupSize4 == 64 && groupSize8 == 32:
|
||||
if hintBits == 8 {
|
||||
return 32, 8, true
|
||||
}
|
||||
if hintBits == 4 {
|
||||
return 64, 4, true
|
||||
}
|
||||
}
|
||||
|
||||
if isCommonGroupSize(groupSize4) && !isCommonGroupSize(groupSize8) {
|
||||
return groupSize4, 4, true
|
||||
}
|
||||
if isCommonGroupSize(groupSize8) && !isCommonGroupSize(groupSize4) {
|
||||
return groupSize8, 8, true
|
||||
}
|
||||
|
||||
return 0, 0, false
|
||||
}
|
||||
|
||||
func isCommonGroupSize(v int) bool {
|
||||
switch v {
|
||||
case 16, 32, 64, 128:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
299
x/mlxrunner/model/root.go
Normal file
299
x/mlxrunner/model/root.go
Normal file
@@ -0,0 +1,299 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
modeltypes "github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/x/imagegen/manifest"
|
||||
)
|
||||
|
||||
// TensorQuantInfo describes per-tensor quantization metadata.
|
||||
type TensorQuantInfo struct {
|
||||
QuantType string
|
||||
GroupSize int
|
||||
}
|
||||
|
||||
// Root wraps a ModelManifest with pre-scanned quantization metadata.
|
||||
type Root struct {
|
||||
Manifest *manifest.ModelManifest
|
||||
Draft *modeltypes.Draft
|
||||
|
||||
// Backwards-compatible model-level quant metadata (first tensor blob).
|
||||
quantType string
|
||||
groupSize int
|
||||
|
||||
// Per-tensor quantization metadata.
|
||||
tensorQuant map[string]*TensorQuantInfo
|
||||
}
|
||||
|
||||
// Open loads a manifest for the given model name and scans tensor blobs for
|
||||
// quantization metadata.
|
||||
func Open(modelName string) (*Root, error) {
|
||||
m, err := manifest.LoadManifest(modelName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
root := &Root{
|
||||
Manifest: m,
|
||||
tensorQuant: make(map[string]*TensorQuantInfo),
|
||||
}
|
||||
root.Draft = readDraftConfig(m)
|
||||
|
||||
for _, layer := range m.GetTensorLayers("") {
|
||||
blobPath := m.BlobPath(layer.Digest)
|
||||
|
||||
infos, blobQuantType, blobGroupSize, err := readBlobTensorQuantInfo(blobPath)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for name, info := range infos {
|
||||
root.tensorQuant[name] = info
|
||||
}
|
||||
|
||||
if root.quantType == "" && blobQuantType != "" {
|
||||
root.quantType = strings.ToUpper(blobQuantType)
|
||||
root.groupSize = blobGroupSize
|
||||
if root.groupSize == 0 {
|
||||
root.groupSize = defaultGroupSize(root.quantType)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return root, nil
|
||||
}
|
||||
|
||||
func readDraftConfig(m *manifest.ModelManifest) *modeltypes.Draft {
|
||||
if m == nil || m.Manifest == nil || m.Manifest.Config.Digest == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(m.BlobPath(m.Manifest.Config.Digest))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var cfg modeltypes.ConfigV2
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return nil
|
||||
}
|
||||
if cfg.Draft != nil {
|
||||
return cfg.Draft
|
||||
}
|
||||
|
||||
if m.GetConfigLayer("draft/config.json") != nil {
|
||||
return &modeltypes.Draft{
|
||||
ModelFormat: "safetensors",
|
||||
TensorPrefix: "draft.",
|
||||
Config: "draft/config.json",
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close is a no-op for now (future: release resources).
|
||||
func (r *Root) Close() {}
|
||||
|
||||
// QuantType returns the quantization type detected from the first tensor blob metadata.
|
||||
func (r *Root) QuantType() string { return r.quantType }
|
||||
|
||||
// GroupSize returns the quantization group size detected from the first tensor blob metadata.
|
||||
func (r *Root) GroupSize() int { return r.groupSize }
|
||||
|
||||
// TensorQuant returns per-tensor quantization metadata if available.
|
||||
func (r *Root) TensorQuant(name string) *TensorQuantInfo {
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
return r.tensorQuant[name]
|
||||
}
|
||||
|
||||
// AllTensorQuant returns a copy of the per-tensor quantization metadata.
|
||||
func (r *Root) AllTensorQuant() map[string]*TensorQuantInfo {
|
||||
out := make(map[string]*TensorQuantInfo, len(r.tensorQuant))
|
||||
for k, v := range r.tensorQuant {
|
||||
if v == nil {
|
||||
continue
|
||||
}
|
||||
copy := *v
|
||||
out[k] = ©
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func defaultGroupSize(quantType string) int {
|
||||
groupSize, _, _ := QuantizationParams(quantType)
|
||||
return groupSize
|
||||
}
|
||||
|
||||
func readBlobTensorQuantInfo(path string) (map[string]*TensorQuantInfo, string, int, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, "", 0, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var headerSize uint64
|
||||
if err := binary.Read(f, binary.LittleEndian, &headerSize); err != nil {
|
||||
return nil, "", 0, err
|
||||
}
|
||||
if headerSize > 100*1024*1024 {
|
||||
return nil, "", 0, fmt.Errorf("header too large: %d", headerSize)
|
||||
}
|
||||
|
||||
data := make([]byte, headerSize)
|
||||
if _, err := io.ReadFull(f, data); err != nil {
|
||||
return nil, "", 0, err
|
||||
}
|
||||
|
||||
var header map[string]json.RawMessage
|
||||
if err := json.Unmarshal(data, &header); err != nil {
|
||||
return nil, "", 0, err
|
||||
}
|
||||
|
||||
globalQuantType, globalGroupSize := parseGlobalQuantMetadata(header)
|
||||
globalQuantType = strings.ToUpper(globalQuantType)
|
||||
|
||||
// Parse full metadata for per-tensor quant info
|
||||
var metaMap map[string]string
|
||||
if metaRaw, ok := header["__metadata__"]; ok {
|
||||
json.Unmarshal(metaRaw, &metaMap)
|
||||
}
|
||||
|
||||
mainNames := mainTensorNames(header)
|
||||
infos := make(map[string]*TensorQuantInfo)
|
||||
for _, name := range mainNames {
|
||||
if _, ok := header[name+".scale"]; !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
quantType := globalQuantType
|
||||
groupSize := globalGroupSize
|
||||
|
||||
// Check per-tensor metadata (e.g. from packed expert blobs with mixed precision)
|
||||
if metaMap != nil {
|
||||
if qt, ok := metaMap[name+".quant_type"]; ok && qt != "" {
|
||||
quantType = strings.ToUpper(qt)
|
||||
}
|
||||
if gs, ok := metaMap[name+".group_size"]; ok && gs != "" {
|
||||
if v, err := strconv.Atoi(gs); err == nil {
|
||||
groupSize = v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inferredType, inferredGroup := inferQuantTypeFromShapes(header, name, quantType)
|
||||
if quantType == "" {
|
||||
quantType = inferredType
|
||||
}
|
||||
if groupSize == 0 {
|
||||
groupSize = inferredGroup
|
||||
}
|
||||
if quantType == "" {
|
||||
continue
|
||||
}
|
||||
if groupSize == 0 {
|
||||
groupSize = defaultGroupSize(quantType)
|
||||
}
|
||||
|
||||
infos[name] = &TensorQuantInfo{QuantType: quantType, GroupSize: groupSize}
|
||||
}
|
||||
|
||||
return infos, globalQuantType, globalGroupSize, nil
|
||||
}
|
||||
|
||||
func parseGlobalQuantMetadata(header map[string]json.RawMessage) (quantType string, groupSize int) {
|
||||
metaRaw, ok := header["__metadata__"]
|
||||
if !ok {
|
||||
return "", 0
|
||||
}
|
||||
|
||||
var meta map[string]string
|
||||
if err := json.Unmarshal(metaRaw, &meta); err != nil {
|
||||
return "", 0
|
||||
}
|
||||
|
||||
quantType = meta["quant_type"]
|
||||
if gs := meta["group_size"]; gs != "" {
|
||||
groupSize, _ = strconv.Atoi(gs)
|
||||
}
|
||||
return quantType, groupSize
|
||||
}
|
||||
|
||||
func mainTensorNames(header map[string]json.RawMessage) []string {
|
||||
names := make([]string, 0, len(header))
|
||||
for name := range header {
|
||||
if name == "__metadata__" || strings.HasSuffix(name, ".scale") || strings.HasSuffix(name, ".bias") {
|
||||
continue
|
||||
}
|
||||
names = append(names, name)
|
||||
}
|
||||
sort.Strings(names)
|
||||
return names
|
||||
}
|
||||
|
||||
func inferQuantTypeFromShapes(header map[string]json.RawMessage, tensorName string, hintQuantType string) (string, int) {
|
||||
type tensorShape struct {
|
||||
Shape []int64 `json:"shape"`
|
||||
}
|
||||
|
||||
mainRaw, ok := header[tensorName]
|
||||
if !ok {
|
||||
return "", 0
|
||||
}
|
||||
scaleRaw, ok := header[tensorName+".scale"]
|
||||
if !ok {
|
||||
return "", 0
|
||||
}
|
||||
|
||||
var mainInfo tensorShape
|
||||
if err := json.Unmarshal(mainRaw, &mainInfo); err != nil || len(mainInfo.Shape) == 0 {
|
||||
return "", 0
|
||||
}
|
||||
|
||||
var scaleInfo tensorShape
|
||||
if err := json.Unmarshal(scaleRaw, &scaleInfo); err != nil || len(scaleInfo.Shape) == 0 {
|
||||
return "", 0
|
||||
}
|
||||
|
||||
weightCols := int(mainInfo.Shape[len(mainInfo.Shape)-1])
|
||||
scalesCols := int(scaleInfo.Shape[len(scaleInfo.Shape)-1])
|
||||
if weightCols <= 0 || scalesCols <= 0 {
|
||||
return "", 0
|
||||
}
|
||||
|
||||
groupSize4 := weightCols * 8 / scalesCols
|
||||
groupSize8 := weightCols * 4 / scalesCols
|
||||
|
||||
switch {
|
||||
case groupSize4 == 32:
|
||||
return "INT4", 32
|
||||
case groupSize8 == 64:
|
||||
return "INT8", 64
|
||||
case groupSize4 == 64 && groupSize8 == 32:
|
||||
h := strings.ToUpper(hintQuantType)
|
||||
if strings.Contains(h, "8") {
|
||||
return "INT8", 32
|
||||
}
|
||||
if strings.Contains(h, "4") {
|
||||
return "INT4", 64
|
||||
}
|
||||
}
|
||||
|
||||
if isCommonGroupSize(groupSize4) && !isCommonGroupSize(groupSize8) {
|
||||
return "INT4", groupSize4
|
||||
}
|
||||
if isCommonGroupSize(groupSize8) && !isCommonGroupSize(groupSize4) {
|
||||
return "INT8", groupSize8
|
||||
}
|
||||
|
||||
return "", 0
|
||||
}
|
||||
Reference in New Issue
Block a user