ollama source for Momentry Core verification
This commit is contained in:
230
x/create/imagegen.go
Normal file
230
x/create/imagegen.go
Normal file
@@ -0,0 +1,230 @@
|
||||
package create
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/x/safetensors"
|
||||
)
|
||||
|
||||
// CreateImageGenModel imports an image generation model from a directory.
|
||||
// Stores each tensor as a separate blob for fine-grained deduplication.
|
||||
// If quantize is specified, linear weights in transformer/text_encoder are quantized.
|
||||
// Supported quantization types: int4, int8, nvfp4, mxfp8 (or empty for no quantization).
|
||||
// Layer creation and manifest writing are done via callbacks to avoid import cycles.
|
||||
func CreateImageGenModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
|
||||
// Validate quantization type
|
||||
switch quantize {
|
||||
case "", "int4", "int8", "nvfp4", "mxfp8":
|
||||
// valid
|
||||
default:
|
||||
return fmt.Errorf("unsupported quantization type %q: supported types are int4, int8, nvfp4, mxfp8", quantize)
|
||||
}
|
||||
|
||||
var layers []LayerInfo
|
||||
var configLayer LayerInfo
|
||||
var totalParams int64 // Count parameters from original tensor shapes
|
||||
var torchDtype string // Read from component config for quantization display
|
||||
|
||||
// Components to process - extract individual tensors from each
|
||||
components := []string{"text_encoder", "transformer", "vae"}
|
||||
|
||||
for _, component := range components {
|
||||
componentDir := filepath.Join(modelDir, component)
|
||||
if _, err := os.Stat(componentDir); os.IsNotExist(err) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Find all safetensors files in this component
|
||||
entries, err := os.ReadDir(componentDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read %s: %w", component, err)
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
if !strings.HasSuffix(entry.Name(), ".safetensors") {
|
||||
continue
|
||||
}
|
||||
|
||||
stPath := filepath.Join(componentDir, entry.Name())
|
||||
|
||||
// Extract individual tensors from safetensors file
|
||||
extractor, err := safetensors.OpenForExtraction(stPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open %s: %w", stPath, err)
|
||||
}
|
||||
|
||||
tensorNames := extractor.ListTensors()
|
||||
quantizeMsg := ""
|
||||
if quantize != "" && component != "vae" {
|
||||
quantizeMsg = ", quantizing to " + quantize
|
||||
}
|
||||
fn(fmt.Sprintf("importing %s/%s (%d tensors%s)", component, entry.Name(), len(tensorNames), quantizeMsg))
|
||||
|
||||
for _, tensorName := range tensorNames {
|
||||
td, err := extractor.GetTensor(tensorName)
|
||||
if err != nil {
|
||||
extractor.Close()
|
||||
return fmt.Errorf("failed to get tensor %s: %w", tensorName, err)
|
||||
}
|
||||
|
||||
// Count parameters from original tensor shape
|
||||
if len(td.Shape) > 0 {
|
||||
numElements := int64(1)
|
||||
for _, dim := range td.Shape {
|
||||
numElements *= int64(dim)
|
||||
}
|
||||
totalParams += numElements
|
||||
}
|
||||
|
||||
// Store as minimal safetensors format (88 bytes header overhead)
|
||||
// This enables native mmap loading via mlx_load_safetensors
|
||||
// Use path-style name: "component/tensor_name"
|
||||
fullName := component + "/" + tensorName
|
||||
|
||||
// Determine quantization type for this tensor (empty string if not quantizing)
|
||||
quantizeType := ""
|
||||
if quantize != "" && ShouldQuantize(tensorName, component) && canQuantizeShape(td.Shape, quantize) {
|
||||
quantizeType = quantize
|
||||
}
|
||||
|
||||
// createTensorLayer returns multiple layers if quantizing (weight + scales)
|
||||
newLayers, err := createTensorLayer(td.SafetensorsReader(), fullName, td.Dtype, td.Shape, quantizeType)
|
||||
if err != nil {
|
||||
extractor.Close()
|
||||
return fmt.Errorf("failed to create layer for %s: %w", fullName, err)
|
||||
}
|
||||
layers = append(layers, newLayers...)
|
||||
}
|
||||
|
||||
extractor.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// Read torch_dtype from text_encoder config for quantization display
|
||||
if torchDtype == "" {
|
||||
textEncoderConfig := filepath.Join(modelDir, "text_encoder/config.json")
|
||||
if data, err := os.ReadFile(textEncoderConfig); err == nil {
|
||||
var cfg struct {
|
||||
TorchDtype string `json:"torch_dtype"`
|
||||
}
|
||||
if json.Unmarshal(data, &cfg) == nil && cfg.TorchDtype != "" {
|
||||
torchDtype = cfg.TorchDtype
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Import config files
|
||||
configFiles := []string{
|
||||
"model_index.json",
|
||||
"text_encoder/config.json",
|
||||
"text_encoder/generation_config.json",
|
||||
"transformer/config.json",
|
||||
"vae/config.json",
|
||||
"scheduler/scheduler_config.json",
|
||||
"tokenizer/tokenizer.json",
|
||||
"tokenizer/tokenizer_config.json",
|
||||
"tokenizer/vocab.json",
|
||||
}
|
||||
|
||||
for _, cfgPath := range configFiles {
|
||||
fullPath := filepath.Join(modelDir, cfgPath)
|
||||
if _, err := os.Stat(fullPath); os.IsNotExist(err) {
|
||||
continue
|
||||
}
|
||||
|
||||
fn(fmt.Sprintf("importing config %s", cfgPath))
|
||||
|
||||
var r io.Reader
|
||||
|
||||
// For model_index.json, normalize to Ollama format and add metadata
|
||||
if cfgPath == "model_index.json" {
|
||||
data, err := os.ReadFile(fullPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read %s: %w", cfgPath, err)
|
||||
}
|
||||
|
||||
var cfg map[string]any
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return fmt.Errorf("failed to parse %s: %w", cfgPath, err)
|
||||
}
|
||||
|
||||
// Rename _class_name to architecture, remove diffusers-specific fields
|
||||
if className, ok := cfg["_class_name"]; ok {
|
||||
cfg["architecture"] = className
|
||||
delete(cfg, "_class_name")
|
||||
}
|
||||
delete(cfg, "_diffusers_version")
|
||||
|
||||
// Add parameter count (counted from tensor shapes during import)
|
||||
cfg["parameter_count"] = totalParams
|
||||
|
||||
// Add quantization info - use quantize type if set, otherwise torch_dtype
|
||||
if quantize != "" {
|
||||
cfg["quantization"] = strings.ToUpper(quantize)
|
||||
} else {
|
||||
cfg["quantization"] = torchDtype
|
||||
}
|
||||
|
||||
data, err = json.MarshalIndent(cfg, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal %s: %w", cfgPath, err)
|
||||
}
|
||||
r = bytes.NewReader(data)
|
||||
} else {
|
||||
f, err := os.Open(fullPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open %s: %w", cfgPath, err)
|
||||
}
|
||||
defer f.Close()
|
||||
r = f
|
||||
}
|
||||
|
||||
layer, err := createLayer(r, "application/vnd.ollama.image.json", cfgPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create layer for %s: %w", cfgPath, err)
|
||||
}
|
||||
|
||||
// Use model_index.json as the config layer
|
||||
if cfgPath == "model_index.json" {
|
||||
configLayer = layer
|
||||
}
|
||||
|
||||
layers = append(layers, layer)
|
||||
}
|
||||
|
||||
if configLayer.Digest == "" {
|
||||
return fmt.Errorf("model_index.json not found in %s", modelDir)
|
||||
}
|
||||
|
||||
fn(fmt.Sprintf("writing manifest for %s", modelName))
|
||||
|
||||
if err := writeManifest(modelName, configLayer, layers); err != nil {
|
||||
return fmt.Errorf("failed to write manifest: %w", err)
|
||||
}
|
||||
|
||||
fn(fmt.Sprintf("successfully imported %s with %d layers", modelName, len(layers)))
|
||||
return nil
|
||||
}
|
||||
|
||||
// canQuantizeShape returns true if a tensor shape is compatible with MLX quantization.
|
||||
// MLX requires the last dimension to be divisible by the group size.
|
||||
// nvfp4: 16, int4/mxfp8: 32, int8: 64
|
||||
func canQuantizeShape(shape []int32, quantize string) bool {
|
||||
if len(shape) < 2 {
|
||||
return false
|
||||
}
|
||||
groupSize := int32(32)
|
||||
switch strings.ToUpper(quantize) {
|
||||
case "NVFP4":
|
||||
groupSize = 16
|
||||
case "INT8":
|
||||
groupSize = 64
|
||||
}
|
||||
return shape[len(shape)-1]%groupSize == 0
|
||||
}
|
||||
Reference in New Issue
Block a user