203 lines
5.8 KiB
Go
203 lines
5.8 KiB
Go
package base
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"log/slog"
|
|
"sync"
|
|
|
|
"github.com/ollama/ollama/x/mlxrunner/batch"
|
|
"github.com/ollama/ollama/x/mlxrunner/cache"
|
|
"github.com/ollama/ollama/x/mlxrunner/mlx"
|
|
"github.com/ollama/ollama/x/mlxrunner/model"
|
|
"github.com/ollama/ollama/x/tokenizer"
|
|
)
|
|
|
|
// Model is the interface that model implementations must satisfy.
|
|
type Model interface {
|
|
Forward(b *batch.Batch, cache []cache.Cache) *mlx.Array
|
|
Unembed(x *mlx.Array) *mlx.Array
|
|
NumLayers() int
|
|
Tokenizer() *tokenizer.Tokenizer
|
|
MaxContextLength() int
|
|
|
|
// LoadWeights receives all tensors loaded from the manifest and assigns
|
|
// them to model fields. Model-specific logic (MLA absorption, expert
|
|
// stacking, quantized layer creation) happens here.
|
|
LoadWeights(tensors map[string]*mlx.Array) error
|
|
}
|
|
|
|
// DraftModel is an auxiliary model stored alongside a target model.
|
|
type DraftModel interface {
|
|
LoadWeights(tensors map[string]*mlx.Array) error
|
|
}
|
|
|
|
// MTPDefaults holds model-provided draft-token defaults for speculative
|
|
// decoding. Environment settings in the runner may override these values.
|
|
type MTPDefaults struct {
|
|
InitialDraftTokens int
|
|
MaxDraftTokens int
|
|
Enabled bool
|
|
}
|
|
|
|
// MTPDefaultsProvider lets a model provide MTP policy defaults from its own
|
|
// config without teaching the runner model-specific shape heuristics.
|
|
type MTPDefaultsProvider interface {
|
|
MTPDraftDefaults(sample bool) MTPDefaults
|
|
}
|
|
|
|
// MTPDraftModel is a draft model capable of Gemma-style multi-token
|
|
// prediction from target token embeddings, target hidden states, and target KV.
|
|
type MTPDraftModel interface {
|
|
Draft(inputEmbeds *mlx.Array, position int32, caches []cache.Cache) (logits, hidden *mlx.Array)
|
|
}
|
|
|
|
// MTPEmbeddingModel exposes the target token embedding path used by MTP drafts.
|
|
type MTPEmbeddingModel interface {
|
|
TokenEmbeddings(inputIDs *mlx.Array) *mlx.Array
|
|
}
|
|
|
|
// DFlashTargetModel exposes target-layer hidden states for DFlash drafts.
|
|
type DFlashTargetModel interface {
|
|
ForwardDFlash(b *batch.Batch, caches []cache.Cache, layerIDs []int) (hidden, targetHidden *mlx.Array)
|
|
}
|
|
|
|
// DFlashDraftModel is a block-diffusion speculative draft model.
|
|
type DFlashDraftModel interface {
|
|
DraftModel
|
|
|
|
TargetLayerIDs() []int
|
|
BlockSize() int
|
|
MaskTokenID() int32
|
|
NewCaches() []cache.Cache
|
|
AppendContext(targetHidden *mlx.Array, caches []cache.Cache)
|
|
Draft(inputIDs *mlx.Array, caches []cache.Cache) *mlx.Array
|
|
}
|
|
|
|
var (
|
|
mu sync.Mutex
|
|
registry = make(map[string]func(root *model.Root) (Model, error))
|
|
draftRegistry = make(map[string]func(root *model.Root, target Model) (DraftModel, error))
|
|
)
|
|
|
|
// Register registers a model constructor by architecture name.
|
|
// Called from init() in model packages. Panics on duplicate registration.
|
|
func Register(arch string, fn func(root *model.Root) (Model, error)) {
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
|
|
if _, exists := registry[arch]; exists {
|
|
panic(fmt.Sprintf("model architecture %q already registered", arch))
|
|
}
|
|
registry[arch] = fn
|
|
}
|
|
|
|
// RegisterDraft registers a draft model constructor by architecture name.
|
|
func RegisterDraft(arch string, fn func(root *model.Root, target Model) (DraftModel, error)) {
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
|
|
if _, exists := draftRegistry[arch]; exists {
|
|
panic(fmt.Sprintf("draft model architecture %q already registered", arch))
|
|
}
|
|
draftRegistry[arch] = fn
|
|
}
|
|
|
|
// New reads config.json from the manifest, detects the architecture, looks up
|
|
// the registered constructor, and calls it to create the model (with config
|
|
// parsed and struct created, but weights not yet loaded).
|
|
func New(root *model.Root) (Model, error) {
|
|
configData, err := root.Manifest.ReadConfig("config.json")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to read config.json: %w", err)
|
|
}
|
|
|
|
var archConfig struct {
|
|
Architectures []string `json:"architectures"`
|
|
}
|
|
if err := json.Unmarshal(configData, &archConfig); err != nil {
|
|
return nil, fmt.Errorf("failed to parse config.json: %w", err)
|
|
}
|
|
|
|
if len(archConfig.Architectures) == 0 {
|
|
return nil, fmt.Errorf("no architectures found in config.json")
|
|
}
|
|
|
|
arch := archConfig.Architectures[0]
|
|
slog.Info("Model architecture", "arch", arch)
|
|
|
|
mu.Lock()
|
|
fn, ok := registry[arch]
|
|
mu.Unlock()
|
|
|
|
if !ok {
|
|
return nil, fmt.Errorf("unsupported architecture: %s", arch)
|
|
}
|
|
|
|
return fn(root)
|
|
}
|
|
|
|
// NewDraft constructs the draft model described by the manifest config, if any.
|
|
func NewDraft(root *model.Root, target Model) (DraftModel, error) {
|
|
if root == nil || root.Draft == nil {
|
|
return nil, nil
|
|
}
|
|
|
|
configPath := root.Draft.Config
|
|
if configPath == "" {
|
|
configPath = "draft/config.json"
|
|
}
|
|
configData, err := root.Manifest.ReadConfig(configPath)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to read %s: %w", configPath, err)
|
|
}
|
|
|
|
var archConfig struct {
|
|
Architectures []string `json:"architectures"`
|
|
ModelType string `json:"model_type"`
|
|
}
|
|
if err := json.Unmarshal(configData, &archConfig); err != nil {
|
|
return nil, fmt.Errorf("failed to parse %s: %w", configPath, err)
|
|
}
|
|
|
|
arch := root.Draft.Architecture
|
|
if arch == "" && len(archConfig.Architectures) > 0 {
|
|
arch = archConfig.Architectures[0]
|
|
}
|
|
if arch == "" {
|
|
arch = archConfig.ModelType
|
|
}
|
|
if arch == "" {
|
|
return nil, fmt.Errorf("no draft architecture found in %s", configPath)
|
|
}
|
|
slog.Info("Draft model architecture", "arch", arch)
|
|
|
|
mu.Lock()
|
|
fn, ok := draftRegistry[arch]
|
|
mu.Unlock()
|
|
if !ok {
|
|
return nil, fmt.Errorf("unsupported draft architecture: %s", arch)
|
|
}
|
|
|
|
return fn(root, target)
|
|
}
|
|
|
|
// Weights returns a function that loads model weights, then pins all
|
|
// arrays reachable from the model struct and sweeps everything else.
|
|
func Weights(m Model) func(map[string]*mlx.Array) error {
|
|
return func(tensors map[string]*mlx.Array) error {
|
|
if err := m.LoadWeights(tensors); err != nil {
|
|
return err
|
|
}
|
|
|
|
collected := mlx.Collect(m)
|
|
for _, arr := range collected {
|
|
mlx.Pin(arr)
|
|
}
|
|
mlx.Sweep()
|
|
mlx.Eval(collected...)
|
|
|
|
return nil
|
|
}
|
|
}
|