ollama source for Momentry Core verification

This commit is contained in:
Accusys
2026-05-22 17:19:10 +08:00
commit 0b31ff9135
2020 changed files with 1413145 additions and 0 deletions

100
server/auth.go Normal file
View File

@@ -0,0 +1,100 @@
package server
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/auth"
)
type registryChallenge struct {
Realm string
Service string
Scope string
}
func (r registryChallenge) URL() (*url.URL, error) {
redirectURL, err := url.Parse(r.Realm)
if err != nil {
return nil, err
}
values := redirectURL.Query()
values.Add("service", r.Service)
for _, s := range strings.Split(r.Scope, " ") {
values.Add("scope", s)
}
values.Add("ts", strconv.FormatInt(time.Now().Unix(), 10))
nonce, err := auth.NewNonce(rand.Reader, 16)
if err != nil {
return nil, err
}
values.Add("nonce", nonce)
redirectURL.RawQuery = values.Encode()
return redirectURL, nil
}
func getAuthorizationToken(ctx context.Context, challenge registryChallenge, originalHost string) (string, error) {
redirectURL, err := challenge.URL()
if err != nil {
return "", err
}
// Validate that the realm host matches the original request host to prevent sending tokens cross-origin.
if redirectURL.Host != originalHost {
return "", fmt.Errorf("realm host %q does not match original host %q", redirectURL.Host, originalHost)
}
sha256sum := sha256.Sum256(nil)
data := []byte(fmt.Sprintf("%s,%s,%s", http.MethodGet, redirectURL.String(), base64.StdEncoding.EncodeToString([]byte(hex.EncodeToString(sha256sum[:])))))
headers := make(http.Header)
signature, err := auth.Sign(ctx, data)
if err != nil {
return "", err
}
headers.Add("Authorization", signature)
response, err := makeRequest(ctx, http.MethodGet, redirectURL, headers, nil, &registryOptions{})
if err != nil {
return "", err
}
defer response.Body.Close()
body, err := io.ReadAll(response.Body)
if err != nil {
return "", fmt.Errorf("%d: %v", response.StatusCode, err)
}
if response.StatusCode >= http.StatusBadRequest {
if len(body) > 0 {
return "", fmt.Errorf("%d: %s", response.StatusCode, body)
} else {
return "", fmt.Errorf("%d", response.StatusCode)
}
}
var token api.TokenResponse
if err := json.Unmarshal(body, &token); err != nil {
return "", err
}
return token.Token, nil
}

113
server/auth_test.go Normal file
View File

@@ -0,0 +1,113 @@
package server
import (
"context"
"strings"
"testing"
"time"
)
func TestGetAuthorizationTokenRejectsCrossDomain(t *testing.T) {
tests := []struct {
realm string
originalHost string
wantMismatch bool
}{
{"https://example.com/token", "example.com", false},
{"https://example.com/token", "other.com", true},
{"https://example.com/token", "localhost:8000", true},
{"https://localhost:5000/token", "localhost:5000", false},
{"https://localhost:5000/token", "localhost:6000", true},
}
for _, tt := range tests {
t.Run(tt.originalHost, func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
challenge := registryChallenge{Realm: tt.realm, Service: "test", Scope: "repo:x:pull"}
_, err := getAuthorizationToken(ctx, challenge, tt.originalHost)
isMismatch := err != nil && strings.Contains(err.Error(), "does not match")
if tt.wantMismatch && !isMismatch {
t.Errorf("expected domain mismatch error, got: %v", err)
}
if !tt.wantMismatch && isMismatch {
t.Errorf("unexpected domain mismatch error: %v", err)
}
})
}
}
func TestParseRegistryChallenge(t *testing.T) {
tests := []struct {
input string
wantRealm, wantService, wantScope string
}{
{
`Bearer realm="https://auth.example.com/token",service="registry",scope="repo:foo:pull"`,
"https://auth.example.com/token", "registry", "repo:foo:pull",
},
{
`Bearer realm="https://r.ollama.ai/v2/token",service="ollama",scope="-"`,
"https://r.ollama.ai/v2/token", "ollama", "-",
},
{"", "", "", ""},
}
for _, tt := range tests {
result := parseRegistryChallenge(tt.input)
if result.Realm != tt.wantRealm || result.Service != tt.wantService || result.Scope != tt.wantScope {
t.Errorf("parseRegistryChallenge(%q) = {%q, %q, %q}, want {%q, %q, %q}",
tt.input, result.Realm, result.Service, result.Scope,
tt.wantRealm, tt.wantService, tt.wantScope)
}
}
}
func TestRegistryChallengeURL(t *testing.T) {
challenge := registryChallenge{
Realm: "https://auth.example.com/token",
Service: "registry",
Scope: "repo:foo:pull repo:bar:push",
}
u, err := challenge.URL()
if err != nil {
t.Fatalf("URL() error: %v", err)
}
if u.Host != "auth.example.com" {
t.Errorf("host = %q, want %q", u.Host, "auth.example.com")
}
if u.Path != "/token" {
t.Errorf("path = %q, want %q", u.Path, "/token")
}
q := u.Query()
if q.Get("service") != "registry" {
t.Errorf("service = %q, want %q", q.Get("service"), "registry")
}
if scopes := q["scope"]; len(scopes) != 2 {
t.Errorf("scope count = %d, want 2", len(scopes))
}
if q.Get("ts") == "" {
t.Error("missing ts")
}
if q.Get("nonce") == "" {
t.Error("missing nonce")
}
// Nonces should differ between calls
u2, _ := challenge.URL()
if q.Get("nonce") == u2.Query().Get("nonce") {
t.Error("nonce should be unique per call")
}
}
func TestRegistryChallengeURLInvalid(t *testing.T) {
challenge := registryChallenge{Realm: "://invalid"}
if _, err := challenge.URL(); err == nil {
t.Error("expected error for invalid URL")
}
}

568
server/cloud_proxy.go Normal file
View File

@@ -0,0 +1,568 @@
package server
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"net"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/klauspost/compress/zstd"
"github.com/ollama/ollama/auth"
"github.com/ollama/ollama/envconfig"
internalcloud "github.com/ollama/ollama/internal/cloud"
"github.com/ollama/ollama/version"
)
const (
defaultCloudProxyBaseURL = "https://ollama.com:443"
defaultCloudProxySigningHost = "ollama.com"
cloudProxyBaseURLEnv = "OLLAMA_CLOUD_BASE_URL"
legacyCloudAnthropicKey = "legacy_cloud_anthropic_web_search"
cloudProxyClientVersionHeader = "X-Ollama-Client-Version"
// maxDecompressedBodySize limits the size of a decompressed request body
maxDecompressedBodySize = 20 << 20
)
var (
cloudProxyBaseURL = defaultCloudProxyBaseURL
cloudProxySigningHost = defaultCloudProxySigningHost
cloudProxySignRequest = signCloudProxyRequest
cloudProxySigninURL = signinURL
)
var hopByHopHeaders = map[string]struct{}{
"connection": {},
"content-length": {},
"proxy-connection": {},
"keep-alive": {},
"proxy-authenticate": {},
"proxy-authorization": {},
"te": {},
"trailer": {},
"transfer-encoding": {},
"upgrade": {},
}
func init() {
baseURL, signingHost, overridden, err := resolveCloudProxyBaseURL(envconfig.Var(cloudProxyBaseURLEnv), mode)
if err != nil {
slog.Warn("ignoring cloud base URL override", "env", cloudProxyBaseURLEnv, "error", err)
return
}
cloudProxyBaseURL = baseURL
cloudProxySigningHost = signingHost
if overridden {
slog.Info("cloud base URL override enabled", "env", cloudProxyBaseURLEnv, "url", cloudProxyBaseURL, "mode", mode)
}
}
func cloudPassthroughMiddleware(disabledOperation string) gin.HandlerFunc {
return func(c *gin.Context) {
if c.Request.Method != http.MethodPost {
c.Next()
return
}
// Decompress zstd-encoded request bodies so we can inspect the model
if c.GetHeader("Content-Encoding") == "zstd" {
reader, err := zstd.NewReader(c.Request.Body, zstd.WithDecoderMaxMemory(8<<20))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "failed to decompress request body"})
c.Abort()
return
}
defer reader.Close()
c.Request.Body = http.MaxBytesReader(c.Writer, io.NopCloser(reader), maxDecompressedBodySize)
c.Request.Header.Del("Content-Encoding")
}
// TODO(drifkin): Avoid full-body buffering here for model detection.
// A future optimization can parse just enough JSON to read "model" (and
// optionally short-circuit cloud-disabled explicit-cloud requests) while
// preserving raw passthrough semantics.
body, err := readRequestBody(c.Request)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
c.Abort()
return
}
model, ok := extractModelField(body)
if !ok {
c.Next()
return
}
modelRef, err := parseAndValidateModelRef(model)
if err != nil || modelRef.Source != modelSourceCloud {
c.Next()
return
}
normalizedBody, err := replaceJSONModelField(body, modelRef.Base)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
c.Abort()
return
}
// TEMP(drifkin): keep Anthropic web search requests on the local middleware
// path so WebSearchAnthropicWriter can orchestrate follow-up calls.
if c.Request.URL.Path == "/v1/messages" {
if hasAnthropicWebSearchTool(body) {
c.Set(legacyCloudAnthropicKey, true)
c.Next()
return
}
}
proxyCloudRequest(c, normalizedBody, disabledOperation)
c.Abort()
}
}
func cloudModelPathPassthroughMiddleware(disabledOperation string) gin.HandlerFunc {
return func(c *gin.Context) {
modelName := strings.TrimSpace(c.Param("model"))
if modelName == "" {
c.Next()
return
}
modelRef, err := parseAndValidateModelRef(modelName)
if err != nil || modelRef.Source != modelSourceCloud {
c.Next()
return
}
proxyPath := "/v1/models/" + modelRef.Base
proxyCloudRequestWithPath(c, nil, proxyPath, disabledOperation)
c.Abort()
}
}
func proxyCloudJSONRequest(c *gin.Context, payload any, disabledOperation string) {
// TEMP(drifkin): we currently split out this `WithPath` method because we are
// mapping `/v1/messages` + web_search to `/api/chat` temporarily. Once we
// stop doing this, we can inline this method.
proxyCloudJSONRequestWithPath(c, payload, c.Request.URL.Path, disabledOperation)
}
func proxyCloudJSONRequestWithPath(c *gin.Context, payload any, path string, disabledOperation string) {
body, err := json.Marshal(payload)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
proxyCloudRequestWithPath(c, body, path, disabledOperation)
}
func proxyCloudRequest(c *gin.Context, body []byte, disabledOperation string) {
proxyCloudRequestWithPath(c, body, c.Request.URL.Path, disabledOperation)
}
func proxyCloudRequestWithPath(c *gin.Context, body []byte, path string, disabledOperation string) {
if disabled, _ := internalcloud.Status(); disabled {
c.JSON(http.StatusForbidden, gin.H{"error": internalcloud.DisabledError(disabledOperation)})
return
}
baseURL, err := url.Parse(cloudProxyBaseURL)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
targetURL := baseURL.ResolveReference(&url.URL{
Path: path,
RawQuery: c.Request.URL.RawQuery,
})
outReq, err := http.NewRequestWithContext(c.Request.Context(), c.Request.Method, targetURL.String(), bytes.NewReader(body))
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
copyProxyRequestHeaders(outReq.Header, c.Request.Header)
if clientVersion := strings.TrimSpace(version.Version); clientVersion != "" {
outReq.Header.Set(cloudProxyClientVersionHeader, clientVersion)
}
if outReq.Header.Get("Content-Type") == "" && len(body) > 0 {
outReq.Header.Set("Content-Type", "application/json")
}
if err := cloudProxySignRequest(outReq.Context(), outReq); err != nil {
slog.Warn("cloud proxy signing failed", "error", err)
writeCloudUnauthorized(c)
return
}
// TODO(drifkin): Add phase-specific proxy timeouts.
// Connect/TLS/TTFB should have bounded timeouts, but once streaming starts
// we should not enforce a short total timeout for long-lived responses.
resp, err := http.DefaultClient.Do(outReq)
if err != nil {
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
return
}
defer resp.Body.Close()
copyProxyResponseHeaders(c.Writer.Header(), resp.Header)
c.Status(resp.StatusCode)
var bodyWriter http.ResponseWriter = c.Writer
var framedWriter *jsonlFramingResponseWriter
// TEMP(drifkin): only needed on the cloud-proxied first leg of Anthropic
// web_search fallback (which is a path we're removing soon). Local
// /v1/messages writes one JSON value per streamResponse callback directly
// into WebSearchAnthropicWriter, but this proxy copy loop may coalesce
// multiple jsonl records into one Write. WebSearchAnthropicWriter currently
// unmarshals one JSON value per Write.
if path == "/api/chat" && resp.StatusCode == http.StatusOK && c.GetBool(legacyCloudAnthropicKey) {
framedWriter = &jsonlFramingResponseWriter{ResponseWriter: c.Writer}
bodyWriter = framedWriter
}
err = copyProxyResponseBody(bodyWriter, resp.Body)
if err == nil && framedWriter != nil {
err = framedWriter.FlushPending()
}
if err != nil {
ctxErr := c.Request.Context().Err()
if errors.Is(err, context.Canceled) && errors.Is(ctxErr, context.Canceled) {
slog.Debug(
"cloud proxy response stream closed by client",
"path", c.Request.URL.Path,
"status", resp.StatusCode,
)
return
}
slog.Warn(
"cloud proxy response copy failed",
"path", c.Request.URL.Path,
"upstream_path", path,
"status", resp.StatusCode,
"request_context_canceled", ctxErr != nil,
"request_context_err", ctxErr,
"error", err,
)
return
}
}
func replaceJSONModelField(body []byte, model string) ([]byte, error) {
if len(body) == 0 {
return body, nil
}
var payload map[string]json.RawMessage
if err := json.Unmarshal(body, &payload); err != nil {
return nil, err
}
modelJSON, err := json.Marshal(model)
if err != nil {
return nil, err
}
payload["model"] = modelJSON
return json.Marshal(payload)
}
func readRequestBody(r *http.Request) ([]byte, error) {
if r.Body == nil {
return nil, nil
}
body, err := io.ReadAll(r.Body)
if err != nil {
return nil, err
}
r.Body = io.NopCloser(bytes.NewReader(body))
return body, nil
}
func extractModelField(body []byte) (string, bool) {
if len(body) == 0 {
return "", false
}
var payload map[string]json.RawMessage
if err := json.Unmarshal(body, &payload); err != nil {
return "", false
}
raw, ok := payload["model"]
if !ok {
return "", false
}
var model string
if err := json.Unmarshal(raw, &model); err != nil {
return "", false
}
model = strings.TrimSpace(model)
return model, model != ""
}
func hasAnthropicWebSearchTool(body []byte) bool {
if len(body) == 0 {
return false
}
var payload struct {
Tools []struct {
Type string `json:"type"`
} `json:"tools"`
}
if err := json.Unmarshal(body, &payload); err != nil {
return false
}
for _, tool := range payload.Tools {
if strings.HasPrefix(strings.TrimSpace(tool.Type), "web_search") {
return true
}
}
return false
}
func writeCloudUnauthorized(c *gin.Context) {
signinURL, err := cloudProxySigninURL()
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
return
}
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized", "signin_url": signinURL})
}
func signCloudProxyRequest(ctx context.Context, req *http.Request) error {
if !strings.EqualFold(req.URL.Hostname(), cloudProxySigningHost) {
return nil
}
ts := strconv.FormatInt(time.Now().Unix(), 10)
challenge := buildCloudSignatureChallenge(req, ts)
signature, err := auth.Sign(ctx, []byte(challenge))
if err != nil {
return err
}
req.Header.Set("Authorization", signature)
return nil
}
func buildCloudSignatureChallenge(req *http.Request, ts string) string {
query := req.URL.Query()
query.Set("ts", ts)
req.URL.RawQuery = query.Encode()
return fmt.Sprintf("%s,%s", req.Method, req.URL.RequestURI())
}
func resolveCloudProxyBaseURL(rawOverride string, runMode string) (baseURL string, signingHost string, overridden bool, err error) {
baseURL = defaultCloudProxyBaseURL
signingHost = defaultCloudProxySigningHost
rawOverride = strings.TrimSpace(rawOverride)
if rawOverride == "" {
return baseURL, signingHost, false, nil
}
u, err := url.Parse(rawOverride)
if err != nil {
return "", "", false, fmt.Errorf("invalid URL: %w", err)
}
if u.Scheme == "" || u.Host == "" {
return "", "", false, fmt.Errorf("invalid URL: scheme and host are required")
}
if u.User != nil {
return "", "", false, fmt.Errorf("invalid URL: userinfo is not allowed")
}
if u.Path != "" && u.Path != "/" {
return "", "", false, fmt.Errorf("invalid URL: path is not allowed")
}
if u.RawQuery != "" || u.Fragment != "" {
return "", "", false, fmt.Errorf("invalid URL: query and fragment are not allowed")
}
host := u.Hostname()
if host == "" {
return "", "", false, fmt.Errorf("invalid URL: host is required")
}
loopback := isLoopbackHost(host)
if runMode == gin.ReleaseMode && !loopback {
return "", "", false, fmt.Errorf("non-loopback cloud override is not allowed in release mode")
}
if !loopback && !strings.EqualFold(u.Scheme, "https") {
return "", "", false, fmt.Errorf("non-loopback cloud override must use https")
}
u.Path = ""
u.RawPath = ""
u.RawQuery = ""
u.Fragment = ""
return u.String(), strings.ToLower(host), true, nil
}
func isLoopbackHost(host string) bool {
if strings.EqualFold(host, "localhost") {
return true
}
ip := net.ParseIP(host)
return ip != nil && ip.IsLoopback()
}
func copyProxyRequestHeaders(dst, src http.Header) {
connectionTokens := connectionHeaderTokens(src)
for key, values := range src {
if isHopByHopHeader(key) || isConnectionTokenHeader(key, connectionTokens) {
continue
}
dst.Del(key)
for _, value := range values {
dst.Add(key, value)
}
}
}
func copyProxyResponseHeaders(dst, src http.Header) {
connectionTokens := connectionHeaderTokens(src)
for key, values := range src {
if isHopByHopHeader(key) || isConnectionTokenHeader(key, connectionTokens) {
continue
}
dst.Del(key)
for _, value := range values {
dst.Add(key, value)
}
}
}
func copyProxyResponseBody(dst http.ResponseWriter, src io.Reader) error {
flusher, canFlush := dst.(http.Flusher)
buf := make([]byte, 32*1024)
for {
n, err := src.Read(buf)
if n > 0 {
if _, writeErr := dst.Write(buf[:n]); writeErr != nil {
return writeErr
}
if canFlush {
// TODO(drifkin): Consider conditional flushing so non-streaming
// responses don't flush every write and can optimize throughput.
flusher.Flush()
}
}
if err != nil {
if err == io.EOF {
return nil
}
return err
}
}
}
type jsonlFramingResponseWriter struct {
http.ResponseWriter
pending []byte
}
func (w *jsonlFramingResponseWriter) Flush() {
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
flusher.Flush()
}
}
func (w *jsonlFramingResponseWriter) Write(p []byte) (int, error) {
w.pending = append(w.pending, p...)
if err := w.flushCompleteLines(); err != nil {
return len(p), err
}
return len(p), nil
}
func (w *jsonlFramingResponseWriter) FlushPending() error {
trailing := bytes.TrimSpace(w.pending)
w.pending = nil
if len(trailing) == 0 {
return nil
}
_, err := w.ResponseWriter.Write(trailing)
return err
}
func (w *jsonlFramingResponseWriter) flushCompleteLines() error {
for {
newline := bytes.IndexByte(w.pending, '\n')
if newline < 0 {
return nil
}
line := bytes.TrimSpace(w.pending[:newline])
w.pending = w.pending[newline+1:]
if len(line) == 0 {
continue
}
if _, err := w.ResponseWriter.Write(line); err != nil {
return err
}
}
}
func isHopByHopHeader(name string) bool {
_, ok := hopByHopHeaders[strings.ToLower(name)]
return ok
}
func connectionHeaderTokens(header http.Header) map[string]struct{} {
tokens := map[string]struct{}{}
for _, raw := range header.Values("Connection") {
for _, token := range strings.Split(raw, ",") {
token = strings.TrimSpace(strings.ToLower(token))
if token == "" {
continue
}
tokens[token] = struct{}{}
}
}
return tokens
}
func isConnectionTokenHeader(name string, tokens map[string]struct{}) bool {
if len(tokens) == 0 {
return false
}
_, ok := tokens[strings.ToLower(name)]
return ok
}

318
server/cloud_proxy_test.go Normal file
View File

@@ -0,0 +1,318 @@
package server
import (
"bytes"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/klauspost/compress/zstd"
)
func TestCopyProxyRequestHeaders_StripsConnectionTokenHeaders(t *testing.T) {
src := http.Header{}
src.Add("Connection", "keep-alive, X-Trace-Hop, x-alt-hop")
src.Add("X-Trace-Hop", "drop-me")
src.Add("X-Alt-Hop", "drop-me-too")
src.Add("Keep-Alive", "timeout=5")
src.Add("X-End-To-End", "keep-me")
dst := http.Header{}
copyProxyRequestHeaders(dst, src)
if got := dst.Get("Connection"); got != "" {
t.Fatalf("expected Connection to be stripped, got %q", got)
}
if got := dst.Get("Keep-Alive"); got != "" {
t.Fatalf("expected Keep-Alive to be stripped, got %q", got)
}
if got := dst.Get("X-Trace-Hop"); got != "" {
t.Fatalf("expected X-Trace-Hop to be stripped via Connection token, got %q", got)
}
if got := dst.Get("X-Alt-Hop"); got != "" {
t.Fatalf("expected X-Alt-Hop to be stripped via Connection token, got %q", got)
}
if got := dst.Get("X-End-To-End"); got != "keep-me" {
t.Fatalf("expected X-End-To-End to be forwarded, got %q", got)
}
}
func TestCopyProxyResponseHeaders_StripsConnectionTokenHeaders(t *testing.T) {
src := http.Header{}
src.Add("Connection", "X-Upstream-Hop")
src.Add("X-Upstream-Hop", "drop-me")
src.Add("Content-Type", "application/json")
src.Add("X-Server-Trace", "keep-me")
dst := http.Header{}
copyProxyResponseHeaders(dst, src)
if got := dst.Get("Connection"); got != "" {
t.Fatalf("expected Connection to be stripped, got %q", got)
}
if got := dst.Get("X-Upstream-Hop"); got != "" {
t.Fatalf("expected X-Upstream-Hop to be stripped via Connection token, got %q", got)
}
if got := dst.Get("Content-Type"); got != "application/json" {
t.Fatalf("expected Content-Type to be forwarded, got %q", got)
}
if got := dst.Get("X-Server-Trace"); got != "keep-me" {
t.Fatalf("expected X-Server-Trace to be forwarded, got %q", got)
}
}
func TestResolveCloudProxyBaseURL_Default(t *testing.T) {
baseURL, signingHost, overridden, err := resolveCloudProxyBaseURL("", gin.ReleaseMode)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if overridden {
t.Fatal("expected override=false for empty input")
}
if baseURL != defaultCloudProxyBaseURL {
t.Fatalf("expected default base URL %q, got %q", defaultCloudProxyBaseURL, baseURL)
}
if signingHost != defaultCloudProxySigningHost {
t.Fatalf("expected default signing host %q, got %q", defaultCloudProxySigningHost, signingHost)
}
}
func TestResolveCloudProxyBaseURL_ReleaseAllowsLoopback(t *testing.T) {
baseURL, signingHost, overridden, err := resolveCloudProxyBaseURL("http://localhost:8080", gin.ReleaseMode)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !overridden {
t.Fatal("expected override=true")
}
if baseURL != "http://localhost:8080" {
t.Fatalf("unexpected base URL: %q", baseURL)
}
if signingHost != "localhost" {
t.Fatalf("unexpected signing host: %q", signingHost)
}
}
func TestResolveCloudProxyBaseURL_ReleaseRejectsNonLoopback(t *testing.T) {
_, _, _, err := resolveCloudProxyBaseURL("https://example.com", gin.ReleaseMode)
if err == nil {
t.Fatal("expected error for non-loopback override in release mode")
}
}
func TestResolveCloudProxyBaseURL_DevAllowsNonLoopbackHTTPS(t *testing.T) {
baseURL, signingHost, overridden, err := resolveCloudProxyBaseURL("https://example.com:8443", gin.DebugMode)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !overridden {
t.Fatal("expected override=true")
}
if baseURL != "https://example.com:8443" {
t.Fatalf("unexpected base URL: %q", baseURL)
}
if signingHost != "example.com" {
t.Fatalf("unexpected signing host: %q", signingHost)
}
}
func TestResolveCloudProxyBaseURL_DevRejectsNonLoopbackHTTP(t *testing.T) {
_, _, _, err := resolveCloudProxyBaseURL("http://example.com", gin.DebugMode)
if err == nil {
t.Fatal("expected error for non-loopback http override in dev mode")
}
}
func TestBuildCloudSignatureChallengeIncludesExistingQuery(t *testing.T) {
req, err := http.NewRequest(http.MethodPost, "https://ollama.com/v1/messages?beta=true&foo=bar", nil)
if err != nil {
t.Fatalf("failed to create request: %v", err)
}
got := buildCloudSignatureChallenge(req, "123")
want := "POST,/v1/messages?beta=true&foo=bar&ts=123"
if got != want {
t.Fatalf("challenge mismatch: got %q want %q", got, want)
}
if req.URL.RawQuery != "beta=true&foo=bar&ts=123" {
t.Fatalf("unexpected signed query: %q", req.URL.RawQuery)
}
}
func TestCloudPassthroughMiddleware_ZstdBody(t *testing.T) {
gin.SetMode(gin.TestMode)
plainBody := []byte(`{"model":"test-model:cloud","messages":[{"role":"user","content":"hi"}]}`)
var compressed bytes.Buffer
w, err := zstd.NewWriter(&compressed)
if err != nil {
t.Fatalf("zstd writer: %v", err)
}
if _, err := w.Write(plainBody); err != nil {
t.Fatalf("zstd write: %v", err)
}
if err := w.Close(); err != nil {
t.Fatalf("zstd close: %v", err)
}
req := httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(compressed.Bytes()))
req.Header.Set("Content-Encoding", "zstd")
rec := httptest.NewRecorder()
// Track whether the middleware detected the cloud model by checking
// if c.Next() was called (non-cloud path) vs c.Abort() (cloud path).
nextCalled := false
r := gin.New()
r.POST("/v1/responses", cloudPassthroughMiddleware("test"), func(c *gin.Context) {
nextCalled = true
// Verify the body is decompressed and Content-Encoding is removed.
body, err := io.ReadAll(c.Request.Body)
if err != nil {
t.Fatalf("read body: %v", err)
}
model, ok := extractModelField(body)
if !ok {
t.Fatal("expected to extract model from decompressed body")
}
if model != "test-model:cloud" {
t.Fatalf("expected model %q, got %q", "test-model:cloud", model)
}
if enc := c.GetHeader("Content-Encoding"); enc != "" {
t.Fatalf("expected Content-Encoding to be removed, got %q", enc)
}
c.Status(http.StatusOK)
})
r.ServeHTTP(rec, req)
// The cloud passthrough middleware should detect the cloud model and
// proxy (abort), so the next handler should NOT be called.
// However, since there's no actual cloud server to proxy to, the
// middleware will attempt to proxy and fail. We just verify it didn't
// fall through to c.Next() due to failure to read the compressed body.
if nextCalled {
t.Fatal("expected cloud passthrough to detect cloud model from zstd body, but it fell through to next handler")
}
}
func TestCloudPassthroughMiddleware_ZstdBodyTooLarge(t *testing.T) {
gin.SetMode(gin.TestMode)
// Create a body that exceeds the 20MB limit
oversized := make([]byte, maxDecompressedBodySize+1024)
for i := range oversized {
oversized[i] = 'A'
}
var compressed bytes.Buffer
w, err := zstd.NewWriter(&compressed)
if err != nil {
t.Fatalf("zstd writer: %v", err)
}
if _, err := w.Write(oversized); err != nil {
t.Fatalf("zstd write: %v", err)
}
if err := w.Close(); err != nil {
t.Fatalf("zstd close: %v", err)
}
req := httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(compressed.Bytes()))
req.Header.Set("Content-Encoding", "zstd")
rec := httptest.NewRecorder()
r := gin.New()
r.POST("/v1/responses", cloudPassthroughMiddleware("test"), func(c *gin.Context) {
t.Fatal("handler should not be reached for oversized body")
})
r.ServeHTTP(rec, req)
if rec.Code != http.StatusBadRequest {
t.Fatalf("expected status 400, got %d", rec.Code)
}
}
func TestBuildCloudSignatureChallengeOverwritesExistingTimestamp(t *testing.T) {
req, err := http.NewRequest(http.MethodPost, "https://ollama.com/v1/messages?beta=true&ts=999", nil)
if err != nil {
t.Fatalf("failed to create request: %v", err)
}
got := buildCloudSignatureChallenge(req, "123")
want := "POST,/v1/messages?beta=true&ts=123"
if got != want {
t.Fatalf("challenge mismatch: got %q want %q", got, want)
}
if req.URL.RawQuery != "beta=true&ts=123" {
t.Fatalf("unexpected signed query: %q", req.URL.RawQuery)
}
}
func TestJSONLFramingResponseWriter_SplitsCoalescedLines(t *testing.T) {
rec := &chunkRecorder{header: http.Header{}}
w := &jsonlFramingResponseWriter{ResponseWriter: rec}
payload := []byte("{\"a\":1}\n{\"b\":2}\n")
if n, err := w.Write(payload); err != nil {
t.Fatalf("write failed: %v", err)
} else if n != len(payload) {
t.Fatalf("write byte count mismatch: got %d want %d", n, len(payload))
}
if err := w.FlushPending(); err != nil {
t.Fatalf("FlushPending failed: %v", err)
}
if len(rec.chunks) != 2 {
t.Fatalf("expected 2 framed writes, got %d", len(rec.chunks))
}
if got := string(rec.chunks[0]); got != `{"a":1}` {
t.Fatalf("first chunk mismatch: got %q", got)
}
if got := string(rec.chunks[1]); got != `{"b":2}` {
t.Fatalf("second chunk mismatch: got %q", got)
}
}
func TestJSONLFramingResponseWriter_FlushPendingWritesTrailingLine(t *testing.T) {
rec := &chunkRecorder{header: http.Header{}}
w := &jsonlFramingResponseWriter{ResponseWriter: rec}
if _, err := w.Write([]byte("{\"a\":1")); err != nil {
t.Fatalf("write failed: %v", err)
}
if len(rec.chunks) != 0 {
t.Fatalf("expected no writes before newline/flush, got %d", len(rec.chunks))
}
if err := w.FlushPending(); err != nil {
t.Fatalf("FlushPending failed: %v", err)
}
if len(rec.chunks) != 1 {
t.Fatalf("expected 1 write after FlushPending, got %d", len(rec.chunks))
}
if got := string(rec.chunks[0]); got != `{"a":1` {
t.Fatalf("trailing chunk mismatch: got %q", got)
}
}
type chunkRecorder struct {
header http.Header
status int
chunks [][]byte
}
func (r *chunkRecorder) Header() http.Header {
return r.header
}
func (r *chunkRecorder) WriteHeader(statusCode int) {
r.status = statusCode
}
func (r *chunkRecorder) Write(p []byte) (int, error) {
cp := append([]byte(nil), p...)
r.chunks = append(r.chunks, cp)
return len(p), nil
}

903
server/create.go Normal file
View File

@@ -0,0 +1,903 @@
package server
import (
"bytes"
"cmp"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"io/fs"
"log/slog"
"net"
"net/http"
"net/url"
"os"
"path"
"path/filepath"
"slices"
"strings"
"sync/atomic"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/convert"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
ofs "github.com/ollama/ollama/fs"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model"
)
var (
errNoFilesProvided = errors.New("no files provided to convert")
errOnlyOneAdapterSupported = errors.New("only one adapter is currently supported")
errOnlyGGUFSupported = errors.New("supplied file was not in GGUF format")
errUnknownType = errors.New("unknown type")
errNeitherFromOrFiles = errors.New("neither 'from' or 'files' was specified")
errFilePath = errors.New("file path must be relative")
)
func (s *Server) CreateHandler(c *gin.Context) {
config := &model.ConfigV2{
OS: "linux",
Architecture: "amd64",
RootFS: model.RootFS{
Type: "layers",
},
}
var r api.CreateRequest
if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
} else if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
config.Renderer = r.Renderer
config.Parser = r.Parser
config.Requires = r.Requires
for v, digest := range r.Files {
if !fs.ValidPath(v) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errFilePath.Error()})
return
}
if digest == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": manifest.ErrInvalidDigestFormat.Error()})
return
}
}
for _, digest := range r.Adapters {
if digest == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": manifest.ErrInvalidDigestFormat.Error()})
return
}
}
name := model.ParseName(cmp.Or(r.Model, r.Name))
if !name.IsValid() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errtypes.InvalidModelNameErrMsg})
return
}
name, err := getExistingName(name)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
ch := make(chan any)
go func() {
defer close(ch)
fn := func(resp api.ProgressResponse) {
ch <- resp
}
oldManifest, _ := manifest.ParseNamedManifest(name)
var baseLayers []*layerGGML
var err error
var remote bool
if r.From != "" {
slog.Debug("create model from model name", "from", r.From)
fromRef, err := parseAndValidateModelRef(r.From)
if err != nil {
ch <- gin.H{"error": errtypes.InvalidModelNameErrMsg, "status": http.StatusBadRequest}
return
}
fromName := fromRef.Name
remoteHost := r.RemoteHost
if fromRef.Source == modelSourceCloud && remoteHost == "" {
remoteHost = cloudProxyBaseURL
}
if remoteHost != "" {
ru, err := remoteURL(remoteHost)
if err != nil {
ch <- gin.H{"error": "bad remote", "status": http.StatusBadRequest}
return
}
config.RemoteModel = fromRef.Base
config.RemoteHost = ru
remote = true
} else {
ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel()
baseLayers, err = parseFromModel(ctx, fromName, fn)
if err != nil {
ch <- gin.H{"error": err.Error()}
}
if err == nil && !remote {
mf, mErr := manifest.ParseNamedManifest(fromName)
if mErr == nil && mf.Config.Digest != "" {
configPath, pErr := manifest.BlobsPath(mf.Config.Digest)
if pErr == nil {
if cfgFile, fErr := os.Open(configPath); fErr == nil {
var baseConfig model.ConfigV2
if decErr := json.NewDecoder(cfgFile).Decode(&baseConfig); decErr == nil {
if config.Renderer == "" {
config.Renderer = baseConfig.Renderer
}
if config.Parser == "" {
config.Parser = baseConfig.Parser
}
if config.Requires == "" {
config.Requires = baseConfig.Requires
}
if config.ModelFormat == "" {
config.ModelFormat = baseConfig.ModelFormat
}
if len(config.Capabilities) == 0 {
config.Capabilities = baseConfig.Capabilities
}
}
cfgFile.Close()
}
}
}
}
}
} else if r.Files != nil {
baseLayers, err = convertModelFromFiles(r.Files, baseLayers, false, fn)
if err != nil {
for _, badReq := range []error{errNoFilesProvided, errOnlyGGUFSupported, errUnknownType} {
if errors.Is(err, badReq) {
ch <- gin.H{"error": err.Error(), "status": http.StatusBadRequest}
return
}
}
ch <- gin.H{"error": err.Error()}
return
}
} else {
ch <- gin.H{"error": errNeitherFromOrFiles.Error(), "status": http.StatusBadRequest}
return
}
var adapterLayers []*layerGGML
if !remote && r.Adapters != nil {
adapterLayers, err = convertModelFromFiles(r.Adapters, baseLayers, true, fn)
if err != nil {
for _, badReq := range []error{errNoFilesProvided, errOnlyOneAdapterSupported, errOnlyGGUFSupported, errUnknownType, errFilePath} {
if errors.Is(err, badReq) {
ch <- gin.H{"error": err.Error(), "status": http.StatusBadRequest}
return
}
}
ch <- gin.H{"error": err.Error(), "status": http.StatusBadRequest}
return
}
}
if len(adapterLayers) > 0 {
baseLayers = append(baseLayers, adapterLayers...)
}
// Info is not currently exposed by Modelfiles, but allows overriding various
// config values
if r.Info != nil {
caps, ok := r.Info["capabilities"]
if ok {
switch tcaps := caps.(type) {
case []any:
caps := make([]string, len(tcaps))
for i, c := range tcaps {
str, ok := c.(string)
if !ok {
continue
}
caps[i] = str
}
config.Capabilities = append(config.Capabilities, caps...)
}
}
strFromInfo := func(k string) string {
v, ok := r.Info[k]
if ok {
val := v.(string)
return val
}
return ""
}
vFromInfo := func(k string) float64 {
v, ok := r.Info[k]
if ok {
val := v.(float64)
return val
}
return 0
}
config.ModelFamily = strFromInfo("model_family")
if config.ModelFamily != "" {
config.ModelFamilies = []string{config.ModelFamily}
}
config.BaseName = strFromInfo("base_name")
config.FileType = strFromInfo("quantization_level")
config.ModelType = strFromInfo("parameter_size")
config.ContextLen = int(vFromInfo("context_length"))
config.EmbedLen = int(vFromInfo("embedding_length"))
}
if err := createModel(r, name, baseLayers, config, fn); err != nil {
if errors.Is(err, errBadTemplate) {
ch <- gin.H{"error": err.Error(), "status": http.StatusBadRequest}
return
}
ch <- gin.H{"error": err.Error()}
return
}
if !envconfig.NoPrune() && oldManifest != nil {
if err := oldManifest.RemoveLayers(); err != nil {
ch <- gin.H{"error": err.Error()}
}
}
s.refreshModelListCache(name)
ch <- api.ProgressResponse{Status: "success"}
}()
if r.Stream != nil && !*r.Stream {
waitForStream(c, ch)
return
}
streamResponse(c, ch)
}
func remoteURL(raw string) (string, error) {
// Specialcase: user supplied only a path ("/foo/bar").
if strings.HasPrefix(raw, "/") {
return (&url.URL{
Scheme: "http",
Host: net.JoinHostPort("localhost", "11434"),
Path: path.Clean(raw),
}).String(), nil
}
if !strings.Contains(raw, "://") {
raw = "http://" + raw
}
if raw == "ollama.com" || raw == "http://ollama.com" {
raw = "https://ollama.com:443"
}
u, err := url.Parse(raw)
if err != nil {
return "", fmt.Errorf("parse error: %w", err)
}
if u.Host == "" {
u.Host = "localhost"
}
hostPart, portPart, err := net.SplitHostPort(u.Host)
if err == nil {
u.Host = net.JoinHostPort(hostPart, portPart)
} else {
u.Host = net.JoinHostPort(u.Host, "11434")
}
if u.Path != "" {
u.Path = path.Clean(u.Path)
}
if u.Path == "/" {
u.Path = ""
}
return u.String(), nil
}
func convertModelFromFiles(files map[string]string, baseLayers []*layerGGML, isAdapter bool, fn func(resp api.ProgressResponse)) ([]*layerGGML, error) {
switch detectModelTypeFromFiles(files) {
case "safetensors":
layers, err := convertFromSafetensors(files, baseLayers, isAdapter, fn)
if err != nil {
slog.Error("error converting from safetensors", "error", err)
return nil, err
}
return layers, nil
case "gguf":
if len(files) == 0 {
return nil, errNoFilesProvided
} else if len(files) > 1 && isAdapter {
return nil, errOnlyOneAdapterSupported
}
var digest string
var allLayers []*layerGGML
for _, v := range files {
digest = v
layers, err := ggufLayers(digest, fn)
if err != nil {
return nil, err
}
allLayers = append(allLayers, layers...)
}
return allLayers, nil
default:
return nil, errUnknownType
}
}
func detectModelTypeFromFiles(files map[string]string) string {
for fn := range files {
if strings.HasSuffix(fn, ".safetensors") {
return "safetensors"
} else if strings.HasSuffix(fn, ".gguf") {
return "gguf"
} else {
// try to see if we can find a gguf file even without the file extension
blobPath, err := manifest.BlobsPath(files[fn])
if err != nil {
slog.Error("error getting blobs path", "file", fn)
return ""
}
f, err := os.Open(blobPath)
if err != nil {
slog.Error("error reading file", "error", err)
return ""
}
defer f.Close()
buf := make([]byte, 4)
_, err = f.Read(buf)
if err != nil {
slog.Error("error reading file", "error", err)
return ""
}
ct := ggml.DetectContentType(buf)
if ct == "gguf" {
return "gguf"
}
}
}
return ""
}
func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, isAdapter bool, fn func(resp api.ProgressResponse)) ([]*layerGGML, error) {
tmpDir, err := os.MkdirTemp(envconfig.Models(), "ollama-safetensors")
if err != nil {
return nil, err
}
defer os.RemoveAll(tmpDir)
// Set up a root to validate paths
root, err := os.OpenRoot(tmpDir)
if err != nil {
return nil, err
}
defer root.Close()
for fp, digest := range files {
if !fs.ValidPath(fp) {
return nil, fmt.Errorf("%w: %s", errFilePath, fp)
}
if _, err := root.Stat(fp); err != nil && !errors.Is(err, fs.ErrNotExist) {
// Path is likely outside the root
return nil, fmt.Errorf("%w: %s: %s", errFilePath, err, fp)
}
blobPath, err := manifest.BlobsPath(digest)
if err != nil {
return nil, err
}
if err := createLink(blobPath, filepath.Join(tmpDir, fp)); err != nil {
return nil, err
}
}
t, err := os.CreateTemp(tmpDir, "fp16")
if err != nil {
return nil, err
}
defer t.Close()
var mediaType string
if !isAdapter {
fn(api.ProgressResponse{Status: "converting model"})
mediaType = "application/vnd.ollama.image.model"
if err := convert.ConvertModel(os.DirFS(tmpDir), t); err != nil {
return nil, err
}
} else {
kv, err := kvFromLayers(baseLayers)
if err != nil {
return nil, err
}
fn(api.ProgressResponse{Status: "converting adapter"})
mediaType = "application/vnd.ollama.image.adapter"
if err := convert.ConvertAdapter(os.DirFS(tmpDir), t, kv); err != nil {
return nil, err
}
}
if _, err := t.Seek(0, io.SeekStart); err != nil {
return nil, err
}
layer, err := manifest.NewLayer(t, mediaType)
if err != nil {
return nil, err
}
bin, err := layer.Open()
if err != nil {
return nil, err
}
defer bin.Close()
f, err := ggml.Decode(bin, -1)
if err != nil {
return nil, err
}
layers := []*layerGGML{{layer, f}}
if !isAdapter {
return detectChatTemplate(layers)
}
return layers, nil
}
func kvFromLayers(baseLayers []*layerGGML) (ofs.Config, error) {
for _, l := range baseLayers {
if l.GGML != nil {
return l.KV(), nil
}
}
return ggml.KV{}, fmt.Errorf("no base model was found")
}
func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML, config *model.ConfigV2, fn func(resp api.ProgressResponse)) (err error) {
var layers []manifest.Layer
for _, layer := range baseLayers {
if layer.GGML != nil {
quantType := strings.ToUpper(cmp.Or(r.Quantize, r.Quantization))
ft := layer.GGML.KV().FileType()
if quantType == "" && hasSourceFP8Tensors(layer.GGML.KV()) && layer.GGML.Name() == "gguf" && layer.MediaType == "application/vnd.ollama.image.model" && slices.Contains([]string{"F16", "BF16", "F32"}, ft.String()) {
quantType = "Q8_0"
}
if quantType != "" && layer.GGML.Name() == "gguf" && layer.MediaType == "application/vnd.ollama.image.model" {
want, err := ggml.ParseFileType(quantType)
if err != nil {
return err
}
if !slices.Contains([]string{"F16", "BF16", "F32"}, ft.String()) {
return errors.New("quantization is only supported for F16, BF16 and F32 models")
} else if ft != want {
layer, err = quantizeLayer(layer, quantType, fn)
if err != nil {
return err
}
}
}
config.ModelFormat = cmp.Or(config.ModelFormat, layer.GGML.Name())
config.ModelFamily = cmp.Or(config.ModelFamily, layer.GGML.KV().Architecture())
config.ModelType = cmp.Or(config.ModelType, format.HumanNumber(layer.GGML.KV().ParameterCount()))
config.FileType = cmp.Or(config.FileType, layer.GGML.KV().FileType().String())
config.ModelFamilies = append(config.ModelFamilies, layer.GGML.KV().Architecture())
// Auto-detect renderer, parser, and stop tokens from GGUF architecture.
// TODO: abstract this into a registry/lookup table when multiple models
// need architecture-based renderer/parser/stop defaults.
if config.Renderer == "" || config.Parser == "" {
arch := layer.GGML.KV().Architecture()
switch arch {
case "gemma4":
config.Renderer = cmp.Or(config.Renderer, gemma4RendererLegacy)
config.Parser = cmp.Or(config.Parser, "gemma4")
if _, ok := r.Parameters["stop"]; !ok {
if r.Parameters == nil {
r.Parameters = make(map[string]any)
}
r.Parameters["stop"] = []string{"<turn|>"}
}
case "laguna":
config.Renderer = cmp.Or(config.Renderer, "laguna")
config.Parser = cmp.Or(config.Parser, "laguna")
case "nemotron_h", "nemotron_h_moe", "nemotron_h_omni":
config.Renderer = cmp.Or(config.Renderer, "nemotron-3-nano")
config.Parser = cmp.Or(config.Parser, "nemotron-3-nano")
}
}
}
layers = append(layers, layer.Layer)
}
if r.Template != "" {
layers, err = setTemplate(layers, r.Template)
if err != nil {
return err
}
}
if r.System != "" {
layers, err = setSystem(layers, r.System)
if err != nil {
return err
}
}
if r.License != nil {
switch l := r.License.(type) {
case string:
if l != "" {
layers, err = setLicense(layers, l)
if err != nil {
return err
}
}
case any:
var licenses []string
b, _ := json.Marshal(l) // re-marshal to JSON
if err := json.Unmarshal(b, &licenses); err != nil {
return err
}
for _, v := range licenses {
layers, err = setLicense(layers, v)
if err != nil {
return err
}
}
default:
return fmt.Errorf("unknown license type: %T", l)
}
}
layers, err = setParameters(layers, r.Parameters)
if err != nil {
return err
}
layers, err = setMessages(layers, r.Messages)
if err != nil {
return err
}
configLayer, err := createConfigLayer(layers, *config)
if err != nil {
return err
}
for _, layer := range layers {
if layer.Status != "" {
fn(api.ProgressResponse{Status: layer.Status})
}
}
fn(api.ProgressResponse{Status: "writing manifest"})
if err := manifest.WriteManifest(name, *configLayer, layers); err != nil {
return err
}
return nil
}
func hasSourceFP8Tensors(kv ggml.KV) bool {
return kv.String("source_quantization") == "hf_fp8" && len(kv.Strings("source_fp8_tensors")) > 0
}
func quantizeLayer(layer *layerGGML, quantizeType string, fn func(resp api.ProgressResponse)) (*layerGGML, error) {
ft := layer.GGML.KV().FileType()
var doneBytes atomic.Uint64
totalBytes := uint64(layer.Size) - layer.GGML.Tensors().Offset
fnWrap := func(n uint64) {
done := doneBytes.Add(n)
progress := float32(done) / float32(totalBytes)
fn(api.ProgressResponse{Status: fmt.Sprintf("quantizing %s model to %s", ft, quantizeType), Digest: "0000000000000000000", Total: layer.Size, Completed: int64(progress * float32(layer.Size))})
}
ftype, err := ggml.ParseFileType(quantizeType)
if err != nil {
return nil, err
}
blob, err := manifest.BlobsPath(layer.Digest)
if err != nil {
return nil, err
}
fp, err := os.Open(blob)
if err != nil {
return nil, err
}
defer fp.Close()
temp, err := os.CreateTemp(filepath.Dir(blob), quantizeType)
if err != nil {
return nil, err
}
defer temp.Close()
defer os.Remove(temp.Name())
if err := quantize(fp, temp, layer.GGML, ftype, fnWrap); err != nil {
return nil, err
}
temp.Seek(0, io.SeekStart)
fn(api.ProgressResponse{Status: "verifying conversion"})
newLayer, err := manifest.NewLayer(temp, layer.MediaType)
if err != nil {
return nil, err
}
if _, err := temp.Seek(0, io.SeekStart); err != nil {
return nil, err
}
f, err := ggml.Decode(temp, 1024)
if err != nil {
slog.Error(fmt.Sprintf("error decoding ggml: %s\n", err))
return nil, err
}
return &layerGGML{newLayer, f}, nil
}
func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML, error) {
var layers []*layerGGML
fn(api.ProgressResponse{Status: "parsing GGUF"})
blobPath, err := manifest.BlobsPath(digest)
if err != nil {
return nil, err
}
blob, err := os.Open(blobPath)
if err != nil {
return nil, err
}
defer blob.Close()
sr := io.NewSectionReader(blob, 0, 512)
contentType, err := detectContentType(sr)
if err != nil {
return nil, err
}
if contentType != "gguf" {
slog.Error(fmt.Sprintf("unsupported content type: %s", contentType))
return nil, errOnlyGGUFSupported
}
f, err := ggml.Decode(blob, -1)
if err != nil {
return nil, err
}
mediatype := "application/vnd.ollama.image.model"
if f.KV().Kind() == "adapter" {
mediatype = "application/vnd.ollama.image.adapter"
} else if (f.KV().Uint("block_count") == 0 && f.KV().Uint("vision.block_count") > 0) || f.KV().Kind() == "projector" {
// if a model has vision.block_count but not block_count, it is a standalone vision model
mediatype = "application/vnd.ollama.image.projector"
}
layer, err := manifest.NewLayerFromLayer(digest, mediatype, blob.Name())
if err != nil {
slog.Debug("could not create new layer from layer", "error", err)
return nil, err
}
layers = append(layers, &layerGGML{layer, f})
return detectChatTemplate(layers)
}
func removeLayer(layers []manifest.Layer, mediatype string) []manifest.Layer {
return slices.DeleteFunc(layers, func(layer manifest.Layer) bool {
if layer.MediaType != mediatype {
return false
}
if err := layer.Remove(); err != nil {
slog.Warn("couldn't remove blob", "digest", layer.Digest, "error", err)
return true
}
return true
})
}
func setTemplate(layers []manifest.Layer, t string) ([]manifest.Layer, error) {
layers = removeLayer(layers, "application/vnd.ollama.image.template")
if _, err := template.Parse(t); err != nil {
return nil, fmt.Errorf("%w: %s", errBadTemplate, err)
}
if _, err := template.Parse(t); err != nil {
return nil, fmt.Errorf("%w: %s", errBadTemplate, err)
}
blob := strings.NewReader(t)
layer, err := manifest.NewLayer(blob, "application/vnd.ollama.image.template")
if err != nil {
return nil, err
}
layers = append(layers, layer)
return layers, nil
}
func setSystem(layers []manifest.Layer, s string) ([]manifest.Layer, error) {
layers = removeLayer(layers, "application/vnd.ollama.image.system")
if s != "" {
blob := strings.NewReader(s)
layer, err := manifest.NewLayer(blob, "application/vnd.ollama.image.system")
if err != nil {
return nil, err
}
layers = append(layers, layer)
}
return layers, nil
}
func setLicense(layers []manifest.Layer, l string) ([]manifest.Layer, error) {
blob := strings.NewReader(l)
layer, err := manifest.NewLayer(blob, "application/vnd.ollama.image.license")
if err != nil {
return nil, err
}
layers = append(layers, layer)
return layers, nil
}
func setParameters(layers []manifest.Layer, p map[string]any) ([]manifest.Layer, error) {
if p == nil {
p = make(map[string]any)
}
for _, layer := range layers {
if layer.MediaType != "application/vnd.ollama.image.params" {
continue
}
digestPath, err := manifest.BlobsPath(layer.Digest)
if err != nil {
return nil, err
}
fn, err := os.Open(digestPath)
if err != nil {
return nil, err
}
defer fn.Close()
var existing map[string]any
if err := json.NewDecoder(fn).Decode(&existing); err != nil {
return nil, err
}
for k, v := range existing {
if _, exists := p[k]; exists {
continue
}
p[k] = v
}
}
if len(p) == 0 {
return layers, nil
}
layers = removeLayer(layers, "application/vnd.ollama.image.params")
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(p); err != nil {
return nil, err
}
layer, err := manifest.NewLayer(&b, "application/vnd.ollama.image.params")
if err != nil {
return nil, err
}
layers = append(layers, layer)
return layers, nil
}
func setMessages(layers []manifest.Layer, m []api.Message) ([]manifest.Layer, error) {
// this leaves the old messages intact if no new messages were specified
// which may not be the correct behaviour
if len(m) == 0 {
return layers, nil
}
fmt.Printf("removing old messages\n")
layers = removeLayer(layers, "application/vnd.ollama.image.messages")
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(m); err != nil {
return nil, err
}
layer, err := manifest.NewLayer(&b, "application/vnd.ollama.image.messages")
if err != nil {
return nil, err
}
layers = append(layers, layer)
return layers, nil
}
func createConfigLayer(layers []manifest.Layer, config model.ConfigV2) (*manifest.Layer, error) {
digests := make([]string, len(layers))
for i, layer := range layers {
digests[i] = layer.Digest
}
config.RootFS.DiffIDs = digests
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(config); err != nil {
return nil, err
}
layer, err := manifest.NewLayer(&b, "application/vnd.docker.container.image.v1+json")
if err != nil {
return nil, err
}
return &layer, nil
}
func createLink(src, dst string) error {
// make any subdirs for dst
if err := os.MkdirAll(filepath.Dir(dst), 0o755); err != nil {
return err
}
_ = os.Remove(dst)
if err := os.Symlink(src, dst); err != nil {
if err := copyFile(src, dst); err != nil {
return err
}
}
return nil
}
func copyFile(src, dst string) error {
srcFile, err := os.Open(src)
if err != nil {
return err
}
defer srcFile.Close()
dstFile, err := os.Create(dst)
if err != nil {
return err
}
defer dstFile.Close()
_, err = io.Copy(dstFile, srcFile)
return err
}

258
server/create_test.go Normal file
View File

@@ -0,0 +1,258 @@
package server
import (
"bytes"
"encoding/binary"
"errors"
"os"
"path/filepath"
"strings"
"testing"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/manifest"
)
func TestConvertFromSafetensors(t *testing.T) {
t.Setenv("OLLAMA_MODELS", t.TempDir())
// Helper function to create a new layer and return its digest
makeTemp := func(content string) string {
l, err := manifest.NewLayer(strings.NewReader(content), "application/octet-stream")
if err != nil {
t.Fatalf("Failed to create layer: %v", err)
}
return l.Digest
}
// Create a safetensors compatible file with empty JSON content
var buf bytes.Buffer
headerSize := int64(len("{}"))
binary.Write(&buf, binary.LittleEndian, headerSize)
buf.WriteString("{}")
model := makeTemp(buf.String())
config := makeTemp(`{
"architectures": ["LlamaForCausalLM"],
"vocab_size": 32000
}`)
tokenizer := makeTemp(`{
"version": "1.0",
"truncation": null,
"padding": null,
"added_tokens": [
{
"id": 0,
"content": "<|endoftext|>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
}
]
}`)
tests := []struct {
name string
filePath string
wantErr error
}{
// Invalid
{
name: "InvalidRelativePathShallow",
filePath: filepath.Join("..", "file.safetensors"),
wantErr: errFilePath,
},
{
name: "InvalidRelativePathDeep",
filePath: filepath.Join("..", "..", "..", "..", "..", "..", "data", "file.txt"),
wantErr: errFilePath,
},
{
name: "InvalidNestedPath",
filePath: filepath.Join("dir", "..", "..", "..", "..", "..", "other.safetensors"),
wantErr: errFilePath,
},
{
name: "AbsolutePathOutsideRoot",
filePath: filepath.Join(os.TempDir(), "model.safetensors"),
wantErr: errFilePath, // Should fail since it's outside tmpDir
},
{
name: "ValidRelativePath",
filePath: "model.safetensors",
wantErr: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create the minimum required file map for convertFromSafetensors
files := map[string]string{
tt.filePath: model,
"config.json": config,
"tokenizer.json": tokenizer,
}
_, err := convertFromSafetensors(files, nil, false, func(resp api.ProgressResponse) {})
if (tt.wantErr == nil && err != nil) ||
(tt.wantErr != nil && err == nil) ||
(tt.wantErr != nil && !errors.Is(err, tt.wantErr)) {
t.Errorf("convertFromSafetensors() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestRemoteURL(t *testing.T) {
tests := []struct {
name string
input string
expected string
hasError bool
}{
{
name: "absolute path",
input: "/foo/bar",
expected: "http://localhost:11434/foo/bar",
hasError: false,
},
{
name: "absolute path with cleanup",
input: "/foo/../bar",
expected: "http://localhost:11434/bar",
hasError: false,
},
{
name: "root path",
input: "/",
expected: "http://localhost:11434/",
hasError: false,
},
{
name: "host without scheme",
input: "example.com",
expected: "http://example.com:11434",
hasError: false,
},
{
name: "host with port",
input: "example.com:8080",
expected: "http://example.com:8080",
hasError: false,
},
{
name: "full URL",
input: "https://example.com:8080/path",
expected: "https://example.com:8080/path",
hasError: false,
},
{
name: "full URL with path cleanup",
input: "https://example.com:8080/path/../other",
expected: "https://example.com:8080/other",
hasError: false,
},
{
name: "ollama.com special case",
input: "ollama.com",
expected: "https://ollama.com:443",
hasError: false,
},
{
name: "http ollama.com special case",
input: "http://ollama.com",
expected: "https://ollama.com:443",
hasError: false,
},
{
name: "URL with only host",
input: "http://example.com",
expected: "http://example.com:11434",
hasError: false,
},
{
name: "URL with root path cleaned",
input: "http://example.com/",
expected: "http://example.com:11434",
hasError: false,
},
{
name: "invalid URL",
input: "http://[::1]:namedport", // invalid port
expected: "",
hasError: true,
},
{
name: "empty string",
input: "",
expected: "http://localhost:11434",
hasError: false,
},
{
name: "host with scheme but no port",
input: "http://localhost",
expected: "http://localhost:11434",
hasError: false,
},
{
name: "complex path cleanup",
input: "/a/b/../../c/./d",
expected: "http://localhost:11434/c/d",
hasError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := remoteURL(tt.input)
if tt.hasError {
if err == nil {
t.Errorf("expected error but got none")
}
return
}
if err != nil {
t.Errorf("unexpected error: %v", err)
return
}
if result != tt.expected {
t.Errorf("expected %q, got %q", tt.expected, result)
}
})
}
}
func TestRemoteURL_Idempotent(t *testing.T) {
// Test that applying remoteURL twice gives the same result as applying it once
testInputs := []string{
"/foo/bar",
"example.com",
"https://example.com:8080/path",
"ollama.com",
"http://localhost:11434",
}
for _, input := range testInputs {
t.Run(input, func(t *testing.T) {
firstResult, err := remoteURL(input)
if err != nil {
t.Fatalf("first call failed: %v", err)
}
secondResult, err := remoteURL(firstResult)
if err != nil {
t.Fatalf("second call failed: %v", err)
}
if firstResult != secondResult {
t.Errorf("function is not idempotent: first=%q, second=%q", firstResult, secondResult)
}
})
}
}

509
server/download.go Normal file
View File

@@ -0,0 +1,509 @@
package server
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"math"
"math/rand/v2"
"net/http"
"net/url"
"os"
"path/filepath"
"strconv"
"strings"
"sync"
"sync/atomic"
"syscall"
"time"
"golang.org/x/sync/errgroup"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/types/model"
)
const maxRetries = 6
var (
errMaxRetriesExceeded = errors.New("max retries exceeded")
errPartStalled = errors.New("part stalled")
errMaxRedirectsExceeded = errors.New("maximum redirects exceeded (10) for directURL")
)
var blobDownloadManager sync.Map
type blobDownload struct {
Name string
Digest string
Total int64
Completed atomic.Int64
Parts []*blobDownloadPart
context.CancelFunc
done chan struct{}
err error
references atomic.Int32
}
type blobDownloadPart struct {
N int
Offset int64
Size int64
Completed atomic.Int64
lastUpdatedMu sync.Mutex
lastUpdated time.Time
*blobDownload `json:"-"`
}
type jsonBlobDownloadPart struct {
N int
Offset int64
Size int64
Completed int64
}
func (p *blobDownloadPart) MarshalJSON() ([]byte, error) {
return json.Marshal(jsonBlobDownloadPart{
N: p.N,
Offset: p.Offset,
Size: p.Size,
Completed: p.Completed.Load(),
})
}
func (p *blobDownloadPart) UnmarshalJSON(b []byte) error {
var j jsonBlobDownloadPart
if err := json.Unmarshal(b, &j); err != nil {
return err
}
*p = blobDownloadPart{
N: j.N,
Offset: j.Offset,
Size: j.Size,
}
p.Completed.Store(j.Completed)
return nil
}
const (
numDownloadParts = 16
minDownloadPartSize int64 = 100 * format.MegaByte
maxDownloadPartSize int64 = 1000 * format.MegaByte
)
func (p *blobDownloadPart) Name() string {
return strings.Join([]string{
p.blobDownload.Name, "partial", strconv.Itoa(p.N),
}, "-")
}
func (p *blobDownloadPart) StartsAt() int64 {
return p.Offset + p.Completed.Load()
}
func (p *blobDownloadPart) StopsAt() int64 {
return p.Offset + p.Size
}
func (p *blobDownloadPart) Write(b []byte) (n int, err error) {
n = len(b)
p.blobDownload.Completed.Add(int64(n))
p.lastUpdatedMu.Lock()
p.lastUpdated = time.Now()
p.lastUpdatedMu.Unlock()
return n, nil
}
func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *registryOptions) error {
partFilePaths, err := filepath.Glob(b.Name + "-partial-*")
if err != nil {
return err
}
b.done = make(chan struct{})
for _, partFilePath := range partFilePaths {
part, err := b.readPart(partFilePath)
if err != nil {
return err
}
b.Total += part.Size
b.Completed.Add(part.Completed.Load())
b.Parts = append(b.Parts, part)
}
if len(b.Parts) == 0 {
resp, err := makeRequestWithRetry(ctx, http.MethodHead, requestURL, nil, nil, opts)
if err != nil {
return err
}
defer resp.Body.Close()
b.Total, _ = strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
size := b.Total / numDownloadParts
switch {
case size < minDownloadPartSize:
size = minDownloadPartSize
case size > maxDownloadPartSize:
size = maxDownloadPartSize
}
var offset int64
for offset < b.Total {
if offset+size > b.Total {
size = b.Total - offset
}
if err := b.newPart(offset, size); err != nil {
return err
}
offset += size
}
}
if len(b.Parts) > 0 {
slog.Info(fmt.Sprintf("downloading %s in %d %s part(s)", b.Digest[7:19], len(b.Parts), format.HumanBytes(b.Parts[0].Size)))
}
return nil
}
func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *registryOptions) {
defer close(b.done)
b.err = b.run(ctx, requestURL, opts)
}
func newBackoff(maxBackoff time.Duration) func(ctx context.Context) error {
var n int
return func(ctx context.Context) error {
if ctx.Err() != nil {
return ctx.Err()
}
n++
// n^2 backoff timer is a little smoother than the
// common choice of 2^n.
d := min(time.Duration(n*n)*10*time.Millisecond, maxBackoff)
// Randomize the delay between 0.5-1.5 x msec, in order
// to prevent accidental "thundering herd" problems.
d = time.Duration(float64(d) * (rand.Float64() + 0.5))
t := time.NewTimer(d)
defer t.Stop()
select {
case <-ctx.Done():
return ctx.Err()
case <-t.C:
return nil
}
}
}
func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *registryOptions) error {
defer blobDownloadManager.Delete(b.Digest)
ctx, b.CancelFunc = context.WithCancel(ctx)
file, err := os.OpenFile(b.Name+"-partial", os.O_CREATE|os.O_RDWR, 0o644)
if err != nil {
return err
}
defer file.Close()
setSparse(file)
_ = file.Truncate(b.Total)
directURL, err := func() (*url.URL, error) {
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
backoff := newBackoff(10 * time.Second)
for {
// shallow clone opts to be used in the closure
// without affecting the outer opts.
newOpts := new(registryOptions)
*newOpts = *opts
newOpts.CheckRedirect = func(req *http.Request, via []*http.Request) error {
if len(via) > 10 {
return errMaxRedirectsExceeded
}
// if the hostname is the same, allow the redirect
if req.URL.Hostname() == requestURL.Hostname() {
return nil
}
// stop at the first redirect that is not
// the same hostname as the original
// request.
return http.ErrUseLastResponse
}
resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, nil, nil, newOpts)
if err != nil {
slog.Warn("failed to get direct URL; backing off and retrying", "err", err)
if err := backoff(ctx); err != nil {
return nil, err
}
continue
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusTemporaryRedirect && resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected status code %d", resp.StatusCode)
}
return resp.Location()
}
}()
if err != nil {
return err
}
g, inner := errgroup.WithContext(ctx)
g.SetLimit(numDownloadParts)
for i := range b.Parts {
part := b.Parts[i]
if part.Completed.Load() == part.Size {
continue
}
g.Go(func() error {
var err error
for try := 0; try < maxRetries; try++ {
w := io.NewOffsetWriter(file, part.StartsAt())
err = b.downloadChunk(inner, directURL, w, part)
switch {
case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC):
// return immediately if the context is canceled or the device is out of space
return err
case errors.Is(err, errPartStalled):
try--
continue
case err != nil:
sleep := time.Second * time.Duration(math.Pow(2, float64(try)))
slog.Info(fmt.Sprintf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep))
time.Sleep(sleep)
continue
default:
return nil
}
}
return fmt.Errorf("%w: %w", errMaxRetriesExceeded, err)
})
}
if err := g.Wait(); err != nil {
return err
}
// explicitly close the file so we can rename it
if err := file.Close(); err != nil {
return err
}
for i := range b.Parts {
if err := os.Remove(file.Name() + "-" + strconv.Itoa(i)); err != nil {
return err
}
}
if err := os.Rename(file.Name(), b.Name); err != nil {
return err
}
return nil
}
func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart) error {
g, ctx := errgroup.WithContext(ctx)
g.Go(func() error {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL.String(), nil)
if err != nil {
return err
}
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1))
resp, err := http.DefaultClient.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
n, err := io.CopyN(w, io.TeeReader(resp.Body, part), part.Size-part.Completed.Load())
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrUnexpectedEOF) {
// rollback progress
b.Completed.Add(-n)
return err
}
part.Completed.Add(n)
if err := b.writePart(part.Name(), part); err != nil {
return err
}
// return nil or context.Canceled or UnexpectedEOF (resumable)
return err
})
g.Go(func() error {
ticker := time.NewTicker(time.Second)
for {
select {
case <-ticker.C:
if part.Completed.Load() >= part.Size {
return nil
}
part.lastUpdatedMu.Lock()
lastUpdated := part.lastUpdated
part.lastUpdatedMu.Unlock()
if !lastUpdated.IsZero() && time.Since(lastUpdated) > 30*time.Second {
const msg = "%s part %d stalled; retrying. If this persists, press ctrl-c to exit, then 'ollama pull' to find a faster connection."
slog.Info(fmt.Sprintf(msg, b.Digest[7:19], part.N))
// reset last updated
part.lastUpdatedMu.Lock()
part.lastUpdated = time.Time{}
part.lastUpdatedMu.Unlock()
return errPartStalled
}
case <-ctx.Done():
return ctx.Err()
}
}
})
return g.Wait()
}
func (b *blobDownload) newPart(offset, size int64) error {
part := blobDownloadPart{blobDownload: b, Offset: offset, Size: size, N: len(b.Parts)}
if err := b.writePart(part.Name(), &part); err != nil {
return err
}
b.Parts = append(b.Parts, &part)
return nil
}
func (b *blobDownload) readPart(partName string) (*blobDownloadPart, error) {
var part blobDownloadPart
partFile, err := os.Open(partName)
if err != nil {
return nil, err
}
defer partFile.Close()
if err := json.NewDecoder(partFile).Decode(&part); err != nil {
return nil, err
}
part.blobDownload = b
return &part, nil
}
func (b *blobDownload) writePart(partName string, part *blobDownloadPart) error {
partFile, err := os.OpenFile(partName, os.O_CREATE|os.O_RDWR|os.O_TRUNC, 0o644)
if err != nil {
return err
}
defer partFile.Close()
return json.NewEncoder(partFile).Encode(part)
}
func (b *blobDownload) acquire() {
b.references.Add(1)
}
func (b *blobDownload) release() {
if b.references.Add(-1) == 0 {
b.CancelFunc()
}
}
func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse)) error {
b.acquire()
defer b.release()
ticker := time.NewTicker(60 * time.Millisecond)
for {
select {
case <-b.done:
return b.err
case <-ticker.C:
fn(api.ProgressResponse{
Status: fmt.Sprintf("pulling %s", b.Digest[7:19]),
Digest: b.Digest,
Total: b.Total,
Completed: b.Completed.Load(),
})
case <-ctx.Done():
return ctx.Err()
}
}
}
type downloadOpts struct {
n model.Name
digest string
regOpts *registryOptions
fn func(api.ProgressResponse)
}
// downloadBlob downloads a blob from the registry and stores it in the blobs directory
func downloadBlob(ctx context.Context, opts downloadOpts) (cacheHit bool, _ error) {
if opts.digest == "" {
return false, fmt.Errorf(("%s: %s"), opts.n.DisplayNamespaceModel(), "digest is empty")
}
fp, err := manifest.BlobsPath(opts.digest)
if err != nil {
return false, err
}
fi, err := os.Stat(fp)
switch {
case errors.Is(err, os.ErrNotExist):
case err != nil:
return false, err
default:
opts.fn(api.ProgressResponse{
Status: fmt.Sprintf("pulling %s", opts.digest[7:19]),
Digest: opts.digest,
Total: fi.Size(),
Completed: fi.Size(),
})
return true, nil
}
data, ok := blobDownloadManager.LoadOrStore(opts.digest, &blobDownload{Name: fp, Digest: opts.digest})
download := data.(*blobDownload)
if !ok {
requestURL := opts.n.BaseURL()
requestURL = requestURL.JoinPath("v2", opts.n.DisplayNamespaceModel(), "blobs", opts.digest)
if err := download.Prepare(ctx, requestURL, opts.regOpts); err != nil {
blobDownloadManager.Delete(opts.digest)
return false, err
}
//nolint:contextcheck
go download.Run(context.Background(), requestURL, opts.regOpts)
}
return false, download.Wait(ctx, opts.fn)
}

26
server/fixblobs.go Normal file
View File

@@ -0,0 +1,26 @@
package server
import (
"os"
"path/filepath"
"strings"
)
// fixBlobs walks the provided dir and replaces (":") to ("-") in the file
// prefix. (e.g. sha256:1234 -> sha256-1234)
func fixBlobs(dir string) error {
return filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
baseName := filepath.Base(path)
typ, sha, ok := strings.Cut(baseName, ":")
if ok && typ == "sha256" {
newPath := filepath.Join(filepath.Dir(path), typ+"-"+sha)
if err := os.Rename(path, newPath); err != nil {
return err
}
}
return nil
})
}

83
server/fixblobs_test.go Normal file
View File

@@ -0,0 +1,83 @@
package server
import (
"io/fs"
"os"
"path/filepath"
"runtime"
"slices"
"strings"
"testing"
)
func TestFixBlobs(t *testing.T) {
cases := []struct {
path []string
want []string
}{
{path: []string{"sha256-1234"}, want: []string{"sha256-1234"}},
{path: []string{"sha256:1234"}, want: []string{"sha256-1234"}},
{path: []string{"sha259:5678"}, want: []string{"sha259:5678"}},
{path: []string{"sha256:abcd"}, want: []string{"sha256-abcd"}},
{path: []string{"x/y/sha256:abcd"}, want: []string{"x/y/sha256-abcd"}},
{path: []string{"x:y/sha256:abcd"}, want: []string{"x:y/sha256-abcd"}},
{path: []string{"x:y/sha256:abcd"}, want: []string{"x:y/sha256-abcd"}},
{path: []string{"x:y/sha256:abcd", "sha256:1234"}, want: []string{"x:y/sha256-abcd", "sha256-1234"}},
{path: []string{"x:y/sha256:abcd", "sha256-1234"}, want: []string{"x:y/sha256-abcd", "sha256-1234"}},
}
for _, tt := range cases {
t.Run(strings.Join(tt.path, "|"), func(t *testing.T) {
hasColon := slices.ContainsFunc(tt.path, func(s string) bool { return strings.Contains(s, ":") })
if hasColon && runtime.GOOS == "windows" {
t.Skip("skipping test on windows")
}
rootDir := t.TempDir()
for _, path := range tt.path {
fullPath := filepath.Join(rootDir, path)
fullDir, _ := filepath.Split(fullPath)
t.Logf("creating dir %s", fullDir)
if err := os.MkdirAll(fullDir, 0o755); err != nil {
t.Fatal(err)
}
t.Logf("writing file %s", fullPath)
if err := os.WriteFile(fullPath, nil, 0o644); err != nil {
t.Fatal(err)
}
}
if err := fixBlobs(rootDir); err != nil {
t.Fatal(err)
}
got := slurpFiles(os.DirFS(rootDir))
slices.Sort(tt.want)
slices.Sort(got)
if !slices.Equal(got, tt.want) {
t.Fatalf("got = %v, want %v", got, tt.want)
}
})
}
}
func slurpFiles(fsys fs.FS) []string {
var sfs []string
fn := func(path string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
if d.IsDir() {
return nil
}
sfs = append(sfs, path)
return nil
}
if err := fs.WalkDir(fsys, ".", fn); err != nil {
panic(err)
}
return sfs
}

78
server/gemma4_test.go Normal file
View File

@@ -0,0 +1,78 @@
package server
import "testing"
func TestResolveGemma4Renderer(t *testing.T) {
tests := []struct {
name string
model *Model
want string
}{
{
name: "nil model falls back to legacy alias",
model: nil,
want: gemma4RendererLegacy,
},
{
name: "explicit small passes through",
model: &Model{
Config: testConfigWithRenderer(gemma4RendererSmall),
},
want: gemma4RendererSmall,
},
{
name: "explicit large passes through",
model: &Model{
Config: testConfigWithRenderer(gemma4RendererLarge),
},
want: gemma4RendererLarge,
},
{
name: "legacy e4b tag resolves small",
model: &Model{
Name: "gemma4:e4b",
ShortName: "gemma4:e4b",
Config: testConfigWithRenderer(gemma4RendererLegacy),
},
want: gemma4RendererSmall,
},
{
name: "legacy 31b tag resolves large",
model: &Model{
Name: "gemma4:31b-cloud",
ShortName: "gemma4:31b-cloud",
Config: testConfigWithRenderer(gemma4RendererLegacy),
},
want: gemma4RendererLarge,
},
{
name: "legacy model type resolves small",
model: &Model{
Config: testConfigWithRendererAndType(gemma4RendererLegacy, "4.3B"),
},
want: gemma4RendererSmall,
},
{
name: "legacy model type resolves large",
model: &Model{
Config: testConfigWithRendererAndType(gemma4RendererLegacy, "25.2B"),
},
want: gemma4RendererLarge,
},
{
name: "legacy unknown defaults small",
model: &Model{
Config: testConfigWithRenderer(gemma4RendererLegacy),
},
want: gemma4RendererSmall,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := resolveGemma4Renderer(tt.model); got != tt.want {
t.Fatalf("resolveGemma4Renderer() = %q, want %q", got, tt.want)
}
})
}
}

1050
server/images.go Normal file

File diff suppressed because it is too large Load Diff

432
server/images_test.go Normal file
View File

@@ -0,0 +1,432 @@
package server
import (
"crypto/sha256"
"fmt"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"time"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/model"
)
func TestPruneLayersSkipsRecentOrphans(t *testing.T) {
t.Setenv("OLLAMA_MODELS", t.TempDir())
recentDigest := "sha256:0000000000000000000000000000000000000000000000000000000000000001"
oldDigest := "sha256:0000000000000000000000000000000000000000000000000000000000000002"
for _, digest := range []string{recentDigest, oldDigest} {
p, err := manifest.BlobsPath(digest)
if err != nil {
t.Fatal(err)
}
if err := os.WriteFile(p, nil, 0o644); err != nil {
t.Fatal(err)
}
}
oldPath, err := manifest.BlobsPath(oldDigest)
if err != nil {
t.Fatal(err)
}
oldTime := time.Now().Add(-layerPruneGracePeriod - time.Hour)
if err := os.Chtimes(oldPath, oldTime, oldTime); err != nil {
t.Fatal(err)
}
if err := PruneLayers(); err != nil {
t.Fatal(err)
}
recentPath, err := manifest.BlobsPath(recentDigest)
if err != nil {
t.Fatal(err)
}
if _, err := os.Stat(recentPath); err != nil {
t.Fatalf("recent orphan was pruned: %v", err)
}
if _, err := os.Stat(oldPath); !os.IsNotExist(err) {
t.Fatalf("old orphan still exists: %v", err)
}
}
func TestModelCapabilities(t *testing.T) {
// Create completion model (llama architecture without vision)
completionModelPath, _ := createBinFile(t, ggml.KV{
"general.architecture": "llama",
}, []*ggml.Tensor{})
// Create vision model (llama architecture with vision block count)
visionModelPath, _ := createBinFile(t, ggml.KV{
"general.architecture": "llama",
"llama.vision.block_count": uint32(1),
}, []*ggml.Tensor{})
// Create embedding model (bert architecture with pooling type)
embeddingModelPath, _ := createBinFile(t, ggml.KV{
"general.architecture": "bert",
"bert.pooling_type": uint32(1),
}, []*ggml.Tensor{})
toolsInsertTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}{{ if .suffix }}{{ .suffix }}{{ end }}")
if err != nil {
t.Fatalf("Failed to parse template: %v", err)
}
chatTemplate, err := template.Parse("{{ .prompt }}")
if err != nil {
t.Fatalf("Failed to parse template: %v", err)
}
toolsTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}")
if err != nil {
t.Fatalf("Failed to parse template: %v", err)
}
testModels := []struct {
name string
model Model
expectedCaps []model.Capability
}{
{
name: "model with image generation capability via config",
model: Model{
Config: model.ConfigV2{
Capabilities: []string{"image"},
},
},
expectedCaps: []model.Capability{model.CapabilityImage},
},
{
name: "model with image and vision capability (image editing)",
model: Model{
Config: model.ConfigV2{
Capabilities: []string{"image", "vision"},
},
},
expectedCaps: []model.Capability{model.CapabilityImage, model.CapabilityVision},
},
{
name: "model with completion capability",
model: Model{
ModelPath: completionModelPath,
Template: chatTemplate,
},
expectedCaps: []model.Capability{model.CapabilityCompletion},
},
{
name: "model with completion, tools, and insert capability",
model: Model{
ModelPath: completionModelPath,
Template: toolsInsertTemplate,
},
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityTools, model.CapabilityInsert},
},
{
name: "model with tools capability",
model: Model{
ModelPath: completionModelPath,
Template: toolsTemplate,
},
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityTools},
},
{
name: "model with vision capability",
model: Model{
ModelPath: visionModelPath,
Template: chatTemplate,
},
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityVision},
},
{
name: "model with vision, tools, and insert capability",
model: Model{
ModelPath: visionModelPath,
Template: toolsInsertTemplate,
},
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityVision, model.CapabilityTools, model.CapabilityInsert},
},
{
name: "model with embedding capability",
model: Model{
ModelPath: embeddingModelPath,
Template: chatTemplate,
},
expectedCaps: []model.Capability{model.CapabilityEmbedding},
},
{
name: "gemma4 small safetensors suppresses vision and audio",
model: Model{
Config: model.ConfigV2{
ModelFormat: "safetensors",
Renderer: gemma4RendererSmall,
Capabilities: []string{"vision", "audio"},
},
Template: chatTemplate,
},
},
{
name: "gemma4 large safetensors suppresses vision and audio",
model: Model{
Config: model.ConfigV2{
ModelFormat: "safetensors",
Renderer: gemma4RendererLarge,
Capabilities: []string{"vision", "audio"},
},
Template: chatTemplate,
},
},
{
name: "legacy gemma4 safetensors suppresses vision and audio",
model: Model{
Config: model.ConfigV2{
ModelFormat: "safetensors",
Renderer: gemma4RendererLegacy,
Capabilities: []string{"vision", "audio"},
},
Template: chatTemplate,
},
},
}
// compare two slices of model.Capability regardless of order
compareCapabilities := func(a, b []model.Capability) bool {
if len(a) != len(b) {
return false
}
aCount := make(map[model.Capability]int)
for _, cap := range a {
aCount[cap]++
}
bCount := make(map[model.Capability]int)
for _, cap := range b {
bCount[cap]++
}
for cap, count := range aCount {
if bCount[cap] != count {
return false
}
}
return true
}
for _, tt := range testModels {
t.Run(tt.name, func(t *testing.T) {
// Test Capabilities method
caps := tt.model.Capabilities()
if !compareCapabilities(caps, tt.expectedCaps) {
t.Errorf("Expected capabilities %v, got %v", tt.expectedCaps, caps)
}
})
}
}
func TestModelCheckCapabilities(t *testing.T) {
// Create simple model file for tests that don't depend on GGUF content
completionModelPath, _ := createBinFile(t, ggml.KV{
"general.architecture": "llama",
}, []*ggml.Tensor{})
// Create vision model (llama architecture with vision block count)
visionModelPath, _ := createBinFile(t, ggml.KV{
"general.architecture": "llama",
"llama.vision.block_count": uint32(1),
}, []*ggml.Tensor{})
// Create embedding model (bert architecture with pooling type)
embeddingModelPath, _ := createBinFile(t, ggml.KV{
"general.architecture": "bert",
"bert.pooling_type": uint32(1),
}, []*ggml.Tensor{})
toolsInsertTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}{{ if .suffix }}{{ .suffix }}{{ end }}")
if err != nil {
t.Fatalf("Failed to parse template: %v", err)
}
chatTemplate, err := template.Parse("{{ .prompt }}")
if err != nil {
t.Fatalf("Failed to parse template: %v", err)
}
toolsTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}")
if err != nil {
t.Fatalf("Failed to parse template: %v", err)
}
tests := []struct {
name string
model Model
checkCaps []model.Capability
expectedErrMsg string
}{
{
name: "completion model without tools capability",
model: Model{
ModelPath: completionModelPath,
Template: chatTemplate,
},
checkCaps: []model.Capability{model.CapabilityTools},
expectedErrMsg: "does not support tools",
},
{
name: "model with all needed capabilities",
model: Model{
ModelPath: completionModelPath,
Template: toolsInsertTemplate,
},
checkCaps: []model.Capability{model.CapabilityTools, model.CapabilityInsert},
},
{
name: "model missing insert capability",
model: Model{
ModelPath: completionModelPath,
Template: toolsTemplate,
},
checkCaps: []model.Capability{model.CapabilityInsert},
expectedErrMsg: "does not support insert",
},
{
name: "model missing vision capability",
model: Model{
ModelPath: completionModelPath,
Template: toolsTemplate,
},
checkCaps: []model.Capability{model.CapabilityVision},
expectedErrMsg: "does not support vision",
},
{
name: "model with vision capability",
model: Model{
ModelPath: visionModelPath,
Template: chatTemplate,
},
checkCaps: []model.Capability{model.CapabilityVision},
},
{
name: "model with embedding capability",
model: Model{
ModelPath: embeddingModelPath,
Template: chatTemplate,
},
checkCaps: []model.Capability{model.CapabilityEmbedding},
},
{
name: "unknown capability",
model: Model{
ModelPath: completionModelPath,
Template: chatTemplate,
},
checkCaps: []model.Capability{"unknown"},
expectedErrMsg: "unknown capability",
},
{
name: "model missing image generation capability",
model: Model{
ModelPath: completionModelPath,
Template: chatTemplate,
},
checkCaps: []model.Capability{model.CapabilityImage},
expectedErrMsg: "does not support image generation",
},
{
name: "model with image generation capability",
model: Model{
Config: model.ConfigV2{
Capabilities: []string{"image"},
},
},
checkCaps: []model.Capability{model.CapabilityImage},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Test CheckCapabilities method
err := tt.model.CheckCapabilities(tt.checkCaps...)
if tt.expectedErrMsg == "" {
if err != nil {
t.Errorf("Expected no error, got: %v", err)
}
} else {
if err == nil {
t.Errorf("Expected error containing %q, got nil", tt.expectedErrMsg)
} else if !strings.Contains(err.Error(), tt.expectedErrMsg) {
t.Errorf("Expected error containing %q, got: %v", tt.expectedErrMsg, err)
}
}
})
}
}
func TestPullModelManifest(t *testing.T) {
cases := []struct {
name string
manifest string
}{
{
name: "pretty printed",
manifest: `{ "schemaVersion": 2, "mediaType": "application/vnd.docker.distribution.manifest.v2+json",
"config": { "digest": "sha256:abc", "mediaType": "application/vnd.docker.container.image.v1+json", "size": 50 },
"layers": [{ "digest": "sha256:t1", "mediaType": "application/vnd.ollama.image.tensor", "size": 1024, "name": "model.weight" }]
}`,
},
{
name: "non-standard field order",
manifest: `{"layers":[{"size":999,"digest":"sha256:def","mediaType":"application/vnd.ollama.image.model"}],"schemaVersion":2,"config":{"size":50,"digest":"sha256:abc","mediaType":"application/vnd.docker.container.image.v1+json"},"mediaType":"application/vnd.docker.distribution.manifest.v2+json"}`,
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(tt.manifest))
}))
defer ts.Close()
n := model.ParseName("test/model:latest")
n.ProtocolScheme = "http"
n.Host = strings.TrimPrefix(ts.URL, "http://")
mf, data, err := pullModelManifest(t.Context(), n, &registryOptions{})
if err != nil {
t.Fatal(err)
}
// Raw bytes must be byte-for-byte identical to what the server sent
if string(data) != tt.manifest {
t.Fatalf("raw bytes differ from server response")
}
// SHA256 of returned data must match the expected registry digest
expectedDigest := fmt.Sprintf("%x", sha256.Sum256([]byte(tt.manifest)))
gotDigest := fmt.Sprintf("%x", sha256.Sum256(data))
if gotDigest != expectedDigest {
t.Fatalf("digest mismatch\ngot: %s\nwant: %s", gotDigest, expectedDigest)
}
// Parsed manifest must still be usable
if mf.SchemaVersion != 2 {
t.Fatalf("schemaVersion = %d, want 2", mf.SchemaVersion)
}
if mf.Config.Digest == "" {
t.Fatal("config digest is empty")
}
if len(mf.Layers) == 0 {
t.Fatal("expected at least one layer")
}
})
}
}

View File

@@ -0,0 +1,144 @@
package server
import (
"bytes"
"fmt"
"io"
"log/slog"
"os"
"path/filepath"
"strings"
"sync/atomic"
"time"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/envconfig"
)
type inferenceRequestLogger struct {
dir string
counter uint64
}
func newInferenceRequestLogger() (*inferenceRequestLogger, error) {
dir, err := os.MkdirTemp("", "ollama-request-logs-*")
if err != nil {
return nil, err
}
return &inferenceRequestLogger{dir: dir}, nil
}
func (s *Server) initRequestLogging() error {
if !envconfig.DebugLogRequests() {
return nil
}
requestLogger, err := newInferenceRequestLogger()
if err != nil {
return fmt.Errorf("enable OLLAMA_DEBUG_LOG_REQUESTS: %w", err)
}
s.requestLogger = requestLogger
slog.Info(fmt.Sprintf("request debug logging enabled; inference request logs will be stored in %s and include request bodies and replay curl commands", requestLogger.dir))
return nil
}
func (s *Server) withInferenceRequestLogging(route string, handlers ...gin.HandlerFunc) []gin.HandlerFunc {
if s.requestLogger == nil {
return handlers
}
return append([]gin.HandlerFunc{s.requestLogger.middleware(route)}, handlers...)
}
func (l *inferenceRequestLogger) middleware(route string) gin.HandlerFunc {
return func(c *gin.Context) {
if c.Request == nil {
c.Next()
return
}
method := c.Request.Method
host := c.Request.Host
scheme := "http"
if c.Request.TLS != nil {
scheme = "https"
}
contentType := c.GetHeader("Content-Type")
var body []byte
if c.Request.Body != nil {
var err error
body, err = io.ReadAll(c.Request.Body)
c.Request.Body = io.NopCloser(bytes.NewReader(body))
if err != nil {
slog.Warn("failed to read request body for debug logging", "route", route, "error", err)
}
}
c.Next()
l.log(route, method, scheme, host, contentType, body)
}
}
func (l *inferenceRequestLogger) log(route, method, scheme, host, contentType string, body []byte) {
if l == nil || l.dir == "" {
return
}
if contentType == "" {
contentType = "application/json"
}
if host == "" || scheme == "" {
base := envconfig.Host()
if host == "" {
host = base.Host
}
if scheme == "" {
scheme = base.Scheme
}
}
routeForFilename := sanitizeRouteForFilename(route)
timestamp := fmt.Sprintf("%s-%06d", time.Now().UTC().Format("20060102T150405.000000000Z"), atomic.AddUint64(&l.counter, 1))
bodyFilename := fmt.Sprintf("%s_%s_body.json", timestamp, routeForFilename)
curlFilename := fmt.Sprintf("%s_%s_request.sh", timestamp, routeForFilename)
bodyPath := filepath.Join(l.dir, bodyFilename)
curlPath := filepath.Join(l.dir, curlFilename)
if err := os.WriteFile(bodyPath, body, 0o600); err != nil {
slog.Warn("failed to write debug request body", "route", route, "error", err)
return
}
url := fmt.Sprintf("%s://%s%s", scheme, host, route)
curl := fmt.Sprintf("#!/bin/sh\nSCRIPT_DIR=\"$(CDPATH= cd -- \"$(dirname -- \"$0\")\" && pwd)\"\ncurl --request %s --url %q --header %q --data-binary @\"${SCRIPT_DIR}/%s\"\n", method, url, "Content-Type: "+contentType, bodyFilename)
if err := os.WriteFile(curlPath, []byte(curl), 0o600); err != nil {
slog.Warn("failed to write debug request replay command", "route", route, "error", err)
return
}
slog.Info(fmt.Sprintf("logged to %s, replay using curl with `sh %s`", bodyPath, curlPath))
}
func sanitizeRouteForFilename(route string) string {
route = strings.TrimPrefix(route, "/")
if route == "" {
return "root"
}
var b strings.Builder
b.Grow(len(route))
for _, r := range route {
if ('a' <= r && r <= 'z') || ('A' <= r && r <= 'Z') || ('0' <= r && r <= '9') {
b.WriteRune(r)
} else {
b.WriteByte('_')
}
}
return b.String()
}

544
server/internal/cache/blob/cache.go vendored Normal file
View File

@@ -0,0 +1,544 @@
// Package blob implements a content-addressable disk cache for blobs and
// manifests.
package blob
import (
"bytes"
"crypto/sha256"
"errors"
"fmt"
"hash"
"io"
"io/fs"
"iter"
"os"
"path/filepath"
"strings"
"time"
"github.com/ollama/ollama/server/internal/internal/names"
)
// Entry contains metadata about a blob in the cache.
type Entry struct {
Digest Digest
Size int64
Time time.Time // when added to the cache
}
// DiskCache caches blobs and manifests on disk.
//
// The cache is rooted at a directory, which is created if it does not exist.
//
// Blobs are stored in the "blobs" subdirectory, and manifests are stored in the
// "manifests" subdirectory. A example directory structure might look like:
//
// <dir>/
// blobs/
// sha256-<digest> - <blob data>
// manifests/
// <host>/
// <namespace>/
// <name>/
// <tag> - <manifest data>
//
// The cache is safe for concurrent use.
//
// Name casing is preserved in the cache, but is not significant when resolving
// names. For example, "Foo" and "foo" are considered the same name.
//
// The cache is not safe for concurrent use. It guards concurrent writes, but
// does not prevent duplicated effort. Because blobs are immutable, duplicate
// writes should result in the same file being written to disk.
type DiskCache struct {
// Dir specifies the top-level directory where blobs and manifest
// pointers are stored.
dir string
now func() time.Time
testHookBeforeFinalWrite func(f *os.File)
}
// PutBytes is a convenience function for c.Put(d, strings.NewReader(s), int64(len(s))).
func PutBytes[S string | []byte](c *DiskCache, d Digest, data S) error {
return c.Put(d, bytes.NewReader([]byte(data)), int64(len(data)))
}
// Open opens a cache rooted at the given directory. If the directory does not
// exist, it is created. If the directory is not a directory, an error is
// returned.
func Open(dir string) (*DiskCache, error) {
if dir == "" {
return nil, errors.New("blob: empty directory name")
}
info, err := os.Stat(dir)
if err == nil && !info.IsDir() {
return nil, fmt.Errorf("%q is not a directory", dir)
}
if err := os.MkdirAll(dir, 0o777); err != nil {
return nil, err
}
subdirs := []string{"blobs", "manifests"}
for _, subdir := range subdirs {
if err := os.MkdirAll(filepath.Join(dir, subdir), 0o777); err != nil {
return nil, err
}
}
// TODO(bmizerany): support shards
c := &DiskCache{
dir: dir,
now: time.Now,
}
return c, nil
}
func readAndSum(filename string, limit int64) (data []byte, _ Digest, err error) {
f, err := os.Open(filename)
if err != nil {
return nil, Digest{}, err
}
defer f.Close()
h := sha256.New()
r := io.TeeReader(f, h)
data, err = io.ReadAll(io.LimitReader(r, limit))
if err != nil {
return nil, Digest{}, err
}
var d Digest
h.Sum(d.sum[:0])
return data, d, nil
}
//lint:ignore U1000 used for debugging purposes as needed in tests
var debug = false
// debugger returns a function that can be used to add a step to the error message.
// The error message will be a list of steps that were taken before the error occurred.
// The steps are added in the order they are called.
//
// To set the error message, call the returned function with an empty string.
//
//lint:ignore U1000 used for debugging purposes as needed in tests
func debugger(err *error) func(step string) {
if !debug {
return func(string) {}
}
var steps []string
return func(step string) {
if step == "" && *err != nil {
*err = fmt.Errorf("%q: %w", steps, *err)
return
}
steps = append(steps, step)
if len(steps) > 100 {
// shift hints in case of a bug that causes a lot of hints
copy(steps, steps[1:])
steps = steps[:100]
}
}
}
// Resolve resolves a name to a digest. The name is expected to
// be in either of the following forms:
//
// @<digest>
// <name>@<digest>
// <name>
//
// If a digest is provided, it is returned as is and nothing else happens.
//
// If a name is provided for a manifest that exists in the cache, the digest
// of the manifest is returned. If there is no manifest in the cache, it
// returns [fs.ErrNotExist].
//
// To cover the case where a manifest may change without the cache knowing
// (e.g. it was reformatted or modified by hand), the manifest data read and
// hashed is passed to a PutBytes call to ensure that the manifest is in the
// blob store. This is done to ensure that future calls to [Get] succeed in
// these cases.
func (c *DiskCache) Resolve(name string) (Digest, error) {
name, digest := splitNameDigest(name)
if digest != "" {
return ParseDigest(digest)
}
// We want to address manifests files by digest using Get. That requires
// them to be blobs. This cannot be directly accomplished by looking in
// the blob store because manifests can change without Ollama knowing
// (e.g. a user modifies a manifests by hand then pushes it to update
// their model). We also need to support the blob caches inherited from
// older versions of Ollama, which do not store manifests in the blob
// store, so for these cases, we need to handle adding the manifests to
// the blob store, just in time.
//
// So now we read the manifests file, hash it, and copy it to the blob
// store if it's not already there.
//
// This should be cheap because manifests are small, and accessed
// infrequently.
file, err := c.manifestPath(name)
if err != nil {
return Digest{}, err
}
data, d, err := readAndSum(file, 1<<20)
if err != nil {
return Digest{}, err
}
// Ideally we'd read the "manifest" file as a manifest to the blob file,
// but we are not changing this yet, so copy the manifest to the blob
// store so it can be addressed by digest subsequent calls to Get.
if err := PutBytes(c, d, data); err != nil {
return Digest{}, err
}
return d, nil
}
// Put writes a new blob to the cache, identified by its digest. The operation
// reads content from r, which must precisely match both the specified size and
// digest.
//
// Concurrent write safety is achieved through file locking. The implementation
// guarantees write integrity by enforcing size limits and content validation
// before allowing the file to reach its final state.
func (c *DiskCache) Put(d Digest, r io.Reader, size int64) error {
return c.copyNamedFile(c.GetFile(d), r, d, size)
}
// Import imports a blob from the provided reader into the cache. It reads the
// entire content of the reader, calculates its digest, and stores it in the
// cache.
//
// Import should be considered unsafe for use with untrusted data, such as data
// read from a network. The caller is responsible for ensuring the integrity of
// the data being imported.
func (c *DiskCache) Import(r io.Reader, size int64) (Digest, error) {
// users that want to change the temp dir can set TEMPDIR.
f, err := os.CreateTemp("", "blob-")
if err != nil {
return Digest{}, err
}
defer os.Remove(f.Name())
// Copy the blob to a temporary file.
h := sha256.New()
r = io.TeeReader(r, h)
n, err := io.Copy(f, r)
if err != nil {
return Digest{}, err
}
if n != size {
return Digest{}, fmt.Errorf("blob: expected %d bytes, got %d", size, n)
}
// Check the digest.
var d Digest
h.Sum(d.sum[:0])
if err := f.Close(); err != nil {
return Digest{}, err
}
name := c.GetFile(d)
// Rename the temporary file to the final file.
if err := os.Rename(f.Name(), name); err != nil {
return Digest{}, err
}
os.Chtimes(name, c.now(), c.now()) // mainly for tests
return d, nil
}
// Get retrieves a blob from the cache using the provided digest. The operation
// fails if the digest is malformed or if any errors occur during blob
// retrieval.
func (c *DiskCache) Get(d Digest) (Entry, error) {
name := c.GetFile(d)
info, err := os.Stat(name)
if err != nil {
return Entry{}, err
}
if info.Size() == 0 {
return Entry{}, fs.ErrNotExist
}
return Entry{
Digest: d,
Size: info.Size(),
Time: info.ModTime(),
}, nil
}
// Link creates a symbolic reference in the cache that maps the provided name
// to a blob identified by its digest, making it retrievable by name using
// [Resolve].
//
// It returns an error if either the name or digest is invalid, or if link
// creation encounters any issues.
func (c *DiskCache) Link(name string, d Digest) error {
manifest, err := c.manifestPath(name)
if err != nil {
return err
}
f, err := os.OpenFile(c.GetFile(d), os.O_RDONLY, 0)
if err != nil {
return err
}
defer f.Close()
// TODO(bmizerany): test this happens only if the blob was found to
// avoid leaving debris
if err := os.MkdirAll(filepath.Dir(manifest), 0o777); err != nil {
return err
}
info, err := f.Stat()
if err != nil {
return err
}
// Copy manifest to cache directory.
return c.copyNamedFile(manifest, f, d, info.Size())
}
// Unlink unlinks the manifest by name from the cache. If the name is not
// found. If a manifest is removed ok will be true, otherwise false. If an
// error occurs, it returns ok false, and the error.
func (c *DiskCache) Unlink(name string) (ok bool, _ error) {
manifest, err := c.manifestPath(name)
if err != nil {
return false, err
}
err = os.Remove(manifest)
if errors.Is(err, fs.ErrNotExist) {
return false, nil
}
return true, err
}
// GetFile returns the absolute path to the file, in the cache, for the given
// digest. It does not check if the file exists.
//
// The returned path should not be stored, used outside the lifetime of the
// cache, or interpreted in any way.
func (c *DiskCache) GetFile(d Digest) string {
filename := fmt.Sprintf("sha256-%x", d.sum)
return absJoin(c.dir, "blobs", filename)
}
// Links returns a sequence of link names. The sequence is in lexical order.
// Names are converted from their relative path form to their name form but are
// not guaranteed to be valid. Callers should validate the names before using.
func (c *DiskCache) Links() iter.Seq2[string, error] {
return func(yield func(string, error) bool) {
for path, err := range c.links() {
if err != nil {
yield("", err)
return
}
if !yield(pathToName(path), nil) {
return
}
}
}
}
// pathToName converts a path to a name. It is the inverse of nameToPath. The
// path is assumed to be in filepath.ToSlash format.
func pathToName(s string) string {
s = strings.TrimPrefix(s, "manifests/")
rr := []rune(s)
for i := len(rr) - 1; i > 0; i-- {
if rr[i] == '/' {
rr[i] = ':'
return string(rr)
}
}
return s
}
// manifestPath finds the first manifest file on disk that matches the given
// name using a case-insensitive comparison. If no manifest file is found, it
// returns the path where the manifest file would be if it existed.
//
// If two manifest files exists on disk that match the given name using a
// case-insensitive comparison, the one that sorts first, lexically, is
// returned.
func (c *DiskCache) manifestPath(name string) (string, error) {
np, err := nameToPath(name)
if err != nil {
return "", err
}
maybe := filepath.Join("manifests", np)
for l, err := range c.links() {
if err != nil {
return "", err
}
if strings.EqualFold(maybe, l) {
return filepath.Join(c.dir, l), nil
}
}
return filepath.Join(c.dir, maybe), nil
}
// links returns a sequence of links in the cache in lexical order.
func (c *DiskCache) links() iter.Seq2[string, error] {
// TODO(bmizerany): reuse empty dirnames if exist
return func(yield func(string, error) bool) {
fsys := os.DirFS(c.dir)
manifests, err := fs.Glob(fsys, "manifests/*/*/*/*")
if err != nil {
yield("", err)
return
}
for _, manifest := range manifests {
if !yield(manifest, nil) {
return
}
}
}
}
type checkWriter struct {
size int64
d Digest
f *os.File
h hash.Hash
w io.Writer // underlying writer; set by creator
n int64
err error
testHookBeforeFinalWrite func(*os.File)
}
func (w *checkWriter) seterr(err error) error {
if w.err == nil {
w.err = err
}
return err
}
// Write writes p to the underlying hash and writer. The last write to the
// underlying writer is guaranteed to be the last byte of p as verified by the
// hash.
func (w *checkWriter) Write(p []byte) (int, error) {
if w.err != nil {
return 0, w.err
}
_, err := w.h.Write(p)
if err != nil {
return 0, w.seterr(err)
}
nextSize := w.n + int64(len(p))
if nextSize == w.size {
// last write. check hash.
sum := w.h.Sum(nil)
if !bytes.Equal(sum, w.d.sum[:]) {
return 0, w.seterr(fmt.Errorf("file content changed underfoot"))
}
if w.testHookBeforeFinalWrite != nil {
w.testHookBeforeFinalWrite(w.f)
}
}
if nextSize > w.size {
return 0, w.seterr(fmt.Errorf("content exceeds expected size: %d > %d", nextSize, w.size))
}
n, err := w.w.Write(p)
w.n += int64(n)
return n, w.seterr(err)
}
// copyNamedFile copies file into name, expecting it to have the given Digest
// and size, if that file is not present already.
func (c *DiskCache) copyNamedFile(name string, file io.Reader, out Digest, size int64) error {
info, err := os.Stat(name)
if err == nil && info.Size() == size {
// File already exists with correct size. This is good enough.
// We can skip expensive hash checks.
//
// TODO: Do the hash check, but give caller a way to skip it.
return nil
}
// Copy file to cache directory.
mode := os.O_RDWR | os.O_CREATE
if err == nil && info.Size() > size { // shouldn't happen but fix in case
mode |= os.O_TRUNC
}
f, err := os.OpenFile(name, mode, 0o666)
if err != nil {
return err
}
defer f.Close()
if size == 0 {
// File now exists with correct size.
// Only one possible zero-length file, so contents are OK too.
// Early return here makes sure there's a "last byte" for code below.
return nil
}
// From here on, if any of the I/O writing the file fails,
// we make a best-effort attempt to truncate the file f
// before returning, to avoid leaving bad bytes in the file.
// Copy file to f, but also into h to double-check hash.
cw := &checkWriter{
d: out,
size: size,
h: sha256.New(),
f: f,
w: f,
testHookBeforeFinalWrite: c.testHookBeforeFinalWrite,
}
n, err := io.Copy(cw, file)
if err != nil {
f.Truncate(0)
return err
}
if n < size {
f.Truncate(0)
return io.ErrUnexpectedEOF
}
if err := f.Close(); err != nil {
// Data might not have been written,
// but file may look like it is the right size.
// To be extra careful, remove cached file.
os.Remove(name)
return err
}
os.Chtimes(name, c.now(), c.now()) // mainly for tests
return nil
}
func splitNameDigest(s string) (name, digest string) {
i := strings.LastIndexByte(s, '@')
if i < 0 {
return s, ""
}
return s[:i], s[i+1:]
}
var errInvalidName = errors.New("invalid name")
func nameToPath(name string) (_ string, err error) {
n := names.Parse(name)
if !n.IsFullyQualified() {
return "", errInvalidName
}
return filepath.Join(n.Host(), n.Namespace(), n.Model(), n.Tag()), nil
}
func absJoin(pp ...string) string {
abs, err := filepath.Abs(filepath.Join(pp...))
if err != nil {
panic(err) // this should never happen
}
return abs
}

688
server/internal/cache/blob/cache_test.go vendored Normal file
View File

@@ -0,0 +1,688 @@
package blob
import (
"crypto/sha256"
"errors"
"fmt"
"io"
"io/fs"
"os"
"path/filepath"
"slices"
"strings"
"testing"
"time"
"github.com/ollama/ollama/server/internal/testutil"
)
func init() {
debug = true
}
var epoch = func() time.Time {
d := time.Date(2021, 1, 1, 0, 0, 0, 0, time.UTC)
if d.IsZero() {
panic("time zero")
}
return d
}()
func TestOpenErrors(t *testing.T) {
exe, err := os.Executable()
if err != nil {
panic(err)
}
cases := []struct {
dir string
err string
}{
{t.TempDir(), ""},
{"", "empty directory name"},
{exe, "not a directory"},
}
for _, tt := range cases {
t.Run(tt.dir, func(t *testing.T) {
_, err := Open(tt.dir)
if tt.err == "" {
if err != nil {
t.Fatal(err)
}
return
}
if err == nil {
t.Fatal("expected error")
}
if !strings.Contains(err.Error(), tt.err) {
t.Fatalf("err = %v, want %q", err, tt.err)
}
})
}
}
func TestGetFile(t *testing.T) {
t.Chdir(t.TempDir())
c, err := Open(".")
if err != nil {
t.Fatal(err)
}
d := mkdigest("1")
got := c.GetFile(d)
cleaned := filepath.Clean(got)
if cleaned != got {
t.Fatalf("got is unclean: %q", got)
}
if !filepath.IsAbs(got) {
t.Fatal("got is not absolute")
}
abs, _ := filepath.Abs(c.dir)
if !strings.HasPrefix(got, abs) {
t.Fatalf("got is not local to %q", c.dir)
}
}
func TestBasic(t *testing.T) {
c, err := Open(t.TempDir())
if err != nil {
t.Fatal(err)
}
now := epoch
c.now = func() time.Time { return now }
checkEntry := entryChecker(t, c)
checkFailed := func(err error) {
if err == nil {
t.Helper()
t.Fatal("expected error")
}
}
_, err = c.Resolve("invalid")
checkFailed(err)
_, err = c.Resolve("h/n/m:t")
checkFailed(err)
dx := mkdigest("x")
d, err := c.Resolve(fmt.Sprintf("h/n/m:t@%s", dx))
if err != nil {
t.Fatal(err)
}
if d != dx {
t.Fatalf("d = %v, want %v", d, dx)
}
_, err = c.Get(Digest{})
checkFailed(err)
// not committed yet
_, err = c.Get(dx)
checkFailed(err)
err = PutBytes(c, dx, "!")
checkFailed(err)
err = PutBytes(c, dx, "x")
if err != nil {
t.Fatal(err)
}
checkEntry(dx, 1, now)
t0 := now
now = now.Add(1*time.Hour + 1*time.Minute)
err = PutBytes(c, dx, "x")
if err != nil {
t.Fatal(err)
}
// check not updated
checkEntry(dx, 1, t0)
}
type sleepFunc func(d time.Duration) time.Time
func openTester(t *testing.T) (*DiskCache, sleepFunc) {
t.Helper()
c, err := Open(t.TempDir())
if err != nil {
t.Fatal(err)
}
now := epoch
c.now = func() time.Time { return now }
return c, func(d time.Duration) time.Time {
now = now.Add(d)
return now
}
}
func TestManifestPath(t *testing.T) {
check := testutil.Checker(t)
c, sleep := openTester(t)
d1 := mkdigest("1")
err := PutBytes(c, d1, "1")
check(err)
err = c.Link("h/n/m:t", d1)
check(err)
t0 := sleep(0)
sleep(1 * time.Hour)
err = c.Link("h/n/m:t", d1) // nop expected
check(err)
file := must(c.manifestPath("h/n/m:t"))
info, err := os.Stat(file)
check(err)
testutil.CheckTime(t, info.ModTime(), t0)
}
func TestManifestExistsWithoutBlob(t *testing.T) {
t.Chdir(t.TempDir())
check := testutil.Checker(t)
c, err := Open(".")
check(err)
checkEntry := entryChecker(t, c)
man := must(c.manifestPath("h/n/m:t"))
os.MkdirAll(filepath.Dir(man), 0o777)
testutil.WriteFile(t, man, "1")
got, err := c.Resolve("h/n/m:t")
check(err)
want := mkdigest("1")
if got != want {
t.Fatalf("got = %v, want %v", got, want)
}
e, err := c.Get(got)
check(err)
checkEntry(got, 1, e.Time)
}
func TestPut(t *testing.T) {
c, sleep := openTester(t)
check := testutil.Checker(t)
checkEntry := entryChecker(t, c)
d := mkdigest("hello, world")
err := PutBytes(c, d, "hello")
if err == nil {
t.Fatal("expected error")
}
got, err := c.Get(d)
if !errors.Is(err, fs.ErrNotExist) {
t.Fatalf("expected error, got %v", got)
}
// Put a valid blob
err = PutBytes(c, d, "hello, world")
check(err)
checkEntry(d, 12, sleep(0))
// Put a blob with content that does not hash to the digest
err = PutBytes(c, d, "hello")
if err == nil {
t.Fatal("expected error")
}
checkNotExists(t, c, d)
// Put the valid blob back and check it
err = PutBytes(c, d, "hello, world")
check(err)
checkEntry(d, 12, sleep(0))
// Put a blob that errors during Read
err = c.Put(d, &errOnBangReader{s: "!"}, 1)
if err == nil {
t.Fatal("expected error")
}
checkNotExists(t, c, d)
// Put valid blob back and check it
err = PutBytes(c, d, "hello, world")
check(err)
checkEntry(d, 12, sleep(0))
// Put a blob with mismatched size
err = c.Put(d, strings.NewReader("hello, world"), 11)
if err == nil {
t.Fatal("expected error")
}
checkNotExists(t, c, d)
// Final byte does not match the digest (testing commit phase)
err = PutBytes(c, d, "hello, world$")
if err == nil {
t.Fatal("expected error")
}
checkNotExists(t, c, d)
reset := c.setTestHookBeforeFinalWrite(func(f *os.File) {
// change mode to read-only
f.Truncate(0)
f.Chmod(0o400)
f.Close()
f1, err := os.OpenFile(f.Name(), os.O_RDONLY, 0)
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() { f1.Close() })
*f = *f1
})
defer reset()
err = PutBytes(c, d, "hello, world")
if err == nil {
t.Fatal("expected error")
}
checkNotExists(t, c, d)
reset()
}
func TestImport(t *testing.T) {
c, _ := openTester(t)
checkEntry := entryChecker(t, c)
want := mkdigest("x")
got, err := c.Import(strings.NewReader("x"), 1)
if err != nil {
t.Fatal(err)
}
if want != got {
t.Fatalf("digest = %v, want %v", got, want)
}
checkEntry(want, 1, epoch)
got, err = c.Import(strings.NewReader("x"), 1)
if err != nil {
t.Fatal(err)
}
if want != got {
t.Fatalf("digest = %v, want %v", got, want)
}
checkEntry(want, 1, epoch)
}
func (c *DiskCache) setTestHookBeforeFinalWrite(h func(*os.File)) (reset func()) {
old := c.testHookBeforeFinalWrite
c.testHookBeforeFinalWrite = h
return func() { c.testHookBeforeFinalWrite = old }
}
func TestPutGetZero(t *testing.T) {
c, sleep := openTester(t)
check := testutil.Checker(t)
checkEntry := entryChecker(t, c)
d := mkdigest("x")
err := PutBytes(c, d, "x")
check(err)
checkEntry(d, 1, sleep(0))
err = os.Truncate(c.GetFile(d), 0)
check(err)
_, err = c.Get(d)
if !errors.Is(err, fs.ErrNotExist) {
t.Fatalf("err = %v, want fs.ErrNotExist", err)
}
}
func TestPutZero(t *testing.T) {
c, _ := openTester(t)
d := mkdigest("x")
err := c.Put(d, strings.NewReader("x"), 0) // size == 0 (not size of content)
testutil.Check(t, err)
checkNotExists(t, c, d)
}
func TestCommit(t *testing.T) {
check := testutil.Checker(t)
c, err := Open(t.TempDir())
if err != nil {
t.Fatal(err)
}
checkEntry := entryChecker(t, c)
now := epoch
c.now = func() time.Time { return now }
d1 := mkdigest("1")
err = c.Link("h/n/m:t", d1)
if !errors.Is(err, fs.ErrNotExist) {
t.Fatalf("err = %v, want fs.ErrNotExist", err)
}
err = PutBytes(c, d1, "1")
check(err)
err = c.Link("h/n/m:t", d1)
check(err)
got, err := c.Resolve("h/n/m:t")
check(err)
if got != d1 {
t.Fatalf("d = %v, want %v", got, d1)
}
// commit again, more than 1 byte
d2 := mkdigest("22")
err = PutBytes(c, d2, "22")
check(err)
err = c.Link("h/n/m:t", d2)
check(err)
checkEntry(d2, 2, now)
filename := must(c.manifestPath("h/n/m:t"))
data, err := os.ReadFile(filename)
check(err)
if string(data) != "22" {
t.Fatalf("data = %q, want %q", data, "22")
}
t0 := now
now = now.Add(1 * time.Hour)
err = c.Link("h/n/m:t", d2) // same contents; nop
check(err)
info, err := os.Stat(filename)
check(err)
testutil.CheckTime(t, info.ModTime(), t0)
}
func TestManifestInvalidBlob(t *testing.T) {
c, _ := openTester(t)
d := mkdigest("1")
err := c.Link("h/n/m:t", d)
if err == nil {
t.Fatal("expected error")
}
checkNotExists(t, c, d)
err = PutBytes(c, d, "1")
testutil.Check(t, err)
err = os.WriteFile(c.GetFile(d), []byte("invalid"), 0o666)
if err != nil {
t.Fatal(err)
}
err = c.Link("h/n/m:t", d)
if !strings.Contains(err.Error(), "underfoot") {
t.Fatalf("err = %v, want error to contain %q", err, "underfoot")
}
}
func TestManifestNameReuse(t *testing.T) {
t.Run("case-insensitive", func(t *testing.T) {
// This should run on all file system types.
testManifestNameReuse(t)
})
t.Run("case-sensitive", func(t *testing.T) {
useCaseInsensitiveTempDir(t)
testManifestNameReuse(t)
})
}
func testManifestNameReuse(t *testing.T) {
check := testutil.Checker(t)
c, _ := openTester(t)
d1 := mkdigest("1")
err := PutBytes(c, d1, "1")
check(err)
err = c.Link("h/n/m:t", d1)
check(err)
d2 := mkdigest("22")
err = PutBytes(c, d2, "22")
check(err)
err = c.Link("H/N/M:T", d2)
check(err)
var g [2]Digest
g[0], err = c.Resolve("h/n/m:t")
check(err)
g[1], err = c.Resolve("H/N/M:T")
check(err)
w := [2]Digest{d2, d2}
if g != w {
t.Fatalf("g = %v, want %v", g, w)
}
var got []string
for l, err := range c.links() {
if err != nil {
t.Fatal(err)
}
got = append(got, l)
}
want := []string{"manifests/h/n/m/t"}
if !slices.Equal(got, want) {
t.Fatalf("got = %v, want %v", got, want)
}
// relink with different case
unlinked, err := c.Unlink("h/n/m:t")
check(err)
if !unlinked {
t.Fatal("expected unlinked")
}
err = c.Link("h/n/m:T", d1)
check(err)
got = got[:0]
for l, err := range c.links() {
if err != nil {
t.Fatal(err)
}
got = append(got, l)
}
// we should have only one link that is same case as the last link
want = []string{"manifests/h/n/m/T"}
if !slices.Equal(got, want) {
t.Fatalf("got = %v, want %v", got, want)
}
}
func TestManifestFile(t *testing.T) {
cases := []struct {
in string
want string
}{
{"", ""},
// valid names
{"h/n/m:t", "/manifests/h/n/m/t"},
{"hh/nn/mm:tt", "/manifests/hh/nn/mm/tt"},
{"%/%/%/%", ""},
// already a path
{"h/n/m/t", ""},
// refs are not names
{"h/n/m:t@sha256-1", ""},
{"m@sha256-1", ""},
{"n/m:t@sha256-1", ""},
}
c, _ := openTester(t)
for _, tt := range cases {
t.Run(tt.in, func(t *testing.T) {
got, err := c.manifestPath(tt.in)
if err != nil && tt.want != "" {
t.Fatalf("unexpected error: %v", err)
}
if err == nil && tt.want == "" {
t.Fatalf("expected error")
}
dir := filepath.ToSlash(c.dir)
got = filepath.ToSlash(got)
got = strings.TrimPrefix(got, dir)
if got != tt.want {
t.Fatalf("got = %q, want %q", got, tt.want)
}
})
}
}
func TestNames(t *testing.T) {
c, _ := openTester(t)
check := testutil.Checker(t)
check(PutBytes(c, mkdigest("1"), "1"))
check(PutBytes(c, mkdigest("2"), "2"))
check(c.Link("h/n/m:t", mkdigest("1")))
check(c.Link("h/n/m:u", mkdigest("2")))
var got []string
for l, err := range c.Links() {
if err != nil {
t.Fatal(err)
}
got = append(got, l)
}
want := []string{"h/n/m:t", "h/n/m:u"}
if !slices.Equal(got, want) {
t.Fatalf("got = %v, want %v", got, want)
}
}
func mkdigest(s string) Digest {
return Digest{sha256.Sum256([]byte(s))}
}
func checkNotExists(t *testing.T, c *DiskCache, d Digest) {
t.Helper()
_, err := c.Get(d)
if !errors.Is(err, fs.ErrNotExist) {
t.Fatalf("err = %v, want fs.ErrNotExist", err)
}
}
func entryChecker(t *testing.T, c *DiskCache) func(Digest, int64, time.Time) {
t.Helper()
return func(d Digest, size int64, mod time.Time) {
t.Helper()
t.Run("checkEntry:"+d.String(), func(t *testing.T) {
t.Helper()
defer func() {
if t.Failed() {
dumpCacheContents(t, c)
}
}()
e, err := c.Get(d)
if size == 0 && errors.Is(err, fs.ErrNotExist) {
err = nil
}
if err != nil {
t.Fatal(err)
}
if e.Digest != d {
t.Errorf("e.Digest = %v, want %v", e.Digest, d)
}
if e.Size != size {
t.Fatalf("e.Size = %v, want %v", e.Size, size)
}
testutil.CheckTime(t, e.Time, mod)
info, err := os.Stat(c.GetFile(d))
if err != nil {
t.Fatal(err)
}
if info.Size() != size {
t.Fatalf("info.Size = %v, want %v", info.Size(), size)
}
testutil.CheckTime(t, info.ModTime(), mod)
})
}
}
func must[T any](v T, err error) T {
if err != nil {
panic(err)
}
return v
}
func TestNameToPath(t *testing.T) {
_, err := nameToPath("h/n/m:t")
if err != nil {
t.Fatal(err)
}
}
type errOnBangReader struct {
s string
n int
}
func (e *errOnBangReader) Read(p []byte) (int, error) {
if len(p) < 1 {
return 0, io.ErrShortBuffer
}
if e.n >= len(p) {
return 0, io.EOF
}
if e.s[e.n] == '!' {
return 0, errors.New("bang")
}
p[0] = e.s[e.n]
e.n++
return 1, nil
}
func dumpCacheContents(t *testing.T, c *DiskCache) {
t.Helper()
var b strings.Builder
fsys := os.DirFS(c.dir)
fs.WalkDir(fsys, ".", func(path string, d fs.DirEntry, err error) error {
t.Helper()
if err != nil {
return err
}
info, err := d.Info()
if err != nil {
return err
}
// Format like ls:
//
// ; ls -la
// drwxr-xr-x 224 Jan 13 14:22 blob/sha256-123
// drwxr-xr-x 224 Jan 13 14:22 manifest/h/n/m
fmt.Fprintf(&b, " %s % 4d %s %s\n",
info.Mode(),
info.Size(),
info.ModTime().Format("Jan 2 15:04"),
path,
)
return nil
})
t.Log()
t.Logf("cache contents:\n%s", b.String())
}

View File

@@ -0,0 +1,93 @@
package blob
import (
"fmt"
"os"
"path/filepath"
"runtime"
"strings"
"testing"
)
func isCaseSensitive(dir string) bool {
defer func() {
os.Remove(filepath.Join(dir, "_casecheck"))
}()
exists := func(file string) bool {
_, err := os.Stat(file)
return err == nil
}
file := filepath.Join(dir, "_casecheck")
FILE := filepath.Join(dir, "_CASECHECK")
if exists(file) || exists(FILE) {
panic(fmt.Sprintf("_casecheck already exists in %q; remove and try again.", dir))
}
err := os.WriteFile(file, nil, 0o666)
if err != nil {
panic(err)
}
return !exists(FILE)
}
func isCI() bool {
return os.Getenv("CI") != ""
}
const volumeHint = `
Unable to locate case-insensitive TMPDIR on darwin.
To run tests, create the case-insensitive volume /Volumes/data:
$ sudo diskutil apfs addVolume disk1 APFSX data -mountpoint /Volumes/data
or run with:
CI=1 go test ./...
`
// useCaseInsensitiveTempDir sets TMPDIR to a case-insensitive directory
// can find one, otherwise it skips the test if the CI environment variable is
// set, or GOOS is not darwin.
func useCaseInsensitiveTempDir(t *testing.T) bool {
if isCaseSensitive(os.TempDir()) {
// Use the default temp dir if it is already case-sensitive.
return true
}
if runtime.GOOS == "darwin" {
// If darwin, check for the special case-sensitive volume and
// use it if available.
const volume = "/Volumes/data"
_, err := os.Stat(volume)
if err == nil {
tmpdir := filepath.Join(volume, "tmp")
os.MkdirAll(tmpdir, 0o700)
t.Setenv("TMPDIR", tmpdir)
return true
}
if isCI() {
// Special case darwin in CI; it is not case-sensitive
// by default, and we will be testing other platforms
// that are case-sensitive, so we'll have the test
// being skipped covered there.
t.Skip("Skipping test in CI for darwin; TMPDIR is not case-insensitive.")
}
}
if !isCI() {
// Require devs to always tests with a case-insensitive TMPDIR.
// TODO(bmizerany): Print platform-specific instructions or
// link to docs on that topic.
lines := strings.Split(volumeHint, "\n")
for _, line := range lines {
t.Skip(line)
}
}
return false
}

73
server/internal/cache/blob/chunked.go vendored Normal file
View File

@@ -0,0 +1,73 @@
package blob
import (
"crypto/sha256"
"errors"
"io"
"os"
)
// Chunk represents a range of bytes in a blob.
type Chunk struct {
Start int64
End int64
}
// Size returns end minus start plus one.
func (c Chunk) Size() int64 {
return c.End - c.Start + 1
}
// Chunker writes to a blob in chunks.
// Its zero value is invalid. Use [DiskCache.Chunked] to create a new Chunker.
type Chunker struct {
digest Digest
size int64
f *os.File // nil means pre-validated
}
// Chunked returns a new Chunker, ready for use storing a blob of the given
// size in chunks.
//
// Use [Chunker.Put] to write data to the blob at specific offsets.
func (c *DiskCache) Chunked(d Digest, size int64) (*Chunker, error) {
name := c.GetFile(d)
info, err := os.Stat(name)
if err == nil && info.Size() == size {
return &Chunker{}, nil
}
f, err := os.OpenFile(name, os.O_CREATE|os.O_WRONLY, 0o666)
if err != nil {
return nil, err
}
return &Chunker{digest: d, size: size, f: f}, nil
}
// Put copies chunk.Size() bytes from r to the blob at the given offset,
// merging the data with the existing blob. It returns an error if any. As a
// special case, if r has less than chunk.Size() bytes, Put returns
// io.ErrUnexpectedEOF.
func (c *Chunker) Put(chunk Chunk, d Digest, r io.Reader) error {
if c.f == nil {
return nil
}
cw := &checkWriter{
d: d,
size: chunk.Size(),
h: sha256.New(),
f: c.f,
w: io.NewOffsetWriter(c.f, chunk.Start),
}
_, err := io.CopyN(cw, r, chunk.Size())
if err != nil && errors.Is(err, io.EOF) {
return io.ErrUnexpectedEOF
}
return err
}
// Close closes the underlying file.
func (c *Chunker) Close() error {
return c.f.Close()
}

99
server/internal/cache/blob/digest.go vendored Normal file
View File

@@ -0,0 +1,99 @@
package blob
import (
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"slices"
"strings"
)
var ErrInvalidDigest = errors.New("invalid digest")
// Digest is a blob identifier that is the SHA-256 hash of a blob's content.
//
// It is comparable and can be used as a map key.
type Digest struct {
sum [32]byte
}
// ParseDigest parses a digest from a string. If the string is not a valid
// digest, a call to the returned digest's IsValid method will return false.
//
// The input string may be in one of two forms:
//
// - ("sha256-<hex>"), where <hex> is a 64-character hexadecimal string.
// - ("sha256:<hex>"), where <hex> is a 64-character hexadecimal string.
//
// The [Digest.String] method will return the canonical form of the
// digest, "sha256:<hex>".
func ParseDigest[S ~[]byte | ~string](v S) (Digest, error) {
s := string(v)
i := strings.IndexAny(s, ":-")
var zero Digest
if i < 0 {
return zero, ErrInvalidDigest
}
prefix, sum := s[:i], s[i+1:]
if prefix != "sha256" || len(sum) != 64 {
return zero, ErrInvalidDigest
}
var d Digest
_, err := hex.Decode(d.sum[:], []byte(sum))
if err != nil {
return zero, ErrInvalidDigest
}
return d, nil
}
func DigestFromBytes[S ~[]byte | ~string](v S) Digest {
return Digest{sha256.Sum256([]byte(v))}
}
// String returns the string representation of the digest in the conventional
// form "sha256:<hex>".
func (d Digest) String() string {
return fmt.Sprintf("sha256:%x", d.sum[:])
}
func (d Digest) Short() string {
return fmt.Sprintf("%x", d.sum[:4])
}
func (d Digest) Sum() [32]byte {
return d.sum
}
func (d Digest) Compare(other Digest) int {
return slices.Compare(d.sum[:], other.sum[:])
}
// IsValid returns true if the digest is valid, i.e. if it is the SHA-256 hash
// of some content.
func (d Digest) IsValid() bool {
return d != (Digest{})
}
// MarshalText implements the encoding.TextMarshaler interface. It returns an
// error if [Digest.IsValid] returns false.
func (d Digest) MarshalText() ([]byte, error) {
return []byte(d.String()), nil
}
// UnmarshalText implements the encoding.TextUnmarshaler interface, and only
// works for a zero digest. If [Digest.IsValid] returns true, it returns an
// error.
func (d *Digest) UnmarshalText(text []byte) error {
if *d != (Digest{}) {
return errors.New("digest: illegal UnmarshalText on valid digest")
}
v, err := ParseDigest(string(text))
if err != nil {
return err
}
*d = v
return nil
}

View File

@@ -0,0 +1,63 @@
package blob
import (
"encoding/json"
"testing"
)
func TestParseDigest(t *testing.T) {
cases := []struct {
in string
valid bool
}{
{"sha256-0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", true},
{"sha256:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", true},
// too short
{"sha256-0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcde", false},
{"sha256:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcde", false},
// too long
{"sha256-0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0", false},
{"sha256:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0", false},
// invalid prefix
{"sha255-0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", false},
{"sha255:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", false},
{"sha256!0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", false},
// invalid hex
{"sha256-XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX", false},
{"sha256:XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX", false},
}
for _, tt := range cases {
got, err := ParseDigest(tt.in)
if tt.valid && err != nil {
t.Errorf("ParseDigest(%q) = %v, %v; want valid", tt.in, got, err)
}
want := "sha256:" + tt.in[7:]
if tt.valid && got.String() != want {
t.Errorf("ParseDigest(%q).String() = %q, want %q", tt.in, got.String(), want)
}
}
}
func TestDigestMarshalText(t *testing.T) {
const s = `"sha256-0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"`
var d Digest
if err := json.Unmarshal([]byte(s), &d); err != nil {
t.Errorf("json.Unmarshal: %v", err)
}
out, err := json.Marshal(d)
if err != nil {
t.Errorf("json.Marshal: %v", err)
}
want := `"sha256:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"`
if string(out) != want {
t.Errorf("json.Marshal: got %s, want %s", out, want)
}
if err := json.Unmarshal([]byte(`"invalid"`), &Digest{}); err == nil {
t.Errorf("json.Unmarshal: expected error")
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,51 @@
// TODO: go:build goexperiment.synctest
package ollama
import (
"context"
"errors"
"io"
"net/http"
"strings"
"testing"
"time"
)
func TestPullDownloadTimeout(t *testing.T) {
rc, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
defer t.Log("upstream", r.Method, r.URL.Path)
switch {
case strings.HasPrefix(r.URL.Path, "/v2/library/smol/manifests/"):
io.WriteString(w, `{
"layers": [{"digest": "sha256:1111111111111111111111111111111111111111111111111111111111111111", "size": 3}]
}`)
case strings.HasPrefix(r.URL.Path, "/v2/library/smol/blobs/sha256:1111111111111111111111111111111111111111111111111111111111111111"):
// Get headers out to client and then hang on the response
w.WriteHeader(200)
w.(http.Flusher).Flush()
// Hang on the response and unblock when the client
// gives up
<-r.Context().Done()
default:
t.Fatalf("unexpected request: %s", r.URL.Path)
}
})
rc.ReadTimeout = 100 * time.Millisecond
done := make(chan error, 1)
go func() {
done <- rc.Pull(ctx, "http://example.com/library/smol")
}()
select {
case err := <-done:
want := context.DeadlineExceeded
if !errors.Is(err, want) {
t.Errorf("err = %v, want %v", err, want)
}
case <-time.After(3 * time.Second):
t.Error("timeout waiting for Pull to finish")
}
}

View File

@@ -0,0 +1,953 @@
package ollama
import (
"bytes"
"cmp"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"io/fs"
"net"
"net/http"
"net/http/httptest"
"os"
"reflect"
"strings"
"sync/atomic"
"testing"
"github.com/ollama/ollama/server/internal/cache/blob"
"github.com/ollama/ollama/server/internal/testutil"
)
func ExampleRegistry_cancelOnFirstError() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ctx = WithTrace(ctx, &Trace{
Update: func(l *Layer, n int64, err error) {
if err != nil {
// Discontinue pulling layers if there is an
// error instead of continuing to pull more
// data.
cancel()
}
},
})
var r Registry
if err := r.Pull(ctx, "model"); err != nil {
// panic for demo purposes
panic(err)
}
}
func TestManifestMarshalJSON(t *testing.T) {
// All manifests should contain an "empty" config object.
var m Manifest
data, err := json.Marshal(m)
if err != nil {
t.Fatal(err)
}
if !bytes.Contains(data, []byte(`"config":{"digest":"sha256:`)) {
t.Error("expected manifest to contain empty config")
t.Fatalf("got:\n%s", string(data))
}
}
var errRoundTrip = errors.New("forced roundtrip error")
type recordRoundTripper http.HandlerFunc
func (rr recordRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
w := httptest.NewRecorder()
rr(w, req)
if w.Code == 499 {
return nil, errRoundTrip
}
resp := w.Result()
// For some reason, Response.Request is not set by httptest.NewRecorder, so we
// set it manually.
resp.Request = req
return w.Result(), nil
}
// newClient constructs a cache with predefined manifests for testing. The manifests are:
//
// empty: no data
// zero: no layers
// single: one layer with the contents "exists"
// multiple: two layers with the contents "exists" and "here"
// notfound: a layer that does not exist in the cache
// null: one null layer (e.g. [null])
// sizemismatch: one valid layer, and one with a size mismatch (file size is less than the reported size)
// invalid: a layer with invalid JSON data
//
// Tests that want to ensure the client does not communicate with the upstream
// registry should pass a nil handler, which will cause a panic if
// communication is attempted.
//
// To simulate a network error, pass a handler that returns a 499 status code.
func newClient(t *testing.T, upstreamRegistry http.HandlerFunc) (*Registry, *blob.DiskCache) {
t.Helper()
c, err := blob.Open(t.TempDir())
if err != nil {
t.Fatal(err)
}
mklayer := func(data string) *Layer {
return &Layer{
Digest: importBytes(t, c, data),
Size: int64(len(data)),
}
}
r := &Registry{
Cache: c,
HTTPClient: &http.Client{
Transport: recordRoundTripper(upstreamRegistry),
},
}
link := func(name string, manifest string) {
n, err := r.parseName(name)
if err != nil {
panic(err)
}
d, err := c.Import(bytes.NewReader([]byte(manifest)), int64(len(manifest)))
if err != nil {
panic(err)
}
if err := c.Link(n.String(), d); err != nil {
panic(err)
}
}
commit := func(name string, layers ...*Layer) {
t.Helper()
data, err := json.Marshal(&Manifest{Layers: layers})
if err != nil {
t.Fatal(err)
}
link(name, string(data))
}
link("empty", "")
commit("zero")
commit("single", mklayer("exists"))
commit("multiple", mklayer("exists"), mklayer("present"))
commit("notfound", &Layer{Digest: blob.DigestFromBytes("notfound"), Size: int64(len("notfound"))})
commit("null", nil)
commit("sizemismatch", mklayer("exists"), &Layer{Digest: blob.DigestFromBytes("present"), Size: 499})
link("invalid", "!!!!!")
return r, c
}
func okHandler(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}
func checkErrCode(t *testing.T, err error, status int, code string) {
t.Helper()
var e *Error
if !errors.As(err, &e) || e.status != status || e.Code != code {
t.Errorf("err = %v; want %v %v", err, status, code)
}
}
func importBytes(t *testing.T, c *blob.DiskCache, data string) blob.Digest {
d, err := c.Import(strings.NewReader(data), int64(len(data)))
if err != nil {
t.Fatal(err)
}
return d
}
func withTraceUnexpected(ctx context.Context) (context.Context, *Trace) {
t := &Trace{Update: func(*Layer, int64, error) { panic("unexpected") }}
return WithTrace(ctx, t), t
}
func TestPushZero(t *testing.T) {
rc, _ := newClient(t, okHandler)
err := rc.Push(t.Context(), "empty", nil)
if !errors.Is(err, ErrManifestInvalid) {
t.Errorf("err = %v; want %v", err, ErrManifestInvalid)
}
}
func TestPushSingle(t *testing.T) {
rc, _ := newClient(t, okHandler)
err := rc.Push(t.Context(), "single", nil)
testutil.Check(t, err)
}
func TestPushMultiple(t *testing.T) {
rc, _ := newClient(t, okHandler)
err := rc.Push(t.Context(), "multiple", nil)
testutil.Check(t, err)
}
func TestPushNotFound(t *testing.T) {
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
t.Errorf("unexpected request: %v", r)
})
err := rc.Push(t.Context(), "notfound", nil)
if !errors.Is(err, fs.ErrNotExist) {
t.Errorf("err = %v; want %v", err, fs.ErrNotExist)
}
}
func TestPushNullLayer(t *testing.T) {
rc, _ := newClient(t, nil)
err := rc.Push(t.Context(), "null", nil)
if err == nil || !strings.Contains(err.Error(), "invalid manifest") {
t.Errorf("err = %v; want invalid manifest", err)
}
}
func TestPushSizeMismatch(t *testing.T) {
rc, _ := newClient(t, nil)
ctx, _ := withTraceUnexpected(t.Context())
got := rc.Push(ctx, "sizemismatch", nil)
if got == nil || !strings.Contains(got.Error(), "size mismatch") {
t.Errorf("err = %v; want size mismatch", got)
}
}
func TestPushInvalid(t *testing.T) {
rc, _ := newClient(t, nil)
err := rc.Push(t.Context(), "invalid", nil)
if err == nil || !strings.Contains(err.Error(), "invalid manifest") {
t.Errorf("err = %v; want invalid manifest", err)
}
}
func TestPushExistsAtRemote(t *testing.T) {
var pushed bool
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/uploads/") {
if !pushed {
// First push. Return an uploadURL.
pushed = true
w.Header().Set("Location", "http://blob.store/blobs/123")
return
}
w.WriteHeader(http.StatusAccepted)
return
}
io.Copy(io.Discard, r.Body)
w.WriteHeader(http.StatusOK)
})
rc.MaxStreams = 1 // prevent concurrent uploads
var errs []error
ctx := WithTrace(t.Context(), &Trace{
Update: func(_ *Layer, n int64, err error) {
// uploading one at a time so no need to lock
errs = append(errs, err)
},
})
check := testutil.Checker(t)
err := rc.Push(ctx, "single", nil)
check(err)
if !errors.Is(errors.Join(errs...), nil) {
t.Errorf("errs = %v; want %v", errs, []error{ErrCached})
}
err = rc.Push(ctx, "single", nil)
check(err)
}
func TestPushRemoteError(t *testing.T) {
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/blobs/") {
w.WriteHeader(500)
io.WriteString(w, `{"errors":[{"code":"blob_error"}]}`)
return
}
})
got := rc.Push(t.Context(), "single", nil)
checkErrCode(t, got, 500, "blob_error")
}
func TestPushLocationError(t *testing.T) {
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Location", ":///x")
w.WriteHeader(http.StatusAccepted)
})
got := rc.Push(t.Context(), "single", nil)
wantContains := "invalid upload URL"
if got == nil || !strings.Contains(got.Error(), wantContains) {
t.Errorf("err = %v; want to contain %v", got, wantContains)
}
}
func TestPushUploadRoundtripError(t *testing.T) {
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if r.Host == "blob.store" {
w.WriteHeader(499) // force RoundTrip error on upload
return
}
w.Header().Set("Location", "http://blob.store/blobs/123")
})
got := rc.Push(t.Context(), "single", nil)
if !errors.Is(got, errRoundTrip) {
t.Errorf("got = %v; want %v", got, errRoundTrip)
}
}
func TestPushUploadFileOpenError(t *testing.T) {
rc, c := newClient(t, okHandler)
ctx := WithTrace(t.Context(), &Trace{
Update: func(l *Layer, _ int64, err error) {
// Remove the file just before it is opened for upload,
// but after the initial Stat that happens before the
// upload starts
os.Remove(c.GetFile(l.Digest))
},
})
got := rc.Push(ctx, "single", nil)
if !errors.Is(got, fs.ErrNotExist) {
t.Errorf("got = %v; want fs.ErrNotExist", got)
}
}
func TestPushCommitRoundtripError(t *testing.T) {
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/blobs/") {
panic("unexpected")
}
w.WriteHeader(499) // force RoundTrip error
})
err := rc.Push(t.Context(), "zero", nil)
if !errors.Is(err, errRoundTrip) {
t.Errorf("err = %v; want %v", err, errRoundTrip)
}
}
func TestRegistryPullInvalidName(t *testing.T) {
rc, _ := newRegistryClient(t, nil)
err := rc.Pull(t.Context(), "://")
if !errors.Is(err, ErrNameInvalid) {
t.Errorf("err = %v; want %v", err, ErrNameInvalid)
}
}
func TestRegistryPullInvalidManifest(t *testing.T) {
cases := []string{
"",
"null",
"!!!",
`{"layers":[]}`,
}
for _, resp := range cases {
rc, _ := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
io.WriteString(w, resp)
})
err := rc.Pull(t.Context(), "http://example.com/a/b")
if !errors.Is(err, ErrManifestInvalid) {
t.Errorf("err = %v; want invalid manifest", err)
}
}
}
func TestRegistryResolveByDigest(t *testing.T) {
check := testutil.Checker(t)
exists := blob.DigestFromBytes("exists")
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v2/alice/palace/blobs/"+exists.String() {
w.WriteHeader(499) // should not hit manifest endpoint
}
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":5}]}`, exists)
})
_, err := rc.Resolve(t.Context(), "alice/palace@"+exists.String())
check(err)
}
func TestInsecureSkipVerify(t *testing.T) {
exists := blob.DigestFromBytes("exists")
s := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":5}]}`, exists)
}))
defer s.Close()
const name = "library/insecure"
var rc Registry
url := fmt.Sprintf("https://%s/%s", s.Listener.Addr(), name)
_, err := rc.Resolve(t.Context(), url)
if err == nil || !strings.Contains(err.Error(), "failed to verify") {
t.Errorf("err = %v; want cert verification failure", err)
}
url = fmt.Sprintf("https+insecure://%s/%s", s.Listener.Addr(), name)
_, err = rc.Resolve(t.Context(), url)
testutil.Check(t, err)
}
func TestErrorUnmarshal(t *testing.T) {
cases := []struct {
name string
data string
want *Error
wantErr bool
}{
{
name: "errors empty",
data: `{"errors":[]}`,
wantErr: true,
},
{
name: "errors empty",
data: `{"errors":[]}`,
wantErr: true,
},
{
name: "errors single",
data: `{"errors":[{"code":"blob_unknown"}]}`,
want: &Error{Code: "blob_unknown", Message: ""},
},
{
name: "errors multiple",
data: `{"errors":[{"code":"blob_unknown"},{"code":"blob_error"}]}`,
want: &Error{Code: "blob_unknown", Message: ""},
},
{
name: "error empty",
data: `{"error":""}`,
wantErr: true,
},
{
name: "error very empty",
data: `{}`,
wantErr: true,
},
{
name: "error message",
data: `{"error":"message", "code":"code"}`,
want: &Error{Code: "code", Message: "message"},
},
{
name: "invalid value",
data: `{"error": 1}`,
wantErr: true,
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
var got Error
err := json.Unmarshal([]byte(tt.data), &got)
if err != nil {
if tt.wantErr {
return
}
t.Errorf("Unmarshal() error = %v", err)
// fallthrough and check got
}
if tt.want == nil {
tt.want = &Error{}
}
if !reflect.DeepEqual(got, *tt.want) {
t.Errorf("got = %v; want %v", got, *tt.want)
}
})
}
}
// TestParseNameExtendedErrors tests that parseName returns errors messages with enough
// detail for users to debug naming issues they may encounter. Previous to this
// test, the error messages were not very helpful and each problem was reported
// as the same message.
//
// It is only for testing error messages, not that all invalids and valids are
// covered. Those are in other tests for names.Name and blob.Digest.
func TestParseNameExtendedErrors(t *testing.T) {
cases := []struct {
name string
err error
want string
}{}
var r Registry
for _, tt := range cases {
_, _, _, err := r.parseNameExtended(tt.name)
if !errors.Is(err, tt.err) {
t.Errorf("[%s]: err = %v; want %v", tt.name, err, tt.err)
}
if err != nil && !strings.Contains(err.Error(), tt.want) {
t.Errorf("[%s]: err =\n\t%v\nwant\n\t%v", tt.name, err, tt.want)
}
}
}
func TestParseNameExtended(t *testing.T) {
cases := []struct {
in string
scheme string
name string
digest string
err string
}{
{in: "http://m", scheme: "http", name: "m"},
{in: "https+insecure://m", scheme: "https+insecure", name: "m"},
{in: "http+insecure://m", err: "unsupported scheme"},
{in: "http://m@sha256:1111111111111111111111111111111111111111111111111111111111111111", scheme: "http", name: "m", digest: "sha256:1111111111111111111111111111111111111111111111111111111111111111"},
{in: "", err: "invalid or missing name"},
{in: "m", scheme: "https", name: "m"},
{in: "://", err: "invalid or missing name"},
{in: "@sha256:deadbeef", err: "invalid digest"},
{in: "@sha256:deadbeef@sha256:deadbeef", err: "invalid digest"},
}
for _, tt := range cases {
t.Run(tt.in, func(t *testing.T) {
var r Registry
scheme, n, digest, err := r.parseNameExtended(tt.in)
if err != nil {
if tt.err == "" {
t.Errorf("err = %v; want nil", err)
} else if !strings.Contains(err.Error(), tt.err) {
t.Errorf("err = %v; want %q", err, tt.err)
}
} else if tt.err != "" {
t.Errorf("err = nil; want %q", tt.err)
}
if err == nil && !n.IsFullyQualified() {
t.Errorf("name = %q; want fully qualified", n)
}
if scheme != tt.scheme {
t.Errorf("scheme = %q; want %q", scheme, tt.scheme)
}
// smoke-test name is superset of tt.name
if !strings.Contains(n.String(), tt.name) {
t.Errorf("name = %q; want %q", n, tt.name)
}
tt.digest = cmp.Or(tt.digest, (&blob.Digest{}).String())
if digest.String() != tt.digest {
t.Errorf("digest = %q; want %q", digest, tt.digest)
}
})
}
}
func TestUnlink(t *testing.T) {
t.Run("found by name", func(t *testing.T) {
check := testutil.Checker(t)
rc, _ := newRegistryClient(t, nil)
// make a blob and link it
d := blob.DigestFromBytes("{}")
err := blob.PutBytes(rc.Cache, d, "{}")
check(err)
err = rc.Cache.Link("registry.ollama.ai/library/single:latest", d)
check(err)
// confirm linked
_, err = rc.ResolveLocal("single")
check(err)
// unlink
_, err = rc.Unlink("single")
check(err)
// confirm unlinked
_, err = rc.ResolveLocal("single")
if !errors.Is(err, fs.ErrNotExist) {
t.Errorf("err = %v; want fs.ErrNotExist", err)
}
})
t.Run("not found by name", func(t *testing.T) {
rc, _ := newRegistryClient(t, nil)
ok, err := rc.Unlink("manifestNotFound")
if err != nil {
t.Fatal(err)
}
if ok {
t.Error("expected not found")
}
})
}
// Many tests from here out, in this file are based on a single blob, "abc",
// with the checksum of its sha256 hash. The checksum is:
//
// "abc" -> sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad
//
// Using the literal value instead of a constant with fmt.Xprintf calls proved
// to be the most readable and maintainable approach. The sum is consistently
// used in the tests and unique so searches do not yield false positives.
func checkRequest(t *testing.T, req *http.Request, method, path string) {
t.Helper()
if got := req.URL.Path; got != path {
t.Errorf("URL = %q, want %q", got, path)
}
if req.Method != method {
t.Errorf("Method = %q, want %q", req.Method, method)
}
}
func newRegistryClient(t *testing.T, upstream http.HandlerFunc) (*Registry, context.Context) {
s := httptest.NewServer(upstream)
t.Cleanup(s.Close)
cache, err := blob.Open(t.TempDir())
if err != nil {
t.Fatal(err)
}
ctx := WithTrace(t.Context(), &Trace{
Update: func(l *Layer, n int64, err error) {
t.Log("trace:", l.Digest.Short(), n, err)
},
})
rc := &Registry{
Cache: cache,
HTTPClient: &http.Client{Transport: &http.Transport{
Dial: func(network, addr string) (net.Conn, error) {
return net.Dial(network, s.Listener.Addr().String())
},
}},
}
return rc, ctx
}
func TestPullChunked(t *testing.T) {
var steps atomic.Int64
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
switch steps.Add(1) {
case 1:
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
case 2:
checkRequest(t, r, "GET", "/v2/library/abc/chunksums/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
fmt.Fprintf(w, "%s 0-1\n", blob.DigestFromBytes("ab"))
fmt.Fprintf(w, "%s 2-2\n", blob.DigestFromBytes("c"))
case 3, 4:
checkRequest(t, r, "GET", "/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
switch rng := r.Header.Get("Range"); rng {
case "bytes=0-1":
io.WriteString(w, "ab")
case "bytes=2-2":
t.Logf("writing c")
io.WriteString(w, "c")
default:
t.Errorf("unexpected range %q", rng)
}
default:
t.Errorf("unexpected steps %d: %v", steps.Load(), r)
http.Error(w, "unexpected steps", http.StatusInternalServerError)
}
})
c.ChunkingThreshold = 1 // force chunking
err := c.Pull(ctx, "http://o.com/library/abc")
testutil.Check(t, err)
_, err = c.Cache.Resolve("o.com/library/abc:latest")
testutil.Check(t, err)
if g := steps.Load(); g != 4 {
t.Fatalf("got %d steps, want 4", g)
}
}
func TestPullCached(t *testing.T) {
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
})
check := testutil.Checker(t)
// Premeptively cache the blob
d, err := blob.ParseDigest("sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
check(err)
err = blob.PutBytes(c.Cache, d, []byte("abc"))
check(err)
// Pull only the manifest, which should be enough to resolve the cached blob
err = c.Pull(ctx, "http://o.com/library/abc")
check(err)
}
func TestPullManifestError(t *testing.T) {
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
w.WriteHeader(http.StatusNotFound)
io.WriteString(w, `{"errors":[{"code":"MANIFEST_UNKNOWN"}]}`)
})
err := c.Pull(ctx, "http://o.com/library/abc")
if err == nil {
t.Fatalf("expected error")
}
var got *Error
if !errors.Is(err, ErrModelNotFound) {
t.Fatalf("err = %v, want %v", got, ErrModelNotFound)
}
}
func TestPullLayerError(t *testing.T) {
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
io.WriteString(w, `!`)
})
err := c.Pull(ctx, "http://o.com/library/abc")
if err == nil {
t.Fatalf("expected error")
}
var want *json.SyntaxError
if !errors.As(err, &want) {
t.Fatalf("err = %T, want %T", err, want)
}
}
func TestPullLayerChecksumError(t *testing.T) {
var step atomic.Int64
c, _ := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
switch step.Add(1) {
case 1:
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
case 2:
checkRequest(t, r, "GET", "/v2/library/abc/chunksums/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
fmt.Fprintf(w, "%s 0-1\n", blob.DigestFromBytes("ab"))
fmt.Fprintf(w, "%s 2-2\n", blob.DigestFromBytes("c"))
case 3:
w.WriteHeader(http.StatusNotFound)
io.WriteString(w, `{"errors":[{"code":"BLOB_UNKNOWN"}]}`)
case 4:
io.WriteString(w, "c")
default:
t.Errorf("unexpected steps %d: %v", step.Load(), r)
http.Error(w, "unexpected steps", http.StatusInternalServerError)
}
})
c.MaxStreams = 1
c.ChunkingThreshold = 1 // force chunking
var written atomic.Int64
ctx := WithTrace(t.Context(), &Trace{
Update: func(l *Layer, n int64, err error) {
t.Log("trace:", l.Digest.Short(), n, err)
written.Add(n)
},
})
err := c.Pull(ctx, "http://o.com/library/abc")
var got *Error
if !errors.As(err, &got) || got.Code != "BLOB_UNKNOWN" {
t.Fatalf("err = %v, want %v", err, got)
}
if g := written.Load(); g != 1 {
t.Fatalf("wrote %d bytes, want 1", g)
}
}
func TestPullChunksumStreamError(t *testing.T) {
var step atomic.Int64
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
switch step.Add(1) {
case 1:
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
case 2:
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
// Write one valid chunksum and one invalid chunksum
fmt.Fprintf(w, "%s 0-1\n", blob.DigestFromBytes("ab")) // valid
fmt.Fprint(w, "sha256:!") // invalid
case 3:
io.WriteString(w, "ab")
default:
t.Errorf("unexpected steps %d: %v", step.Load(), r)
http.Error(w, "unexpected steps", http.StatusInternalServerError)
}
})
c.ChunkingThreshold = 1 // force chunking
got := c.Pull(ctx, "http://o.com/library/abc")
if !errors.Is(got, ErrIncomplete) {
t.Fatalf("err = %v, want %v", got, ErrIncomplete)
}
}
type flushAfterWriter struct {
w io.Writer
}
func (f *flushAfterWriter) Write(p []byte) (n int, err error) {
n, err = f.w.Write(p)
f.w.(http.Flusher).Flush() // panic if not a flusher
return
}
func TestPullChunksumStreaming(t *testing.T) {
csr, csw := io.Pipe()
defer csw.Close()
var step atomic.Int64
c, _ := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
switch step.Add(1) {
case 1:
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
case 2:
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
fw := &flushAfterWriter{w} // ensure client gets data as it arrives by aggressively flushing
_, err := io.Copy(fw, csr)
if err != nil {
t.Errorf("copy: %v", err)
}
case 3:
io.WriteString(w, "ab")
case 4:
io.WriteString(w, "c")
default:
t.Errorf("unexpected steps %d: %v", step.Load(), r)
http.Error(w, "unexpected steps", http.StatusInternalServerError)
}
})
c.ChunkingThreshold = 1 // force chunking
update := make(chan int64, 1)
ctx := WithTrace(t.Context(), &Trace{
Update: func(l *Layer, n int64, err error) {
t.Log("trace:", l.Digest.Short(), n, err)
if n > 0 {
update <- n
}
},
})
errc := make(chan error, 1)
go func() {
errc <- c.Pull(ctx, "http://o.com/library/abc")
}()
// Send first chunksum and ensure it kicks off work immediately
fmt.Fprintf(csw, "%s 0-1\n", blob.DigestFromBytes("ab"))
if g := <-update; g != 2 {
t.Fatalf("got %d, want 2", g)
}
// now send the second chunksum and ensure it kicks off work immediately
fmt.Fprintf(csw, "%s 2-2\n", blob.DigestFromBytes("c"))
if g := <-update; g != 3 {
t.Fatalf("got %d, want 3", g)
}
csw.Close()
testutil.Check(t, <-errc)
}
func TestPullChunksumsCached(t *testing.T) {
var step atomic.Int64
c, _ := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
switch step.Add(1) {
case 1:
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
case 2:
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
fmt.Fprintf(w, "%s 0-1\n", blob.DigestFromBytes("ab"))
fmt.Fprintf(w, "%s 2-2\n", blob.DigestFromBytes("c"))
case 3, 4:
switch rng := r.Header.Get("Range"); rng {
case "bytes=0-1":
io.WriteString(w, "ab")
case "bytes=2-2":
io.WriteString(w, "c")
default:
t.Errorf("unexpected range %q", rng)
}
default:
t.Errorf("unexpected steps %d: %v", step.Load(), r)
http.Error(w, "unexpected steps", http.StatusInternalServerError)
}
})
c.MaxStreams = 1 // force serial processing of chunksums
c.ChunkingThreshold = 1 // force chunking
ctx, cancel := context.WithCancel(t.Context())
defer cancel()
// Cancel the pull after the first chunksum is processed, but before
// the second chunksum is processed (which is waiting because
// MaxStreams=1). This should cause the second chunksum to error out
// leaving the blob incomplete.
ctx = WithTrace(ctx, &Trace{
Update: func(l *Layer, n int64, err error) {
if n > 0 {
cancel()
}
},
})
err := c.Pull(ctx, "http://o.com/library/abc")
if !errors.Is(err, context.Canceled) {
t.Fatalf("err = %v, want %v", err, context.Canceled)
}
_, err = c.Cache.Resolve("o.com/library/abc:latest")
if !errors.Is(err, fs.ErrNotExist) {
t.Fatalf("err = %v, want nil", err)
}
// Reset state and pull again to ensure the blob chunks that should
// have been cached are, and the remaining chunk was downloaded, making
// the blob complete.
step.Store(0)
var written atomic.Int64
var cached atomic.Int64
ctx = WithTrace(t.Context(), &Trace{
Update: func(l *Layer, n int64, err error) {
t.Log("trace:", l.Digest.Short(), n, err)
if errors.Is(err, ErrCached) {
cached.Add(n)
}
written.Add(n)
},
})
check := testutil.Checker(t)
err = c.Pull(ctx, "http://o.com/library/abc")
check(err)
_, err = c.Cache.Resolve("o.com/library/abc:latest")
check(err)
if g := written.Load(); g != 5 {
t.Fatalf("wrote %d bytes, want 3", g)
}
if g := cached.Load(); g != 2 { // "ab" should have been cached
t.Fatalf("cached %d bytes, want 5", g)
}
}

View File

@@ -0,0 +1,72 @@
package ollama
import (
"context"
)
// Trace is a set of functions that are called to report progress during blob
// downloads and uploads.
//
// Use [WithTrace] to attach a Trace to a context for use with [Registry.Push]
// and [Registry.Pull].
type Trace struct {
// Update is called during [Registry.Push] and [Registry.Pull] to
// report the progress of blob uploads and downloads.
//
// The n argument is the number of bytes transferred so far, and err is
// any error that has occurred. If n == 0, and err is nil, the download
// or upload has just started. If err is [ErrCached], the download or
// upload has been skipped because the blob is already present in the
// local cache or remote registry, respectively. Otherwise, if err is
// non-nil, the download or upload has failed. When l.Size == n, and
// err is nil, the download or upload has completed.
//
// A function assigned must be safe for concurrent use. The function is
// called synchronously and so should not block or take long to run.
Update func(_ *Layer, n int64, _ error)
}
func (t *Trace) update(l *Layer, n int64, err error) {
if t.Update != nil {
t.Update(l, n, err)
}
}
type traceKey struct{}
// WithTrace adds a trace to the context for transfer progress reporting.
func WithTrace(ctx context.Context, t *Trace) context.Context {
old := traceFromContext(ctx)
if old == t {
// No change, return the original context. This also prevents
// infinite recursion below, if the caller passes the same
// Trace.
return ctx
}
// Create a new Trace that wraps the old one, if any. If we used the
// same pointer t, we end up with a recursive structure.
composed := &Trace{
Update: func(l *Layer, n int64, err error) {
if old != nil {
old.update(l, n, err)
}
t.update(l, n, err)
},
}
return context.WithValue(ctx, traceKey{}, composed)
}
var emptyTrace = &Trace{}
// traceFromContext returns the Trace associated with ctx, or an empty Trace if
// none is found.
//
// It never returns nil.
func traceFromContext(ctx context.Context) *Trace {
t, _ := ctx.Value(traceKey{}).(*Trace)
if t == nil {
return emptyTrace
}
return t
}

View File

@@ -0,0 +1,45 @@
package backoff
import (
"context"
"iter"
"math/rand/v2"
"time"
)
func Loop(ctx context.Context, maxBackoff time.Duration) iter.Seq2[int, error] {
var n int
return func(yield func(int, error) bool) {
var t *time.Timer
for {
if ctx.Err() != nil {
yield(n, ctx.Err())
return
}
if !yield(n, nil) {
return
}
n++
// n^2 backoff timer is a little smoother than the
// common choice of 2^n.
d := min(time.Duration(n*n)*10*time.Millisecond, maxBackoff)
// Randomize the delay between 0.5-1.5 x msec, in order
// to prevent accidental "thundering herd" problems.
d = time.Duration(float64(d) * (rand.Float64() + 0.5))
if t == nil {
t = time.NewTimer(d)
} else {
t.Reset(d)
}
select {
case <-ctx.Done():
t.Stop()
case <-t.C:
}
}
}
}

View File

@@ -0,0 +1,38 @@
package backoff
import (
"context"
"errors"
"testing"
"testing/synctest"
"time"
)
func TestLoop(t *testing.T) {
synctest.Test(t, func(t *testing.T) {
last := -1
ctx, cancel := context.WithCancel(t.Context())
defer cancel()
for n, err := range Loop(ctx, 100*time.Millisecond) {
if !errors.Is(err, ctx.Err()) {
t.Errorf("err = %v, want nil", err)
}
if err != nil {
break
}
if n != last+1 {
t.Errorf("n = %d, want %d", n, last+1)
}
last = n
if n > 5 {
cancel()
}
}
if last != 6 {
t.Errorf("last = %d, want 6", last)
}
})
}

View File

@@ -0,0 +1,24 @@
package backoff
import (
"testing"
)
func TestLoopAllocs(t *testing.T) {
for i := range 3 {
got := testing.AllocsPerRun(1000, func() {
for tick := range Loop(t.Context(), 1) {
if tick >= i {
break
}
}
})
want := float64(0)
if i > 0 {
want = 3 // due to time.NewTimer
}
if got > want {
t.Errorf("[%d ticks]: allocs = %v, want 0", i, want)
}
}
}

View File

@@ -0,0 +1,228 @@
package names
import (
"cmp"
"fmt"
"strings"
"github.com/ollama/ollama/server/internal/internal/stringsx"
)
const MaxNameLength = 350 + 1 + 80 + 1 + 80 + 1 + 80 // <host>/<namespace>/<model>:<tag>
type Name struct {
// Make incomparable to enfoce use of Compare / Equal for
// case-insensitive comparisons.
_ [0]func()
h string
n string
m string
t string
}
// Parse parses and assembles a Name from a name string. The
// format of a valid name string is:
//
// s:
// { host } "/" { namespace } "/" { model } ":" { tag }
// { host } "/" { namespace } "/" { model }
// { namespace } "/" { model } ":" { tag }
// { namespace } "/" { model }
// { model } ":" { tag }
// { model }
// host:
// pattern: { alphanum | "_" } { alphanum | "_" | "-" | "." | ":" }*
// length: [1, 350]
// namespace:
// pattern: { alphanum | "_" } { alphanum | "-" | "_" }*
// length: [1, 80]
// model:
// pattern: { alphanum | "_" } { alphanum | "-" | "_" | "." }*
// length: [1, 80]
// tag:
// pattern: { alphanum | "_" } { alphanum | "-" | "_" | "." }*
// length: [1, 80]
//
// The name returned is not guaranteed to be valid. If it is not valid, the
// field values are left in an undefined state. Use [Name.IsValid] to check
// if the name is valid.
func Parse(s string) Name {
if len(s) > MaxNameLength {
return Name{}
}
var n Name
var tail string
var c byte
for {
s, tail, c = cutLastAny(s, "/:")
switch c {
case ':':
n.t = tail
continue // look for model
case '/':
n.h, n.n, _ = cutLastAny(s, "/")
n.m = tail
return n
case 0:
n.m = tail
return n
}
}
}
// Split splits an extended name string into its scheme, name, and digest
// parts.
//
// Examples:
//
// http://ollama.com/bmizerany/smol:latest@digest
// https://ollama.com/bmizerany/smol:latest
// ollama.com/bmizerany/smol:latest@digest // returns "https" scheme.
// model@digest
// @digest
func Split(s string) (scheme, name, digest string) {
i := strings.Index(s, "://")
if i >= 0 {
scheme = s[:i]
s = s[i+3:]
}
i = strings.LastIndex(s, "@")
if i >= 0 {
digest = s[i+1:]
s = s[:i]
}
return scheme, s, digest
}
// Merge merges two names into a single name. Non-empty host, namespace, and
// tag parts of a take precedence over fields in b. The model field is left as
// is.
//
// The returned name is not guaranteed to be valid. Use [Name.IsValid] to check
// if the name is valid.
func Merge(a, b Name) Name {
a.h = cmp.Or(a.h, b.h)
a.n = cmp.Or(a.n, b.n)
a.t = cmp.Or(a.t, b.t)
return a
}
// IsValid returns true if the name is valid.
func (n Name) IsValid() bool {
if n.h != "" && !isValidPart(partHost, n.h) {
return false
}
if n.n != "" && !isValidPart(partNamespace, n.n) {
return false
}
if n.t != "" && !isValidPart(partTag, n.t) {
return false
}
// at bare minimum, model must be present and valid
return n.m != "" && isValidPart(partModel, n.m)
}
func (n Name) IsFullyQualified() bool {
return n.IsValid() && n.h != "" && n.n != "" && n.m != "" && n.t != ""
}
const (
partHost = iota
partNamespace
partModel
partTag
)
func isValidPart(kind int, s string) bool {
maxlen := 80
if kind == partHost {
maxlen = 350
}
if len(s) > maxlen {
return false
}
for i := range s {
if i == 0 {
if !isAlphanumericOrUnderscore(s[i]) {
return false
}
continue
}
switch s[i] {
case '_', '-':
case '.':
if kind == partNamespace {
return false
}
case ':':
if kind != partHost {
return false
}
default:
if !isAlphanumericOrUnderscore(s[i]) {
return false
}
}
}
return true
}
func isAlphanumericOrUnderscore(c byte) bool {
return c >= 'A' && c <= 'Z' || c >= 'a' && c <= 'z' || c >= '0' && c <= '9' || c == '_'
}
func (n Name) Host() string { return n.h }
func (n Name) Namespace() string { return n.n }
func (n Name) Model() string { return n.m }
func (n Name) Tag() string { return n.t }
// Compare compares n and o case-insensitively. It returns 0 if n and o are
// equal, -1 if n sorts before o, and 1 if n sorts after o.
func (n Name) Compare(o Name) int {
return cmp.Or(
stringsx.CompareFold(n.h, o.h),
stringsx.CompareFold(n.n, o.n),
stringsx.CompareFold(n.m, o.m),
stringsx.CompareFold(n.t, o.t),
)
}
// String returns the fully qualified name in the format
// <namespace>/<model>:<tag>.
func (n Name) String() string {
var b strings.Builder
if n.h != "" {
b.WriteString(n.h)
b.WriteByte('/')
}
if n.n != "" {
b.WriteString(n.n)
b.WriteByte('/')
}
b.WriteString(n.m)
if n.t != "" {
b.WriteByte(':')
b.WriteString(n.t)
}
return b.String()
}
func (n Name) GoString() string {
return fmt.Sprintf("<Name %q %q %q %q>", n.h, n.n, n.m, n.t)
}
// cutLastAny is like strings.Cut but scans in reverse for the last character
// in chars. If no character is found, before is the empty string and after is
// s. The returned sep is the byte value of the character in chars if one was
// found; otherwise it is 0.
func cutLastAny(s, chars string) (before, after string, sep byte) {
i := strings.LastIndexAny(s, chars)
if i >= 0 {
return s[:i], s[i+1:], s[i]
}
return "", s, 0
}

View File

@@ -0,0 +1,220 @@
package names
import (
"strings"
"testing"
)
func TestParseName(t *testing.T) {
cases := []struct {
in string
want Name
}{
{"", Name{}},
{"m:t", Name{m: "m", t: "t"}},
{"m", Name{m: "m"}},
{"/m", Name{m: "m"}},
{"/n/m:t", Name{n: "n", m: "m", t: "t"}},
{"n/m", Name{n: "n", m: "m"}},
{"n/m:t", Name{n: "n", m: "m", t: "t"}},
{"n/m", Name{n: "n", m: "m"}},
{"n/m", Name{n: "n", m: "m"}},
{strings.Repeat("m", MaxNameLength+1), Name{}},
{"h/n/m:t", Name{h: "h", n: "n", m: "m", t: "t"}},
{"ollama.com/library/_:latest", Name{h: "ollama.com", n: "library", m: "_", t: "latest"}},
// Invalids
// TODO: {"n:t/m:t", Name{}},
// TODO: {"/h/n/m:t", Name{}},
}
for _, tt := range cases {
t.Run(tt.in, func(t *testing.T) {
got := Parse(tt.in)
if got.Compare(tt.want) != 0 {
t.Errorf("parseName(%q) = %#v, want %q", tt.in, got, tt.want)
}
})
}
}
func TestString(t *testing.T) {
cases := []string{
"",
"m:t",
"m:t",
"m",
"n/m",
"n/m:t",
"n/m",
"n/m",
"h/n/m:t",
"ollama.com/library/_:latest",
// Special cased to "round trip" without the leading slash.
"/m",
"/n/m:t",
}
for _, s := range cases {
t.Run(s, func(t *testing.T) {
s = strings.TrimPrefix(s, "/")
if g := Parse(s).String(); g != s {
t.Errorf("parse(%q).String() = %q", s, g)
}
})
}
}
func TestParseExtended(t *testing.T) {
cases := []struct {
in string
wantScheme string
wantName Name
wantDigest string
}{
{"", "", Name{}, ""},
{"m", "", Name{m: "m"}, ""},
{"http://m", "http", Name{m: "m"}, ""},
{"http+insecure://m", "http+insecure", Name{m: "m"}, ""},
{"http://m@sha256:deadbeef", "http", Name{m: "m"}, "sha256:deadbeef"},
}
for _, tt := range cases {
t.Run(tt.in, func(t *testing.T) {
scheme, name, digest := Split(tt.in)
n := Parse(name)
if scheme != tt.wantScheme || n.Compare(tt.wantName) != 0 || digest != tt.wantDigest {
t.Errorf("ParseExtended(%q) = %q, %#v, %q, want %q, %#v, %q", tt.in, scheme, name, digest, tt.wantScheme, tt.wantName, tt.wantDigest)
}
})
}
}
func TestMerge(t *testing.T) {
cases := []struct {
a, b string
want string
}{
{"", "", ""},
{"m", "", "m"},
{"", "m", ""},
{"x", "y", "x"},
{"o.com/n/m:t", "o.com/n/m:t", "o.com/n/m:t"},
{"o.com/n/m:t", "o.com/n/_:t", "o.com/n/m:t"},
{"bmizerany/smol", "ollama.com/library/_:latest", "ollama.com/bmizerany/smol:latest"},
{"localhost:8080/bmizerany/smol", "ollama.com/library/_:latest", "localhost:8080/bmizerany/smol:latest"},
}
for _, tt := range cases {
t.Run("", func(t *testing.T) {
a, b := Parse(tt.a), Parse(tt.b)
got := Merge(a, b)
if got.Compare(Parse(tt.want)) != 0 {
t.Errorf("merge(%q, %q) = %#v, want %q", tt.a, tt.b, got, tt.want)
}
})
}
}
func TestParseStringRoundTrip(t *testing.T) {
cases := []string{
"",
"m",
"m:t",
"n/m",
"n/m:t",
"n/m:t",
"n/m",
"n/m",
"h/n/m:t",
"ollama.com/library/_:latest",
}
for _, s := range cases {
t.Run(s, func(t *testing.T) {
if got := Parse(s).String(); got != s {
t.Errorf("parse(%q).String() = %q", s, got)
}
})
}
}
var junkName Name
func BenchmarkParseName(b *testing.B) {
b.ReportAllocs()
for range b.N {
junkName = Parse("h/n/m:t")
}
}
const (
part80 = "88888888888888888888888888888888888888888888888888888888888888888888888888888888"
part350 = "33333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333"
)
var testCases = map[string]bool{ // name -> valid
"": false,
"_why/_the/_lucky:_stiff": true,
// minimal
"h/n/m:t": true,
"host/namespace/model:tag": true,
"host/namespace/model": true,
"namespace/model": true,
"model": true,
// long (but valid)
part80 + "/" + part80 + "/" + part80 + ":" + part80: true,
part350 + "/" + part80 + "/" + part80 + ":" + part80: true,
// too long
part80 + "/" + part80 + "/" + part80 + ":" + part350: false,
"x" + part350 + "/" + part80 + "/" + part80 + ":" + part80: false,
"h/nn/mm:t": true, // bare minimum part sizes
// unqualified
"m": true,
"n/m:": true,
"h/n/m": true,
"@t": false,
"m@d": false,
// invalids
"^": false,
"mm:": true,
"/nn/mm": true,
"//": false, // empty model
"//mm": true,
"hh//": false, // empty model
"//mm:@": false,
"00@": false,
"@": false,
// not starting with alphanum
"-hh/nn/mm:tt": false,
"hh/-nn/mm:tt": false,
"hh/nn/-mm:tt": false,
"hh/nn/mm:-tt": false,
// smells like a flag
"-h": false,
// hosts
"host:https/namespace/model:tag": true,
// colon in non-host part before tag
"host/name:space/model:tag": false,
}
func TestParseNameValidation(t *testing.T) {
for s, valid := range testCases {
got := Parse(s)
if got.IsValid() != valid {
t.Logf("got: %v", got)
t.Errorf("Parse(%q).IsValid() = %v; want !%[2]v", s, got.IsValid())
}
}
}

View File

@@ -0,0 +1,52 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// Package stringsx provides additional string manipulation functions
// that aren't in the standard library's strings package or go4.org/mem.
package stringsx
import (
"unicode"
"unicode/utf8"
)
// CompareFold returns -1, 0, or 1 depending on whether a < b, a == b, or a > b,
// like cmp.Compare, but case insensitively.
func CompareFold(a, b string) int {
// Track our position in both strings
ia, ib := 0, 0
for ia < len(a) && ib < len(b) {
ra, wa := nextRuneLower(a[ia:])
rb, wb := nextRuneLower(b[ib:])
if ra < rb {
return -1
}
if ra > rb {
return 1
}
ia += wa
ib += wb
if wa == 0 || wb == 0 {
break
}
}
// If we've reached here, one or both strings are exhausted
// The shorter string is "less than" if they match up to this point
switch {
case ia == len(a) && ib == len(b):
return 0
case ia == len(a):
return -1
default:
return 1
}
}
// nextRuneLower returns the next rune in the string, lowercased, along with its
// original (consumed) width in bytes. If the string is empty, it returns
// (utf8.RuneError, 0)
func nextRuneLower(s string) (r rune, width int) {
r, width = utf8.DecodeRuneInString(s)
return unicode.ToLower(r), width
}

View File

@@ -0,0 +1,78 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package stringsx
import (
"cmp"
"strings"
"testing"
)
func TestCompareFold(t *testing.T) {
tests := []struct {
a, b string
}{
// Basic ASCII cases
{"", ""},
{"a", "a"},
{"a", "A"},
{"A", "a"},
{"a", "b"},
{"b", "a"},
{"abc", "ABC"},
{"ABC", "abc"},
{"abc", "abd"},
{"abd", "abc"},
// Length differences
{"abc", "ab"},
{"ab", "abc"},
// Unicode cases
{"世界", "世界"},
{"Hello世界", "hello世界"},
{"世界Hello", "世界hello"},
{"世界", "世界x"},
{"世界x", "世界"},
// Special case folding examples
{"ß", "ss"}, // German sharp s
{"fi", "fi"}, // fi ligature
{"Σ", "σ"}, // Greek sigma
{"İ", "i\u0307"}, // Turkish dotted I
// Mixed cases
{"HelloWorld", "helloworld"},
{"HELLOWORLD", "helloworld"},
{"helloworld", "HELLOWORLD"},
{"HelloWorld", "helloworld"},
{"helloworld", "HelloWorld"},
// Edge cases
{" ", " "},
{"1", "1"},
{"123", "123"},
{"!@#", "!@#"},
}
wants := []int{}
for _, tt := range tests {
got := CompareFold(tt.a, tt.b)
want := cmp.Compare(strings.ToLower(tt.a), strings.ToLower(tt.b))
if got != want {
t.Errorf("CompareFold(%q, %q) = %v, want %v", tt.a, tt.b, got, want)
}
wants = append(wants, want)
}
if n := testing.AllocsPerRun(1000, func() {
for i, tt := range tests {
if CompareFold(tt.a, tt.b) != wants[i] {
panic("unexpected")
}
}
}); n > 0 {
t.Errorf("allocs = %v; want 0", int(n))
}
}

View File

@@ -0,0 +1,201 @@
// Package syncs provides synchronization primitives.
package syncs
import (
"cmp"
"io"
"sync"
)
var closedChan = func() chan struct{} {
ch := make(chan struct{})
close(ch)
return ch
}()
// Ticket represents a ticket in a sequence of tickets. The zero value is
// invalid. Use [Line.Take] to get a valid ticket.
//
// A Ticket is not safe for concurrent use.
type Ticket struct {
ahead chan struct{} // ticket ahead of this one
ch chan struct{}
}
// Ready returns a channel that is closed when the ticket before this one is
// done.
//
// It is incorrect to wait on Ready after the ticket is done.
func (t *Ticket) Ready() chan struct{} {
return cmp.Or(t.ahead, closedChan)
}
// Done signals that this ticket is done and that the next ticket in line can
// proceed.
//
// The first call to [Done] unblocks the ticket after it, if any. Subsequent
// calls are no-ops.
func (t *Ticket) Done() {
if t.ch != nil {
close(t.ch)
}
t.ch = nil
}
// Line is an ordered sequence of tickets waiting for their turn to proceed.
//
// To get a ticket use [Line.Take].
// To signal that a ticket is done use [Ticket.Done].
// To wait your turn use [Ticket.Ready].
//
// A Line is not safe for concurrent use.
type Line struct {
last chan struct{} // last ticket in line
}
func (q *Line) Take() *Ticket {
t := &Ticket{
ahead: q.last,
ch: make(chan struct{}),
}
q.last = t.ch
return t
}
// RelayReader implements an [io.WriterTo] that yields the passed
// writer to its [WriteTo] method each [io.WriteCloser] taken from [Take], in
// the order they are taken. Each [io.WriteCloser] blocks until the previous
// one is closed, or a call to [RelayReader.CloseWithError] is made.
//
// The zero value is invalid. Use [NewWriteToLine] to get a valid RelayReader.
//
// It is not safe for concurrent use.
type RelayReader struct {
line Line
t *Ticket
w io.Writer
n int64
mu sync.Mutex
err error // set by CloseWithError
closedCh chan struct{} // closed if err is set
}
var (
_ io.Closer = (*RelayReader)(nil)
_ io.WriterTo = (*RelayReader)(nil)
_ io.Reader = (*RelayReader)(nil)
)
func NewRelayReader() *RelayReader {
var q RelayReader
q.closedCh = make(chan struct{})
q.t = q.line.Take()
return &q
}
// CloseWithError terminates the line, unblocking any writer waiting for its
// turn with the error, or [io.EOF] if err is nil. It is safe to call
// [CloseWithError] multiple times and across multiple goroutines.
//
// If the line is already closed, [CloseWithError] is a no-op.
//
// It never returns an error.
func (q *RelayReader) CloseWithError(err error) error {
q.mu.Lock()
defer q.mu.Unlock()
if q.err == nil {
q.err = cmp.Or(q.err, err, io.EOF)
close(q.closedCh)
}
return nil
}
// Close closes the line. Any writer waiting for its turn will be unblocked
// with an [io.ErrClosedPipe] error.
//
// It never returns an error.
func (q *RelayReader) Close() error {
return q.CloseWithError(nil)
}
func (q *RelayReader) closed() <-chan struct{} {
q.mu.Lock()
defer q.mu.Unlock()
return q.closedCh
}
func (q *RelayReader) Read(p []byte) (int, error) {
panic("RelayReader.Read is for show only; use WriteTo")
}
// WriteTo yields the writer w to the first writer in line and blocks until the
// first call to [Close].
//
// It is safe to call [Take] concurrently with [WriteTo].
func (q *RelayReader) WriteTo(dst io.Writer) (int64, error) {
select {
case <-q.closed():
return 0, io.ErrClosedPipe
default:
}
// We have a destination writer; let the relay begin.
q.w = dst
q.t.Done()
<-q.closed()
return q.n, nil
}
// Take returns a writer that will be passed to the next writer in line.
//
// It is not safe for use across multiple goroutines.
func (q *RelayReader) Take() io.WriteCloser {
return &relayWriter{q: q, t: q.line.Take()}
}
type relayWriter struct {
q *RelayReader
t *Ticket
ready bool
}
var _ io.StringWriter = (*relayWriter)(nil)
// Write writes to the writer passed to [RelayReader.WriteTo] as soon as the
// writer is ready. It returns io.ErrClosedPipe if the line is closed before
// the writer is ready.
func (w *relayWriter) Write(p []byte) (int, error) {
if !w.awaitTurn() {
return 0, w.q.err
}
n, err := w.q.w.Write(p)
w.q.n += int64(n)
return n, err
}
func (w *relayWriter) WriteString(s string) (int, error) {
if !w.awaitTurn() {
return 0, w.q.err
}
return io.WriteString(w.q.w, s)
}
// Close signals that the writer is done, unblocking the next writer in line.
func (w *relayWriter) Close() error {
w.t.Done()
return nil
}
func (t *relayWriter) awaitTurn() (ok bool) {
if t.ready {
return true
}
select {
case <-t.t.Ready():
t.ready = true
return true
case <-t.q.closed():
return false
}
}

View File

@@ -0,0 +1,65 @@
package syncs
import (
"bytes"
"io"
"math/rand/v2"
"testing"
"testing/synctest"
)
func TestPipelineReadWriterTo(t *testing.T) {
for range 10 {
synctest.Test(t, func(t *testing.T) {
q := NewRelayReader()
tickets := []struct {
io.WriteCloser
s string
}{
{q.Take(), "you"},
{q.Take(), " say hi,"},
{q.Take(), " and "},
{q.Take(), "I say "},
{q.Take(), "hello"},
}
rand.Shuffle(len(tickets), func(i, j int) {
tickets[i], tickets[j] = tickets[j], tickets[i]
})
var g Group
for i, t := range tickets {
g.Go(func() {
defer t.Close()
if i%2 == 0 {
// Use [relayWriter.WriteString]
io.WriteString(t.WriteCloser, t.s)
} else {
t.Write([]byte(t.s))
}
})
}
var got bytes.Buffer
var copyErr error // checked at end
g.Go(func() {
_, copyErr = io.Copy(&got, q)
})
synctest.Wait()
q.Close()
g.Wait()
if copyErr != nil {
t.Fatal(copyErr)
}
want := "you say hi, and I say hello"
if got.String() != want {
t.Fatalf("got %q, want %q", got.String(), want)
}
})
}
}

View File

@@ -0,0 +1,41 @@
package syncs
import (
"sync"
"sync/atomic"
)
// Group is a [sync.WaitGroup] with a Go method.
type Group struct {
wg sync.WaitGroup
n atomic.Int64
}
func (g *Group) Go(f func()) {
g.wg.Add(1)
go func() {
g.n.Add(1) // Now we are running
defer func() {
g.wg.Done()
g.n.Add(-1) // Now we are done
}()
f()
}()
}
// Running returns the number of goroutines that are currently running.
//
// If a call to [Running] returns zero, and a call to [Wait] is made without
// any calls to [Go], then [Wait] will return immediately. This is true even if
// a goroutine is started and finishes between the two calls.
//
// It is possible for [Running] to return non-zero and for [Wait] to return
// immediately. This can happen if the all running goroutines finish between
// the two calls.
func (g *Group) Running() int64 {
return g.n.Load()
}
func (g *Group) Wait() {
g.wg.Wait()
}

View File

@@ -0,0 +1,116 @@
// Package manifest provides documentation for the Ollama manifest format.
// This package contains no code.
//
// # Manifests
//
// A manifest is a JSON object that describes a model. The JSON object has a
// single field "layers" which is a list of layers that make up the model.
// A layer is a single, logical unit of a model. Layers are stored in the cache
// as files with the name of the digest of the layer. Layers are pushed and
// pulled from the registry as blobs.
//
// A layer is represented as a JSON object with the following fields:
//
// - "digest": The digest of the layer.
// - "mediaType": The media type of the layer.
// - "size": The size of the layer in bytes.
//
// Layers are typically stored in a blob store, such as a registry, and are
// referenced by their digest. This package does not define how layers are
// stored or retrieved.
//
// # Configuration Layer
//
// The configuration of a model is represented as a layer with the media type:
//
// application/vnd.ollama.image.config; type=<type>
//
// The "type" parameter in the media type specifies the format of the
// configuration (e.g., "safetensor" or "gguf").
//
// There may be only one configuration layer in a model.
//
// # Template Layer
//
// The model template is a layer with the media type:
//
// application/vnd.ollama.image.template; [name=<name>]
//
// The "name" parameter in the media type specifies the name of the template as
// for lookup at runtime. The name is optional and may be omitted. If omitted,
// the template is the default template for the model.
//
// # Tensor Layers
//
// The tensors of a model are represented as layers with the media type:
//
// application/vnd.ollama.image.tensor; name=<name>; dtype=<dtype>; shape=<shape>
//
// The "name" parameter in the media type specifies the name of the tensor as
// defined in the model's configuration and are bound only by the rules for
// names as defined in the configuration format, as represented by the
// configuration's "type".
//
// The "dtype" parameter in the media type specifies the data type of the tensor
// as a string.
//
// TODO: Define more specifically how to represent data types as strings.
//
// The "shape" parameter in the media type specifies the shape of the tensor as
// a comma-separated list of integers; one per dimension.
//
// # Tokenization Layers
//
// The tokenization of a model is represented as a layer with the media type:
//
// application/vnd.ollama.image.tokenizer
//
// The configuration of the tokenizer is represented as a layer with the media type:
//
// application/vnd.ollama.image.tokenizer.config
//
// # Miscellaneous Layers
//
// These extra layer mime types are reserved:
//
// application/vnd.ollama.image.license
//
// This layer contains one of the many licenses for the model in plain text.
//
// # Example Manifest
//
// The following is an example manifest containing a configuration, a model
// template, and two tensors (digests shortened for brevity):
//
// {
// "layers": [{
// "digest": "sha256:a...",
// "mediaType": "application/vnd.ollama.image.config; type=safetensors",
// "size": 1234
// },{
// "digest": "sha256:b...",
// "mediaType": "application/vnd.ollama.image.template",
// "size": 5678
// },{
// "digest": "sha256:c...",
// "mediaType": "application/vnd.ollama.image.tensor; name=input; dtype=F32; shape=1,2,3",
// "size": 9012
// },{
// "digest": "sha256:d...",
// "mediaType": "application/vnd.ollama.image.tensor; name=output; dtype=I32; shape=4,5,6",
// "size": 3456
// }]
// }
//
// # Legacy Media Types
//
// The appliaction/vnd.ollama.image.model media type is deprecated, but will
// remain supported for backwards compatibility, for some undefined amount of
// time. New models should use the media types defined above.
//
// # Reserved media types
//
// The media type prefix "application/vnd.ollama.image." is reserved for
// defining new media types for layers known to Ollama. Currently, all other
// prefixes are ignored by official Ollama registry clients.
package manifest

View File

@@ -0,0 +1,417 @@
// Package registry implements an http.Handler for handling local Ollama API
// model management requests. See [Local] for details.
package registry
import (
"cmp"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"net/http"
"slices"
"strings"
"sync"
"time"
"github.com/ollama/ollama/server/internal/cache/blob"
"github.com/ollama/ollama/server/internal/client/ollama"
"github.com/ollama/ollama/server/internal/internal/backoff"
)
// Local implements an http.Handler for handling local Ollama API model
// management requests, such as pushing, pulling, and deleting models.
//
// It can be arranged for all unknown requests to be passed through to a
// fallback handler, if one is provided.
type Local struct {
Client *ollama.Registry // required
Logger *slog.Logger // required
// Fallback, if set, is used to handle requests that are not handled by
// this handler.
Fallback http.Handler
// Prune, if set, is called to prune the local disk cache after a model
// is deleted.
Prune func() error // optional
}
// serverError is like ollama.Error, but with a Status field for the HTTP
// response code. We want to avoid adding that field to ollama.Error because it
// would always be 0 to clients (we don't want to leak the status code in
// errors), and so it would be confusing to have a field that is always 0.
type serverError struct {
Status int `json:"-"`
// TODO(bmizerany): Decide if we want to keep this and maybe
// bring back later.
Code string `json:"code"`
Message string `json:"error"`
}
func (e serverError) Error() string {
return e.Message
}
// Common API errors
var (
errMethodNotAllowed = &serverError{405, "method_not_allowed", "method not allowed"}
errNotFound = &serverError{404, "not_found", "not found"}
errModelNotFound = &serverError{404, "not_found", "model not found"}
errInternalError = &serverError{500, "internal_error", "internal server error"}
)
type statusCodeRecorder struct {
_status int // use status() to get the status code
http.ResponseWriter
}
func (r *statusCodeRecorder) WriteHeader(status int) {
if r._status == 0 {
r._status = status
r.ResponseWriter.WriteHeader(status)
}
}
func (r *statusCodeRecorder) Write(b []byte) (int, error) {
r._status = r.status()
return r.ResponseWriter.Write(b)
}
var (
_ http.ResponseWriter = (*statusCodeRecorder)(nil)
_ http.CloseNotifier = (*statusCodeRecorder)(nil)
_ http.Flusher = (*statusCodeRecorder)(nil)
)
// CloseNotify implements the http.CloseNotifier interface, for Gin. Remove with Gin.
//
// It panics if the underlying ResponseWriter is not a CloseNotifier.
func (r *statusCodeRecorder) CloseNotify() <-chan bool {
return r.ResponseWriter.(http.CloseNotifier).CloseNotify()
}
// Flush implements the http.Flusher interface, for Gin. Remove with Gin.
//
// It panics if the underlying ResponseWriter is not a Flusher.
func (r *statusCodeRecorder) Flush() {
r.ResponseWriter.(http.Flusher).Flush()
}
func (r *statusCodeRecorder) status() int {
return cmp.Or(r._status, 200)
}
func (s *Local) ServeHTTP(w http.ResponseWriter, r *http.Request) {
rec := &statusCodeRecorder{ResponseWriter: w}
s.serveHTTP(rec, r)
}
func (s *Local) serveHTTP(rec *statusCodeRecorder, r *http.Request) {
var errattr slog.Attr
proxied, err := func() (bool, error) {
switch r.URL.Path {
case "/api/delete":
return false, s.handleDelete(rec, r)
case "/api/pull":
return false, s.handlePull(rec, r)
default:
if s.Fallback != nil {
s.Fallback.ServeHTTP(rec, r)
return true, nil
}
return false, errNotFound
}
}()
if err != nil {
// We always log the error, so fill in the error log attribute
errattr = slog.String("error", err.Error())
var e *serverError
switch {
case errors.As(err, &e):
case errors.Is(err, ollama.ErrNameInvalid):
e = &serverError{400, "bad_request", err.Error()}
default:
e = errInternalError
}
data, err := json.Marshal(e)
if err != nil {
// unreachable
panic(err)
}
rec.Header().Set("Content-Type", "application/json")
rec.WriteHeader(e.Status)
rec.Write(data)
// fallthrough to log
}
if !proxied {
// we're only responsible for logging if we handled the request
var level slog.Level
if rec.status() >= 500 {
level = slog.LevelError
} else if rec.status() >= 400 {
level = slog.LevelWarn
}
s.Logger.LogAttrs(r.Context(), level, "http",
errattr, // report first in line to make it easy to find
// TODO(bmizerany): Write a test to ensure that we are logging
// all of this correctly. That also goes for the level+error
// logic above.
slog.Int("status", rec.status()),
slog.String("method", r.Method),
slog.String("path", r.URL.Path),
slog.Int64("content-length", r.ContentLength),
slog.String("remote", r.RemoteAddr),
slog.String("proto", r.Proto),
slog.String("query", r.URL.RawQuery),
)
}
}
type params struct {
// DeprecatedName is the name of the model to push, pull, or delete,
// but is deprecated. New clients should use [Model] instead.
//
// Use [model()] to get the model name for both old and new API requests.
DeprecatedName string `json:"name"`
// Model is the name of the model to push, pull, or delete.
//
// Use [model()] to get the model name for both old and new API requests.
Model string `json:"model"`
// AllowNonTLS is a flag that indicates a client using HTTP
// is doing so, deliberately.
//
// Deprecated: This field is ignored and only present for this
// deprecation message. It should be removed in a future release.
//
// Users can just use http or https+insecure to show intent to
// communicate they want to do insecure things, without awkward and
// confusing flags such as this.
AllowNonTLS bool `json:"insecure"`
// Stream, if true, will make the server send progress updates in a
// streaming of JSON objects. If false, the server will send a single
// JSON object with the final status as "success", or an error object
// if an error occurred.
//
// Unfortunately, this API was designed to be a bit awkward. Stream is
// defined to default to true if not present, so we need a way to check
// if the client decisively set it to false. So, we use a pointer to a
// bool. Gross.
//
// Use [stream()] to get the correct value for this field.
Stream *bool `json:"stream"`
}
// model returns the model name for both old and new API requests.
func (p params) model() string {
return cmp.Or(p.Model, p.DeprecatedName)
}
func (p params) stream() bool {
if p.Stream == nil {
return true
}
return *p.Stream
}
func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error {
if r.Method != "DELETE" {
return errMethodNotAllowed
}
p, err := decodeUserJSON[*params](r.Body)
if err != nil {
return err
}
ok, err := s.Client.Unlink(p.model())
if err != nil {
return err
}
if !ok {
return errModelNotFound
}
if s.Prune != nil {
return s.Prune()
}
return nil
}
type progressUpdateJSON struct {
Error string `json:"error,omitempty,omitzero"`
Status string `json:"status,omitempty,omitzero"`
Digest blob.Digest `json:"digest,omitempty,omitzero"`
Total int64 `json:"total,omitempty,omitzero"`
Completed int64 `json:"completed,omitempty,omitzero"`
}
func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
if r.Method != "POST" {
return errMethodNotAllowed
}
p, err := decodeUserJSON[*params](r.Body)
if err != nil {
return err
}
enc := json.NewEncoder(w)
if !p.stream() {
if err := s.Client.Pull(r.Context(), p.model()); err != nil {
if errors.Is(err, ollama.ErrModelNotFound) {
return errModelNotFound
}
return err
}
enc.Encode(progressUpdateJSON{Status: "success"})
return nil
}
var mu sync.Mutex
var progress []progressUpdateJSON
flushProgress := func() {
mu.Lock()
progress := slices.Clone(progress) // make a copy and release lock before encoding to the wire
mu.Unlock()
for _, p := range progress {
enc.Encode(p)
}
fl, _ := w.(http.Flusher)
if fl != nil {
fl.Flush()
}
}
t := time.NewTicker(1<<63 - 1) // "unstarted" timer
start := sync.OnceFunc(func() {
flushProgress() // flush initial state
t.Reset(100 * time.Millisecond)
})
ctx := ollama.WithTrace(r.Context(), &ollama.Trace{
Update: func(l *ollama.Layer, n int64, err error) {
if err != nil && !errors.Is(err, ollama.ErrCached) {
s.Logger.Error("pulling", "model", p.model(), "error", err)
return
}
func() {
mu.Lock()
defer mu.Unlock()
for i, p := range progress {
if p.Digest == l.Digest {
progress[i].Completed = n
return
}
}
progress = append(progress, progressUpdateJSON{
Digest: l.Digest,
Total: l.Size,
})
}()
// Block flushing progress updates until every
// layer is accounted for. Clients depend on a
// complete model size to calculate progress
// correctly; if they use an incomplete total,
// progress indicators would erratically jump
// as new layers are registered.
start()
},
})
done := make(chan error, 1)
go func() (err error) {
defer func() { done <- err }()
for _, err := range backoff.Loop(ctx, 3*time.Second) {
if err != nil {
return err
}
err := s.Client.Pull(ctx, p.model())
if canRetry(err) {
continue
}
return err
}
return nil
}()
enc.Encode(progressUpdateJSON{Status: "pulling manifest"})
for {
select {
case <-t.C:
flushProgress()
case err := <-done:
flushProgress()
if err != nil {
if errors.Is(err, ollama.ErrModelNotFound) {
return &serverError{
Status: 404,
Code: "not_found",
Message: fmt.Sprintf("model %q not found", p.model()),
}
} else {
return err
}
}
// Emulate old client pull progress (for now):
enc.Encode(progressUpdateJSON{Status: "verifying sha256 digest"})
enc.Encode(progressUpdateJSON{Status: "writing manifest"})
enc.Encode(progressUpdateJSON{Status: "success"})
return nil
}
}
}
func decodeUserJSON[T any](r io.Reader) (T, error) {
var v T
err := json.NewDecoder(r).Decode(&v)
if err == nil {
return v, nil
}
var zero T
// Not sure why, but I can't seem to be able to use:
//
// errors.As(err, &json.UnmarshalTypeError{})
//
// This is working fine in stdlib, so I'm not sure what rules changed
// and why this no longer works here. So, we do it the verbose way.
var a *json.UnmarshalTypeError
var b *json.SyntaxError
if errors.As(err, &a) || errors.As(err, &b) {
err = &serverError{Status: 400, Message: err.Error(), Code: "bad_request"}
}
if errors.Is(err, io.EOF) {
err = &serverError{Status: 400, Message: "empty request body", Code: "bad_request"}
}
return zero, err
}
func canRetry(err error) bool {
if err == nil {
return false
}
var oe *ollama.Error
if errors.As(err, &oe) {
return oe.Temporary()
}
s := err.Error()
return cmp.Or(
errors.Is(err, context.DeadlineExceeded),
strings.Contains(s, "unreachable"),
strings.Contains(s, "no route to host"),
strings.Contains(s, "connection reset by peer"),
)
}

View File

@@ -0,0 +1,302 @@
package registry
import (
"bytes"
"context"
"encoding/json"
"io"
"io/fs"
"net"
"net/http"
"net/http/httptest"
"os"
"regexp"
"strings"
"sync"
"testing"
"github.com/ollama/ollama/server/internal/cache/blob"
"github.com/ollama/ollama/server/internal/client/ollama"
"github.com/ollama/ollama/server/internal/testutil"
"golang.org/x/tools/txtar"
_ "embed"
)
type panicTransport struct{}
func (t *panicTransport) RoundTrip(r *http.Request) (*http.Response, error) {
panic("unexpected RoundTrip call")
}
var panicOnRoundTrip = &http.Client{Transport: &panicTransport{}}
// bytesResetter is an interface for types that can be reset and return a byte
// slice, only. This is to prevent inadvertent use of bytes.Buffer.Read/Write
// etc for the purpose of checking logs.
type bytesResetter interface {
Bytes() []byte
Reset()
}
func newTestServer(t *testing.T, upstreamRegistry http.HandlerFunc) *Local {
t.Helper()
dir := t.TempDir()
err := os.CopyFS(dir, os.DirFS("testdata/models"))
if err != nil {
t.Fatal(err)
}
c, err := blob.Open(dir)
if err != nil {
t.Fatal(err)
}
client := panicOnRoundTrip
if upstreamRegistry != nil {
s := httptest.NewTLSServer(upstreamRegistry)
t.Cleanup(s.Close)
tr := s.Client().Transport.(*http.Transport).Clone()
tr.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
var d net.Dialer
return d.DialContext(ctx, "tcp", s.Listener.Addr().String())
}
client = &http.Client{Transport: tr}
}
rc := &ollama.Registry{
Cache: c,
HTTPClient: client,
Mask: "example.com/library/_:latest",
}
l := &Local{
Client: rc,
Logger: testutil.Slogger(t),
}
return l
}
func (s *Local) send(t *testing.T, method, path, body string) *httptest.ResponseRecorder {
t.Helper()
ctx := ollama.WithTrace(t.Context(), &ollama.Trace{
Update: func(l *ollama.Layer, n int64, err error) {
t.Logf("update: %s %d %v", l.Digest, n, err)
},
})
req := httptest.NewRequestWithContext(ctx, method, path, strings.NewReader(body))
return s.sendRequest(t, req)
}
func (s *Local) sendRequest(t *testing.T, req *http.Request) *httptest.ResponseRecorder {
t.Helper()
w := httptest.NewRecorder()
s.ServeHTTP(w, req)
return w
}
type invalidReader struct{}
func (r *invalidReader) Read(p []byte) (int, error) {
return 0, os.ErrInvalid
}
// captureLogs is a helper to capture logs from the server. It returns a
// shallow copy of the server with a new logger and a bytesResetter for the
// logs.
func captureLogs(t *testing.T, s *Local) (*Local, bytesResetter) {
t.Helper()
log, logs := testutil.SlogBuffer()
l := *s // shallow copy
l.Logger = log
return &l, logs
}
func TestServerDelete(t *testing.T) {
check := testutil.Checker(t)
s := newTestServer(t, nil)
_, err := s.Client.ResolveLocal("smol")
check(err)
got := s.send(t, "DELETE", "/api/delete", `{"model": "smol"}`)
if got.Code != 200 {
t.Fatalf("Code = %d; want 200", got.Code)
}
_, err = s.Client.ResolveLocal("smol")
if err == nil {
t.Fatal("expected smol to have been deleted")
}
got = s.send(t, "DELETE", "/api/delete", `!`)
checkErrorResponse(t, got, 400, "bad_request", "invalid character '!' looking for beginning of value")
got = s.send(t, "GET", "/api/delete", `{"model": "smol"}`)
checkErrorResponse(t, got, 405, "method_not_allowed", "method not allowed")
got = s.send(t, "DELETE", "/api/delete", ``)
checkErrorResponse(t, got, 400, "bad_request", "empty request body")
got = s.send(t, "DELETE", "/api/delete", `{"model": "://"}`)
checkErrorResponse(t, got, 400, "bad_request", "invalid or missing name")
got = s.send(t, "DELETE", "/unknown_path", `{}`) // valid body
checkErrorResponse(t, got, 404, "not_found", "not found")
s, logs := captureLogs(t, s)
req := httptest.NewRequestWithContext(t.Context(), "DELETE", "/api/delete", &invalidReader{})
got = s.sendRequest(t, req)
checkErrorResponse(t, got, 500, "internal_error", "internal server error")
ok, err := regexp.Match(`ERROR.*error="invalid argument"`, logs.Bytes())
check(err)
if !ok {
t.Logf("logs:\n%s", logs)
t.Fatalf("expected log to contain ERROR with invalid argument")
}
}
//go:embed testdata/registry.txt
var registryTXT []byte
var registryFS = sync.OnceValue(func() fs.FS {
// Txtar gets hung up on \r\n line endings, so we need to convert them
// to \n when parsing the txtar on Windows.
data := bytes.ReplaceAll(registryTXT, []byte("\r\n"), []byte("\n"))
a := txtar.Parse(data)
fsys, err := txtar.FS(a)
if err != nil {
panic(err)
}
return fsys
})
func TestServerPull(t *testing.T) {
modelsHandler := http.FileServerFS(registryFS())
s := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/v2/library/BOOM/manifests/latest":
w.WriteHeader(999)
io.WriteString(w, `{"error": "boom"}`)
case "/v2/library/unknown/manifests/latest":
w.WriteHeader(404)
io.WriteString(w, `{"errors": [{"code": "MANIFEST_UNKNOWN", "message": "manifest unknown"}]}`)
default:
t.Logf("serving blob: %s", r.URL.Path)
modelsHandler.ServeHTTP(w, r)
}
})
checkResponse := func(got *httptest.ResponseRecorder, wantlines string) {
t.Helper()
if got.Code != 200 {
t.Errorf("Code = %d; want 200", got.Code)
}
gotlines := got.Body.String()
if strings.TrimSpace(gotlines) == "" {
gotlines = "<empty>"
}
t.Logf("got:\n%s", gotlines)
for want := range strings.Lines(wantlines) {
want = strings.TrimSpace(want)
want, unwanted := strings.CutPrefix(want, "!")
want = strings.TrimSpace(want)
if !unwanted && !strings.Contains(gotlines, want) {
t.Errorf("\t! missing %q in body", want)
}
if unwanted && strings.Contains(gotlines, want) {
t.Errorf("\t! unexpected %q in body", want)
}
}
}
got := s.send(t, "POST", "/api/pull", `{"model": "smol"}`)
checkResponse(got, `
{"status":"pulling manifest"}
{"digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5,"completed":5}
{"status":"verifying sha256 digest"}
{"status":"writing manifest"}
{"status":"success"}
`)
got = s.send(t, "POST", "/api/pull", `{"model": "unknown"}`)
checkResponse(got, `
{"code":"not_found","error":"model \"unknown\" not found"}
`)
got = s.send(t, "DELETE", "/api/pull", `{"model": "smol"}`)
checkErrorResponse(t, got, 405, "method_not_allowed", "method not allowed")
got = s.send(t, "POST", "/api/pull", `!`)
checkErrorResponse(t, got, 400, "bad_request", "invalid character '!' looking for beginning of value")
got = s.send(t, "POST", "/api/pull", ``)
checkErrorResponse(t, got, 400, "bad_request", "empty request body")
got = s.send(t, "POST", "/api/pull", `{"model": "://"}`)
checkResponse(got, `
{"code":"bad_request","error":"invalid or missing name: \"\""}
`)
// Non-streaming pulls
got = s.send(t, "POST", "/api/pull", `{"model": "://", "stream": false}`)
checkErrorResponse(t, got, 400, "bad_request", "invalid or missing name")
got = s.send(t, "POST", "/api/pull", `{"model": "smol", "stream": false}`)
checkResponse(got, `
{"status":"success"}
!digest
!total
!completed
`)
got = s.send(t, "POST", "/api/pull", `{"model": "unknown", "stream": false}`)
checkErrorResponse(t, got, 404, "not_found", "model not found")
}
func TestServerUnknownPath(t *testing.T) {
s := newTestServer(t, nil)
got := s.send(t, "DELETE", "/api/unknown", `{}`)
checkErrorResponse(t, got, 404, "not_found", "not found")
var fellback bool
s.Fallback = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fellback = true
})
got = s.send(t, "DELETE", "/api/unknown", `{}`)
if !fellback {
t.Fatal("expected Fallback to be called")
}
if got.Code != 200 {
t.Fatalf("Code = %d; want 200", got.Code)
}
}
func checkErrorResponse(t *testing.T, got *httptest.ResponseRecorder, status int, code, msg string) {
t.Helper()
var printedBody bool
errorf := func(format string, args ...any) {
t.Helper()
if !printedBody {
t.Logf("BODY:\n%s", got.Body.String())
printedBody = true
}
t.Errorf(format, args...)
}
if got.Code != status {
errorf("Code = %d; want %d", got.Code, status)
}
// unmarshal the error as *ollama.Error (proving *serverError is an *ollama.Error)
var e *ollama.Error
if err := json.Unmarshal(got.Body.Bytes(), &e); err != nil {
errorf("unmarshal error: %v", err)
t.FailNow()
}
if e.Code != code {
errorf("Code = %q; want %q", e.Code, code)
}
if !strings.Contains(e.Message, msg) {
errorf("Message = %q; want to contain %q", e.Message, msg)
}
}

View File

@@ -0,0 +1 @@
{"schemaVersion":2,"mediaType":"application/vnd.docker.distribution.manifest.v2+json","config":{"mediaType":"application/vnd.docker.container.image.v1+json","digest":"sha256:ca239d7bd8ea90e4a5d2e6bf88f8d74a47b14336e73eb4e18bed4dd325018116","size":267},"layers":[{"mediaType":"application/vnd.ollama.image.model","digest":"sha256:a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99","size":24}]}

View File

@@ -0,0 +1 @@
{"schemaVersion":2,"mediaType":"application/vnd.docker.distribution.manifest.v2+json","config":{"mediaType":"application/vnd.docker.container.image.v1+json","digest":"sha256:ca239d7bd8ea90e4a5d2e6bf88f8d74a47b14336e73eb4e18bed4dd325018116","size":267},"layers":[{"mediaType":"application/vnd.ollama.image.model","digest":"sha256:a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99","size":24}]}

View File

@@ -0,0 +1,22 @@
-- v2/library/smol/manifests/latest --
{
"schemaVersion": 2,
"mediaType": "application/vnd.docker.distribution.manifest.v2+json",
"config": {
"mediaType": "application/vnd.docker.container.image.v1+json",
"digest": "sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356",
"size": 3
},
"layers": [
{
"mediaType": "application/vnd.ollama.image.model",
"digest": "sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312",
"size": 5
}
]
}
-- v2/library/smol/blobs/sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312 --
GGUF
-- v2/library/smol/blobs/sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356 --
{}

View File

@@ -0,0 +1,102 @@
package testutil
import (
"bytes"
"io"
"log/slog"
"os"
"path/filepath"
"testing"
"time"
)
// LogWriter returns an [io.Writer] that logs each Write using t.Log.
func LogWriter(t *testing.T) io.Writer {
return testWriter{t}
}
type testWriter struct{ t *testing.T }
func (w testWriter) Write(b []byte) (int, error) {
w.t.Logf("%s", b)
return len(b), nil
}
// Slogger returns a [*slog.Logger] that writes each message
// using t.Log.
func Slogger(t *testing.T) *slog.Logger {
return slog.New(slog.NewTextHandler(LogWriter(t), nil))
}
// SlogBuffer returns a [*slog.Logger] that writes each message to out.
func SlogBuffer() (lg *slog.Logger, out *bytes.Buffer) {
var buf bytes.Buffer
lg = slog.New(slog.NewTextHandler(&buf, nil))
return lg, &buf
}
// Check calls t.Fatal(err) if err is not nil.
func Check(t *testing.T, err error) {
if err != nil {
t.Helper()
t.Fatal(err)
}
}
// CheckFunc exists so other packages do not need to invent their own type for
// taking a Check function.
type CheckFunc func(err error)
// Checker returns a check function that
// calls t.Fatal if err is not nil.
func Checker(t *testing.T) (check func(err error)) {
return func(err error) {
if err != nil {
t.Helper()
t.Fatal(err)
}
}
}
// StopPanic runs f but silently recovers from any panic f causes.
// The normal usage is:
//
// testutil.StopPanic(func() {
// callThatShouldPanic()
// t.Errorf("callThatShouldPanic did not panic")
// })
func StopPanic(f func()) {
defer func() { recover() }()
f()
}
// CheckTime calls t.Fatalf if got != want. Included in the error message is
// want.Sub(got) to help diagnose the difference, along with their values in
// UTC.
func CheckTime(t *testing.T, got, want time.Time) {
t.Helper()
if !got.Equal(want) {
t.Fatalf("got %v, want %v (%v)", got.UTC(), want.UTC(), want.Sub(got))
}
}
// WriteFile writes data to a file named name. It makes the directory if it
// doesn't exist and sets the file mode to perm.
//
// The name must be a relative path and must not contain .. or start with a /;
// otherwise WriteFile will panic.
func WriteFile[S []byte | string](t testing.TB, name string, data S) {
t.Helper()
if filepath.IsAbs(name) {
t.Fatalf("WriteFile: name must be a relative path, got %q", name)
}
name = filepath.Clean(name)
dir := filepath.Dir(name)
if err := os.MkdirAll(dir, 0o755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(name, []byte(data), 0o644); err != nil {
t.Fatal(err)
}
}

View File

@@ -0,0 +1,90 @@
package server
import (
"testing"
fsggml "github.com/ollama/ollama/fs/ggml"
)
func TestLagunaGGUFQuantization(t *testing.T) {
cases := []struct {
name string
tensor string
originalType fsggml.TensorType
requestedType fsggml.TensorType
fileType fsggml.FileType
blockCount int
wantType fsggml.TensorType
wantQuantize bool
}{
{
name: "non_routed_weights_preserved",
tensor: "blk.1.attn_q.weight",
originalType: fsggml.TensorTypeBF16,
requestedType: fsggml.TensorTypeQ8_0,
fileType: fsggml.FileTypeQ8_0,
blockCount: 2,
wantType: fsggml.TensorTypeBF16,
wantQuantize: false,
},
{
name: "shared_expert_weights_preserved",
tensor: "blk.1.ffn_gate_shexp.weight",
originalType: fsggml.TensorTypeBF16,
requestedType: fsggml.TensorTypeQ4_K,
fileType: fsggml.FileTypeQ4_K_M,
blockCount: 2,
wantType: fsggml.TensorTypeBF16,
wantQuantize: false,
},
{
name: "routed_gate_q8",
tensor: "blk.1.ffn_gate_exps.weight",
originalType: fsggml.TensorTypeBF16,
requestedType: fsggml.TensorTypeQ8_0,
fileType: fsggml.FileTypeQ8_0,
blockCount: 2,
wantType: fsggml.TensorTypeQ8_0,
wantQuantize: true,
},
{
name: "routed_down_q4_promoted",
tensor: "blk.1.ffn_down_exps.weight",
originalType: fsggml.TensorTypeBF16,
requestedType: fsggml.TensorTypeQ4_K,
fileType: fsggml.FileTypeQ4_K_M,
blockCount: 2,
wantType: fsggml.TensorTypeQ6_K,
wantQuantize: true,
},
{
name: "routed_down_q4_not_promoted_when_q8_requested",
tensor: "blk.1.ffn_down_exps.weight",
originalType: fsggml.TensorTypeBF16,
requestedType: fsggml.TensorTypeQ8_0,
fileType: fsggml.FileTypeQ4_K_M,
blockCount: 2,
wantType: fsggml.TensorTypeQ8_0,
wantQuantize: true,
},
{
name: "routed_down_q4_k_s_promoted",
tensor: "blk.0.ffn_down_exps.weight",
originalType: fsggml.TensorTypeBF16,
requestedType: fsggml.TensorTypeQ4_K,
fileType: fsggml.FileTypeQ4_K_S,
blockCount: 8,
wantType: fsggml.TensorTypeQ5_K,
wantQuantize: true,
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
gotType, gotQuantize := lagunaGGUFQuantization(tt.tensor, tt.originalType, tt.requestedType, tt.fileType, tt.blockCount)
if gotType != tt.wantType || gotQuantize != tt.wantQuantize {
t.Fatalf("lagunaGGUFQuantization(%q) = (%s, %v), want (%s, %v)", tt.tensor, gotType, gotQuantize, tt.wantType, tt.wantQuantize)
}
})
}
}

44
server/logprob.go Normal file
View File

@@ -0,0 +1,44 @@
package server
import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
)
// toAPILogprobs converts llm.Logprobs to api.Logprobs
func toAPILogprobs(logprobs []llm.Logprob) []api.Logprob {
result := make([]api.Logprob, len(logprobs))
for i, lp := range logprobs {
result[i] = api.Logprob{
TokenLogprob: api.TokenLogprob{
Token: lp.Token,
Bytes: stringToByteInts(lp.Token),
Logprob: lp.Logprob,
},
}
if len(lp.TopLogprobs) > 0 {
result[i].TopLogprobs = make([]api.TokenLogprob, len(lp.TopLogprobs))
for j, tlp := range lp.TopLogprobs {
result[i].TopLogprobs[j] = api.TokenLogprob{
Token: tlp.Token,
Bytes: stringToByteInts(tlp.Token),
Logprob: tlp.Logprob,
}
}
}
}
return result
}
func stringToByteInts(s string) []int {
if s == "" {
return nil
}
raw := []byte(s)
ints := make([]int, len(raw))
for i, b := range raw {
ints[i] = int(b)
}
return ints
}

129
server/model.go Normal file
View File

@@ -0,0 +1,129 @@
package server
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"net/http"
"os"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/model"
)
var intermediateBlobs map[string]string = make(map[string]string)
type layerGGML struct {
manifest.Layer
*ggml.GGML
}
func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressResponse)) (layers []*layerGGML, err error) {
m, err := manifest.ParseNamedManifest(name)
switch {
case errors.Is(err, os.ErrNotExist):
if err := PullModel(ctx, name.String(), &registryOptions{}, fn); err != nil {
return nil, err
}
m, err = manifest.ParseNamedManifest(name)
if err != nil {
return nil, err
}
case err != nil:
return nil, err
}
for _, srcLayer := range m.Layers {
layer, err := manifest.NewLayerFromLayer(srcLayer.Digest, srcLayer.MediaType, name.DisplayShortest())
if err != nil {
return nil, err
}
layer.Name = srcLayer.Name
switch layer.MediaType {
case "application/vnd.ollama.image.model",
"application/vnd.ollama.image.projector",
"application/vnd.ollama.image.adapter":
blobpath, err := manifest.BlobsPath(layer.Digest)
if err != nil {
return nil, err
}
blob, err := os.Open(blobpath)
if err != nil {
return nil, err
}
defer blob.Close()
f, err := ggml.Decode(blob, -1)
if err != nil {
return nil, err
}
layers = append(layers, &layerGGML{layer, f})
default:
layers = append(layers, &layerGGML{layer, nil})
}
}
return layers, nil
}
func detectChatTemplate(layers []*layerGGML) ([]*layerGGML, error) {
for _, layer := range layers {
if s := layer.GGML.KV().ChatTemplate(); s != "" {
if t, err := template.Named(s); err != nil {
slog.Debug("template detection", "error", err, "template", s)
} else {
layer, err := manifest.NewLayer(t.Reader(), "application/vnd.ollama.image.template")
if err != nil {
return nil, err
}
layer.Status = fmt.Sprintf("using autodetected template %s", t.Name)
layers = append(layers, &layerGGML{layer, nil})
if t.Parameters != nil {
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(t.Parameters); err != nil {
return nil, err
}
layer, err := manifest.NewLayer(&b, "application/vnd.ollama.image.params")
if err != nil {
return nil, err
}
layers = append(layers, &layerGGML{layer, nil})
}
}
}
}
return layers, nil
}
func detectContentType(r io.Reader) (string, error) {
var b bytes.Buffer
if _, err := io.Copy(&b, r); err != nil {
return "", err
}
if contentType := ggml.DetectContentType(b.Bytes()); contentType != "" {
return contentType, nil
}
if contentType := http.DetectContentType(b.Bytes()); contentType != "application/octet-stream" {
return contentType, nil
}
return "unknown", nil
}

32
server/model_caches.go Normal file
View File

@@ -0,0 +1,32 @@
package server
import "context"
type modelCaches struct {
recommendations *modelRecommendationsCache
show *modelShowCache
modelList *modelListCache
}
func newModelCaches() *modelCaches {
return &modelCaches{
recommendations: newModelRecommendationsCache(),
show: newModelShowCache(),
modelList: newModelListCache(),
}
}
func (c *modelCaches) Start(ctx context.Context) {
if c == nil {
return
}
if c.recommendations != nil {
c.recommendations.Start(ctx)
}
if c.show != nil {
c.show.Start(ctx)
}
if c.modelList != nil {
c.modelList.Start(ctx)
}
}

824
server/model_list_cache.go Normal file
View File

@@ -0,0 +1,824 @@
package server
import (
"bufio"
"cmp"
"context"
"encoding/binary"
"encoding/json"
"fmt"
"io"
"log/slog"
"os"
"slices"
"strings"
"sync"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/model/parsers"
ollamatemplate "github.com/ollama/ollama/template"
"github.com/ollama/ollama/thinking"
"github.com/ollama/ollama/types/model"
)
type modelListSummary struct {
Model string
Name string
RemoteModel string
RemoteHost string
Size int64
Digest string
ModifiedAt time.Time
Details api.ModelDetails
Capabilities []model.Capability
}
type modelListCacheEntry struct {
Digest string
Summary modelListSummary
}
type modelListCache struct {
mu sync.RWMutex
entries map[string]modelListCacheEntry
once sync.Once
readyOnce sync.Once
ready chan struct{}
hydrateErr error
build func(model.Name, *manifest.Manifest) (modelListSummary, error)
}
func newModelListCache() *modelListCache {
return &modelListCache{
entries: make(map[string]modelListCacheEntry),
ready: make(chan struct{}),
build: buildModelListSummary,
}
}
func (c *modelListCache) Start(ctx context.Context) {
if c == nil {
return
}
c.once.Do(func() {
slog.Debug("starting model list cache")
go func() {
err := c.hydrate(ctx)
c.markReady(err)
if err != nil {
if ctx != nil && ctx.Err() != nil {
return
}
slog.Warn("model list cache hydration failed", "error", err)
}
}()
})
}
func (c *modelListCache) hydrate(ctx context.Context) error {
start := time.Now()
manifests, err := manifest.Manifests(true)
if err != nil {
return err
}
var hydrated, failed int
for name, mf := range manifests {
if ctx != nil {
if err := ctx.Err(); err != nil {
return err
}
}
summary, err := c.build(name, mf)
if err != nil {
failed++
slog.Warn("failed to hydrate model list cache", "model", name.String(), "error", err)
continue
}
c.set(name, mf.Digest(), summary)
hydrated++
}
slog.Info("model list cache hydration complete", "models", hydrated, "failures", failed, "elapsed", time.Since(start))
return nil
}
func (c *modelListCache) markReady(err error) {
c.mu.Lock()
c.hydrateErr = err
c.mu.Unlock()
c.readyOnce.Do(func() {
close(c.ready)
})
}
func (c *modelListCache) Wait(ctx context.Context) error {
if c == nil {
return nil
}
if ctx == nil {
ctx = context.Background()
}
select {
case <-c.ready:
c.mu.RLock()
err := c.hydrateErr
c.mu.RUnlock()
return err
case <-ctx.Done():
return ctx.Err()
}
}
func (c *modelListCache) List(ctx context.Context) ([]api.ListModelResponse, error) {
if err := c.Wait(ctx); err != nil {
return nil, err
}
if err := c.syncManifests(ctx); err != nil {
return nil, err
}
c.mu.RLock()
models := make([]api.ListModelResponse, 0, len(c.entries))
for _, entry := range c.entries {
models = append(models, entry.Summary.ListModelResponse())
}
c.mu.RUnlock()
sortListModelResponses(models)
return models, nil
}
func (c *modelListCache) syncManifests(ctx context.Context) error {
manifests, err := manifest.Manifests(true)
if err != nil {
return err
}
c.mu.RLock()
current := make(map[string]string, len(c.entries))
for name, entry := range c.entries {
current[name] = entry.Digest
}
c.mu.RUnlock()
type update struct {
name model.Name
digest string
summary modelListSummary
}
seen := make(map[string]struct{}, len(manifests))
stale := make(map[string]struct{})
var updates []update
for name, mf := range manifests {
if ctx != nil {
if err := ctx.Err(); err != nil {
return err
}
}
key := name.String()
digest := mf.Digest()
seen[key] = struct{}{}
if current[key] == digest {
continue
}
summary, err := c.build(name, mf)
if err != nil {
slog.Warn("failed to refresh model list cache", "model", key, "error", err)
if _, ok := current[key]; ok {
stale[key] = struct{}{}
}
continue
}
updates = append(updates, update{name: name, digest: digest, summary: summary})
}
c.mu.Lock()
for name := range c.entries {
if _, ok := seen[name]; !ok {
delete(c.entries, name)
continue
}
if _, ok := stale[name]; ok {
delete(c.entries, name)
}
}
for _, update := range updates {
c.entries[update.name.String()] = modelListCacheEntry{
Digest: update.digest,
Summary: cloneModelListSummary(update.summary),
}
}
c.mu.Unlock()
return nil
}
func (c *modelListCache) RefreshModel(name model.Name) error {
if c == nil {
return nil
}
if !name.IsFullyQualified() {
var err error
name, err = getExistingName(name)
if err != nil {
return err
}
}
mf, err := manifest.ParseNamedManifest(name)
if err != nil {
c.DeleteModel(name)
return err
}
summary, err := c.build(name, mf)
if err != nil {
c.DeleteModel(name)
return err
}
c.set(name, mf.Digest(), summary)
return nil
}
func (c *modelListCache) DeleteModel(name model.Name) {
if c == nil {
return
}
c.mu.Lock()
delete(c.entries, name.String())
c.mu.Unlock()
}
func (c *modelListCache) Get(name model.Name) (modelListSummary, bool) {
if c == nil {
return modelListSummary{}, false
}
if !name.IsFullyQualified() {
if existing, err := getExistingName(name); err == nil {
name = existing
}
}
c.mu.RLock()
entry, ok := c.entries[name.String()]
c.mu.RUnlock()
if !ok {
return modelListSummary{}, false
}
return cloneModelListSummary(entry.Summary), true
}
func (c *modelListCache) Len() int {
if c == nil {
return 0
}
c.mu.RLock()
defer c.mu.RUnlock()
return len(c.entries)
}
func (c *modelListCache) set(name model.Name, digest string, summary modelListSummary) {
c.mu.Lock()
c.entries[name.String()] = modelListCacheEntry{
Digest: digest,
Summary: cloneModelListSummary(summary),
}
c.mu.Unlock()
}
func buildModelListSummary(name model.Name, mf *manifest.Manifest) (modelListSummary, error) {
cfg, err := readModelListConfig(mf)
if err != nil {
return modelListSummary{}, err
}
var modified time.Time
if fi := mf.FileInfo(); fi != nil {
modified = fi.ModTime()
}
summary := modelListSummary{
Model: name.DisplayShortest(),
Name: name.DisplayShortest(),
RemoteModel: cfg.RemoteModel,
RemoteHost: cfg.RemoteHost,
Size: mf.Size(),
Digest: mf.Digest(),
ModifiedAt: modified,
Details: api.ModelDetails{
Format: cfg.ModelFormat,
Family: cfg.ModelFamily,
Families: append([]string(nil), cfg.ModelFamilies...),
ParameterSize: cfg.ModelType,
QuantizationLevel: cfg.FileType,
ContextLength: cfg.ContextLen,
EmbeddingLength: cfg.EmbedLen,
},
}
modelPath, projectorCount, tmpl, err := readModelListLayers(mf, &summary)
if err != nil {
return modelListSummary{}, err
}
if cfg.RemoteHost == "" && cfg.RemoteModel == "" && modelPath != "" {
info, err := readModelListGGUF(modelPath)
if err != nil {
slog.Debug("failed to read gguf model metadata", "model", name.String(), "error", err)
} else {
summary.Capabilities = appendModelListCapabilities(summary.Capabilities, info.Capabilities...)
if summary.Details.ContextLength == 0 {
summary.Details.ContextLength = info.ContextLength
}
if summary.Details.EmbeddingLength == 0 {
summary.Details.EmbeddingLength = info.EmbeddingLength
}
}
}
for _, c := range cfg.Capabilities {
summary.Capabilities = appendModelListCapability(summary.Capabilities, model.Capability(c))
}
builtinParser := parsers.ParserForName(cfg.Parser)
if tmpl != nil {
vars, err := tmpl.Vars()
if err != nil {
slog.Warn("model template contains errors", "model", name.String(), "error", err)
}
if slices.Contains(vars, "tools") || (builtinParser != nil && builtinParser.HasToolSupport()) {
summary.Capabilities = appendModelListCapability(summary.Capabilities, model.CapabilityTools)
}
if slices.Contains(vars, "suffix") {
summary.Capabilities = appendModelListCapability(summary.Capabilities, model.CapabilityInsert)
}
openingTag, closingTag := thinking.InferTags(tmpl.Template)
hasTags := openingTag != "" && closingTag != ""
isGptoss := slices.Contains([]string{"gptoss", "gpt-oss"}, cfg.ModelFamily)
if !slices.Contains(summary.Capabilities, model.CapabilityThinking) &&
(hasTags || isGptoss || (builtinParser != nil && builtinParser.HasThinkingSupport())) {
summary.Capabilities = appendModelListCapability(summary.Capabilities, model.CapabilityThinking)
}
}
if projectorCount > 0 {
summary.Capabilities = appendModelListCapability(summary.Capabilities, model.CapabilityVision)
}
if cfg.ModelFormat == "safetensors" && isGemma4Renderer(cfg.Renderer) {
summary.Capabilities = slices.DeleteFunc(summary.Capabilities, func(c model.Capability) bool {
return c == model.CapabilityVision || c == model.CapabilityAudio
})
}
return summary, nil
}
func readModelListConfig(mf *manifest.Manifest) (model.ConfigV2, error) {
var cfg model.ConfigV2
if mf == nil || mf.Config.Digest == "" {
return cfg, nil
}
f, err := mf.Config.Open()
if err != nil {
return cfg, err
}
defer f.Close()
if err := json.NewDecoder(f).Decode(&cfg); err != nil {
return cfg, err
}
return cfg, nil
}
func readModelListLayers(mf *manifest.Manifest, summary *modelListSummary) (string, int, *ollamatemplate.Template, error) {
var modelPath string
var projectorCount int
tmpl := ollamatemplate.DefaultTemplate
for _, layer := range mf.Layers {
switch layer.MediaType {
case "application/vnd.ollama.image.model":
filename, err := manifest.BlobsPath(layer.Digest)
if err != nil {
return "", 0, nil, err
}
modelPath = filename
summary.Details.ParentModel = layer.From
case "application/vnd.ollama.image.projector":
projectorCount++
case "application/vnd.ollama.image.prompt",
"application/vnd.ollama.image.template":
filename, err := manifest.BlobsPath(layer.Digest)
if err != nil {
return "", 0, nil, err
}
bts, err := os.ReadFile(filename)
if err != nil {
return "", 0, nil, err
}
tmpl, err = ollamatemplate.Parse(string(bts))
if err != nil {
return "", 0, nil, err
}
}
}
return modelPath, projectorCount, tmpl, nil
}
type modelListGGUF struct {
Capabilities []model.Capability
ContextLength int
EmbeddingLength int
}
const (
modelListGGUFMagicLE = 0x46554747
modelListGGUFMagicBE = 0x47475546
)
const (
modelListGGUFTypeUint8 uint32 = iota
modelListGGUFTypeInt8
modelListGGUFTypeUint16
modelListGGUFTypeInt16
modelListGGUFTypeUint32
modelListGGUFTypeInt32
modelListGGUFTypeFloat32
modelListGGUFTypeBool
modelListGGUFTypeString
modelListGGUFTypeArray
modelListGGUFTypeUint64
modelListGGUFTypeInt64
modelListGGUFTypeFloat64
)
// readModelListGGUF scans only the small GGUF header values launch needs
// and stops before tokenizer arrays. Using gguf.File.KeyValue for missing keys
// can otherwise advance through large arrays just to discover absence.
func readModelListGGUF(path string) (modelListGGUF, error) {
f, err := os.Open(path)
if err != nil {
return modelListGGUF{}, err
}
defer f.Close()
r := bufio.NewReaderSize(f, 32<<10)
var magic uint32
if err := binary.Read(r, binary.LittleEndian, &magic); err != nil {
return modelListGGUF{}, err
}
var byteOrder binary.ByteOrder = binary.LittleEndian
switch magic {
case modelListGGUFMagicLE:
case modelListGGUFMagicBE:
byteOrder = binary.BigEndian
default:
return modelListGGUF{}, fmt.Errorf("invalid file magic")
}
var version uint32
if err := binary.Read(r, byteOrder, &version); err != nil {
return modelListGGUF{}, err
}
var numKV uint64
switch version {
case 1:
var header struct {
NumTensor uint32
NumKV uint32
}
if err := binary.Read(r, byteOrder, &header); err != nil {
return modelListGGUF{}, err
}
numKV = uint64(header.NumKV)
default:
var header struct {
NumTensor uint64
NumKV uint64
}
if err := binary.Read(r, byteOrder, &header); err != nil {
return modelListGGUF{}, err
}
numKV = header.NumKV
}
info := modelListGGUF{}
var architecture string
var hasPoolingType bool
for range numKV {
key, err := readModelListGGUFString(r, byteOrder, version)
if err != nil {
return modelListGGUF{}, err
}
var valueType uint32
if err := binary.Read(r, byteOrder, &valueType); err != nil {
return modelListGGUF{}, err
}
if key == "general.architecture" {
value, err := readModelListGGUFStringValue(r, byteOrder, version, valueType)
if err != nil {
return modelListGGUF{}, err
}
architecture = value
continue
}
if architecture != "" && strings.HasPrefix(key, "tokenizer.") {
break
}
if architecture != "" && strings.HasPrefix(key, architecture+".") {
switch strings.TrimPrefix(key, architecture+".") {
case "pooling_type":
hasPoolingType = true
case "vision.block_count":
info.Capabilities = appendModelListCapability(info.Capabilities, model.CapabilityVision)
case "audio.block_count":
info.Capabilities = appendModelListCapability(info.Capabilities, model.CapabilityAudio)
case "context_length":
value, err := readModelListGGUFIntValue(r, byteOrder, version, valueType)
if err != nil {
return modelListGGUF{}, err
}
info.ContextLength = value
continue
case "embedding_length":
value, err := readModelListGGUFIntValue(r, byteOrder, version, valueType)
if err != nil {
return modelListGGUF{}, err
}
info.EmbeddingLength = value
continue
}
}
if err := skipModelListGGUFValue(r, byteOrder, version, valueType); err != nil {
return modelListGGUF{}, err
}
}
if hasPoolingType {
info.Capabilities = appendModelListCapability(info.Capabilities, model.CapabilityEmbedding)
} else {
info.Capabilities = appendModelListCapability(info.Capabilities, model.CapabilityCompletion)
}
return info, nil
}
func readModelListGGUFStringValue(r io.Reader, byteOrder binary.ByteOrder, version uint32, valueType uint32) (string, error) {
if valueType != modelListGGUFTypeString {
if err := skipModelListGGUFValue(r, byteOrder, version, valueType); err != nil {
return "", err
}
return "", fmt.Errorf("unexpected gguf string type %d", valueType)
}
return readModelListGGUFString(r, byteOrder, version)
}
func readModelListGGUFIntValue(r io.Reader, byteOrder binary.ByteOrder, version uint32, valueType uint32) (int, error) {
switch valueType {
case modelListGGUFTypeUint8:
var value uint8
if err := binary.Read(r, byteOrder, &value); err != nil {
return 0, err
}
return int(value), nil
case modelListGGUFTypeInt8:
var value int8
if err := binary.Read(r, byteOrder, &value); err != nil {
return 0, err
}
return int(value), nil
case modelListGGUFTypeUint16:
var value uint16
if err := binary.Read(r, byteOrder, &value); err != nil {
return 0, err
}
return int(value), nil
case modelListGGUFTypeInt16:
var value int16
if err := binary.Read(r, byteOrder, &value); err != nil {
return 0, err
}
return int(value), nil
case modelListGGUFTypeUint32:
var value uint32
if err := binary.Read(r, byteOrder, &value); err != nil {
return 0, err
}
return int(value), nil
case modelListGGUFTypeInt32:
var value int32
if err := binary.Read(r, byteOrder, &value); err != nil {
return 0, err
}
return int(value), nil
case modelListGGUFTypeUint64:
var value uint64
if err := binary.Read(r, byteOrder, &value); err != nil {
return 0, err
}
return int(value), nil
case modelListGGUFTypeInt64:
var value int64
if err := binary.Read(r, byteOrder, &value); err != nil {
return 0, err
}
return int(value), nil
default:
if err := skipModelListGGUFValue(r, byteOrder, version, valueType); err != nil {
return 0, err
}
return 0, fmt.Errorf("unexpected gguf integer type %d", valueType)
}
}
func skipModelListGGUFValue(r io.Reader, byteOrder binary.ByteOrder, version uint32, valueType uint32) error {
switch valueType {
case modelListGGUFTypeUint8, modelListGGUFTypeInt8, modelListGGUFTypeBool:
return discardModelListGGUFBytes(r, 1)
case modelListGGUFTypeUint16, modelListGGUFTypeInt16:
return discardModelListGGUFBytes(r, 2)
case modelListGGUFTypeUint32, modelListGGUFTypeInt32, modelListGGUFTypeFloat32:
return discardModelListGGUFBytes(r, 4)
case modelListGGUFTypeUint64, modelListGGUFTypeInt64, modelListGGUFTypeFloat64:
return discardModelListGGUFBytes(r, 8)
case modelListGGUFTypeString:
return skipModelListGGUFString(r, byteOrder, version)
case modelListGGUFTypeArray:
var arrayType uint32
if err := binary.Read(r, byteOrder, &arrayType); err != nil {
return err
}
var count uint64
if err := binary.Read(r, byteOrder, &count); err != nil {
return err
}
return skipModelListGGUFArray(r, byteOrder, version, arrayType, count)
default:
return fmt.Errorf("unsupported gguf value type %d", valueType)
}
}
func skipModelListGGUFArray(r io.Reader, byteOrder binary.ByteOrder, version uint32, arrayType uint32, count uint64) error {
var size uint64
switch arrayType {
case modelListGGUFTypeUint8, modelListGGUFTypeInt8, modelListGGUFTypeBool:
size = 1
case modelListGGUFTypeUint16, modelListGGUFTypeInt16:
size = 2
case modelListGGUFTypeUint32, modelListGGUFTypeInt32, modelListGGUFTypeFloat32:
size = 4
case modelListGGUFTypeUint64, modelListGGUFTypeInt64, modelListGGUFTypeFloat64:
size = 8
case modelListGGUFTypeString:
for range count {
if err := skipModelListGGUFString(r, byteOrder, version); err != nil {
return err
}
}
return nil
default:
return fmt.Errorf("unsupported gguf array type %d", arrayType)
}
return discardModelListGGUFBytes(r, int64(count*size))
}
func readModelListGGUFString(r io.Reader, byteOrder binary.ByteOrder, version uint32) (string, error) {
var length uint64
if err := binary.Read(r, byteOrder, &length); err != nil {
return "", err
}
if length == 0 {
return "", nil
}
bts := make([]byte, length)
if _, err := io.ReadFull(r, bts); err != nil {
return "", err
}
if version == 1 && bts[len(bts)-1] == 0 {
bts = bts[:len(bts)-1]
}
return string(bts), nil
}
func skipModelListGGUFString(r io.Reader, byteOrder binary.ByteOrder, version uint32) error {
var length uint64
if err := binary.Read(r, byteOrder, &length); err != nil {
return err
}
return discardModelListGGUFBytes(r, int64(length))
}
func discardModelListGGUFBytes(r io.Reader, n int64) error {
if n <= 0 {
return nil
}
_, err := io.CopyN(io.Discard, r, n)
return err
}
func appendModelListCapabilities(capabilities []model.Capability, values ...model.Capability) []model.Capability {
for _, capability := range values {
capabilities = appendModelListCapability(capabilities, capability)
}
return capabilities
}
func appendModelListCapability(capabilities []model.Capability, capability model.Capability) []model.Capability {
if capability == "" || slices.Contains(capabilities, capability) {
return capabilities
}
return append(capabilities, capability)
}
func cloneModelListSummary(summary modelListSummary) modelListSummary {
summary.Details.Families = append([]string(nil), summary.Details.Families...)
summary.Capabilities = append([]model.Capability(nil), summary.Capabilities...)
return summary
}
func (s modelListSummary) ListModelResponse() api.ListModelResponse {
resp := api.ListModelResponse{
Model: s.Model,
Name: s.Name,
RemoteModel: s.RemoteModel,
RemoteHost: s.RemoteHost,
Size: s.Size,
Digest: s.Digest,
ModifiedAt: s.ModifiedAt,
Details: api.ModelDetails{
ParentModel: s.Details.ParentModel,
Format: s.Details.Format,
Family: s.Details.Family,
Families: append([]string(nil), s.Details.Families...),
ParameterSize: s.Details.ParameterSize,
QuantizationLevel: s.Details.QuantizationLevel,
ContextLength: s.Details.ContextLength,
EmbeddingLength: s.Details.EmbeddingLength,
},
}
resp.Capabilities = append([]model.Capability(nil), s.Capabilities...)
return resp
}
func sortListModelResponses(models []api.ListModelResponse) {
slices.SortStableFunc(models, func(i, j api.ListModelResponse) int {
// Preserve the existing /api/tags order: most recently modified first.
return cmp.Compare(j.ModifiedAt.Unix(), i.ModifiedAt.Unix())
})
}
func (s *Server) refreshModelListCache(name model.Name) {
if s == nil || s.modelCaches == nil || s.modelCaches.modelList == nil {
return
}
if err := s.modelCaches.modelList.RefreshModel(name); err != nil {
slog.Warn("failed to refresh model list cache", "model", name.String(), "error", err)
}
}
func (s *Server) deleteModelListCache(name model.Name) {
if s == nil || s.modelCaches == nil || s.modelCaches.modelList == nil {
return
}
s.modelCaches.modelList.DeleteModel(name)
}

View File

@@ -0,0 +1,239 @@
package server
import (
"context"
"errors"
"net/http"
"slices"
"testing"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/types/model"
)
func TestModelListCacheHydratesSummary(t *testing.T) {
gin.SetMode(gin.TestMode)
setTestHome(t, t.TempDir())
createListCacheModel(t, "list-cache", map[string]any{
"test.context_length": uint32(4096),
"test.embedding_length": uint32(384),
}, "{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}{{ if .suffix }}{{ .suffix }}{{ end }}")
cache := newModelListCache()
if err := cache.hydrate(context.Background()); err != nil {
t.Fatalf("hydrate failed: %v", err)
}
summary, ok := cache.Get(model.ParseName("list-cache"))
if !ok {
t.Fatal("list summary missing")
}
if summary.Model != "list-cache:latest" || summary.Name != "list-cache:latest" {
t.Fatalf("summary model/name = %q/%q, want list-cache:latest", summary.Model, summary.Name)
}
if summary.Digest == "" {
t.Fatal("summary digest is empty")
}
if summary.Size == 0 {
t.Fatal("summary size is zero")
}
if summary.Details.Family != "test" || summary.Details.Format != "gguf" {
t.Fatalf("summary details = %+v, want gguf/test", summary.Details)
}
if summary.Details.ContextLength != 4096 {
t.Fatalf("context length = %d, want 4096", summary.Details.ContextLength)
}
if summary.Details.EmbeddingLength != 384 {
t.Fatalf("embedding length = %d, want 384", summary.Details.EmbeddingLength)
}
for _, capability := range []model.Capability{model.CapabilityCompletion, model.CapabilityTools, model.CapabilityInsert} {
if !slices.Contains(summary.Capabilities, capability) {
t.Fatalf("capabilities = %v, want %s", summary.Capabilities, capability)
}
}
listModel := summary.ListModelResponse()
if !slices.Contains(listModel.Capabilities, model.CapabilityTools) ||
listModel.Details.ContextLength != 4096 ||
listModel.Details.EmbeddingLength != 384 {
t.Fatalf("list response = %+v, want capabilities/context/embedding", listModel)
}
}
func TestModelListCacheRefreshUpdatesEntry(t *testing.T) {
gin.SetMode(gin.TestMode)
setTestHome(t, t.TempDir())
createListCacheModel(t, "list-refresh", map[string]any{"test.context_length": uint32(1024)}, "")
cache := newModelListCache()
if err := cache.hydrate(context.Background()); err != nil {
t.Fatalf("hydrate failed: %v", err)
}
name := model.ParseName("list-refresh")
first, ok := cache.Get(name)
if !ok {
t.Fatal("list summary missing")
}
changeShowCacheManifest(t, "list-refresh")
if err := cache.RefreshModel(name); err != nil {
t.Fatalf("refresh failed: %v", err)
}
refreshed, ok := cache.Get(name)
if !ok {
t.Fatal("refreshed list summary missing")
}
if refreshed.Digest == first.Digest {
t.Fatalf("digest did not change after refresh: %s", refreshed.Digest)
}
if cache.Len() != 1 {
t.Fatalf("cache entries = %d, want 1", cache.Len())
}
}
func TestModelListCacheMutationHooks(t *testing.T) {
gin.SetMode(gin.TestMode)
setTestHome(t, t.TempDir())
cache := newModelListCache()
s := Server{modelCaches: &modelCaches{modelList: cache}}
_, digest := createBinFile(t, map[string]any{"test.context_length": uint32(2048)}, nil)
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "list-hooks",
Files: map[string]string{"model.gguf": digest},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("create model status = %d, want 200: %s", w.Code, w.Body.String())
}
if _, ok := cache.Get(model.ParseName("list-hooks")); !ok {
t.Fatal("create did not refresh model list cache")
}
w = createRequest(t, s.CopyHandler, api.CopyRequest{
Source: "list-hooks",
Destination: "list-hooks-copy",
})
if w.Code != http.StatusOK {
t.Fatalf("copy model status = %d, want 200: %s", w.Code, w.Body.String())
}
if _, ok := cache.Get(model.ParseName("list-hooks-copy")); !ok {
t.Fatal("copy did not refresh model list cache")
}
w = createRequest(t, s.DeleteHandler, api.DeleteRequest{Model: "list-hooks-copy"})
if w.Code != http.StatusOK {
t.Fatalf("delete model status = %d, want 200: %s", w.Code, w.Body.String())
}
if _, ok := cache.Get(model.ParseName("list-hooks-copy")); ok {
t.Fatal("delete did not remove model list cache entry")
}
}
func TestModelListCacheSyncsManifestChanges(t *testing.T) {
gin.SetMode(gin.TestMode)
setTestHome(t, t.TempDir())
createListCacheModel(t, "list-sync-a", map[string]any{"test.context_length": uint32(1024)}, "")
cache := newModelListCache()
cache.Start(context.Background())
if err := cache.Wait(context.Background()); err != nil {
t.Fatalf("wait failed: %v", err)
}
createListCacheModel(t, "list-sync-b", map[string]any{"test.context_length": uint32(2048)}, "")
models, err := cache.List(context.Background())
if err != nil {
t.Fatalf("list failed: %v", err)
}
names := make([]string, 0, len(models))
for _, m := range models {
names = append(names, m.Name)
}
for _, want := range []string{"list-sync-a:latest", "list-sync-b:latest"} {
if !slices.Contains(names, want) {
t.Fatalf("names = %v, want %s", names, want)
}
}
var other Server
w := createRequest(t, other.DeleteHandler, api.DeleteRequest{Model: "list-sync-a"})
if w.Code != http.StatusOK {
t.Fatalf("delete model status = %d, want 200: %s", w.Code, w.Body.String())
}
models, err = cache.List(context.Background())
if err != nil {
t.Fatalf("list after delete failed: %v", err)
}
names = names[:0]
for _, m := range models {
names = append(names, m.Name)
}
if slices.Contains(names, "list-sync-a:latest") || !slices.Contains(names, "list-sync-b:latest") {
t.Fatalf("names after delete = %v, want only list-sync-b", names)
}
}
func TestModelListCacheSyncDropsStaleEntryOnRefreshFailure(t *testing.T) {
gin.SetMode(gin.TestMode)
setTestHome(t, t.TempDir())
createListCacheModel(t, "list-stale", map[string]any{"test.context_length": uint32(1024)}, "")
cache := newModelListCache()
cache.Start(context.Background())
if err := cache.Wait(context.Background()); err != nil {
t.Fatalf("wait failed: %v", err)
}
name := model.ParseName("list-stale")
if _, ok := cache.Get(name); !ok {
t.Fatal("list summary missing")
}
changeShowCacheManifest(t, "list-stale")
cache.build = func(model.Name, *manifest.Manifest) (modelListSummary, error) {
return modelListSummary{}, errors.New("refresh failed")
}
models, err := cache.List(context.Background())
if err != nil {
t.Fatalf("list failed: %v", err)
}
if len(models) != 0 {
t.Fatalf("models = %+v, want stale entry removed", models)
}
if _, ok := cache.Get(name); ok {
t.Fatal("stale entry remained in cache after refresh failure")
}
}
func createListCacheModel(t *testing.T, name string, kv map[string]any, tmpl string) {
t.Helper()
_, digest := createBinFile(t, kv, nil)
req := api.CreateRequest{
Model: name,
Files: map[string]string{"model.gguf": digest},
Stream: &stream,
}
if tmpl != "" {
req.Template = tmpl
}
var s Server
w := createRequest(t, s.CreateHandler, req)
if w.Code != http.StatusOK {
t.Fatalf("create model status = %d, want 200: %s", w.Code, w.Body.String())
}
}

View File

@@ -0,0 +1,402 @@
package server
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"math/rand/v2"
"net/http"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
)
const modelRecommendationsURL = "https://ollama.com/api/experimental/model-recommendations"
var (
modelRecommendationsRefreshInterval = 4 * time.Hour
modelRecommendationsFetchTimeout = 3 * time.Second
modelRecommendationsReadRefreshCooldown = 5 * time.Second
modelRecommendationsBackoffSteps = []time.Duration{
5 * time.Minute,
15 * time.Minute,
time.Hour,
4 * time.Hour,
}
errModelRecommendationsNoCloud = errors.New("cloud disabled")
)
type modelRecommendationsCache struct {
mu sync.RWMutex
recommendations []api.ModelRecommendation
refreshing bool
nextReadRefreshAfter time.Time
once sync.Once
client *http.Client
}
func newModelRecommendationsCache() *modelRecommendationsCache {
return &modelRecommendationsCache{
recommendations: cloneModelRecommendations(defaultModelRecommendations),
client: http.DefaultClient,
}
}
func (c *modelRecommendationsCache) Start(ctx context.Context) {
c.once.Do(func() {
slog.Debug("starting model recommendations cache",
"default_recommendations", len(defaultModelRecommendations),
"refresh_interval", modelRecommendationsRefreshInterval.String(),
"fetch_timeout", modelRecommendationsFetchTimeout.String(),
)
go c.run(ctx)
})
}
func (c *modelRecommendationsCache) Get() []api.ModelRecommendation {
c.mu.RLock()
defer c.mu.RUnlock()
return cloneModelRecommendations(c.recommendations)
}
func (c *modelRecommendationsCache) GetSWR(ctx context.Context) []api.ModelRecommendation {
recs := c.Get()
c.triggerRefreshOnRead(ctx)
return recs
}
func (c *modelRecommendationsCache) set(recs []api.ModelRecommendation) {
c.mu.Lock()
c.recommendations = cloneModelRecommendations(recs)
c.mu.Unlock()
}
func (c *modelRecommendationsCache) beginRefresh() bool {
c.mu.Lock()
defer c.mu.Unlock()
if c.refreshing {
return false
}
c.refreshing = true
return true
}
func (c *modelRecommendationsCache) beginReadRefresh() bool {
c.mu.Lock()
defer c.mu.Unlock()
now := time.Now()
if c.refreshing || now.Before(c.nextReadRefreshAfter) {
return false
}
c.refreshing = true
return true
}
func (c *modelRecommendationsCache) endRefresh() {
c.mu.Lock()
c.refreshing = false
c.mu.Unlock()
}
func (c *modelRecommendationsCache) endReadRefresh() {
c.mu.Lock()
c.refreshing = false
c.nextReadRefreshAfter = time.Now().Add(modelRecommendationsReadRefreshCooldown)
c.mu.Unlock()
}
func (c *modelRecommendationsCache) refreshIfIdle(ctx context.Context) (bool, error) {
if !c.beginRefresh() {
return false, nil
}
defer c.endRefresh()
return true, c.refresh(ctx)
}
func (c *modelRecommendationsCache) triggerRefreshOnRead(ctx context.Context) {
if !c.beginReadRefresh() {
return
}
if ctx == nil {
ctx = context.Background()
}
ctx = context.WithoutCancel(ctx)
slog.Debug("triggering model recommendations refresh on read")
go func() {
defer c.endReadRefresh()
if err := c.refresh(ctx); err != nil {
switch {
case errors.Is(err, errModelRecommendationsNoCloud):
slog.Debug("skipping model recommendations read refresh because cloud is disabled")
default:
slog.Warn("model recommendations read refresh failed", "error", err)
}
}
}()
}
func (c *modelRecommendationsCache) run(ctx context.Context) {
c.loadSnapshot()
failures := 0
for {
started, err := c.refreshIfIdle(ctx)
switch {
case !started:
failures = 0
slog.Debug("skipping timer model recommendations refresh because refresh is already running")
case err == nil:
failures = 0
case errors.Is(err, errModelRecommendationsNoCloud):
failures = 0
slog.Debug("skipping model recommendations refresh because cloud is disabled")
default:
failures++
slog.Warn("model recommendations refresh failed", "error", err)
}
var wait time.Duration
if failures == 0 {
wait = withJitter(modelRecommendationsRefreshInterval)
} else {
wait = withJitter(modelRecommendationsBackoffSteps[min(failures-1, len(modelRecommendationsBackoffSteps)-1)])
}
slog.Info("model recommendations cache sleep scheduled", "wait", wait.String(), "consecutive_failures", failures)
select {
case <-ctx.Done():
slog.Debug("stopping model recommendations cache")
return
case <-time.After(wait):
}
}
}
func (c *modelRecommendationsCache) refresh(ctx context.Context) error {
if envconfig.NoCloud() {
return errModelRecommendationsNoCloud
}
slog.Debug("refreshing model recommendations from remote", "url", modelRecommendationsURL)
reqCtx, cancel := context.WithTimeout(ctx, modelRecommendationsFetchTimeout)
defer cancel()
req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, modelRecommendationsURL, nil)
if err != nil {
return err
}
req.Header.Set("Accept", "application/json")
resp, err := c.client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode >= http.StatusBadRequest {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
return fmt.Errorf("status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
var payload api.ModelRecommendationsResponse
if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
return err
}
recs, err := validateModelRecommendations(payload.Recommendations)
if err != nil {
return err
}
c.set(recs)
slog.Debug("model recommendations refreshed", "count", len(recs))
if err := c.persistSnapshot(recs); err != nil {
slog.Warn("failed to persist model recommendations snapshot", "error", err)
}
return nil
}
func (c *modelRecommendationsCache) loadSnapshot() {
path, err := modelRecommendationsSnapshotPath()
if err != nil {
slog.Warn("failed to resolve model recommendations snapshot path", "error", err)
return
}
data, err := os.ReadFile(path)
if err != nil {
if !errors.Is(err, os.ErrNotExist) {
slog.Warn("failed to read model recommendations snapshot", "path", path, "error", err)
} else {
slog.Debug("model recommendations snapshot not found", "path", path)
}
return
}
var snap api.ModelRecommendationsResponse
if err := json.Unmarshal(data, &snap); err != nil {
slog.Warn("failed to parse model recommendations snapshot", "path", path, "error", err)
return
}
recs, err := validateModelRecommendations(snap.Recommendations)
if err != nil {
slog.Warn("ignoring invalid model recommendations snapshot", "path", path, "error", err)
return
}
c.set(recs)
slog.Debug("loaded model recommendations snapshot", "path", path, "count", len(recs))
}
func (c *modelRecommendationsCache) persistSnapshot(recs []api.ModelRecommendation) error {
path, err := modelRecommendationsSnapshotPath()
if err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
return err
}
payload := api.ModelRecommendationsResponse{Recommendations: recs}
data, err := json.MarshalIndent(payload, "", " ")
if err != nil {
return err
}
tmp, err := os.CreateTemp(filepath.Dir(path), ".model-recommendations-*.tmp")
if err != nil {
return err
}
tmpPath := tmp.Name()
defer os.Remove(tmpPath)
if _, err := tmp.Write(data); err != nil {
_ = tmp.Close()
return err
}
if err := tmp.Sync(); err != nil {
_ = tmp.Close()
return err
}
if err := tmp.Close(); err != nil {
return err
}
if err := os.Rename(tmpPath, path); err != nil {
return err
}
slog.Debug("persisted model recommendations snapshot", "path", path, "count", len(recs))
return nil
}
func modelRecommendationsSnapshotPath() (string, error) {
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
return filepath.Join(home, ".ollama", "cache", "model-recommendations.json"), nil
}
func validateModelRecommendations(recs []api.ModelRecommendation) ([]api.ModelRecommendation, error) {
if len(recs) == 0 {
return nil, errors.New("empty recommendations")
}
seen := make(map[string]struct{}, len(recs))
valid := make([]api.ModelRecommendation, 0, len(recs))
for _, rec := range recs {
rec.Model = strings.TrimSpace(rec.Model)
rec.Description = strings.TrimSpace(rec.Description)
rec.RequiredPlan = strings.TrimSpace(rec.RequiredPlan)
if rec.Model == "" {
return nil, errors.New("recommendation missing model")
}
if _, ok := seen[rec.Model]; ok {
return nil, fmt.Errorf("duplicate recommendation %q", rec.Model)
}
seen[rec.Model] = struct{}{}
if isCloudRecommendation(rec.Model) && (rec.ContextLength <= 0 || rec.MaxOutputTokens <= 0) {
slog.Warn("dropping cloud recommendation missing limits", "model", rec.Model)
continue
}
valid = append(valid, rec)
}
if len(valid) == 0 {
return nil, errors.New("no valid recommendations")
}
return valid, nil
}
func isCloudRecommendation(modelName string) bool {
return strings.HasSuffix(modelName, ":cloud") || strings.HasSuffix(modelName, "-cloud")
}
func withJitter(d time.Duration) time.Duration {
if d <= 0 {
return d
}
// jitter in range [0.8x, 1.2x]
factor := 0.8 + rand.Float64()*0.4
return time.Duration(float64(d) * factor)
}
func cloneModelRecommendations(in []api.ModelRecommendation) []api.ModelRecommendation {
out := make([]api.ModelRecommendation, len(in))
copy(out, in)
return out
}
var defaultModelRecommendations = []api.ModelRecommendation{
{
Model: "kimi-k2.6:cloud",
Description: "State-of-the-art coding, long-horizon execution, and multimodal agent swarm capability",
ContextLength: 262_144,
MaxOutputTokens: 262_144,
},
{
Model: "glm-5.1:cloud",
Description: "Reasoning and code generation",
ContextLength: 202_752,
MaxOutputTokens: 131_072,
},
{
Model: "qwen3.5:cloud",
Description: "Reasoning, coding, and agentic tool use with vision",
ContextLength: 262_144,
MaxOutputTokens: 32_768,
},
{
Model: "minimax-m2.7:cloud",
Description: "Fast, efficient coding and real-world productivity",
ContextLength: 204_800,
MaxOutputTokens: 128_000,
},
{
Model: "gemma4",
Description: "Reasoning and code generation locally",
VRAMBytes: 12 * format.GigaByte,
},
{
Model: "qwen3.5",
Description: "Reasoning, coding, and visual understanding locally",
VRAMBytes: 14 * format.GigaByte,
},
}

View File

@@ -0,0 +1,619 @@
package server
import (
"context"
"encoding/json"
"errors"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"slices"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
)
func TestModelRecommendationsDefaultOrder(t *testing.T) {
want := []string{
"kimi-k2.6:cloud",
"glm-5.1:cloud",
"qwen3.5:cloud",
"minimax-m2.7:cloud",
"gemma4",
"qwen3.5",
}
if got := modelRecommendationNames(defaultModelRecommendations); !slices.Equal(got, want) {
t.Fatalf("recommendations = %v, want %v", got, want)
}
}
func TestModelRecommendationsCacheRefreshAppliesServerSideChanges(t *testing.T) {
setupModelRecommendationsTestEnv(t, "")
first := []api.ModelRecommendation{
{Model: " first-cloud:cloud ", Description: " first ", ContextLength: 2048, MaxOutputTokens: 512},
{Model: " first-local ", Description: " first local ", VRAMBytes: 3 * format.GigaByte},
}
second := []api.ModelRecommendation{
{Model: "second-cloud:cloud", Description: "second", ContextLength: 4096, MaxOutputTokens: 1024},
{Model: "second-local", Description: "second local", VRAMBytes: 6 * format.GigaByte},
}
calls := 0
cache := newModelRecommendationsCache()
cache.client = &http.Client{Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
if req.Method != http.MethodGet {
t.Fatalf("method = %q, want GET", req.Method)
}
if req.URL.String() != modelRecommendationsURL {
t.Fatalf("url = %q, want %q", req.URL.String(), modelRecommendationsURL)
}
calls++
payload := api.ModelRecommendationsResponse{Recommendations: first}
if calls > 1 {
payload.Recommendations = second
}
data, err := json.Marshal(payload)
if err != nil {
t.Fatalf("marshal payload failed: %v", err)
}
return jsonHTTPResponse(http.StatusOK, string(data)), nil
})}
if err := cache.refresh(context.Background()); err != nil {
t.Fatalf("first refresh failed: %v", err)
}
if got, want := cache.Get(), []api.ModelRecommendation{
{Model: "first-cloud:cloud", Description: "first", ContextLength: 2048, MaxOutputTokens: 512},
{Model: "first-local", Description: "first local", VRAMBytes: 3 * format.GigaByte},
}; !slices.Equal(got, want) {
t.Fatalf("after first refresh recommendations = %#v, want %#v", got, want)
}
if err := cache.refresh(context.Background()); err != nil {
t.Fatalf("second refresh failed: %v", err)
}
if got, want := cache.Get(), second; !slices.Equal(got, want) {
t.Fatalf("after second refresh recommendations = %#v, want %#v", got, want)
}
path, err := modelRecommendationsSnapshotPath()
if err != nil {
t.Fatalf("snapshot path failed: %v", err)
}
snapshotData, err := os.ReadFile(path)
if err != nil {
t.Fatalf("read snapshot failed: %v", err)
}
var snapshot api.ModelRecommendationsResponse
if err := json.Unmarshal(snapshotData, &snapshot); err != nil {
t.Fatalf("unmarshal snapshot failed: %v", err)
}
if !slices.Equal(snapshot.Recommendations, second) {
t.Fatalf("snapshot recommendations = %#v, want %#v", snapshot.Recommendations, second)
}
}
func TestModelRecommendationsCacheRefreshErrorCasesPreserveCurrentData(t *testing.T) {
cases := []struct {
name string
transport roundTripFunc
errSubstr string
}{
{
name: "transport error",
transport: func(*http.Request) (*http.Response, error) {
return nil, errors.New("network down")
},
errSubstr: "network down",
},
{
name: "remote status error",
transport: func(*http.Request) (*http.Response, error) {
return jsonHTTPResponse(http.StatusInternalServerError, "upstream broken"), nil
},
errSubstr: "status 500: upstream broken",
},
{
name: "invalid json payload",
transport: func(*http.Request) (*http.Response, error) {
return jsonHTTPResponse(http.StatusOK, "{"), nil
},
errSubstr: "unexpected EOF",
},
{
name: "duplicate recommendations",
transport: func(*http.Request) (*http.Response, error) {
return jsonHTTPResponse(http.StatusOK, `{"recommendations":[{"model":"dup","description":"a"},{"model":"dup","description":"b"}]}`), nil
},
errSubstr: `duplicate recommendation "dup"`,
},
{
name: "empty recommendations",
transport: func(*http.Request) (*http.Response, error) {
return jsonHTTPResponse(http.StatusOK, `{"recommendations":[]}`), nil
},
errSubstr: "empty recommendations",
},
{
name: "only invalid cloud recommendations",
transport: func(*http.Request) (*http.Response, error) {
return jsonHTTPResponse(http.StatusOK, `{"recommendations":[{"model":"bad:cloud","description":"missing limits"}]}`), nil
},
errSubstr: "no valid recommendations",
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
setupModelRecommendationsTestEnv(t, "")
cache := newModelRecommendationsCache()
stable := []api.ModelRecommendation{{Model: "stable-local", Description: "stable desc", VRAMBytes: 2 * format.GigaByte}}
cache.set(stable)
cache.client = &http.Client{Transport: tc.transport}
err := cache.refresh(context.Background())
if err == nil {
t.Fatalf("refresh returned nil error")
}
if !strings.Contains(err.Error(), tc.errSubstr) {
t.Fatalf("error = %q, want substring %q", err.Error(), tc.errSubstr)
}
if got := cache.Get(); !slices.Equal(got, stable) {
t.Fatalf("recommendations changed on error: got %#v, want %#v", got, stable)
}
path, pathErr := modelRecommendationsSnapshotPath()
if pathErr != nil {
t.Fatalf("snapshot path failed: %v", pathErr)
}
if _, statErr := os.Stat(path); !errors.Is(statErr, os.ErrNotExist) {
t.Fatalf("snapshot file should not be written on error, stat err = %v", statErr)
}
})
}
}
func TestModelRecommendationsCacheRefreshNoCloudShortCircuits(t *testing.T) {
setupModelRecommendationsTestEnv(t, "1")
called := false
cache := newModelRecommendationsCache()
cache.client = &http.Client{Transport: roundTripFunc(func(*http.Request) (*http.Response, error) {
called = true
return jsonHTTPResponse(http.StatusOK, `{"recommendations":[{"model":"should-not-be-used","description":"n/a"}]}`), nil
})}
err := cache.refresh(context.Background())
if !errors.Is(err, errModelRecommendationsNoCloud) {
t.Fatalf("refresh error = %v, want %v", err, errModelRecommendationsNoCloud)
}
if called {
t.Fatalf("remote endpoint should not be called when cloud is disabled")
}
}
func TestModelRecommendationsSnapshotPersistAndLoad(t *testing.T) {
setupModelRecommendationsTestEnv(t, "")
want := []api.ModelRecommendation{
{Model: "persist-cloud:cloud", Description: "persisted", ContextLength: 8192, MaxOutputTokens: 2048},
{Model: "persist-local", Description: "persisted local", VRAMBytes: 5 * format.GigaByte},
}
writer := newModelRecommendationsCache()
if err := writer.persistSnapshot(want); err != nil {
t.Fatalf("persistSnapshot failed: %v", err)
}
loader := newModelRecommendationsCache()
loader.set([]api.ModelRecommendation{{Model: "old", Description: "old"}})
loader.loadSnapshot()
if got := loader.Get(); !slices.Equal(got, want) {
t.Fatalf("loaded recommendations = %#v, want %#v", got, want)
}
}
func TestModelRecommendationsLoadSnapshotInvalidDoesNotOverwrite(t *testing.T) {
setupModelRecommendationsTestEnv(t, "")
path, err := modelRecommendationsSnapshotPath()
if err != nil {
t.Fatalf("snapshot path failed: %v", err)
}
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
t.Fatalf("mkdir failed: %v", err)
}
if err := os.WriteFile(path, []byte("{invalid"), 0o644); err != nil {
t.Fatalf("write invalid snapshot failed: %v", err)
}
cache := newModelRecommendationsCache()
existing := []api.ModelRecommendation{{Model: "existing", Description: "existing description"}}
cache.set(existing)
cache.loadSnapshot()
if got := cache.Get(); !slices.Equal(got, existing) {
t.Fatalf("recommendations overwritten by invalid snapshot: got %#v, want %#v", got, existing)
}
}
func TestValidateModelRecommendationsTrimsAndDropsInvalidCloudEntries(t *testing.T) {
input := []api.ModelRecommendation{
{Model: " good-cloud:cloud ", Description: " good cloud ", ContextLength: 1024, MaxOutputTokens: 256, RequiredPlan: " pro "},
{Model: "bad-cloud:cloud", Description: "missing limits"},
{Model: " good-local ", Description: " good local ", VRAMBytes: 2 * format.GigaByte},
}
got, err := validateModelRecommendations(input)
if err != nil {
t.Fatalf("validateModelRecommendations failed: %v", err)
}
want := []api.ModelRecommendation{
{Model: "good-cloud:cloud", Description: "good cloud", ContextLength: 1024, MaxOutputTokens: 256, RequiredPlan: "pro"},
{Model: "good-local", Description: "good local", VRAMBytes: 2 * format.GigaByte},
}
if !slices.Equal(got, want) {
t.Fatalf("validated recommendations = %#v, want %#v", got, want)
}
}
func TestValidateModelRecommendationsDoesNotSynthesizeRequiredPlans(t *testing.T) {
input := []api.ModelRecommendation{
{Model: "kimi-k2.6:cloud", Description: "coding", ContextLength: 262_144, MaxOutputTokens: 262_144},
{Model: "qwen3.5:cloud", Description: "reasoning", ContextLength: 262_144, MaxOutputTokens: 32_768},
{Model: "custom:cloud", Description: "custom", ContextLength: 4096, MaxOutputTokens: 1024},
{Model: "minimax-m2.7:cloud", Description: "custom", ContextLength: 204_800, MaxOutputTokens: 128_000, RequiredPlan: "team"},
}
got, err := validateModelRecommendations(input)
if err != nil {
t.Fatalf("validateModelRecommendations failed: %v", err)
}
byName := make(map[string]api.ModelRecommendation, len(got))
for _, rec := range got {
byName[rec.Model] = rec
}
if rec := byName["kimi-k2.6:cloud"]; rec.RequiredPlan != "" {
t.Fatalf("kimi required plan should not be synthesized: %#v", rec)
}
if rec := byName["qwen3.5:cloud"]; rec.RequiredPlan != "" {
t.Fatalf("qwen required plan should not be synthesized: %#v", rec)
}
if rec := byName["custom:cloud"]; rec.RequiredPlan != "" {
t.Fatalf("custom required plan should not be synthesized: %#v", rec)
}
if rec := byName["minimax-m2.7:cloud"]; rec.RequiredPlan != "team" {
t.Fatalf("explicit required plan should not be overwritten: %#v", rec)
}
}
func TestModelRecommendationsHandlerReturnsDefaults(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(w)
ctx.Request = httptest.NewRequest(http.MethodGet, "/api/experimental/model-recommendations", nil)
s := &Server{}
s.ModelRecommendationsExperimentalHandler(ctx)
if w.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", w.Code, http.StatusOK)
}
got := decodeRecommendationNames(t, w)
want := modelRecommendationNames(defaultModelRecommendations)
if !slices.Equal(got, want) {
t.Fatalf("models = %v, want %v", got, want)
}
}
func TestModelRecommendationsHandlerUsesCache(t *testing.T) {
gin.SetMode(gin.TestMode)
setupModelRecommendationsTestEnv(t, "1")
cache := newModelRecommendationsCache()
cache.set([]api.ModelRecommendation{{Model: "test-model", Description: "test description"}})
w := httptest.NewRecorder()
ctx, _ := gin.CreateTestContext(w)
ctx.Request = httptest.NewRequest(http.MethodGet, "/api/experimental/model-recommendations", nil)
s := &Server{modelCaches: &modelCaches{recommendations: cache}}
s.ModelRecommendationsExperimentalHandler(ctx)
if w.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", w.Code, http.StatusOK)
}
got := decodeRecommendationNames(t, w)
if !slices.Equal(got, []string{"test-model"}) {
t.Fatalf("models = %v, want %v", got, []string{"test-model"})
}
waitForCacheIdle(t, cache)
}
func TestModelRecommendationsRouteRegistration(t *testing.T) {
gin.SetMode(gin.TestMode)
setupModelRecommendationsTestEnv(t, "1")
cache := newModelRecommendationsCache()
cache.set([]api.ModelRecommendation{{Model: "route-model", Description: "route description"}})
s := &Server{modelCaches: &modelCaches{recommendations: cache}}
router, err := s.GenerateRoutes(nil)
if err != nil {
t.Fatalf("GenerateRoutes failed: %v", err)
}
getReq := httptest.NewRequest(http.MethodGet, "/api/experimental/model-recommendations", nil)
getResp := httptest.NewRecorder()
router.ServeHTTP(getResp, getReq)
if getResp.Code != http.StatusOK {
t.Fatalf("GET status = %d, want %d", getResp.Code, http.StatusOK)
}
if got := decodeRecommendationNames(t, getResp); !slices.Equal(got, []string{"route-model"}) {
t.Fatalf("GET models = %v, want %v", got, []string{"route-model"})
}
postReq := httptest.NewRequest(http.MethodPost, "/api/experimental/model-recommendations", nil)
postResp := httptest.NewRecorder()
router.ServeHTTP(postResp, postReq)
if postResp.Code != http.StatusMethodNotAllowed {
t.Fatalf("POST status = %d, want %d", postResp.Code, http.StatusMethodNotAllowed)
}
waitForCacheIdle(t, cache)
}
func TestModelRecommendationsGetSWRTriggersRefreshOnRead(t *testing.T) {
setupModelRecommendationsTestEnv(t, "")
cache := newModelRecommendationsCache()
old := []api.ModelRecommendation{{Model: "old", Description: "old"}}
newRecs := []api.ModelRecommendation{{Model: "new-cloud:cloud", Description: "new", ContextLength: 1024, MaxOutputTokens: 256}}
cache.set(old)
refreshDone := make(chan struct{})
cache.client = &http.Client{Transport: roundTripFunc(func(*http.Request) (*http.Response, error) {
defer close(refreshDone)
return jsonHTTPResponse(http.StatusOK, `{"recommendations":[{"model":"new-cloud:cloud","description":"new","context_length":1024,"max_output_tokens":256}]}`), nil
})}
gotImmediate := cache.GetSWR(context.Background())
if !slices.Equal(gotImmediate, old) {
t.Fatalf("GetSWR should return current cache immediately: got %#v, want %#v", gotImmediate, old)
}
select {
case <-refreshDone:
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for async refresh")
}
waitForCondition(t, 2*time.Second, func() bool {
return slices.Equal(cache.Get(), newRecs)
})
waitForCacheIdle(t, cache)
}
func TestModelRecommendationsGetSWRSkipsWhenRefreshAlreadyInFlight(t *testing.T) {
setupModelRecommendationsTestEnv(t, "")
cache := newModelRecommendationsCache()
cache.set([]api.ModelRecommendation{{Model: "old", Description: "old"}})
started := make(chan struct{})
release := make(chan struct{})
var calls atomic.Int32
cache.client = &http.Client{Transport: roundTripFunc(func(*http.Request) (*http.Response, error) {
n := calls.Add(1)
if n == 1 {
close(started)
}
<-release
return jsonHTTPResponse(http.StatusOK, `{"recommendations":[{"model":"updated","description":"ok"}]}`), nil
})}
cache.GetSWR(context.Background())
select {
case <-started:
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for first refresh call")
}
for range 5 {
cache.GetSWR(context.Background())
}
time.Sleep(50 * time.Millisecond)
if got := calls.Load(); got != 1 {
t.Fatalf("calls during in-flight refresh = %d, want 1", got)
}
close(release)
waitForCacheIdle(t, cache)
}
func TestModelRecommendationsGetSWRThrottlesRefreshAfterCompletion(t *testing.T) {
setupModelRecommendationsTestEnv(t, "")
withModelRecommendationsReadRefreshCooldown(t, 100*time.Millisecond)
cache := newModelRecommendationsCache()
cache.set([]api.ModelRecommendation{{Model: "old", Description: "old"}})
started := make(chan struct{})
release := make(chan struct{})
var calls atomic.Int32
cache.client = &http.Client{Transport: roundTripFunc(func(*http.Request) (*http.Response, error) {
if calls.Add(1) == 1 {
close(started)
<-release
}
return jsonHTTPResponse(http.StatusOK, `{"recommendations":[{"model":"updated","description":"ok"}]}`), nil
})}
cache.GetSWR(context.Background())
select {
case <-started:
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for first refresh call")
}
time.Sleep(2 * modelRecommendationsReadRefreshCooldown)
close(release)
waitForCacheIdle(t, cache)
cache.GetSWR(context.Background())
time.Sleep(25 * time.Millisecond)
if got := calls.Load(); got != 1 {
t.Fatalf("calls during read refresh cooldown = %d, want 1", got)
}
}
func TestModelRecommendationsGetSWRRetriesAfterReadRefreshCooldown(t *testing.T) {
setupModelRecommendationsTestEnv(t, "")
withModelRecommendationsReadRefreshCooldown(t, 100*time.Millisecond)
cache := newModelRecommendationsCache()
old := []api.ModelRecommendation{{Model: "old", Description: "old"}}
cache.set(old)
var calls atomic.Int32
cache.client = &http.Client{Transport: roundTripFunc(func(*http.Request) (*http.Response, error) {
if calls.Add(1) == 1 {
return nil, errors.New("temporary upstream failure")
}
return jsonHTTPResponse(http.StatusOK, `{"recommendations":[{"model":"recovered","description":"ok"}]}`), nil
})}
cache.GetSWR(context.Background())
waitForCondition(t, 2*time.Second, func() bool { return calls.Load() >= 1 })
waitForCacheIdle(t, cache)
if !slices.Equal(cache.Get(), old) {
t.Fatalf("cache should remain unchanged after failed refresh, got %#v", cache.Get())
}
cache.GetSWR(context.Background())
time.Sleep(25 * time.Millisecond)
if got := calls.Load(); got != 1 {
t.Fatalf("calls during read refresh cooldown after failure = %d, want 1", got)
}
waitForCondition(t, 2*time.Second, func() bool {
cache.GetSWR(context.Background())
return calls.Load() >= 2
})
waitForCondition(t, 2*time.Second, func() bool {
return slices.Equal(cache.Get(), []api.ModelRecommendation{{Model: "recovered", Description: "ok"}})
})
waitForCacheIdle(t, cache)
}
type roundTripFunc func(*http.Request) (*http.Response, error)
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}
func jsonHTTPResponse(statusCode int, body string) *http.Response {
return &http.Response{
StatusCode: statusCode,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader(body)),
}
}
func setupModelRecommendationsTestEnv(t *testing.T, noCloudEnv string) {
t.Helper()
home := t.TempDir()
t.Setenv("HOME", home)
t.Setenv("USERPROFILE", home)
t.Setenv("HOMEDRIVE", filepath.VolumeName(home))
t.Setenv("HOMEPATH", strings.TrimPrefix(home, filepath.VolumeName(home)))
// Use explicit false rather than empty to avoid platform/env ambiguity.
if noCloudEnv == "" {
noCloudEnv = "false"
}
t.Setenv("OLLAMA_NO_CLOUD", noCloudEnv)
envconfig.ReloadServerConfig()
t.Cleanup(envconfig.ReloadServerConfig)
}
func withModelRecommendationsReadRefreshCooldown(t *testing.T, d time.Duration) {
t.Helper()
old := modelRecommendationsReadRefreshCooldown
modelRecommendationsReadRefreshCooldown = d
t.Cleanup(func() {
modelRecommendationsReadRefreshCooldown = old
})
}
func waitForCondition(t *testing.T, timeout time.Duration, cond func() bool) {
t.Helper()
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
if cond() {
return
}
time.Sleep(10 * time.Millisecond)
}
t.Fatal("timed out waiting for condition")
}
func waitForCacheIdle(t *testing.T, cache *modelRecommendationsCache) {
t.Helper()
waitForCondition(t, 2*time.Second, func() bool {
cache.mu.RLock()
refreshing := cache.refreshing
cache.mu.RUnlock()
return !refreshing
})
}
func decodeRecommendationNames(t *testing.T, w *httptest.ResponseRecorder) []string {
t.Helper()
var resp struct {
Recommendations []struct {
Model string `json:"model"`
} `json:"recommendations"`
}
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatalf("decode failed: %v", err)
}
names := make([]string, 0, len(resp.Recommendations))
for _, rec := range resp.Recommendations {
names = append(names, rec.Model)
}
return names
}
func modelRecommendationNames(recs []api.ModelRecommendation) []string {
names := make([]string, len(recs))
for i, rec := range recs {
names[i] = rec.Model
}
return names
}

81
server/model_resolver.go Normal file
View File

@@ -0,0 +1,81 @@
package server
import (
"github.com/ollama/ollama/internal/modelref"
"github.com/ollama/ollama/types/model"
)
type modelSource = modelref.ModelSource
const (
modelSourceUnspecified modelSource = modelref.ModelSourceUnspecified
modelSourceLocal modelSource = modelref.ModelSourceLocal
modelSourceCloud modelSource = modelref.ModelSourceCloud
)
var (
errConflictingModelSource = modelref.ErrConflictingSourceSuffix
errModelRequired = modelref.ErrModelRequired
)
type parsedModelRef struct {
// Original is the caller-provided model string before source parsing.
// Example: "gpt-oss:20b:cloud".
Original string
// Base is the model string after source suffix normalization.
// Example: "gpt-oss:20b:cloud" -> "gpt-oss:20b".
Base string
// Name is Base parsed as a fully-qualified model.Name with defaults applied.
// Example: "registry.ollama.ai/library/gpt-oss:20b".
Name model.Name
// Source captures explicit source intent from the original input.
// Example: "gpt-oss:20b:cloud" -> modelSourceCloud.
Source modelSource
}
func parseAndValidateModelRef(raw string) (parsedModelRef, error) {
var zero parsedModelRef
parsed, err := modelref.ParseRef(raw)
if err != nil {
return zero, err
}
name := model.ParseName(parsed.Base)
if !name.IsValid() {
return zero, model.Unqualified(name)
}
return parsedModelRef{
Original: parsed.Original,
Base: parsed.Base,
Name: name,
Source: parsed.Source,
}, nil
}
func parseNormalizePullModelRef(raw string) (parsedModelRef, error) {
var zero parsedModelRef
parsedRef, err := modelref.ParseRef(raw)
if err != nil {
return zero, err
}
normalizedName, _, err := modelref.NormalizePullName(raw)
if err != nil {
return zero, err
}
name := model.ParseName(normalizedName)
if !name.IsValid() {
return zero, model.Unqualified(name)
}
return parsedModelRef{
Original: parsedRef.Original,
Base: normalizedName,
Name: name,
Source: parsedRef.Source,
}, nil
}

View File

@@ -0,0 +1,170 @@
package server
import (
"errors"
"strings"
"testing"
)
func TestParseModelSelector(t *testing.T) {
t.Run("cloud suffix", func(t *testing.T) {
got, err := parseAndValidateModelRef("gpt-oss:20b:cloud")
if err != nil {
t.Fatalf("parseModelSelector returned error: %v", err)
}
if got.Source != modelSourceCloud {
t.Fatalf("expected source cloud, got %v", got.Source)
}
if got.Base != "gpt-oss:20b" {
t.Fatalf("expected base gpt-oss:20b, got %q", got.Base)
}
if got.Name.String() != "registry.ollama.ai/library/gpt-oss:20b" {
t.Fatalf("unexpected resolved name: %q", got.Name.String())
}
})
t.Run("legacy cloud suffix", func(t *testing.T) {
got, err := parseAndValidateModelRef("gpt-oss:20b-cloud")
if err != nil {
t.Fatalf("parseModelSelector returned error: %v", err)
}
if got.Source != modelSourceCloud {
t.Fatalf("expected source cloud, got %v", got.Source)
}
if got.Base != "gpt-oss:20b" {
t.Fatalf("expected base gpt-oss:20b, got %q", got.Base)
}
})
t.Run("bare dash cloud name is not explicit cloud", func(t *testing.T) {
got, err := parseAndValidateModelRef("my-cloud-model")
if err != nil {
t.Fatalf("parseModelSelector returned error: %v", err)
}
if got.Source != modelSourceUnspecified {
t.Fatalf("expected source unspecified, got %v", got.Source)
}
if got.Base != "my-cloud-model" {
t.Fatalf("expected base my-cloud-model, got %q", got.Base)
}
})
t.Run("local suffix", func(t *testing.T) {
got, err := parseAndValidateModelRef("qwen3:8b:local")
if err != nil {
t.Fatalf("parseModelSelector returned error: %v", err)
}
if got.Source != modelSourceLocal {
t.Fatalf("expected source local, got %v", got.Source)
}
if got.Base != "qwen3:8b" {
t.Fatalf("expected base qwen3:8b, got %q", got.Base)
}
})
t.Run("conflicting source suffixes fail", func(t *testing.T) {
_, err := parseAndValidateModelRef("foo:cloud:local")
if !errors.Is(err, errConflictingModelSource) {
t.Fatalf("expected errConflictingModelSource, got %v", err)
}
})
t.Run("unspecified source", func(t *testing.T) {
got, err := parseAndValidateModelRef("llama3")
if err != nil {
t.Fatalf("parseModelSelector returned error: %v", err)
}
if got.Source != modelSourceUnspecified {
t.Fatalf("expected source unspecified, got %v", got.Source)
}
if got.Name.Tag != "latest" {
t.Fatalf("expected default latest tag, got %q", got.Name.Tag)
}
})
t.Run("unknown suffix is treated as tag", func(t *testing.T) {
got, err := parseAndValidateModelRef("gpt-oss:clod")
if err != nil {
t.Fatalf("parseModelSelector returned error: %v", err)
}
if got.Source != modelSourceUnspecified {
t.Fatalf("expected source unspecified, got %v", got.Source)
}
if got.Name.Tag != "clod" {
t.Fatalf("expected tag clod, got %q", got.Name.Tag)
}
})
t.Run("empty model fails", func(t *testing.T) {
_, err := parseAndValidateModelRef("")
if !errors.Is(err, errModelRequired) {
t.Fatalf("expected errModelRequired, got %v", err)
}
})
t.Run("invalid model fails", func(t *testing.T) {
_, err := parseAndValidateModelRef("::cloud")
if err == nil {
t.Fatal("expected error for invalid model")
}
if !strings.Contains(err.Error(), "unqualified") {
t.Fatalf("expected unqualified model error, got %v", err)
}
})
}
func TestParsePullModelRef(t *testing.T) {
t.Run("explicit local is normalized", func(t *testing.T) {
got, err := parseNormalizePullModelRef("gpt-oss:20b:local")
if err != nil {
t.Fatalf("parseNormalizePullModelRef returned error: %v", err)
}
if got.Source != modelSourceLocal {
t.Fatalf("expected source local, got %v", got.Source)
}
if got.Base != "gpt-oss:20b" {
t.Fatalf("expected base gpt-oss:20b, got %q", got.Base)
}
})
t.Run("explicit cloud with size maps to legacy cloud suffix", func(t *testing.T) {
got, err := parseNormalizePullModelRef("gpt-oss:20b:cloud")
if err != nil {
t.Fatalf("parseNormalizePullModelRef returned error: %v", err)
}
if got.Base != "gpt-oss:20b-cloud" {
t.Fatalf("expected base gpt-oss:20b-cloud, got %q", got.Base)
}
if got.Name.String() != "registry.ollama.ai/library/gpt-oss:20b-cloud" {
t.Fatalf("unexpected resolved name: %q", got.Name.String())
}
})
t.Run("explicit cloud without size maps to cloud tag", func(t *testing.T) {
got, err := parseNormalizePullModelRef("qwen3:cloud")
if err != nil {
t.Fatalf("parseNormalizePullModelRef returned error: %v", err)
}
if got.Base != "qwen3:cloud" {
t.Fatalf("expected base qwen3:cloud, got %q", got.Base)
}
if got.Name.String() != "registry.ollama.ai/library/qwen3:cloud" {
t.Fatalf("unexpected resolved name: %q", got.Name.String())
}
})
}

692
server/model_show_cache.go Normal file
View File

@@ -0,0 +1,692 @@
package server
import (
"bytes"
"context"
"encoding/json"
"errors"
"io"
"log/slog"
"net/http"
"net/url"
"slices"
"strings"
"sync"
"time"
"github.com/ollama/ollama/api"
internalcloud "github.com/ollama/ollama/internal/cloud"
"github.com/ollama/ollama/internal/modelref"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
)
/*
The /api/show cache stores full api.ShowResponse values because callers use
more than capabilities: launch flows also need context length, embeddings
metadata, quantization details, remote metadata, and model-specific fields.
TODO(parthsareen): Consider removing show cache if /api/tags grows to cover
the remaining callers.
Local model entries are stored lazily by canonical model name and verbose flag,
with the manifest digest recorded in the entry. The manifest digest is the
freshness boundary: if the model content changes, the digest changes, so the
previous response is replaced instead of accumulating under an old digest key.
Requests with System or Options overlays bypass the cache because those overlays
mutate the effective show response.
Cloud model entries are keyed by normalized cloud base model name and verbose.
They use stale-while-revalidate behavior: a warm read returns the cached
response immediately and starts a throttled background refresh for that model.
Cold cloud reads preserve existing proxy behavior. Local and cloud entries live
in separate maps, so a local "qwen3.5" and an explicit "qwen3.5:cloud" cannot
collide. The cloud suffix is request routing intent; api.ShowResponse does not
carry a model-name field to reconstruct on the way out.
The cache is process-local. Cloud startup hydration runs asynchronously from
cloud tags, while local show responses are populated on demand. No show
responses are written to or read from ~/.ollama/cache/show. That keeps cache
lifetime tied to the server process and avoids snapshot freshness and
invalidation cases for this iteration.
*/
const (
modelShowCloudFetchTimeout = 3 * time.Second
modelShowCloudReadRefreshCooldown = 5 * time.Second
modelShowCloudHydrationConcurrency = 4
)
var errModelShowNoCloud = errors.New("cloud disabled")
// modelShowCache owns process-local show response caches for local and cloud
// models. All cached responses are cloned at read/write boundaries so
// handler-specific mutations, such as user-agent compatibility tweaks, cannot
// leak back into the cache.
type modelShowCache struct {
mu sync.RWMutex
local map[modelShowLocalKey]modelShowLocalEntry
cloud map[modelShowCloudKey]*api.ShowResponse
cloudRefreshing map[modelShowCloudKey]bool
cloudNextReadRefreshAfter map[modelShowCloudKey]time.Time
once sync.Once
client *http.Client
getModelInfo func(api.ShowRequest) (*api.ShowResponse, error)
}
// modelShowLocalKey describes the local cache slot for a model response. The
// manifest digest is stored in the entry instead of the key so a pulled or
// recreated model overwrites the previous response for the same model/verbose
// variant instead of leaving stale digest-keyed entries behind.
//
// Deleted models are not eagerly pruned from this process-local cache. Manifest
// resolution happens before local cache lookup, so stale delete entries are not
// served and disappear on process restart.
type modelShowLocalKey struct {
Model string
Verbose bool
}
type modelShowLocalEntry struct {
Digest string
Response *api.ShowResponse
}
// modelShowCloudKey intentionally excludes any local digest because cloud
// models are refreshed through SWR and normalized by cloud base model name.
type modelShowCloudKey struct {
Model string
Verbose bool
}
func newModelShowCache() *modelShowCache {
return &modelShowCache{
local: make(map[modelShowLocalKey]modelShowLocalEntry),
cloud: make(map[modelShowCloudKey]*api.ShowResponse),
cloudRefreshing: make(map[modelShowCloudKey]bool),
cloudNextReadRefreshAfter: make(map[modelShowCloudKey]time.Time),
client: http.DefaultClient,
getModelInfo: GetModelInfo,
}
}
// modelShowCacheable returns whether a request can use the shared show cache.
// System and Options overlays are request-specific response variants, so v1
// bypasses caching for those rather than expanding the key space.
func modelShowCacheable(req api.ShowRequest) bool {
return req.System == "" && len(req.Options) == 0
}
// Start kicks off non-blocking startup hydration for cloud entries. Local show
// responses stay lazy because even non-verbose show must load GGUF metadata that
// is expensive for large model stores.
func (c *modelShowCache) Start(ctx context.Context) {
c.once.Do(func() {
slog.Debug("starting model show cache")
go c.runStartup(ctx)
})
}
// runStartup hydrates the cloud cache. It is only called in a goroutine from
// Start, so cloud requests cannot delay the listener from accepting traffic.
func (c *modelShowCache) runStartup(ctx context.Context) {
if err := c.hydrateCloud(ctx); err != nil {
switch {
case errors.Is(err, context.Canceled):
case errors.Is(err, errModelShowNoCloud):
slog.Debug("skipping model show cloud cache hydration because cloud is disabled")
default:
slog.Warn("model show cloud cache hydration failed", "error", err)
}
}
}
// GetLocal returns a cached local show response when the current manifest
// digest matches. On a miss, it falls back to GetModelInfo, stores non-remote
// local responses, and returns a clone to the caller.
func (c *modelShowCache) GetLocal(req api.ShowRequest) (*api.ShowResponse, error) {
key, digest, err := modelShowLocalKeyForRequest(req)
if err != nil {
return nil, err
}
if resp, ok := c.getLocal(key, digest); ok {
return resp, nil
}
req.Model = key.Model
resp, err := c.getModelInfo(req)
if err != nil {
return nil, err
}
if resp.RemoteHost == "" {
c.setLocal(key, digest, resp)
}
return cloneShowResponse(resp), nil
}
// GetCloudSWR returns a cached cloud show response and triggers a throttled
// background refresh. The boolean is false on a cold miss so callers can
// preserve existing synchronous proxy behavior.
func (c *modelShowCache) GetCloudSWR(ctx context.Context, req api.ShowRequest) (*api.ShowResponse, bool) {
key := modelShowCloudKeyForModel(req.Model, req.Verbose)
resp, ok := c.getCloud(key)
if !ok {
return nil, false
}
c.triggerCloudRefreshOnRead(ctx, key)
return resp, true
}
func (c *modelShowCache) getLocal(key modelShowLocalKey, digest string) (*api.ShowResponse, bool) {
c.mu.RLock()
entry, ok := c.local[key]
c.mu.RUnlock()
if !ok || entry.Digest != digest || entry.Response == nil {
return nil, false
}
return cloneShowResponse(entry.Response), true
}
func (c *modelShowCache) setLocal(key modelShowLocalKey, digest string, resp *api.ShowResponse) {
c.mu.Lock()
c.local[key] = modelShowLocalEntry{
Digest: digest,
Response: cloneShowResponse(resp),
}
c.mu.Unlock()
}
func (c *modelShowCache) hasLocal(key modelShowLocalKey, digest string) bool {
c.mu.RLock()
entry, ok := c.local[key]
c.mu.RUnlock()
return ok && entry.Digest == digest && entry.Response != nil
}
func (c *modelShowCache) getCloud(key modelShowCloudKey) (*api.ShowResponse, bool) {
c.mu.RLock()
resp, ok := c.cloud[key]
c.mu.RUnlock()
if !ok || resp == nil {
return nil, false
}
return cloneShowResponse(resp), true
}
func (c *modelShowCache) setCloud(key modelShowCloudKey, resp *api.ShowResponse) {
c.mu.Lock()
c.cloud[key] = cloneShowResponse(resp)
c.mu.Unlock()
}
func (c *modelShowCache) beginCloudReadRefresh(key modelShowCloudKey) bool {
c.mu.Lock()
defer c.mu.Unlock()
now := time.Now()
if c.cloudRefreshing[key] || now.Before(c.cloudNextReadRefreshAfter[key]) {
return false
}
c.cloudRefreshing[key] = true
return true
}
func (c *modelShowCache) endCloudReadRefresh(key modelShowCloudKey) {
c.mu.Lock()
c.cloudRefreshing[key] = false
c.cloudNextReadRefreshAfter[key] = time.Now().Add(modelShowCloudReadRefreshCooldown)
c.mu.Unlock()
}
// triggerCloudRefreshOnRead starts the revalidation side of SWR. The refresh
// uses context.WithoutCancel so a completed client request does not cancel the
// cache update it initiated.
func (c *modelShowCache) triggerCloudRefreshOnRead(ctx context.Context, key modelShowCloudKey) {
if !c.beginCloudReadRefresh(key) {
return
}
if ctx == nil {
ctx = context.Background()
}
ctx = context.WithoutCancel(ctx)
slog.Debug("triggering model show cloud refresh on read", "model", key.Model, "verbose", key.Verbose)
go func() {
defer c.endCloudReadRefresh(key)
if err := c.refreshCloud(ctx, key); err != nil {
switch {
case errors.Is(err, errModelShowNoCloud):
slog.Debug("skipping model show cloud read refresh because cloud is disabled", "model", key.Model)
default:
slog.Warn("model show cloud read refresh failed", "model", key.Model, "error", err)
}
}
}()
}
// refreshCloud fetches and stores one cloud show response. Refresh failures are
// returned without touching the existing cached entry, which preserves stale
// data for future reads.
func (c *modelShowCache) refreshCloud(ctx context.Context, key modelShowCloudKey) error {
if disabled, _ := internalcloud.Status(); disabled {
return errModelShowNoCloud
}
resp, err := c.fetchCloudShow(ctx, key.Model, key.Verbose)
if err != nil {
return err
}
c.setCloud(key, resp)
return nil
}
// hydrateLocal scans manifests at startup and refreshes only entries missing
// for the current digest. It hydrates non-verbose responses only, avoiding an
// expensive tensor walk for users who have never asked for verbose show data.
func (c *modelShowCache) hydrateLocal(ctx context.Context) error {
manifests, err := manifest.Manifests(true)
if err != nil {
return err
}
for name, mf := range manifests {
if err := ctx.Err(); err != nil {
return err
}
if modelShowManifestIsRemote(mf) {
continue
}
modelName := name.String()
digest := mf.Digest()
key := modelShowLocalKey{
Model: modelName,
Verbose: false,
}
if c.hasLocal(key, digest) {
continue
}
resp, err := c.getModelInfo(api.ShowRequest{Model: modelName})
if err != nil {
slog.Warn("failed to hydrate local model show cache", "model", modelName, "error", err)
continue
}
if resp.RemoteHost != "" {
continue
}
c.setLocal(key, digest, resp)
}
return nil
}
// hydrateCloud refreshes cloud show entries by listing cloud tags and fetching
// /api/show for each returned model with bounded concurrency. Per-model show
// failures are logged and skipped so one bad cloud entry does not prevent the
// rest of the cache from warming.
func (c *modelShowCache) hydrateCloud(ctx context.Context) error {
if disabled, _ := internalcloud.Status(); disabled {
return errModelShowNoCloud
}
models, err := c.fetchCloudTags(ctx)
if err != nil {
return err
}
jobs := make(chan string)
var wg sync.WaitGroup
worker := func() {
defer wg.Done()
for modelName := range jobs {
if ctx.Err() != nil {
continue
}
key := modelShowCloudKeyForModel(modelName, false)
resp, err := c.fetchCloudShow(ctx, key.Model, key.Verbose)
if err != nil {
slog.Warn("failed to hydrate cloud model show cache", "model", key.Model, "error", err)
continue
}
c.setCloud(key, resp)
}
}
workers := min(modelShowCloudHydrationConcurrency, max(1, len(models)))
for range workers {
wg.Add(1)
go worker()
}
sendLoop:
for _, modelName := range models {
select {
case <-ctx.Done():
break sendLoop
case jobs <- modelName:
}
}
close(jobs)
wg.Wait()
if err := ctx.Err(); err != nil {
return err
}
return nil
}
// fetchCloudTags returns de-duplicated cloud model names normalized to their
// show-cache key form. It accepts either ListModelResponse.Model or the legacy
// Name field because /api/tags responses may contain both.
func (c *modelShowCache) fetchCloudTags(ctx context.Context) ([]string, error) {
var payload api.ListResponse
if err := c.doCloudJSON(ctx, http.MethodGet, "/api/tags", nil, &payload); err != nil {
return nil, err
}
seen := make(map[string]struct{}, len(payload.Models))
models := make([]string, 0, len(payload.Models))
for _, item := range payload.Models {
name := strings.TrimSpace(item.Model)
if name == "" {
name = strings.TrimSpace(item.Name)
}
name = modelShowNormalizeCloudModel(name)
if name == "" {
continue
}
if _, ok := seen[name]; ok {
continue
}
seen[name] = struct{}{}
models = append(models, name)
}
return models, nil
}
func (c *modelShowCache) fetchCloudShow(ctx context.Context, modelName string, verbose bool) (*api.ShowResponse, error) {
payload := api.ShowRequest{
Model: modelShowNormalizeCloudModel(modelName),
Verbose: verbose,
}
var resp api.ShowResponse
if err := c.doCloudJSON(ctx, http.MethodPost, "/api/show", payload, &resp); err != nil {
return nil, err
}
if resp.ModelInfo == nil {
resp.ModelInfo = map[string]any{}
}
return &resp, nil
}
// doCloudJSON is the cache's direct cloud client. It mirrors the cloud proxy's
// signing and client-version behavior but uses an internal timeout because
// hydration and refreshes must not hang indefinitely.
func (c *modelShowCache) doCloudJSON(ctx context.Context, method, path string, payload any, out any) error {
reqCtx, cancel := context.WithTimeout(ctx, modelShowCloudFetchTimeout)
defer cancel()
baseURL, err := url.Parse(cloudProxyBaseURL)
if err != nil {
return err
}
targetURL := baseURL.ResolveReference(&url.URL{Path: path})
var body io.Reader
if payload != nil {
data, err := json.Marshal(payload)
if err != nil {
return err
}
body = bytes.NewReader(data)
}
req, err := http.NewRequestWithContext(reqCtx, method, targetURL.String(), body)
if err != nil {
return err
}
req.Header.Set("Accept", "application/json")
if payload != nil {
req.Header.Set("Content-Type", "application/json")
}
if clientVersion := strings.TrimSpace(version.Version); clientVersion != "" {
req.Header.Set(cloudProxyClientVersionHeader, clientVersion)
}
if err := cloudProxySignRequest(req.Context(), req); err != nil {
return err
}
resp, err := c.client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
data, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
if resp.StatusCode >= http.StatusBadRequest {
return modelShowStatusError(resp, data)
}
if out == nil {
return nil
}
return json.Unmarshal(data, out)
}
// modelShowStatusError preserves the important error shape from cloud
// responses, including AuthorizationError for 401s and StatusError otherwise.
func modelShowStatusError(resp *http.Response, body []byte) error {
if resp.StatusCode == http.StatusUnauthorized {
err := api.AuthorizationError{
StatusCode: resp.StatusCode,
Status: resp.Status,
}
_ = json.Unmarshal(body, &err)
if err.Status == "" {
err.Status = resp.Status
}
return err
}
statusErr := api.StatusError{
StatusCode: resp.StatusCode,
Status: resp.Status,
}
if err := json.Unmarshal(body, &statusErr); err != nil || statusErr.ErrorMessage == "" {
statusErr.ErrorMessage = strings.TrimSpace(string(body))
}
return statusErr
}
// modelShowLocalKeyForRequest normalizes a local show request to the canonical
// on-disk model name and returns the current manifest digest used to validate
// the cached entry.
func modelShowLocalKeyForRequest(req api.ShowRequest) (modelShowLocalKey, string, error) {
name := model.ParseName(req.Model)
if !name.IsValid() {
return modelShowLocalKey{}, "", model.Unqualified(name)
}
name, err := getExistingName(name)
if err != nil {
return modelShowLocalKey{}, "", err
}
mf, err := manifest.ParseNamedManifest(name)
if err != nil {
return modelShowLocalKey{}, "", err
}
return modelShowLocalKey{
Model: name.String(),
Verbose: req.Verbose,
}, mf.Digest(), nil
}
func modelShowCloudKeyForModel(modelName string, verbose bool) modelShowCloudKey {
return modelShowCloudKey{
Model: modelShowNormalizeCloudModel(modelName),
Verbose: verbose,
}
}
// modelShowNormalizeCloudModel strips explicit cloud source syntax, including
// legacy "-cloud" tags, so :cloud and -cloud forms share a cache entry.
func modelShowNormalizeCloudModel(modelName string) string {
modelName = strings.TrimSpace(modelName)
if modelName == "" {
return ""
}
if base, stripped := modelref.StripCloudSourceTag(modelName); stripped {
return strings.TrimSpace(base)
}
return modelName
}
// modelShowManifestIsRemote checks whether a manifest represents a local stub
// for a remote model. Startup hydration skips these so the local content cache
// does not store entries whose freshness is governed by cloud state.
func modelShowManifestIsRemote(mf *manifest.Manifest) bool {
if mf == nil || mf.Config.Digest == "" {
return false
}
f, err := mf.Config.Open()
if err != nil {
slog.Warn("failed to open manifest config while checking model show cache eligibility", "error", err)
return false
}
defer f.Close()
var cfg model.ConfigV2
if err := json.NewDecoder(f).Decode(&cfg); err != nil {
slog.Warn("failed to decode manifest config while checking model show cache eligibility", "error", err)
return false
}
return cfg.RemoteHost != "" || cfg.RemoteModel != ""
}
// cloneShowResponse deep-copies mutable fields of api.ShowResponse before
// storing or returning cached entries. The response contains maps and slices,
// and some handlers mutate ModelInfo before writing JSON.
func cloneShowResponse(in *api.ShowResponse) *api.ShowResponse {
if in == nil {
return nil
}
out := *in
out.Details.Families = slices.Clone(in.Details.Families)
out.Messages = cloneMessages(in.Messages)
out.Capabilities = slices.Clone(in.Capabilities)
out.ModelInfo = cloneAnyMap(in.ModelInfo)
out.ProjectorInfo = cloneAnyMap(in.ProjectorInfo)
out.Tensors = cloneTensors(in.Tensors)
return &out
}
func cloneMessages(in []api.Message) []api.Message {
if in == nil {
return nil
}
out := make([]api.Message, len(in))
for i, msg := range in {
out[i] = msg
if msg.Images != nil {
out[i].Images = make([]api.ImageData, len(msg.Images))
for j, image := range msg.Images {
out[i].Images[j] = slices.Clone(image)
}
}
out[i].ToolCalls = slices.Clone(msg.ToolCalls)
}
return out
}
func cloneTensors(in []api.Tensor) []api.Tensor {
if in == nil {
return nil
}
out := make([]api.Tensor, len(in))
for i, tensor := range in {
out[i] = tensor
out[i].Shape = slices.Clone(tensor.Shape)
}
return out
}
func cloneAnyMap(in map[string]any) map[string]any {
if in == nil {
return nil
}
out := make(map[string]any, len(in))
for k, v := range in {
out[k] = cloneAny(v)
}
return out
}
func cloneAny(v any) any {
switch v := v.(type) {
case map[string]any:
return cloneAnyMap(v)
case []any:
out := make([]any, len(v))
for i, item := range v {
out[i] = cloneAny(item)
}
return out
case []string:
return slices.Clone(v)
case []bool:
return slices.Clone(v)
case []int:
return slices.Clone(v)
case []int8:
return slices.Clone(v)
case []int16:
return slices.Clone(v)
case []int32:
return slices.Clone(v)
case []int64:
return slices.Clone(v)
case []uint:
return slices.Clone(v)
case []uint8:
return slices.Clone(v)
case []uint16:
return slices.Clone(v)
case []uint32:
return slices.Clone(v)
case []uint64:
return slices.Clone(v)
case []float32:
return slices.Clone(v)
case []float64:
return slices.Clone(v)
default:
return v
}
}

View File

@@ -0,0 +1,520 @@
package server
import (
"bytes"
"context"
"encoding/json"
"errors"
"io"
"net/http"
"net/http/httptest"
"slices"
"strings"
"sync"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
internalcloud "github.com/ollama/ollama/internal/cloud"
"github.com/ollama/ollama/manifest"
modelpkg "github.com/ollama/ollama/types/model"
)
func TestModelShowCacheLocalHitUsesManifestDigest(t *testing.T) {
gin.SetMode(gin.TestMode)
setTestHome(t, t.TempDir())
createShowCacheModel(t, "show-cache-local", map[string]any{"test.context_length": uint32(1024)})
cache := newModelShowCache()
calls := 0
cache.getModelInfo = func(req api.ShowRequest) (*api.ShowResponse, error) {
calls++
return showCacheTestResponse(calls, req.Verbose), nil
}
first, err := cache.GetLocal(api.ShowRequest{Model: "show-cache-local"})
if err != nil {
t.Fatalf("first GetLocal failed: %v", err)
}
second, err := cache.GetLocal(api.ShowRequest{Model: "show-cache-local"})
if err != nil {
t.Fatalf("second GetLocal failed: %v", err)
}
if calls != 1 {
t.Fatalf("getModelInfo calls = %d, want 1", calls)
}
if first.ModelInfo["call"] != 1 || second.ModelInfo["call"] != 1 {
t.Fatalf("cached call markers = %v / %v, want both 1", first.ModelInfo["call"], second.ModelInfo["call"])
}
}
func TestModelShowCacheLocalManifestDigestChangeRefreshes(t *testing.T) {
gin.SetMode(gin.TestMode)
setTestHome(t, t.TempDir())
createShowCacheModel(t, "show-cache-refresh", map[string]any{"test.context_length": uint32(1024)})
cache := newModelShowCache()
calls := 0
cache.getModelInfo = func(req api.ShowRequest) (*api.ShowResponse, error) {
calls++
return showCacheTestResponse(calls, req.Verbose), nil
}
if _, err := cache.GetLocal(api.ShowRequest{Model: "show-cache-refresh"}); err != nil {
t.Fatalf("first GetLocal failed: %v", err)
}
changeShowCacheManifest(t, "show-cache-refresh")
refreshed, err := cache.GetLocal(api.ShowRequest{Model: "show-cache-refresh"})
if err != nil {
t.Fatalf("refreshed GetLocal failed: %v", err)
}
if calls != 2 {
t.Fatalf("getModelInfo calls = %d, want 2", calls)
}
if refreshed.ModelInfo["call"] != 2 {
t.Fatalf("refreshed call marker = %v, want 2", refreshed.ModelInfo["call"])
}
if len(cache.local) != 1 {
t.Fatalf("local cache entries = %d, want 1", len(cache.local))
}
}
func TestModelShowCacheLocalVerboseVariantsAreSeparate(t *testing.T) {
gin.SetMode(gin.TestMode)
setTestHome(t, t.TempDir())
createShowCacheModel(t, "show-cache-verbose", map[string]any{"test.context_length": uint32(1024)})
cache := newModelShowCache()
calls := 0
cache.getModelInfo = func(req api.ShowRequest) (*api.ShowResponse, error) {
calls++
return showCacheTestResponse(calls, req.Verbose), nil
}
plain, err := cache.GetLocal(api.ShowRequest{Model: "show-cache-verbose"})
if err != nil {
t.Fatalf("plain GetLocal failed: %v", err)
}
verbose, err := cache.GetLocal(api.ShowRequest{Model: "show-cache-verbose", Verbose: true})
if err != nil {
t.Fatalf("verbose GetLocal failed: %v", err)
}
plainAgain, err := cache.GetLocal(api.ShowRequest{Model: "show-cache-verbose"})
if err != nil {
t.Fatalf("plain repeat GetLocal failed: %v", err)
}
if calls != 2 {
t.Fatalf("getModelInfo calls = %d, want 2", calls)
}
if plain.ModelInfo["verbose"] != false || verbose.ModelInfo["verbose"] != true || plainAgain.ModelInfo["call"] != 1 {
t.Fatalf("unexpected verbose cache markers: plain=%v verbose=%v plainAgainCall=%v", plain.ModelInfo, verbose.ModelInfo, plainAgain.ModelInfo["call"])
}
}
func TestModelShowCacheLocalHydrationSkipsUnchangedInMemory(t *testing.T) {
gin.SetMode(gin.TestMode)
setTestHome(t, t.TempDir())
createShowCacheModel(t, "show-cache-hydrate", map[string]any{"test.context_length": uint32(1024)})
cache := newModelShowCache()
calls := 0
cache.getModelInfo = func(req api.ShowRequest) (*api.ShowResponse, error) {
calls++
return showCacheTestResponse(calls, req.Verbose), nil
}
if err := cache.hydrateLocal(context.Background()); err != nil {
t.Fatalf("first hydrateLocal failed: %v", err)
}
if err := cache.hydrateLocal(context.Background()); err != nil {
t.Fatalf("second hydrateLocal failed: %v", err)
}
resp, err := cache.GetLocal(api.ShowRequest{Model: "show-cache-hydrate"})
if err != nil {
t.Fatalf("GetLocal after hydration failed: %v", err)
}
if calls != 1 {
t.Fatalf("getModelInfo calls after unchanged in-memory hydration = %d, want 1", calls)
}
if resp.ModelInfo["call"] != 1 {
t.Fatalf("hydrated call marker = %v, want 1", resp.ModelInfo["call"])
}
}
func TestModelShowCacheStartupSkipsLocalHydration(t *testing.T) {
gin.SetMode(gin.TestMode)
setTestHome(t, t.TempDir())
t.Setenv("OLLAMA_NO_CLOUD", "1")
createShowCacheModel(t, "show-cache-startup", map[string]any{"test.context_length": uint32(1024)})
cache := newModelShowCache()
cache.getModelInfo = func(req api.ShowRequest) (*api.ShowResponse, error) {
t.Fatalf("startup should not hydrate local show cache, got request: %+v", req)
return nil, nil
}
cache.runStartup(context.Background())
if len(cache.local) != 0 {
t.Fatalf("local cache entries = %d, want 0", len(cache.local))
}
}
func TestModelShowCacheBypassesSystemAndOptionsOverlays(t *testing.T) {
gin.SetMode(gin.TestMode)
setTestHome(t, t.TempDir())
createShowCacheModel(t, "show-cache-overlay", map[string]any{"test.context_length": uint32(1024)})
cache := newModelShowCache()
key, digest, err := modelShowLocalKeyForRequest(api.ShowRequest{Model: "show-cache-overlay"})
if err != nil {
t.Fatalf("local key failed: %v", err)
}
cache.setLocal(key, digest, &api.ShowResponse{System: "cached", ModelInfo: map[string]any{}})
s := Server{modelCaches: &modelCaches{show: cache}}
w := createRequest(t, s.ShowHandler, api.ShowRequest{
Model: "show-cache-overlay",
System: "overlay-system",
})
if w.Code != http.StatusOK {
t.Fatalf("status = %d, want 200: %s", w.Code, w.Body.String())
}
var resp api.ShowResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatalf("decode response failed: %v", err)
}
if resp.System != "overlay-system" {
t.Fatalf("system = %q, want overlay-system", resp.System)
}
w = createRequest(t, s.ShowHandler, api.ShowRequest{
Model: "show-cache-overlay",
Options: map[string]any{"num_ctx": float64(8192)},
})
if w.Code != http.StatusOK {
t.Fatalf("options overlay status = %d, want 200: %s", w.Code, w.Body.String())
}
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatalf("decode options response failed: %v", err)
}
if resp.System == "cached" {
t.Fatalf("options overlay unexpectedly returned cached response")
}
}
func TestModelShowCacheLocalAndCloudSameBaseDoNotCollide(t *testing.T) {
gin.SetMode(gin.TestMode)
setTestHome(t, t.TempDir())
createShowCacheModel(t, "show-cache-dual", map[string]any{"test.context_length": uint32(1024)})
cache := newModelShowCache()
cache.getModelInfo = func(req api.ShowRequest) (*api.ShowResponse, error) {
return &api.ShowResponse{
Details: api.ModelDetails{Format: "local"},
ModelInfo: map[string]any{},
}, nil
}
cloudKey := modelShowCloudKeyForModel("show-cache-dual:cloud", false)
cache.setCloud(cloudKey, &api.ShowResponse{
Details: api.ModelDetails{Format: "cloud"},
ModelInfo: map[string]any{},
})
cache.mu.Lock()
cache.cloudNextReadRefreshAfter[cloudKey] = time.Now().Add(time.Hour)
cache.mu.Unlock()
s := Server{modelCaches: &modelCaches{show: cache}}
w := createRequest(t, s.ShowHandler, api.ShowRequest{Model: "show-cache-dual:cloud"})
if w.Code != http.StatusOK {
t.Fatalf("cloud status = %d, want 200: %s", w.Code, w.Body.String())
}
var cloudResp api.ShowResponse
if err := json.NewDecoder(w.Body).Decode(&cloudResp); err != nil {
t.Fatalf("decode cloud response failed: %v", err)
}
if cloudResp.Details.Format != "cloud" {
t.Fatalf("cloud format = %q, want cloud", cloudResp.Details.Format)
}
w = createRequest(t, s.ShowHandler, api.ShowRequest{Model: "show-cache-dual"})
if w.Code != http.StatusOK {
t.Fatalf("local status = %d, want 200: %s", w.Code, w.Body.String())
}
var localResp api.ShowResponse
if err := json.NewDecoder(w.Body).Decode(&localResp); err != nil {
t.Fatalf("decode local response failed: %v", err)
}
if localResp.Details.Format != "local" {
t.Fatalf("local format = %q, want local", localResp.Details.Format)
}
}
func TestModelShowCacheCloudWarmHitReturnsStaleAndRefreshes(t *testing.T) {
gin.SetMode(gin.TestMode)
setTestHome(t, t.TempDir())
refreshDone := make(chan struct{})
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/show" {
t.Fatalf("unexpected upstream path %q", r.URL.Path)
}
defer close(refreshDone)
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"details":{"format":"updated"},"model_info":{"source":"fresh"}}`))
}))
defer upstream.Close()
withCloudProxyBaseURL(t, upstream.URL)
cache := newModelShowCache()
cache.client = upstream.Client()
cache.setCloud(modelShowCloudKeyForModel("kimi-k2.5:cloud", false), &api.ShowResponse{
Details: api.ModelDetails{Format: "cached"},
ModelInfo: map[string]any{"source": "stale"},
})
s := Server{modelCaches: &modelCaches{show: cache}}
w := createRequest(t, s.ShowHandler, api.ShowRequest{Model: "kimi-k2.5:cloud"})
if w.Code != http.StatusOK {
t.Fatalf("status = %d, want 200: %s", w.Code, w.Body.String())
}
var resp api.ShowResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatalf("decode response failed: %v", err)
}
if resp.Details.Format != "cached" {
t.Fatalf("format = %q, want cached", resp.Details.Format)
}
select {
case <-refreshDone:
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for cloud refresh")
}
waitForCondition(t, 2*time.Second, func() bool {
resp, ok := cache.getCloud(modelShowCloudKeyForModel("kimi-k2.5:cloud", false))
return ok && resp.Details.Format == "updated"
})
}
func TestModelShowCacheCloudColdMissFallsBackToProxy(t *testing.T) {
gin.SetMode(gin.TestMode)
setTestHome(t, t.TempDir())
var capturedPath, capturedBody string
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedPath = r.URL.Path
body, _ := io.ReadAll(r.Body)
capturedBody = string(body)
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"details":{"format":"cold"},"model_info":{}}`))
}))
defer upstream.Close()
withCloudProxyBaseURL(t, upstream.URL)
s := &Server{modelCaches: &modelCaches{show: newModelShowCache()}}
router, err := s.GenerateRoutes(nil)
if err != nil {
t.Fatalf("GenerateRoutes failed: %v", err)
}
local := httptest.NewServer(router)
defer local.Close()
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+"/api/show", bytes.NewBufferString(`{"model":"kimi-k2.5:cloud"}`))
if err != nil {
t.Fatal(err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := local.Client().Do(req)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
t.Fatalf("status = %d, want 200: %s", resp.StatusCode, string(body))
}
if capturedPath != "/api/show" {
t.Fatalf("upstream path = %q, want /api/show", capturedPath)
}
if !strings.Contains(capturedBody, `"model":"kimi-k2.5"`) {
t.Fatalf("expected normalized model in upstream body, got %q", capturedBody)
}
}
func TestModelShowCacheCloudHydrationUsesTagsAndShow(t *testing.T) {
gin.SetMode(gin.TestMode)
setTestHome(t, t.TempDir())
var mu sync.Mutex
var showModels []string
var tagsCalled bool
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
switch r.URL.Path {
case "/api/tags":
if r.Method != http.MethodGet {
t.Fatalf("tags method = %s, want GET", r.Method)
}
mu.Lock()
tagsCalled = true
mu.Unlock()
_, _ = w.Write([]byte(`{"models":[{"name":"alpha:cloud"},{"model":"beta"}]}`))
case "/api/show":
var req api.ShowRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
t.Fatalf("decode show request failed: %v", err)
}
mu.Lock()
showModels = append(showModels, req.Model)
mu.Unlock()
_ = json.NewEncoder(w).Encode(api.ShowResponse{
Details: api.ModelDetails{Format: req.Model},
ModelInfo: map[string]any{"model": req.Model},
})
default:
t.Fatalf("unexpected upstream path %q", r.URL.Path)
}
}))
defer upstream.Close()
withCloudProxyBaseURL(t, upstream.URL)
cache := newModelShowCache()
cache.client = upstream.Client()
if err := cache.hydrateCloud(context.Background()); err != nil {
t.Fatalf("hydrateCloud failed: %v", err)
}
mu.Lock()
gotTagsCalled := tagsCalled
gotShowModels := slices.Clone(showModels)
mu.Unlock()
slices.Sort(gotShowModels)
if !gotTagsCalled {
t.Fatal("expected /api/tags to be called")
}
if !slices.Equal(gotShowModels, []string{"alpha", "beta"}) {
t.Fatalf("show models = %v, want [alpha beta]", gotShowModels)
}
for _, modelName := range gotShowModels {
resp, ok := cache.getCloud(modelShowCloudKeyForModel(modelName, false))
if !ok {
t.Fatalf("missing cached cloud show response for %s", modelName)
}
if resp.Details.Format != modelName {
t.Fatalf("cached format for %s = %q", modelName, resp.Details.Format)
}
}
}
func TestModelShowCacheCloudKeyNormalizesSourceTags(t *testing.T) {
tests := map[string]string{
" kimi-k2.5:cloud ": "kimi-k2.5",
"gpt-oss:20b-cloud": "gpt-oss:20b",
"qwen3": "qwen3",
}
for input, want := range tests {
if got := modelShowCloudKeyForModel(input, false).Model; got != want {
t.Fatalf("cloud key model for %q = %q, want %q", input, got, want)
}
}
}
func TestModelShowCacheCloudDisabledDoesNotServeStale(t *testing.T) {
gin.SetMode(gin.TestMode)
setTestHome(t, t.TempDir())
t.Cleanup(envconfig.ReloadServerConfig)
t.Setenv("OLLAMA_NO_CLOUD", "1")
envconfig.ReloadServerConfig()
cache := newModelShowCache()
cache.setCloud(modelShowCloudKeyForModel("kimi-k2.5:cloud", false), &api.ShowResponse{
Details: api.ModelDetails{Format: "cached"},
ModelInfo: map[string]any{},
})
if err := cache.hydrateCloud(context.Background()); !errors.Is(err, errModelShowNoCloud) {
t.Fatalf("hydrateCloud error = %v, want %v", err, errModelShowNoCloud)
}
s := Server{modelCaches: &modelCaches{show: cache}}
w := createRequest(t, s.ShowHandler, api.ShowRequest{Model: "kimi-k2.5:cloud"})
if w.Code != http.StatusForbidden {
t.Fatalf("status = %d, want 403: %s", w.Code, w.Body.String())
}
if !strings.Contains(w.Body.String(), internalcloud.DisabledError(cloudErrRemoteModelDetailsUnavailable)) {
t.Fatalf("unexpected disabled response: %s", w.Body.String())
}
}
func createShowCacheModel(t *testing.T, name string, kv map[string]any) {
t.Helper()
_, digest := createBinFile(t, kv, nil)
var s Server
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: name,
Files: map[string]string{"model.gguf": digest},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("create model status = %d, want 200: %s", w.Code, w.Body.String())
}
}
func changeShowCacheManifest(t *testing.T, name string) {
t.Helper()
n, err := getExistingName(modelpkg.ParseName(name))
if err != nil {
t.Fatalf("get existing name: %v", err)
}
mf, err := manifest.ParseNamedManifest(n)
if err != nil {
t.Fatalf("parse manifest: %v", err)
}
layer, err := manifest.NewLayer(strings.NewReader("changed"), "application/vnd.ollama.image.system")
if err != nil {
t.Fatalf("new layer: %v", err)
}
layers := append([]manifest.Layer(nil), mf.Layers...)
layers = append(layers, layer)
if err := manifest.WriteManifest(n, mf.Config, layers); err != nil {
t.Fatalf("write manifest: %v", err)
}
}
func showCacheTestResponse(call int, verbose bool) *api.ShowResponse {
return &api.ShowResponse{
Details: api.ModelDetails{
Format: "gguf",
Family: "test",
},
Capabilities: []modelpkg.Capability{modelpkg.CapabilityCompletion},
ModelInfo: map[string]any{
"call": call,
"verbose": verbose,
},
}
}
func withCloudProxyBaseURL(t *testing.T, url string) {
t.Helper()
original := cloudProxyBaseURL
cloudProxyBaseURL = url
t.Cleanup(func() {
cloudProxyBaseURL = original
})
}

144
server/prompt.go Normal file
View File

@@ -0,0 +1,144 @@
package server
import (
"bytes"
"context"
"errors"
"fmt"
"log/slog"
"slices"
"strings"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/model/renderers"
"github.com/ollama/ollama/template"
)
type tokenizeFunc func(context.Context, string) ([]int, error)
// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
// latest message and 2) system messages
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool, think *api.ThinkValue, truncate bool) (prompt string, images []llm.ImageData, _ error) {
var system []api.Message
// TODO: Ideally we would compute this from the projector metadata but some pieces are implementation dependent
// Clip images are represented as 768 tokens, each an embedding
imageNumTokens := 768
lastMsgIdx := len(msgs) - 1
currMsgIdx := 0
if truncate {
// Start with all messages and remove from the front until it fits in context
for i := 0; i <= lastMsgIdx; i++ {
// Collect system messages from the portion we're about to skip
system = make([]api.Message, 0)
for j := range i {
if msgs[j].Role == "system" {
system = append(system, msgs[j])
}
}
p, err := renderPrompt(m, append(system, msgs[i:]...), tools, think)
if err != nil {
return "", nil, err
}
s, err := tokenize(ctx, p)
if err != nil {
return "", nil, err
}
ctxLen := len(s)
if m.ProjectorPaths != nil {
for _, msg := range msgs[i:] {
ctxLen += imageNumTokens * len(msg.Images)
}
}
if ctxLen <= opts.NumCtx {
currMsgIdx = i
break
}
// Must always include at least the last message
if i == lastMsgIdx {
currMsgIdx = lastMsgIdx
break
}
}
}
if currMsgIdx > 0 {
slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[currMsgIdx:]))
}
renderMsgs := slices.Clone(msgs)
for cnt, msg := range renderMsgs[currMsgIdx:] {
if slices.Contains(m.Config.ModelFamilies, "mllama") && len(msg.Images) > 1 {
return "", nil, errors.New("this model only supports one image while more than one image requested")
}
var prefix string
prompt := msg.Content
for _, i := range msg.Images {
imgData := llm.ImageData{
ID: len(images),
Data: i,
}
images = append(images, imgData)
if m.Config.Renderer != "" {
continue
}
imgTag := fmt.Sprintf("[img-%d]", imgData.ID)
if !strings.Contains(prompt, "[img]") {
prefix += imgTag
} else {
prompt = strings.Replace(prompt, "[img]", imgTag, 1)
}
}
if m.Config.Renderer != "" {
continue
}
renderMsgs[currMsgIdx+cnt].Content = prefix + prompt
}
// truncate any messages that do not fit into the context window
p, err := renderPrompt(m, append(system, renderMsgs[currMsgIdx:]...), tools, think)
if err != nil {
return "", nil, err
}
return p, images, nil
}
func renderPrompt(m *Model, msgs []api.Message, tools []api.Tool, think *api.ThinkValue) (string, error) {
if m.Config.Renderer != "" {
rendererName := resolveRendererName(m)
rendered, err := renderers.RenderWithRenderer(rendererName, msgs, tools, think)
if err != nil {
return "", err
}
return rendered, nil
}
var b bytes.Buffer
thinkVal := false
thinkLevel := ""
if think != nil {
thinkVal = think.Bool()
thinkLevel = think.String()
}
if err := m.Template.Execute(&b, template.Values{Messages: msgs, Tools: tools, Think: thinkVal, ThinkLevel: thinkLevel, IsThinkSet: think != nil}); err != nil {
return "", err
}
return b.String(), nil
}

606
server/prompt_test.go Normal file
View File

@@ -0,0 +1,606 @@
package server
import (
"bytes"
"context"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/model"
)
func testConfigWithRenderer(renderer string) model.ConfigV2 {
return model.ConfigV2{Renderer: renderer}
}
func testConfigWithRendererAndType(renderer, modelType string) model.ConfigV2 {
return model.ConfigV2{Renderer: renderer, ModelType: modelType}
}
func TestChatPrompt(t *testing.T) {
type expect struct {
prompt string
images [][]byte
error error
}
tmpl, err := template.Parse(`
{{- if .System }}{{ .System }} {{ end }}
{{- if .Prompt }}{{ .Prompt }} {{ end }}
{{- if .Response }}{{ .Response }} {{ end }}`)
if err != nil {
t.Fatal(err)
}
visionModel := Model{Template: tmpl, ProjectorPaths: []string{"vision"}}
cases := []struct {
name string
model Model
limit int
truncate bool
msgs []api.Message
expect
}{
{
name: "messages",
model: visionModel,
limit: 64,
truncate: true,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
},
expect: expect{
prompt: "You're a test, Harry! I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
},
},
{
name: "truncate messages",
model: visionModel,
limit: 1,
truncate: true,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
},
expect: expect{
prompt: "A test. And a thumping good one at that, I'd wager. ",
},
},
{
name: "truncate messages with image",
model: visionModel,
limit: 64,
truncate: true,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("something")}},
},
expect: expect{
prompt: "[img-0]A test. And a thumping good one at that, I'd wager. ",
images: [][]byte{
[]byte("something"),
},
},
},
{
name: "truncate messages with images",
model: visionModel,
limit: 64,
truncate: true,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("somethingelse")}},
},
expect: expect{
prompt: "[img-0]A test. And a thumping good one at that, I'd wager. ",
images: [][]byte{
[]byte("somethingelse"),
},
},
},
{
name: "messages with images",
model: visionModel,
limit: 2048,
truncate: true,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("somethingelse")}},
},
expect: expect{
prompt: "[img-0]You're a test, Harry! I-I'm a what? [img-1]A test. And a thumping good one at that, I'd wager. ",
images: [][]byte{
[]byte("something"),
[]byte("somethingelse"),
},
},
},
{
name: "message with image tag",
model: visionModel,
limit: 2048,
truncate: true,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry! [img]", Images: []api.ImageData{[]byte("something")}},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("somethingelse")}},
},
expect: expect{
prompt: "You're a test, Harry! [img-0] I-I'm a what? [img-1]A test. And a thumping good one at that, I'd wager. ",
images: [][]byte{
[]byte("something"),
[]byte("somethingelse"),
},
},
},
{
name: "messages with interleaved images",
model: visionModel,
limit: 2048,
truncate: true,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"},
{Role: "user", Images: []api.ImageData{[]byte("something")}},
{Role: "user", Images: []api.ImageData{[]byte("somethingelse")}},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
},
expect: expect{
prompt: "You're a test, Harry!\n\n[img-0]\n\n[img-1] I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
images: [][]byte{
[]byte("something"),
[]byte("somethingelse"),
},
},
},
{
name: "truncate message with interleaved images",
model: visionModel,
limit: 1024,
truncate: true,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"},
{Role: "user", Images: []api.ImageData{[]byte("something")}},
{Role: "user", Images: []api.ImageData{[]byte("somethingelse")}},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
},
expect: expect{
prompt: "[img-0] I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
images: [][]byte{
[]byte("somethingelse"),
},
},
},
{
name: "message with system prompt",
model: visionModel,
limit: 2048,
truncate: true,
msgs: []api.Message{
{Role: "system", Content: "You are the Test Who Lived."},
{Role: "user", Content: "You're a test, Harry!"},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
},
expect: expect{
prompt: "You are the Test Who Lived. You're a test, Harry! I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
},
},
{
name: "out of order system",
model: visionModel,
limit: 2048,
truncate: true,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "system", Content: "You are the Test Who Lived."},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
},
expect: expect{
prompt: "You're a test, Harry! I-I'm a what? You are the Test Who Lived. A test. And a thumping good one at that, I'd wager. ",
},
},
{
name: "multiple images same prompt",
model: visionModel,
limit: 2048,
truncate: true,
msgs: []api.Message{
{Role: "user", Content: "Compare these two pictures of hotdogs", Images: []api.ImageData{[]byte("one hotdog"), []byte("two hotdogs")}},
},
expect: expect{
prompt: "[img-0][img-1]Compare these two pictures of hotdogs ",
images: [][]byte{[]byte("one hotdog"), []byte("two hotdogs")},
},
},
{
name: "no truncate with limit exceeded",
model: visionModel,
limit: 10,
truncate: false,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
},
expect: expect{
prompt: "You're a test, Harry! I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
},
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
model := tt.model
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
think := false
prompt, images, err := chatPrompt(t.Context(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil, &api.ThinkValue{Value: think}, tt.truncate)
if tt.error == nil && err != nil {
t.Fatal(err)
} else if tt.error != nil && err != tt.error {
t.Fatalf("expected err '%q', got '%q'", tt.error, err)
}
if diff := cmp.Diff(prompt, tt.prompt); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
if len(images) != len(tt.images) {
t.Fatalf("expected %d images, got %d", len(tt.images), len(images))
}
for i := range images {
if images[i].ID != i {
t.Errorf("expected ID %d, got %d", i, images[i].ID)
}
if len(model.Config.ModelFamilies) == 0 {
if !bytes.Equal(images[i].Data, tt.images[i]) {
t.Errorf("expected %q, got %q", tt.images[i], images[i].Data)
}
}
}
})
}
}
func TestChatPromptTokenizeCalls(t *testing.T) {
tmpl, err := template.Parse(`
{{- if .System }}{{ .System }} {{ end }}
{{- if .Prompt }}{{ .Prompt }} {{ end }}
{{- if .Response }}{{ .Response }} {{ end }}`)
if err != nil {
t.Fatal(err)
}
model := Model{Template: tmpl}
cases := []struct {
name string
limit int
msgs []api.Message
maxTokenizes int
}{
{
name: "all messages fit",
limit: 2048,
msgs: []api.Message{
{Role: "user", Content: "message 1"},
{Role: "assistant", Content: "response 1"},
{Role: "user", Content: "message 2"},
{Role: "assistant", Content: "response 2"},
{Role: "user", Content: "message 3"},
},
maxTokenizes: 1,
},
{
name: "truncate to last message",
limit: 5,
msgs: []api.Message{
{Role: "user", Content: "message 1"},
{Role: "assistant", Content: "response 1"},
{Role: "user", Content: "message 2"},
{Role: "assistant", Content: "response 2"},
{Role: "user", Content: "message 3"},
},
maxTokenizes: 5,
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
tokenizeCount := 0
countingTokenize := func(ctx context.Context, s string) ([]int, error) {
tokenizeCount++
tokens, err := mockRunner{}.Tokenize(ctx, s)
return tokens, err
}
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
think := false
_, _, err := chatPrompt(t.Context(), &model, countingTokenize, &opts, tt.msgs, nil, &api.ThinkValue{Value: think}, true)
if err != nil {
t.Fatal(err)
}
if tokenizeCount > tt.maxTokenizes {
t.Errorf("tokenize called %d times, expected at most %d", tokenizeCount, tt.maxTokenizes)
}
})
}
}
func TestChatPromptRendererDoesNotRewriteMessageContent(t *testing.T) {
msgs := []api.Message{
{
Role: "user",
Content: "what do these photos have in common?",
Images: []api.ImageData{[]byte("img-1"), []byte("img-2"), []byte("img-3")},
},
}
originalContent := msgs[0].Content
m := Model{
Config: model.ConfigV2{Renderer: "qwen3-vl-instruct"},
ProjectorPaths: []string{"vision"},
}
opts := api.Options{Runner: api.Runner{NumCtx: 8192}}
think := false
prompt, images, err := chatPrompt(t.Context(), &m, mockRunner{}.Tokenize, &opts, msgs, nil, &api.ThinkValue{Value: think}, true)
if err != nil {
t.Fatal(err)
}
if msgs[0].Content != originalContent {
t.Fatalf("renderer path should not mutate message content: got %q, want %q", msgs[0].Content, originalContent)
}
if got, want := len(images), 3; got != want {
t.Fatalf("len(images) = %d, want %d", got, want)
}
if prompt == "" {
t.Fatal("prompt is empty")
}
}
func TestChatPromptGLMOcrRendererAddsImageTags(t *testing.T) {
msgs := []api.Message{
{
Role: "user",
Content: "extract text",
Images: []api.ImageData{[]byte("img-1"), []byte("img-2")},
},
}
m := Model{
Config: model.ConfigV2{Renderer: "glm-ocr"},
ProjectorPaths: []string{"vision"},
}
opts := api.Options{Runner: api.Runner{NumCtx: 8192}}
think := false
prompt, images, err := chatPrompt(t.Context(), &m, mockRunner{}.Tokenize, &opts, msgs, nil, &api.ThinkValue{Value: think}, true)
if err != nil {
t.Fatal(err)
}
if got, want := len(images), 2; got != want {
t.Fatalf("len(images) = %d, want %d", got, want)
}
if !strings.Contains(prompt, "<|user|>\n[img-0][img-1] extract text") {
t.Fatalf("prompt missing glm-ocr image tags, got: %q", prompt)
}
}
func TestChatPromptRendererAddsToolImageTags(t *testing.T) {
msgs := []api.Message{
{
Role: "user",
Content: "look at this file",
Images: []api.ImageData{[]byte("img-1")},
},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_read",
Function: api.ToolCallFunction{
Name: "Read",
},
},
},
},
{
Role: "tool",
Content: "attached image",
Images: []api.ImageData{[]byte("img-2")},
ToolCallID: "call_read",
},
}
tests := []struct {
name string
renderer string
wantUserTag string
wantToolContent string
}{
{
name: "gemma4",
renderer: "gemma4",
wantUserTag: "<|turn>user\n[img-0] look at this file<turn|>\n",
wantToolContent: "[img-1] attached image",
},
{
name: "qwen3-vl",
renderer: "qwen3-vl-instruct",
wantUserTag: "<|im_start|>user\n[img-0] look at this file<|im_end|>\n",
wantToolContent: "<tool_response>\n[img-1] attached image\n</tool_response>",
},
{
name: "qwen3.5",
renderer: "qwen3.5",
wantUserTag: "<|im_start|>user\n[img-0] look at this file<|im_end|>\n",
wantToolContent: "<tool_response>\n[img-1] attached image\n</tool_response>",
},
{
name: "glm-ocr",
renderer: "glm-ocr",
wantUserTag: "<|user|>\n[img-0] look at this file",
wantToolContent: "<tool_response>\n[img-1] attached image\n</tool_response>",
},
{
name: "nemotron-3-nano",
renderer: "nemotron-3-nano",
wantUserTag: "<|im_start|>user\n[img-0] look at this file<|im_end|>\n",
wantToolContent: "<tool_response>\n[img-1] attached image\n</tool_response>",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
m := Model{
Config: model.ConfigV2{Renderer: tt.renderer},
ProjectorPaths: []string{"vision"},
}
opts := api.Options{Runner: api.Runner{NumCtx: 8192}}
think := false
prompt, images, err := chatPrompt(t.Context(), &m, mockRunner{}.Tokenize, &opts, msgs, nil, &api.ThinkValue{Value: think}, true)
if err != nil {
t.Fatal(err)
}
if got, want := len(images), 2; got != want {
t.Fatalf("len(images) = %d, want %d", got, want)
}
if !strings.Contains(prompt, tt.wantUserTag) {
t.Fatalf("prompt missing user image tag, got: %q", prompt)
}
if !strings.Contains(prompt, tt.wantToolContent) {
t.Fatalf("prompt missing tool image tag, got: %q", prompt)
}
})
}
}
func TestChatPromptRendererPreservesExplicitImagePlaceholders(t *testing.T) {
msgs := []api.Message{
{
Role: "user",
Content: "compare [img] and [img]",
Images: []api.ImageData{[]byte("img-1"), []byte("img-2")},
},
}
tests := []struct {
name string
renderer string
wantSnippet string
}{
{
name: "gemma4",
renderer: "gemma4",
wantSnippet: "<|turn>user\ncompare [img-0] and [img-1]<turn|>\n",
},
{
name: "qwen3-vl",
renderer: "qwen3-vl-instruct",
wantSnippet: "<|im_start|>user\ncompare [img-0] and [img-1]<|im_end|>\n",
},
{
name: "qwen3.5",
renderer: "qwen3.5",
wantSnippet: "<|im_start|>user\ncompare [img-0] and [img-1]<|im_end|>\n",
},
{
name: "glm-ocr",
renderer: "glm-ocr",
wantSnippet: "<|user|>\ncompare [img-0] and [img-1]",
},
{
name: "nemotron-3-nano",
renderer: "nemotron-3-nano",
wantSnippet: "<|im_start|>user\ncompare [img-0] and [img-1]<|im_end|>\n",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
m := Model{
Config: model.ConfigV2{Renderer: tt.renderer},
ProjectorPaths: []string{"vision"},
}
opts := api.Options{Runner: api.Runner{NumCtx: 8192}}
think := false
prompt, images, err := chatPrompt(t.Context(), &m, mockRunner{}.Tokenize, &opts, msgs, nil, &api.ThinkValue{Value: think}, true)
if err != nil {
t.Fatal(err)
}
if got, want := len(images), 2; got != want {
t.Fatalf("len(images) = %d, want %d", got, want)
}
if !strings.Contains(prompt, tt.wantSnippet) {
t.Fatalf("prompt missing replaced placeholders, got: %q", prompt)
}
})
}
}
func TestRenderPromptResolvesDynamicGemma4Renderer(t *testing.T) {
msgs := []api.Message{{Role: "user", Content: "Hello"}}
tests := []struct {
name string
model Model
want string
}{
{
name: "small from name",
model: Model{
Name: "gemma4:e4b",
ShortName: "gemma4:e4b",
Config: testConfigWithRenderer(gemma4RendererLegacy),
},
want: "<bos><|turn>user\nHello<turn|>\n<|turn>model\n",
},
{
name: "large from model type",
model: Model{
Config: testConfigWithRendererAndType(gemma4RendererLegacy, "25.2B"),
},
want: "<bos><|turn>user\nHello<turn|>\n<|turn>model\n<|channel>thought\n<channel|>",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := renderPrompt(&tt.model, msgs, nil, nil)
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(got, tt.want); diff != "" {
t.Fatalf("rendered prompt mismatch (-got +want):\n%s", diff)
}
})
}
}

414
server/quantization.go Normal file
View File

@@ -0,0 +1,414 @@
package server
import (
"fmt"
"io"
"log/slog"
"maps"
"os"
"slices"
"strconv"
"strings"
"unsafe"
fsggml "github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/ml/backend/ggml"
)
type quantizer struct {
*os.File
offset uint64
from, to *fsggml.Tensor
progressFn func(n uint64)
}
func (q quantizer) WriteTo(w io.Writer) (int64, error) {
quantize := q.from.Kind != q.to.Kind
sr := io.NewSectionReader(q, int64(q.offset), int64(q.from.Size()))
if !quantize {
n, err := io.Copy(w, sr)
q.progressFn(q.from.Size())
return n, err
}
data, err := io.ReadAll(sr)
if err != nil {
slog.Warn("file read error", "tensor", q.from.Name, "file", q.Name(), "error", err)
return 0, fmt.Errorf("unable to read tensor %s from %s: %s", q.from.Name, q.Name(), err)
}
if uint64(len(data)) < q.from.Size() {
return 0, fmt.Errorf("tensor %s data size %d is less than expected %d from shape %v", q.from.Name, len(data), q.from.Size(), q.from.Shape)
}
var f32s []float32
newType := fsggml.TensorType(q.to.Kind)
if fsggml.TensorType(q.from.Kind) == fsggml.TensorTypeF32 {
f32s = unsafe.Slice((*float32)(unsafe.Pointer(&data[0])), q.from.Elements())
} else {
f32s = ggml.ConvertToF32(data, q.from.Kind, q.from.Elements())
}
data = ggml.Quantize(newType, f32s, q.from.Shape)
n, err := w.Write(data)
q.progressFn(q.from.Size())
return int64(n), err
}
type quantizeState struct {
nAttnV int // Number of attn_*v* weight tensors
nFfnDown int // Number of ffn_down tensors
iAttnV int // Running counter of number of attn_v tensors that have been processed
iFfnDown int // Running counter of number of ffn_down tensors that have been processed
hasOutput bool // used to figure out if a model shares tok_embd with the output weight
preserveSourceFP8ToQ8 bool
preserveSourceQ4 bool
sourceFP8Tensors map[string]struct{}
}
func useMoreBits(iLayer, nLayers int) bool {
return iLayer < (nLayers/8) || iLayer >= 7*nLayers/8 || (iLayer-nLayers/8)%3 == 2
}
func qwen3LinearAttnQuantType(name string) (fsggml.TensorType, bool) {
switch {
// Full attention
case strings.HasSuffix(name, ".attn_q.weight"):
return fsggml.TensorTypeQ4_K, true
case strings.HasSuffix(name, ".attn_k.weight"):
return fsggml.TensorTypeQ4_K, true
case strings.HasSuffix(name, ".attn_v.weight"):
return fsggml.TensorTypeQ6_K, true
case strings.HasSuffix(name, ".attn_output.weight"):
return fsggml.TensorTypeQ4_K, true
// Linear attention (Gated Delta Net) after split
case strings.HasSuffix(name, ".attn_qkv.weight"):
return fsggml.TensorTypeQ4_K, true
case strings.HasSuffix(name, ".attn_gate.weight"):
return fsggml.TensorTypeQ4_K, true
// SSM
case strings.HasSuffix(name, ".ssm_ba.weight"):
return fsggml.TensorTypeQ4_K, true
case strings.HasSuffix(name, ".ssm_beta.weight"):
return fsggml.TensorTypeQ4_K, true
case strings.HasSuffix(name, ".ssm_alpha.weight"):
return fsggml.TensorTypeQ4_K, true
case strings.HasSuffix(name, ".ssm_out.weight"):
return fsggml.TensorTypeQ4_K, true
// MoE experts + shared experts
case strings.HasSuffix(name, ".ffn_down_exps.weight"):
return fsggml.TensorTypeQ6_K, true
case strings.HasSuffix(name, ".ffn_down_shexp.weight"):
return fsggml.TensorTypeQ6_K, true
case strings.HasSuffix(name, ".ffn_gate_exps.weight"):
return fsggml.TensorTypeQ4_K, true
case strings.HasSuffix(name, ".ffn_gate_shexp.weight"):
return fsggml.TensorTypeQ4_K, true
case strings.HasSuffix(name, ".ffn_up_exps.weight"):
return fsggml.TensorTypeQ4_K, true
case strings.HasSuffix(name, ".ffn_up_shexp.weight"):
return fsggml.TensorTypeQ4_K, true
}
return 0, false
}
func isLagunaGGUFRoutedExpertWeight(name string) bool {
return strings.HasSuffix(name, ".weight") && (strings.Contains(name, "ffn_gate_exps") ||
strings.Contains(name, "ffn_up_exps") ||
strings.Contains(name, "ffn_down_exps"))
}
func lagunaGGUFBlockIndex(name string) (int, bool) {
if !strings.HasPrefix(name, "blk.") {
return 0, false
}
parts := strings.SplitN(strings.TrimPrefix(name, "blk."), ".", 2)
if len(parts) != 2 {
return 0, false
}
i, err := strconv.Atoi(parts[0])
if err != nil {
return 0, false
}
return i, true
}
func lagunaGGUFQuantization(name string, originalType, requestedType fsggml.TensorType, ftype fsggml.FileType, blockCount int) (fsggml.TensorType, bool) {
if !isLagunaGGUFRoutedExpertWeight(name) {
return originalType, false
}
if strings.HasSuffix(name, ".ffn_down_exps.weight") {
if i, ok := lagunaGGUFBlockIndex(name); ok && blockCount > 0 {
switch ftype {
case fsggml.FileTypeQ4_K_M:
if requestedType != fsggml.TensorTypeQ8_0 && useMoreBits(i, blockCount) {
return fsggml.TensorTypeQ6_K, true
}
case fsggml.FileTypeQ4_K_S:
if requestedType != fsggml.TensorTypeQ8_0 && i < blockCount/8 {
return fsggml.TensorTypeQ5_K, true
}
}
}
}
return requestedType, true
}
func getTensorNewType(kv fsggml.KV, qs *quantizeState, newType fsggml.TensorType, name string, shape []uint64, ftype fsggml.FileType) fsggml.TensorType {
// Ported from llama_tensor_get_type, removed unsupported quantization types
nExperts := max(1, kv.Uint("expert_count", 0))
if name == "output.weight" || name == "output_norm.weight" || (!qs.hasOutput && name == "token_embd.weight") {
nx := shape[0]
qk_k := newType.BlockSize()
if nx%qk_k != 0 {
newType = fsggml.TensorTypeQ8_0
} else if newType != fsggml.TensorTypeQ8_0 {
newType = fsggml.TensorTypeQ6_K
}
} else if strings.Contains(name, "attn_v.weight") {
if newType != fsggml.TensorTypeQ8_0 && (ftype == fsggml.FileTypeQ4_K_M) &&
useMoreBits(qs.iAttnV, qs.nAttnV) {
newType = fsggml.TensorTypeQ6_K
} else if newType != fsggml.TensorTypeQ8_0 && ftype == fsggml.FileTypeQ4_K_S && qs.iAttnV < 4 {
newType = fsggml.TensorTypeQ5_K
}
// TODO
// if (qs.model.type == LLM_TYPE_70B) {
// // In the 70B model we have 8 heads sharing the same attn_v weights. As a result, the attn_v.weight tensor is
// // 8x smaller compared to attn_q.weight. Hence, we can get a nice boost in quantization accuracy with
// // nearly negligible increase in model size by quantizing this tensor with more bits:
// if (newType == GGML_TYPE_Q3_K || newType == GGML_TYPE_Q4_K) newType = GGML_TYPE_Q5_K;
// }
if nExperts == 8 {
// for the 8-expert model, bumping this to Q8_0 trades just ~128MB
newType = fsggml.TensorTypeQ8_0
}
qs.iAttnV++
} else if strings.Contains(name, "attn_k.weight") {
if nExperts == 8 {
// for the 8-expert model, bumping this to Q8_0 trades just ~128MB
newType = fsggml.TensorTypeQ8_0
}
} else if strings.Contains(name, "attn_k_b.weight") ||
strings.Contains(name, "attn_v_b.weight") ||
strings.Contains(name, "attn_kv_a_mqa.weight") ||
strings.Contains(name, "attn_q_a.weight") ||
strings.Contains(name, "attn_q_b.weight") {
// MLA tensors need higher precision to avoid quality degradation
newType = fsggml.TensorTypeQ8_0
} else if strings.Contains(name, "ffn_down") {
// For MoE models, ffn_down.weight (dense) and ffn_down_exps.weight (expert) both
// exist per layer and should get the same useMoreBits treatment. Dense sorts before
// expert alphabetically, so dense increments the counter and expert uses counter-1.
var iLayer int
if strings.Contains(name, "_exps") {
if kv.Architecture() == "laguna" {
goto finalize
}
iLayer = max(0, qs.iFfnDown-1)
} else {
iLayer = qs.iFfnDown
qs.iFfnDown++
}
n_layer := qs.nFfnDown
if newType != fsggml.TensorTypeQ8_0 && ftype == fsggml.FileTypeQ4_K_M {
if useMoreBits(iLayer, n_layer) {
newType = fsggml.TensorTypeQ6_K
}
} else if newType != fsggml.TensorTypeQ8_0 && ftype == fsggml.FileTypeQ4_K_S && iLayer < n_layer/8 {
newType = fsggml.TensorTypeQ5_K
}
} else if strings.Contains(name, "attn_output.weight") {
if newType != fsggml.TensorTypeQ8_0 && nExperts == 8 {
if ftype == fsggml.FileTypeQ4_K_S || ftype == fsggml.FileTypeQ4_K_M {
newType = fsggml.TensorTypeQ5_K
}
}
} else if strings.Contains(name, "attn_qkv.weight") {
if newType != fsggml.TensorTypeQ8_0 && ftype == fsggml.FileTypeQ4_K_M {
newType = fsggml.TensorTypeQ5_K
}
}
finalize:
if newType.IsQuantized() {
nx := shape[0]
qk_k := newType.BlockSize()
// Check if first dimension is divisible by block size
if nx%qk_k != 0 {
// Store the original type for logging
originalType := newType
// Select appropriate fallback based on original type
switch newType {
case fsggml.TensorTypeQ4_K:
newType = fsggml.TensorTypeQ5_0
case fsggml.TensorTypeQ5_K:
newType = fsggml.TensorTypeQ5_1
case fsggml.TensorTypeQ6_K:
newType = fsggml.TensorTypeQ8_0
}
// Final check - if still incompatible, fall back to F16
if nx%newType.BlockSize() != 0 {
newType = fsggml.TensorTypeF16
}
slog.Warn(fmt.Sprintf("tensor cols %d are not divisible by %d, required for %s - using fallback quantization %s",
nx, qk_k, originalType.String(), newType.String()))
}
}
return newType
}
func quantize(in, out *os.File, orig *fsggml.GGML, newFileType fsggml.FileType, progressFn func(n uint64)) error {
kv := maps.Clone(orig.KV())
kv["general.file_type"] = newFileType
// kv["general.quantization_version"] = ggml.QuantizationVersion()
qs := &quantizeState{
sourceFP8Tensors: sourceFP8TensorSet(kv),
}
hasSourceFP8 := hasSourceFP8Tensors(kv)
qs.preserveSourceFP8ToQ8 = hasSourceFP8 && newFileType == fsggml.FileTypeQ8_0
qs.preserveSourceQ4 = hasSourceFP8 && slices.Contains([]fsggml.FileType{fsggml.FileTypeQ4_K_M, fsggml.FileTypeQ4_K_S}, newFileType)
// Build up the quantize state so newType can adjust types
layerCount := 0
for k, l := range orig.Tensors().GroupLayers() {
if strings.HasPrefix(k, "blk.") {
layerCount++
}
for _, tensor := range l {
if strings.Contains(tensor.Name, "attn_v.weight") ||
strings.Contains(tensor.Name, "attn_qkv.weight") ||
strings.Contains(tensor.Name, "attn_kv_b.weight") {
qs.nAttnV++
} else if tensor.Name == "output.weight" {
qs.hasOutput = true
}
}
}
qs.nFfnDown = layerCount
origTensors := orig.Tensors().Items()
outputTensors := make([]*fsggml.Tensor, len(origTensors))
for i, tensor := range origTensors {
newType := newType(tensor, kv, qs, newFileType)
newTensor := &fsggml.Tensor{
Name: tensor.Name,
Shape: tensor.Shape,
Kind: uint32(newType),
}
outputTensors[i] = newTensor
outputTensors[i].WriterTo = quantizer{
File: in,
offset: orig.Tensors().Offset + tensor.Offset,
from: tensor,
to: newTensor,
progressFn: progressFn,
}
}
return fsggml.WriteGGUF(out, kv, outputTensors)
}
func newType(t *fsggml.Tensor, kv fsggml.KV, qs *quantizeState, ftype fsggml.FileType) fsggml.TensorType {
defaultType := ftype.ToTensorType()
name := t.Name
quantize := strings.HasSuffix(name, "weight")
// don't quantize vision or audio encoder tensors
quantize = quantize && !strings.HasPrefix(name, "v.")
quantize = quantize && !strings.HasPrefix(name, "a.")
quantize = quantize && !strings.Contains(name, "mm.")
// quantize only 2D and 3D tensors (experts)
quantize = quantize && (len(t.Shape) >= 2)
// do not quantize norm tensors
quantize = quantize && !strings.Contains(name, "_norm.weight")
// do not quantize expert gating tensors
quantize = quantize && !strings.Contains(name, "ffn_gate_inp.weight")
quantize = quantize && !strings.Contains(name, "ffn_gate_inp_shexp.weight")
// do not quantize positional embeddings and token types (BERT)
quantize = quantize && (name != "position_embd.weight")
quantize = quantize && (name != "token_types.weight")
// do not quantize Mamba's small yet 2D weights
// NOTE: can't use LLM_TN here because the layer number is not known
quantize = quantize && !strings.Contains(name, "ssm_conv1d.weight")
// do not quantize LFM2's shortconv kernel weights
quantize = quantize && !strings.Contains(name, "shortconv.conv.weight")
// do not quantize RWKV's time_mix_first tensors
quantize = quantize && !strings.Contains(name, "time_mix_first.weight")
quantize = quantize && !strings.Contains(name, "time_mix_w1.weight")
quantize = quantize && !strings.Contains(name, "time_mix_w2.weight")
quantize = quantize && !strings.Contains(name, "time_mix_decay_w1.weight")
quantize = quantize && !strings.Contains(name, "time_mix_decay_w2.weight")
quantize = quantize && !strings.Contains(name, "time_mix_lerp_fused.weight")
// do not quantize relative position bias (T5)
quantize = quantize && !strings.Contains(name, "attn_rel_b.weight")
quantize = quantize && !strings.Contains(name, "per_layer_token_embd.weight")
newType := fsggml.TensorType(t.Kind)
if quantize {
if qs.preserveSourceFP8ToQ8 {
if _, ok := qs.sourceFP8Tensors[name]; !ok {
return newType
}
}
if slices.Contains([]string{"qwen3next", "qwen35", "qwen35moe"}, kv.Architecture()) && (ftype == fsggml.FileTypeQ4_K_M || ftype == fsggml.FileTypeQ4_K_S) {
if qt, ok := qwen3LinearAttnQuantType(name); ok {
return qt
}
}
// TODO: Consider extracting architecture-specific GGUF quantization policy
// from server so different quantization backends can share one source of
// truth for model-family specializations.
// get more optimal quantization type based on the tensor shape, layer, etc.
if qs.preserveSourceQ4 {
if _, ok := qs.sourceFP8Tensors[name]; !ok {
defaultType = fsggml.TensorTypeQ8_0
}
}
if kv.Architecture() == "laguna" {
var ok bool
defaultType, ok = lagunaGGUFQuantization(name, newType, defaultType, ftype, int(kv.Uint("block_count", 0)))
if !ok {
return newType
}
}
newType = getTensorNewType(kv, qs, defaultType, t.Name, t.Shape, ftype)
if newType != defaultType {
slog.Debug("tensor quantization adjusted for better quality", "name", t.Name, "requested", defaultType, "quantization", newType)
}
}
return newType
}
func sourceFP8TensorSet(kv fsggml.KV) map[string]struct{} {
names := kv.Strings("source_fp8_tensors")
if len(names) == 0 {
return nil
}
out := make(map[string]struct{}, len(names))
for _, name := range names {
out[name] = struct{}{}
}
return out
}

834
server/quantization_test.go Normal file
View File

@@ -0,0 +1,834 @@
package server
import (
"bytes"
"fmt"
"math"
"os"
"strings"
"testing"
fsggml "github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/ml/backend/ggml"
)
func TestGetTensorNewType(t *testing.T) {
cases := []struct {
name string
kv map[string]any
qs quantizeState
newType fsggml.TensorType
tensor_name string
shape []uint64
ftype fsggml.FileType
expected fsggml.TensorType
expectedPanic string
}{
{
name: "output_unsupported",
kv: map[string]any{},
newType: fsggml.TensorTypeQ4_0,
tensor_name: "output.weight",
shape: []uint64{100, 100},
ftype: fsggml.FileTypeF32,
expected: fsggml.TensorTypeF16,
},
{
name: "output_Q8",
kv: map[string]any{},
newType: fsggml.TensorTypeQ4_0,
tensor_name: "output.weight",
shape: []uint64{1024, 1024},
ftype: fsggml.FileTypeF32,
expected: fsggml.TensorTypeQ6_K,
},
{
name: "attn_v.weight_q4_k_m",
qs: quantizeState{
iAttnV: 2,
nAttnV: 3 * 8,
},
kv: map[string]any{},
newType: fsggml.TensorTypeQ4_0,
tensor_name: "blk.0.attn_v.weight",
shape: []uint64{256},
ftype: fsggml.FileTypeQ4_K_M,
expected: fsggml.TensorTypeQ6_K,
},
{
name: "attn_v.weight_q4_k_s",
qs: quantizeState{},
kv: map[string]any{},
newType: fsggml.TensorTypeQ4_0,
tensor_name: "blk.0.attn_v.weight",
shape: []uint64{256},
ftype: fsggml.FileTypeQ4_K_S,
expected: fsggml.TensorTypeQ5_K,
},
{
name: "attn_v.weight_8_expert",
qs: quantizeState{},
kv: map[string]any{
"general.architecture": "foo",
"foo.expert_count": uint32(8),
},
newType: fsggml.TensorTypeQ4_0,
tensor_name: "blk.0.attn_v.weight",
shape: []uint64{256},
ftype: fsggml.FileTypeF32,
expected: fsggml.TensorTypeQ8_0,
},
{
name: "attn_k.weight_8_expert",
qs: quantizeState{},
kv: map[string]any{
"general.architecture": "foo",
"foo.expert_count": uint32(8),
},
newType: fsggml.TensorTypeQ4_0,
tensor_name: "blk.0.attn_k.weight",
shape: []uint64{256},
ftype: fsggml.FileTypeF32,
expected: fsggml.TensorTypeQ8_0,
},
{
name: "ffn_down_q4_k_m",
qs: quantizeState{
iFfnDown: 1,
nFfnDown: 8,
},
kv: map[string]any{},
newType: fsggml.TensorTypeQ4_0,
tensor_name: "ffn_down",
shape: []uint64{256},
ftype: fsggml.FileTypeQ4_K_M,
expected: fsggml.TensorTypeQ4_0,
},
{
name: "ffn_down_q4_k_m_6",
qs: quantizeState{
iFfnDown: 2,
nFfnDown: 3 * 8,
},
kv: map[string]any{},
newType: fsggml.TensorTypeQ4_0,
tensor_name: "ffn_down",
shape: []uint64{256},
ftype: fsggml.FileTypeQ4_K_M,
expected: fsggml.TensorTypeQ6_K,
},
{
name: "ffn_down_q4_k_s",
qs: quantizeState{
iFfnDown: 2,
nFfnDown: 3 * 8,
},
kv: map[string]any{},
newType: fsggml.TensorTypeQ4_0,
tensor_name: "ffn_down",
shape: []uint64{256},
ftype: fsggml.FileTypeQ4_K_S,
expected: fsggml.TensorTypeQ5_K,
},
{
name: "attn_qkv.weight_q4_k_m",
qs: quantizeState{},
kv: map[string]any{},
newType: fsggml.TensorTypeQ4_0,
tensor_name: "blk.0.attn_qkv.weight",
shape: []uint64{256},
ftype: fsggml.FileTypeQ4_K_M,
expected: fsggml.TensorTypeQ5_K,
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
if tt.expectedPanic != "" {
defer func() {
e := recover()
if !strings.Contains(fmt.Sprintf("%v", e), tt.expectedPanic) {
t.Fatalf("incorrect panic\ngot: %v\nexpected: %s", e, tt.expectedPanic)
}
}()
} else {
defer func() {
e := recover()
if e != nil {
t.Fatalf("hit unexpected panic %v", e)
}
}()
}
ret := getTensorNewType(tt.kv, &tt.qs, tt.newType, tt.tensor_name, tt.shape, tt.ftype)
if ret != tt.expected {
t.Fatalf("incorrect type returned\ngot: %d\nexpected: %d", ret, tt.expected)
}
})
}
}
func TestQwen3LinearAttentionQuantOverride(t *testing.T) {
cases := []struct {
name string
arch string
tensor string
fileType fsggml.FileType
expected fsggml.TensorType
}{
{
name: "qwen35_beta",
arch: "qwen35",
tensor: "blk.0.ssm_beta.weight",
fileType: fsggml.FileTypeQ4_K_M,
expected: fsggml.TensorTypeQ4_K,
},
{
name: "qwen35_alpha",
arch: "qwen35",
tensor: "blk.0.ssm_alpha.weight",
fileType: fsggml.FileTypeQ4_K_M,
expected: fsggml.TensorTypeQ4_K,
},
{
name: "qwen35moe_attn_qkv",
arch: "qwen35moe",
tensor: "blk.0.attn_qkv.weight",
fileType: fsggml.FileTypeQ4_K_M,
expected: fsggml.TensorTypeQ4_K,
},
{
name: "non_qwen35_falls_back",
arch: "foo",
tensor: "blk.0.attn_qkv.weight",
fileType: fsggml.FileTypeQ4_K_M,
expected: fsggml.TensorTypeQ5_K,
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
kv := fsggml.KV{"general.architecture": tt.arch}
got := newType(&fsggml.Tensor{
Name: tt.tensor,
Shape: []uint64{256, 256},
Kind: uint32(fsggml.TensorTypeF16),
}, kv, &quantizeState{}, tt.fileType)
if got != tt.expected {
t.Fatalf("unexpected tensor type for %s (%s): got %s want %s", tt.tensor, tt.arch, got, tt.expected)
}
})
}
}
func TestQuantizeModel(t *testing.T) {
cases := []struct {
name string
kv map[string]any
tensors []*fsggml.Tensor
newType string
expectedTensorTypes map[string]fsggml.TensorType
expectErr bool
}{
{
name: "f16_q4_k",
kv: map[string]any{
"general.architecture": "foo",
},
tensors: []*fsggml.Tensor{
{
Name: "blk.0.attn.weight", Kind: uint32(fsggml.TensorTypeF16),
Offset: uint64(0), Shape: []uint64{512, 2},
WriterTo: bytes.NewReader(
append(append(append(quantBytes[fsggml.TensorTypeF16], quantBytes[fsggml.TensorTypeF16]...), quantBytes[fsggml.TensorTypeF16]...), quantBytes[fsggml.TensorTypeF16]...),
),
},
{
Name: "output.weight", Kind: uint32(fsggml.TensorTypeF16),
Offset: uint64(0), Shape: []uint64{256, 4},
WriterTo: bytes.NewReader(
append(append(append(quantBytes[fsggml.TensorTypeF16], quantBytes[fsggml.TensorTypeF16]...), quantBytes[fsggml.TensorTypeF16]...), quantBytes[fsggml.TensorTypeF16]...),
),
},
},
newType: "Q4_K",
expectedTensorTypes: map[string]fsggml.TensorType{
"blk.0.attn.weight": fsggml.TensorTypeQ4_K,
"output.weight": fsggml.TensorTypeQ6_K,
},
},
{
name: "f32_q4_k",
kv: map[string]any{
"general.architecture": "foo",
},
tensors: []*fsggml.Tensor{
{
Name: "blk.0.attn_v.weight", Kind: uint32(fsggml.TensorTypeF32),
Offset: uint64(0), Shape: []uint64{512, 2},
WriterTo: bytes.NewReader(
append(append(append(quantBytes[fsggml.TensorTypeF32], quantBytes[fsggml.TensorTypeF32]...), quantBytes[fsggml.TensorTypeF32]...), quantBytes[fsggml.TensorTypeF32]...),
),
},
{
Name: "output.weight", Kind: uint32(fsggml.TensorTypeF32),
Offset: uint64(0), Shape: []uint64{512},
WriterTo: bytes.NewReader(append(quantBytes[fsggml.TensorTypeF32], quantBytes[fsggml.TensorTypeF32]...)),
},
},
newType: "Q4_K",
expectedTensorTypes: map[string]fsggml.TensorType{
"blk.0.attn_v.weight": fsggml.TensorTypeQ6_K,
"output.weight": fsggml.TensorTypeF32,
},
},
{
name: "f16_q8_0",
kv: map[string]any{
"general.architecture": "foo",
},
tensors: []*fsggml.Tensor{
{
Name: "blk.0.attn.weight", Kind: uint32(fsggml.TensorTypeF16),
Offset: uint64(0), Shape: []uint64{32, 16, 2},
WriterTo: bytes.NewReader(
append(append(append(quantBytes[fsggml.TensorTypeF16], quantBytes[fsggml.TensorTypeF16]...), quantBytes[fsggml.TensorTypeF16]...), quantBytes[fsggml.TensorTypeF16]...),
),
},
{
Name: "output.weight", Kind: uint32(fsggml.TensorTypeF16),
Offset: uint64(0), Shape: []uint64{256, 4},
WriterTo: bytes.NewReader(
append(append(append(quantBytes[fsggml.TensorTypeF16], quantBytes[fsggml.TensorTypeF16]...), quantBytes[fsggml.TensorTypeF16]...), quantBytes[fsggml.TensorTypeF16]...),
),
},
},
newType: "Q8_0",
expectedTensorTypes: map[string]fsggml.TensorType{
"blk.0.attn.weight": fsggml.TensorTypeQ8_0,
"output.weight": fsggml.TensorTypeQ8_0,
},
},
{
name: "source_fp8_q8_preserves_bf16_tensors",
kv: map[string]any{
"general.architecture": "test",
"source_quantization": "hf_fp8",
"source_fp8_tensors": []string{"blk.1.ffn_down_exps.weight"},
},
tensors: []*fsggml.Tensor{
{
Name: "blk.1.ffn_down_exps.weight", Kind: uint32(fsggml.TensorTypeBF16),
Offset: uint64(0), Shape: []uint64{256, 1},
WriterTo: bytes.NewReader(quantBytes[fsggml.TensorTypeBF16]),
},
{
Name: "blk.1.attn_q.weight", Kind: uint32(fsggml.TensorTypeBF16),
Offset: uint64(0), Shape: []uint64{256, 1},
WriterTo: bytes.NewReader(quantBytes[fsggml.TensorTypeBF16]),
},
},
newType: "Q8_0",
expectedTensorTypes: map[string]fsggml.TensorType{
"blk.1.ffn_down_exps.weight": fsggml.TensorTypeQ8_0,
"blk.1.attn_q.weight": fsggml.TensorTypeBF16,
},
},
{
name: "source_fp8_q4_promotes_bf16_tensors_to_q8",
kv: map[string]any{
"general.architecture": "test",
"source_quantization": "hf_fp8",
"source_fp8_tensors": []string{
"blk.1.ffn_gate_exps.weight",
"blk.1.ffn_down_exps.weight",
},
},
tensors: []*fsggml.Tensor{
{
Name: "blk.1.ffn_gate_exps.weight", Kind: uint32(fsggml.TensorTypeBF16),
Offset: uint64(0), Shape: []uint64{256, 1},
WriterTo: bytes.NewReader(quantBytes[fsggml.TensorTypeBF16]),
},
{
Name: "blk.1.ffn_down_exps.weight", Kind: uint32(fsggml.TensorTypeBF16),
Offset: uint64(0), Shape: []uint64{256, 1},
WriterTo: bytes.NewReader(quantBytes[fsggml.TensorTypeBF16]),
},
{
Name: "blk.1.attn_q.weight", Kind: uint32(fsggml.TensorTypeBF16),
Offset: uint64(0), Shape: []uint64{256, 1},
WriterTo: bytes.NewReader(quantBytes[fsggml.TensorTypeBF16]),
},
{
Name: "blk.1.attn_v.weight", Kind: uint32(fsggml.TensorTypeBF16),
Offset: uint64(0), Shape: []uint64{256, 1},
WriterTo: bytes.NewReader(quantBytes[fsggml.TensorTypeBF16]),
},
{
Name: "blk.1.ffn_down.weight", Kind: uint32(fsggml.TensorTypeBF16),
Offset: uint64(0), Shape: []uint64{256, 1},
WriterTo: bytes.NewReader(quantBytes[fsggml.TensorTypeBF16]),
},
{
Name: "blk.1.attn_q_norm.weight", Kind: uint32(fsggml.TensorTypeBF16),
Offset: uint64(0), Shape: []uint64{256},
WriterTo: bytes.NewReader(quantBytes[fsggml.TensorTypeBF16]),
},
{
Name: "blk.1.ffn_gate_inp.weight", Kind: uint32(fsggml.TensorTypeBF16),
Offset: uint64(0), Shape: []uint64{256, 1},
WriterTo: bytes.NewReader(quantBytes[fsggml.TensorTypeBF16]),
},
{
Name: "output.weight", Kind: uint32(fsggml.TensorTypeBF16),
Offset: uint64(0), Shape: []uint64{256, 1},
WriterTo: bytes.NewReader(quantBytes[fsggml.TensorTypeBF16]),
},
},
newType: "Q4_K_M",
expectedTensorTypes: map[string]fsggml.TensorType{
"blk.1.ffn_gate_exps.weight": fsggml.TensorTypeQ4_K,
"blk.1.ffn_down_exps.weight": fsggml.TensorTypeQ6_K,
"blk.1.attn_q.weight": fsggml.TensorTypeQ8_0,
"blk.1.attn_v.weight": fsggml.TensorTypeQ8_0,
"blk.1.ffn_down.weight": fsggml.TensorTypeQ8_0,
"blk.1.attn_q_norm.weight": fsggml.TensorTypeBF16,
"blk.1.ffn_gate_inp.weight": fsggml.TensorTypeBF16,
"output.weight": fsggml.TensorTypeQ8_0,
},
},
{
name: "f32_short_data",
kv: map[string]any{
"general.architecture": "foo",
},
tensors: []*fsggml.Tensor{
{
Name: "blk.0.attn.weight", Kind: uint32(fsggml.TensorTypeF32),
Offset: uint64(0), Shape: []uint64{512, 2},
WriterTo: bytes.NewReader(make([]byte, 32)),
},
},
newType: "Q4_K",
expectErr: true,
},
{
name: "f16_short_data",
kv: map[string]any{
"general.architecture": "foo",
},
tensors: []*fsggml.Tensor{
{
Name: "blk.0.attn.weight", Kind: uint32(fsggml.TensorTypeF16),
Offset: uint64(0), Shape: []uint64{512, 2},
WriterTo: bytes.NewReader(make([]byte, 32)),
},
},
newType: "Q4_K",
expectErr: true,
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
p, _ := createBinFile(t, tt.kv, tt.tensors)
fp, err := os.Open(p)
if err != nil {
t.Fatal(err.Error())
}
defer fp.Close()
meta, err := fsggml.Decode(fp, -1)
if tt.expectErr && err != nil {
return
}
if err != nil {
t.Fatal(err.Error())
}
progressCalled := false
progress := func(n uint64) {
// fmt.Fprintf(os.Stderr, "progress: %f\n", p)
progressCalled = true
}
tmp, err := os.CreateTemp(t.TempDir(), tt.name+".out")
if err != nil {
t.Fatal(err.Error())
}
defer tmp.Close()
ftype, err := fsggml.ParseFileType(tt.newType)
if err != nil {
t.Fatal(err.Error())
}
err = quantize(fp, tmp, meta, ftype, progress)
if tt.expectErr {
if err == nil {
t.Fatal("expected quantize to return an error")
}
return
}
if err != nil {
t.Fatalf("error during quantize: %s", err)
}
if !progressCalled {
t.Fatalf("progress was not reported")
}
// Now attempt to load it back and make sure types match expected
fpNew, err := os.Open(tmp.Name())
if err != nil {
t.Fatalf("failed to load the quantized model %s: %s", tmp.Name(), err)
}
defer fpNew.Close()
newMeta, err := fsggml.Decode(fpNew, -1)
if err != nil {
t.Fatalf("failed to load the quantized model %s: %s", tmp.Name(), err)
}
tensors := newMeta.Tensors()
for _, l := range tensors.GroupLayers() {
for _, tensor := range l {
if fsggml.TensorType(tensor.Kind) != tt.expectedTensorTypes[tensor.Name] {
t.Fatalf("incorrect output type for %s\ngot:%s\nexpected:%s", tensor.Name, fsggml.TensorType(tensor.Kind), tt.expectedTensorTypes[tensor.Name])
}
}
}
})
}
}
func TestConvertToF32(t *testing.T) {
expected := make([]float32, 256)
for i := range expected {
expected[i] = float32(i)
}
for dtype, data := range quantBytes {
// Skip the no-op
if dtype == fsggml.TensorTypeF32 {
continue
}
t.Run(dtype.String(), func(t *testing.T) {
fp32 := ggml.ConvertToF32(data, uint32(dtype), 256)
similarity := cosineSimilarity(expected, fp32)
if similarity < 0.999 {
t.Fatalf("Results not similar enough: %s %f", dtype.String(), similarity)
}
})
}
}
func dotProduct[V float32 | float64](v1, v2 []V) V {
var result V = 0
for i := range v1 {
result += v1[i] * v2[i]
}
return result
}
func magnitude[V float32 | float64](v []V) V {
var result V = 0
for _, val := range v {
result += val * val
}
return V(math.Sqrt(float64(result)))
}
func cosineSimilarity[V float32 | float64](v1, v2 []V) V {
return dotProduct(v1, v2) / (magnitude(v1) * magnitude(v2))
}
// Precomputed quantized data - arange 256
// # For gguf-py supported types
// import gguf
// import numpy as np
// print(repr(gguf.quantize(np.arange(256, dtype=np.float16), gguf.GGMLQuantizationType.Q4_0)))
//
// For types not supported by gguf-py converted via ggml_fp32_to_fp16_row and quantize_XXX
//
// data := make([]byte, 256*2)
// fp32 := make([]float32, 256)
// for i := range 256 {
// fp32[i] = float32(i)
// }
// l := C.quantize_q6_K((*C.float)(&fp32[0]), unsafe.Pointer(&data[0]), 1, 256, nil)
// for i := range data[:int(l)] {
// fmt.Printf("%d, ", data[i])
// }
var (
quantBytes = map[fsggml.TensorType][]byte{
fsggml.TensorTypeQ4_0: {
192, 195, 72, 72, 55, 55, 55, 55, 38, 38, 38, 38, 21,
21, 21, 21, 4, 4, 224, 199, 36, 36, 36, 36, 19, 19,
19, 19, 19, 19, 19, 19, 2, 2, 2, 2, 240, 201, 19,
19, 18, 18, 18, 18, 18, 18, 18, 18, 2, 2, 2, 2,
1, 1, 240, 203, 18, 18, 18, 18, 18, 18, 18, 18, 1,
1, 1, 1, 1, 1, 1, 1, 248, 204, 18, 18, 17, 17,
17, 17, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 248,
205, 17, 17, 17, 17, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 248, 206, 17, 17, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 248, 207, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1,
},
fsggml.TensorTypeQ4_1: {
34, 64, 0, 0, 128, 128, 145, 145, 162, 162, 179, 179, 196,
196, 213, 213, 230, 230, 247, 247, 34, 64, 0, 80, 128, 128,
145, 145, 162, 162, 179, 179, 196, 196, 213, 213, 230, 230, 247,
247, 34, 64, 0, 84, 128, 128, 145, 145, 162, 162, 179, 179,
196, 196, 213, 213, 230, 230, 247, 247, 34, 64, 0, 86, 128,
128, 145, 145, 162, 162, 179, 179, 196, 196, 213, 213, 230, 230,
247, 247, 34, 64, 0, 88, 128, 128, 145, 145, 162, 162, 179,
179, 196, 196, 213, 213, 230, 230, 247, 247, 34, 64, 0, 89,
128, 128, 145, 145, 162, 162, 179, 179, 196, 196, 213, 213, 230,
230, 247, 247, 34, 64, 0, 90, 128, 128, 145, 145, 162, 162,
179, 179, 196, 196, 213, 213, 230, 230, 247, 247, 34, 64, 0,
91, 128, 128, 145, 145, 162, 162, 179, 179, 196, 196, 213, 213,
230, 230, 247, 247,
},
fsggml.TensorTypeQ5_0: {
192, 191, 1, 0, 0, 0, 128, 127, 127, 110, 110, 93, 93,
76, 76, 59, 59, 42, 42, 25, 25, 8, 224, 195, 0, 0,
0, 0, 72, 72, 55, 55, 55, 55, 38, 38, 38, 38, 21,
21, 21, 21, 4, 4, 240, 197, 0, 0, 0, 0, 53, 37,
37, 37, 37, 36, 36, 20, 20, 20, 20, 19, 19, 3, 3,
3, 240, 199, 0, 0, 0, 0, 36, 36, 36, 36, 19, 19,
19, 19, 19, 19, 19, 19, 2, 2, 2, 2, 248, 200, 0,
0, 0, 0, 35, 19, 19, 19, 19, 19, 19, 18, 18, 18,
18, 2, 2, 2, 2, 2, 248, 201, 0, 0, 0, 0, 19,
19, 18, 18, 18, 18, 18, 18, 18, 18, 2, 2, 2, 2,
1, 1, 248, 202, 0, 0, 0, 0, 18, 18, 18, 18, 18,
18, 18, 18, 18, 2, 2, 1, 1, 1, 1, 1, 248, 203,
0, 0, 0, 0, 18, 18, 18, 18, 18, 18, 18, 18, 1,
1, 1, 1, 1, 1, 1, 1,
},
fsggml.TensorTypeQ5_1: {
0, 60, 0, 0, 0, 0, 255, 255, 0, 17, 34, 51, 68,
85, 102, 119, 136, 153, 170, 187, 204, 221, 238, 255, 0, 60,
0, 80, 0, 0, 255, 255, 0, 17, 34, 51, 68, 85, 102,
119, 136, 153, 170, 187, 204, 221, 238, 255, 0, 60, 0, 84,
0, 0, 255, 255, 0, 17, 34, 51, 68, 85, 102, 119, 136,
153, 170, 187, 204, 221, 238, 255, 0, 60, 0, 86, 0, 0,
255, 255, 0, 17, 34, 51, 68, 85, 102, 119, 136, 153, 170,
187, 204, 221, 238, 255, 0, 60, 0, 88, 0, 0, 255, 255,
0, 17, 34, 51, 68, 85, 102, 119, 136, 153, 170, 187, 204,
221, 238, 255, 0, 60, 0, 89, 0, 0, 255, 255, 0, 17,
34, 51, 68, 85, 102, 119, 136, 153, 170, 187, 204, 221, 238,
255, 0, 60, 0, 90, 0, 0, 255, 255, 0, 17, 34, 51,
68, 85, 102, 119, 136, 153, 170, 187, 204, 221, 238, 255, 0,
60, 0, 91, 0, 0, 255, 255, 0, 17, 34, 51, 68, 85,
102, 119, 136, 153, 170, 187, 204, 221, 238, 255,
},
fsggml.TensorTypeQ8_0: {
208, 51, 0, 4, 8, 12, 16, 20, 25, 29, 33, 37, 41,
45, 49, 53, 57, 61, 66, 70, 74, 78, 82, 86, 90, 94,
98, 102, 107, 111, 115, 119, 123, 127, 240, 55, 65, 67, 69,
71, 73, 75, 77, 79, 81, 83, 85, 87, 89, 91, 93, 95,
97, 99, 101, 103, 105, 107, 109, 111, 113, 115, 117, 119, 121,
123, 125, 127, 252, 57, 86, 87, 88, 90, 91, 92, 94, 95,
96, 98, 99, 100, 102, 103, 104, 106, 107, 108, 110, 111, 112,
114, 115, 116, 118, 119, 120, 122, 123, 124, 126, 127, 0, 60,
96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108,
109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121,
122, 123, 124, 125, 126, 127, 2, 61, 102, 103, 104, 105, 105,
106, 107, 108, 109, 109, 110, 111, 112, 113, 113, 114, 115, 116,
117, 117, 118, 119, 120, 121, 121, 122, 123, 124, 125, 125, 126,
127, 4, 62, 106, 107, 108, 108, 109, 110, 110, 111, 112, 112,
113, 114, 114, 115, 116, 116, 117, 118, 118, 119, 120, 120, 121,
122, 122, 123, 124, 124, 125, 126, 126, 127, 6, 63, 109, 110,
110, 111, 112, 112, 113, 113, 114, 114, 115, 116, 116, 117, 117,
118, 118, 119, 120, 120, 121, 121, 122, 122, 123, 124, 124, 125,
125, 126, 126, 127, 4, 64, 112, 112, 113, 113, 114, 114, 115,
115, 116, 116, 117, 117, 118, 118, 119, 119, 120, 120, 121, 121,
122, 122, 123, 123, 124, 124, 125, 125, 126, 126, 127, 127,
},
fsggml.TensorTypeBF16: {
0, 0, 128, 63, 0, 64, 64, 64, 128, 64, 160, 64, 192,
64, 224, 64, 0, 65, 16, 65, 32, 65, 48, 65, 64, 65,
80, 65, 96, 65, 112, 65, 128, 65, 136, 65, 144, 65, 152,
65, 160, 65, 168, 65, 176, 65, 184, 65, 192, 65, 200, 65,
208, 65, 216, 65, 224, 65, 232, 65, 240, 65, 248, 65, 0,
66, 4, 66, 8, 66, 12, 66, 16, 66, 20, 66, 24, 66,
28, 66, 32, 66, 36, 66, 40, 66, 44, 66, 48, 66, 52,
66, 56, 66, 60, 66, 64, 66, 68, 66, 72, 66, 76, 66,
80, 66, 84, 66, 88, 66, 92, 66, 96, 66, 100, 66, 104,
66, 108, 66, 112, 66, 116, 66, 120, 66, 124, 66, 128, 66,
130, 66, 132, 66, 134, 66, 136, 66, 138, 66, 140, 66, 142,
66, 144, 66, 146, 66, 148, 66, 150, 66, 152, 66, 154, 66,
156, 66, 158, 66, 160, 66, 162, 66, 164, 66, 166, 66, 168,
66, 170, 66, 172, 66, 174, 66, 176, 66, 178, 66, 180, 66,
182, 66, 184, 66, 186, 66, 188, 66, 190, 66, 192, 66, 194,
66, 196, 66, 198, 66, 200, 66, 202, 66, 204, 66, 206, 66,
208, 66, 210, 66, 212, 66, 214, 66, 216, 66, 218, 66, 220,
66, 222, 66, 224, 66, 226, 66, 228, 66, 230, 66, 232, 66,
234, 66, 236, 66, 238, 66, 240, 66, 242, 66, 244, 66, 246,
66, 248, 66, 250, 66, 252, 66, 254, 66, 0, 67, 1, 67,
2, 67, 3, 67, 4, 67, 5, 67, 6, 67, 7, 67, 8,
67, 9, 67, 10, 67, 11, 67, 12, 67, 13, 67, 14, 67,
15, 67, 16, 67, 17, 67, 18, 67, 19, 67, 20, 67, 21,
67, 22, 67, 23, 67, 24, 67, 25, 67, 26, 67, 27, 67,
28, 67, 29, 67, 30, 67, 31, 67, 32, 67, 33, 67, 34,
67, 35, 67, 36, 67, 37, 67, 38, 67, 39, 67, 40, 67,
41, 67, 42, 67, 43, 67, 44, 67, 45, 67, 46, 67, 47,
67, 48, 67, 49, 67, 50, 67, 51, 67, 52, 67, 53, 67,
54, 67, 55, 67, 56, 67, 57, 67, 58, 67, 59, 67, 60,
67, 61, 67, 62, 67, 63, 67, 64, 67, 65, 67, 66, 67,
67, 67, 68, 67, 69, 67, 70, 67, 71, 67, 72, 67, 73,
67, 74, 67, 75, 67, 76, 67, 77, 67, 78, 67, 79, 67,
80, 67, 81, 67, 82, 67, 83, 67, 84, 67, 85, 67, 86,
67, 87, 67, 88, 67, 89, 67, 90, 67, 91, 67, 92, 67,
93, 67, 94, 67, 95, 67, 96, 67, 97, 67, 98, 67, 99,
67, 100, 67, 101, 67, 102, 67, 103, 67, 104, 67, 105, 67,
106, 67, 107, 67, 108, 67, 109, 67, 110, 67, 111, 67, 112,
67, 113, 67, 114, 67, 115, 67, 116, 67, 117, 67, 118, 67,
119, 67, 120, 67, 121, 67, 122, 67, 123, 67, 124, 67, 125,
67, 126, 67, 127, 67,
},
fsggml.TensorTypeF16: {
0, 0, 0, 60, 0, 64, 0, 66, 0, 68, 0, 69, 0, 70, 0, 71, 0,
72, 128, 72, 0, 73, 128, 73, 0, 74, 128, 74, 0, 75, 128, 75,
0, 76, 64, 76, 128, 76, 192, 76, 0, 77, 64, 77, 128, 77, 192,
77, 0, 78, 64, 78, 128, 78, 192, 78, 0, 79, 64, 79, 128, 79,
192, 79, 0, 80, 32, 80, 64, 80, 96, 80, 128, 80, 160, 80,
192, 80, 224, 80, 0, 81, 32, 81, 64, 81, 96, 81, 128, 81,
160, 81, 192, 81, 224, 81, 0, 82, 32, 82, 64, 82, 96, 82,
128, 82, 160, 82, 192, 82, 224, 82, 0, 83, 32, 83, 64, 83,
96, 83, 128, 83, 160, 83, 192, 83, 224, 83, 0, 84, 16, 84,
32, 84, 48, 84, 64, 84, 80, 84, 96, 84, 112, 84, 128, 84,
144, 84, 160, 84, 176, 84, 192, 84, 208, 84, 224, 84, 240,
84, 0, 85, 16, 85, 32, 85, 48, 85, 64, 85, 80, 85, 96, 85,
112, 85, 128, 85, 144, 85, 160, 85, 176, 85, 192, 85, 208,
85, 224, 85, 240, 85, 0, 86, 16, 86, 32, 86, 48, 86, 64,
86, 80, 86, 96, 86, 112, 86, 128, 86, 144, 86, 160, 86,
176, 86, 192, 86, 208, 86, 224, 86, 240, 86, 0, 87, 16,
87, 32, 87, 48, 87, 64, 87, 80, 87, 96, 87, 112, 87, 128,
87, 144, 87, 160, 87, 176, 87, 192, 87, 208, 87, 224, 87,
240, 87, 0, 88, 8, 88, 16, 88, 24, 88, 32, 88, 40, 88,
48, 88, 56, 88, 64, 88, 72, 88, 80, 88, 88, 88, 96, 88,
104, 88, 112, 88, 120, 88, 128, 88, 136, 88, 144, 88, 152,
88, 160, 88, 168, 88, 176, 88, 184, 88, 192, 88, 200, 88,
208, 88, 216, 88, 224, 88, 232, 88, 240, 88, 248, 88, 0,
89, 8, 89, 16, 89, 24, 89, 32, 89, 40, 89, 48, 89, 56, 89,
64, 89, 72, 89, 80, 89, 88, 89, 96, 89, 104, 89, 112, 89,
120, 89, 128, 89, 136, 89, 144, 89, 152, 89, 160, 89, 168,
89, 176, 89, 184, 89, 192, 89, 200, 89, 208, 89, 216, 89,
224, 89, 232, 89, 240, 89, 248, 89, 0, 90, 8, 90, 16, 90,
24, 90, 32, 90, 40, 90, 48, 90, 56, 90, 64, 90, 72, 90, 80,
90, 88, 90, 96, 90, 104, 90, 112, 90, 120, 90, 128, 90,
136, 90, 144, 90, 152, 90, 160, 90, 168, 90, 176, 90, 184,
90, 192, 90, 200, 90, 208, 90, 216, 90, 224, 90, 232, 90,
240, 90, 248, 90, 0, 91, 8, 91, 16, 91, 24, 91, 32, 91, 40,
91, 48, 91, 56, 91, 64, 91, 72, 91, 80, 91, 88, 91, 96, 91,
104, 91, 112, 91, 120, 91, 128, 91, 136, 91, 144, 91, 152,
91, 160, 91, 168, 91, 176, 91, 184, 91, 192, 91, 200, 91,
208, 91, 216, 91, 224, 91, 232, 91, 240, 91, 248, 91,
},
fsggml.TensorTypeF32: {
0, 0, 0, 0, 0, 0, 128, 63, 0, 0, 0, 64, 0, 0, 64, 64, 0, 0, 128,
64, 0, 0, 160, 64, 0, 0, 192, 64, 0, 0, 224, 64, 0, 0, 0, 65, 0,
0, 16, 65, 0, 0, 32, 65, 0, 0, 48, 65, 0, 0, 64, 65, 0, 0, 80, 65,
0, 0, 96, 65, 0, 0, 112, 65, 0, 0, 128, 65, 0, 0, 136, 65, 0, 0,
144, 65, 0, 0, 152, 65, 0, 0, 160, 65, 0, 0, 168, 65, 0, 0, 176,
65, 0, 0, 184, 65, 0, 0, 192, 65, 0, 0, 200, 65, 0, 0, 208, 65, 0,
0, 216, 65, 0, 0, 224, 65, 0, 0, 232, 65, 0, 0, 240, 65, 0, 0, 248,
65, 0, 0, 0, 66, 0, 0, 4, 66, 0, 0, 8, 66, 0, 0, 12, 66, 0, 0, 16,
66, 0, 0, 20, 66, 0, 0, 24, 66, 0, 0, 28, 66, 0, 0, 32, 66, 0, 0,
36, 66, 0, 0, 40, 66, 0, 0, 44, 66, 0, 0, 48, 66, 0, 0, 52, 66, 0,
0, 56, 66, 0, 0, 60, 66, 0, 0, 64, 66, 0, 0, 68, 66, 0, 0, 72, 66,
0, 0, 76, 66, 0, 0, 80, 66, 0, 0, 84, 66, 0, 0, 88, 66, 0, 0, 92, 66,
0, 0, 96, 66, 0, 0, 100, 66, 0, 0, 104, 66, 0, 0, 108, 66, 0, 0, 112,
66, 0, 0, 116, 66, 0, 0, 120, 66, 0, 0, 124, 66, 0, 0, 128, 66, 0, 0,
130, 66, 0, 0, 132, 66, 0, 0, 134, 66, 0, 0, 136, 66, 0, 0, 138, 66,
0, 0, 140, 66, 0, 0, 142, 66, 0, 0, 144, 66, 0, 0, 146, 66, 0, 0, 148,
66, 0, 0, 150, 66, 0, 0, 152, 66, 0, 0, 154, 66, 0, 0, 156, 66, 0, 0,
158, 66, 0, 0, 160, 66, 0, 0, 162, 66, 0, 0, 164, 66, 0, 0, 166, 66,
0, 0, 168, 66, 0, 0, 170, 66, 0, 0, 172, 66, 0, 0, 174, 66, 0, 0, 176,
66, 0, 0, 178, 66, 0, 0, 180, 66, 0, 0, 182, 66, 0, 0, 184, 66, 0, 0,
186, 66, 0, 0, 188, 66, 0, 0, 190, 66, 0, 0, 192, 66, 0, 0, 194, 66, 0,
0, 196, 66, 0, 0, 198, 66, 0, 0, 200, 66, 0, 0, 202, 66, 0, 0, 204, 66,
0, 0, 206, 66, 0, 0, 208, 66, 0, 0, 210, 66, 0, 0, 212, 66, 0, 0, 214, 66,
0, 0, 216, 66, 0, 0, 218, 66, 0, 0, 220, 66, 0, 0, 222, 66, 0, 0, 224, 66,
0, 0, 226, 66, 0, 0, 228, 66, 0, 0, 230, 66, 0, 0, 232, 66, 0, 0, 234, 66,
0, 0, 236, 66, 0, 0, 238, 66, 0, 0, 240, 66, 0, 0, 242, 66, 0, 0, 244, 66,
0, 0, 246, 66, 0, 0, 248, 66, 0, 0, 250, 66, 0, 0, 252, 66, 0, 0, 254, 66,
0, 0, 0, 67, 0, 0, 1, 67, 0, 0, 2, 67, 0, 0, 3, 67, 0, 0, 4, 67, 0, 0, 5, 67,
0, 0, 6, 67, 0, 0, 7, 67, 0, 0, 8, 67, 0, 0, 9, 67, 0, 0, 10, 67, 0, 0, 11,
67, 0, 0, 12, 67, 0, 0, 13, 67, 0, 0, 14, 67, 0, 0, 15, 67, 0, 0, 16, 67,
0, 0, 17, 67, 0, 0, 18, 67, 0, 0, 19, 67, 0, 0, 20, 67, 0, 0, 21, 67, 0, 0,
22, 67, 0, 0, 23, 67, 0, 0, 24, 67, 0, 0, 25, 67, 0, 0, 26, 67, 0, 0, 27,
67, 0, 0, 28, 67, 0, 0, 29, 67, 0, 0, 30, 67, 0, 0, 31, 67, 0, 0, 32, 67,
0, 0, 33, 67, 0, 0, 34, 67, 0, 0, 35, 67, 0, 0, 36, 67, 0, 0, 37, 67, 0, 0,
38, 67, 0, 0, 39, 67, 0, 0, 40, 67, 0, 0, 41, 67, 0, 0, 42, 67, 0, 0, 43, 67,
0, 0, 44, 67, 0, 0, 45, 67, 0, 0, 46, 67, 0, 0, 47, 67, 0, 0, 48, 67, 0, 0,
49, 67, 0, 0, 50, 67, 0, 0, 51, 67, 0, 0, 52, 67, 0, 0, 53, 67, 0, 0, 54, 67,
0, 0, 55, 67, 0, 0, 56, 67, 0, 0, 57, 67, 0, 0, 58, 67, 0, 0, 59, 67, 0, 0,
60, 67, 0, 0, 61, 67, 0, 0, 62, 67, 0, 0, 63, 67, 0, 0, 64, 67, 0, 0, 65, 67,
0, 0, 66, 67, 0, 0, 67, 67, 0, 0, 68, 67, 0, 0, 69, 67, 0, 0, 70, 67, 0, 0, 71,
67, 0, 0, 72, 67, 0, 0, 73, 67, 0, 0, 74, 67, 0, 0, 75, 67, 0, 0, 76, 67, 0,
0, 77, 67, 0, 0, 78, 67, 0, 0, 79, 67, 0, 0, 80, 67, 0, 0, 81, 67, 0, 0, 82,
67, 0, 0, 83, 67, 0, 0, 84, 67, 0, 0, 85, 67, 0, 0, 86, 67, 0, 0, 87, 67, 0,
0, 88, 67, 0, 0, 89, 67, 0, 0, 90, 67, 0, 0, 91, 67, 0, 0, 92, 67, 0, 0, 93,
67, 0, 0, 94, 67, 0, 0, 95, 67, 0, 0, 96, 67, 0, 0, 97, 67, 0, 0, 98, 67, 0,
0, 99, 67, 0, 0, 100, 67, 0, 0, 101, 67, 0, 0, 102, 67, 0, 0, 103, 67, 0, 0,
104, 67, 0, 0, 105, 67, 0, 0, 106, 67, 0, 0, 107, 67, 0, 0, 108, 67, 0, 0, 109,
67, 0, 0, 110, 67, 0, 0, 111, 67, 0, 0, 112, 67, 0, 0, 113, 67, 0, 0, 114, 67,
0, 0, 115, 67, 0, 0, 116, 67, 0, 0, 117, 67, 0, 0, 118, 67, 0, 0, 119, 67, 0,
0, 120, 67, 0, 0, 121, 67, 0, 0, 122, 67, 0, 0, 123, 67, 0, 0, 124, 67, 0, 0,
125, 67, 0, 0, 126, 67, 0, 0, 127, 67,
},
fsggml.TensorTypeQ4_K: {
52, 52, 0, 0, 136, 208, 216, 223, 0, 0, 0, 0, 8, 0, 8, 15, 128,
128, 129, 129, 146, 146, 147, 147, 164, 164, 165, 165, 166, 182,
183, 183, 184, 200, 201, 201, 202, 218, 218, 219, 219, 236, 236,
237, 237, 254, 254, 255, 202, 202, 202, 203, 203, 203, 219, 219,
219, 220, 220, 220, 220, 220, 236, 237, 237, 237, 237, 237,
237, 237, 238, 254, 254, 254, 254, 254, 255, 255, 255, 255, 220,
220, 220, 220, 221, 221, 221, 221, 221, 221, 221, 237, 237, 237,
238, 238, 238, 238, 238, 238, 238, 238, 238, 254, 254, 255, 255,
255, 255, 255, 255, 255, 237, 237, 237, 237, 237, 237, 237, 238,
238, 238, 238, 238, 238, 238, 238, 238, 254, 254, 254, 254, 254,
254, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
},
fsggml.TensorTypeQ2_K: {
1, 2, 3, 3, 4, 5, 7, 7, 8, 9, 10, 11, 12, 13, 14, 15, 184, 184,
184, 185, 249, 249, 249, 249, 249, 250, 250, 254, 254, 254, 254,
255, 253, 253, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254,
254, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 171, 69, 0, 0,
},
fsggml.TensorTypeQ5_K: {
32, 48, 0, 0, 136, 208, 216, 223, 0, 0, 0, 0, 8, 0, 7, 15, 254,
254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 254,
254, 254, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 0, 1, 2, 19, 20, 37, 38, 55, 56, 73, 74,
91, 92, 109, 110, 127, 112, 128, 129, 146, 147, 164, 165, 182, 183,
200, 201, 218, 219, 236, 237, 254, 133, 133, 149, 150, 150, 150,
167, 167, 167, 168, 184, 184, 185, 185, 201, 202, 202, 202, 219,
219, 219, 219, 236, 236, 236, 237, 253, 253, 254, 254, 254, 255,
169, 169, 169, 169, 186, 186, 186, 186, 186, 187, 187, 203, 203,
203, 204, 204, 204, 220, 220, 221, 221, 221, 221, 237, 237, 238,
238, 238, 238, 254, 255, 255, 203, 203, 203, 204, 204, 204, 204,
204, 220, 220, 220, 221, 221, 221, 221, 221, 237, 237, 238, 238,
238, 238, 238, 238, 254, 255, 255, 255, 255, 255, 255, 255,
},
fsggml.TensorTypeQ6_K: {
96, 110, 92, 90, 88, 70, 68, 50, 48, 46, 44, 42, 24, 22, 4, 2, 80,
95, 78, 77, 76, 59, 58, 57, 40, 39, 38, 21, 20, 19, 2, 1, 75, 75,
74, 57, 57, 56, 55, 39, 38, 37, 21, 20, 20, 19, 2, 2, 72, 55, 55,
54, 54, 37, 37, 36, 36, 19, 19, 18, 18, 1, 1, 0, 35, 35, 35, 35,
34, 18, 18, 18, 17, 17, 17, 1, 1, 0, 0, 0, 35, 35, 34, 34, 18,
18, 18, 17, 17, 17, 17, 1, 0, 0, 0, 0, 35, 35, 35, 19, 19, 18, 18,
18, 18, 18, 1, 1, 1, 1, 1, 1, 34, 34, 18, 18, 18, 18, 17, 17, 17,
17, 1, 1, 0, 0, 0, 0, 2, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0,
0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 248, 240, 231, 224, 216, 208, 200, 192, 184, 176,
166, 160, 152, 144, 136, 128, 235, 43,
},
fsggml.TensorTypeQ3_K: {
1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 20, 20, 23, 23, 7, 7, 6, 6, 6, 2,
1, 1, 1, 1, 0, 0, 22, 22, 6, 6, 5, 5, 5, 1, 1, 1, 1, 1, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 238, 204, 170, 136, 102, 68,
34, 1, 5, 5, 5, 5, 189, 63,
},
}
)

View File

@@ -0,0 +1,110 @@
package server
import (
"strconv"
"strings"
"github.com/ollama/ollama/format"
)
const (
gemma4RendererLegacy = "gemma4"
gemma4RendererSmall = "gemma4-small"
gemma4RendererLarge = "gemma4-large"
// Gemma 4 small templates cover the e2b/e4b family, while 26b/31b use the
// large template. Default to the small prompt unless the model is clearly in
// the large range.
gemma4LargeMinParameterCount = 16_000_000_000
)
func resolveRendererName(m *Model) string {
if m == nil || m.Config.Renderer == "" {
return ""
}
switch m.Config.Renderer {
case gemma4RendererLegacy:
return resolveGemma4Renderer(m)
default:
return m.Config.Renderer
}
}
func resolveGemma4Renderer(m *Model) string {
if m == nil || m.Config.Renderer != gemma4RendererLegacy {
if m == nil {
return gemma4RendererLegacy
}
return m.Config.Renderer
}
if renderer, ok := gemma4RendererFromName(m.ShortName); ok {
return renderer
}
if renderer, ok := gemma4RendererFromName(m.Name); ok {
return renderer
}
if parameterCount, ok := parseHumanParameterCount(m.Config.ModelType); ok {
return gemma4RendererForParameterCount(parameterCount)
}
return gemma4RendererSmall
}
func gemma4RendererForParameterCount(parameterCount uint64) string {
if parameterCount >= gemma4LargeMinParameterCount {
return gemma4RendererLarge
}
return gemma4RendererSmall
}
func gemma4RendererFromName(name string) (string, bool) {
lower := strings.ToLower(name)
switch {
case strings.Contains(lower, "e2b"), strings.Contains(lower, "e4b"):
return gemma4RendererSmall, true
case strings.Contains(lower, "26b"), strings.Contains(lower, "31b"):
return gemma4RendererLarge, true
default:
return "", false
}
}
func parseHumanParameterCount(s string) (uint64, bool) {
if s == "" {
return 0, false
}
unit := strings.ToUpper(s[len(s)-1:])
var multiplier float64
switch unit {
case "B":
multiplier = float64(format.Billion)
case "M":
multiplier = float64(format.Million)
case "K":
multiplier = float64(format.Thousand)
default:
return 0, false
}
value, err := strconv.ParseFloat(s[:len(s)-1], 64)
if err != nil {
return 0, false
}
return uint64(value * multiplier), true
}
func isGemma4Renderer(renderer string) bool {
switch renderer {
case gemma4RendererLegacy, gemma4RendererSmall, gemma4RendererLarge:
return true
default:
return false
}
}

2842
server/routes.go Normal file

File diff suppressed because it is too large Load Diff

1147
server/routes_cloud_test.go Normal file

File diff suppressed because it is too large Load Diff

1506
server/routes_create_test.go Normal file

File diff suppressed because it is too large Load Diff

415
server/routes_debug_test.go Normal file
View File

@@ -0,0 +1,415 @@
package server
import (
"bytes"
"encoding/json"
"net/http"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/ml"
)
func TestGenerateDebugRenderOnly(t *testing.T) {
t.Setenv("OLLAMA_CONTEXT_LENGTH", "4096")
gin.SetMode(gin.TestMode)
mock := mockRunner{
CompletionResponse: llm.CompletionResponse{
Done: true,
DoneReason: llm.DoneReasonStop,
PromptEvalCount: 1,
PromptEvalDuration: 1,
EvalCount: 1,
EvalDuration: 1,
},
}
s := Server{
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
expiredCh: make(chan *runnerRef, 1),
unloadedCh: make(chan any, 1),
loaded: make(map[string]*runnerRef),
newServerFn: newMockServer(&mock),
getGpuFn: getGpuFn,
getSystemInfoFn: getSystemInfoFn,
waitForRecovery: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
// add small delay to simulate loading
time.Sleep(time.Millisecond)
req.successCh <- &runnerRef{
llama: &mock,
}
return false
},
},
}
go s.sched.Run(t.Context())
// Create a test model
stream := false
_, digest := createBinFile(t, ggml.KV{
"general.architecture": "llama",
"llama.block_count": uint32(1),
"llama.context_length": uint32(8192),
"llama.embedding_length": uint32(4096),
"llama.attention.head_count": uint32(32),
"llama.attention.head_count_kv": uint32(8),
"tokenizer.ggml.tokens": []string{""},
"tokenizer.ggml.scores": []float32{0},
"tokenizer.ggml.token_type": []int32{0},
}, []*ggml.Tensor{
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
})
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "test-model",
Files: map[string]string{"file.gguf": digest},
Template: "{{ .Prompt }}",
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
tests := []struct {
name string
request api.GenerateRequest
expectDebug bool
expectTemplate string
expectNumImages int
}{
{
name: "debug render only enabled",
request: api.GenerateRequest{
Model: "test-model",
Prompt: "Hello, world!",
DebugRenderOnly: true,
},
expectDebug: true,
expectTemplate: "Hello, world!",
},
{
name: "debug render only disabled",
request: api.GenerateRequest{
Model: "test-model",
Prompt: "Hello, world!",
DebugRenderOnly: false,
},
expectDebug: false,
},
{
name: "debug render only with system prompt",
request: api.GenerateRequest{
Model: "test-model",
Prompt: "User question",
System: "You are a helpful assistant",
DebugRenderOnly: true,
},
expectDebug: true,
expectTemplate: "User question",
},
{
name: "debug render only with template",
request: api.GenerateRequest{
Model: "test-model",
Prompt: "Hello",
Template: "PROMPT: {{ .Prompt }}",
DebugRenderOnly: true,
},
expectDebug: true,
expectTemplate: "PROMPT: Hello",
},
{
name: "debug render only with images",
request: api.GenerateRequest{
Model: "test-model",
Prompt: "Describe this image",
Images: []api.ImageData{[]byte("fake-image-data")},
DebugRenderOnly: true,
},
expectDebug: true,
expectTemplate: "[img-0]Describe this image",
expectNumImages: 1,
},
{
name: "debug render only with raw mode",
request: api.GenerateRequest{
Model: "test-model",
Prompt: "Raw prompt text",
Raw: true,
DebugRenderOnly: true,
},
expectDebug: true,
expectTemplate: "Raw prompt text",
},
}
for _, tt := range tests {
// Test both with and without streaming
streamValues := []bool{false, true}
for _, stream := range streamValues {
streamSuffix := ""
if stream {
streamSuffix = " (streaming)"
}
t.Run(tt.name+streamSuffix, func(t *testing.T) {
req := tt.request
req.Stream = &stream
w := createRequest(t, s.GenerateHandler, req)
if tt.expectDebug {
if w.Code != http.StatusOK {
t.Errorf("expected status %d, got %d, body: %s", http.StatusOK, w.Code, w.Body.String())
}
var response api.GenerateResponse
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if response.Model != tt.request.Model {
t.Errorf("expected model %s, got %s", tt.request.Model, response.Model)
}
if tt.expectTemplate != "" && response.DebugInfo.RenderedTemplate != tt.expectTemplate {
t.Errorf("expected template %q, got %q", tt.expectTemplate, response.DebugInfo.RenderedTemplate)
}
if tt.expectNumImages > 0 && response.DebugInfo.ImageCount != tt.expectNumImages {
t.Errorf("expected image count %d, got %d", tt.expectNumImages, response.DebugInfo.ImageCount)
}
} else {
// When debug is disabled, it should attempt normal processing
if w.Code != http.StatusOK {
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
}
}
})
}
}
}
func TestChatDebugRenderOnly(t *testing.T) {
t.Setenv("OLLAMA_CONTEXT_LENGTH", "4096")
gin.SetMode(gin.TestMode)
mock := mockRunner{
CompletionResponse: llm.CompletionResponse{
Done: true,
DoneReason: llm.DoneReasonStop,
PromptEvalCount: 1,
PromptEvalDuration: 1,
EvalCount: 1,
EvalDuration: 1,
},
}
s := Server{
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
expiredCh: make(chan *runnerRef, 1),
unloadedCh: make(chan any, 1),
loaded: make(map[string]*runnerRef),
newServerFn: newMockServer(&mock),
getGpuFn: getGpuFn,
getSystemInfoFn: getSystemInfoFn,
waitForRecovery: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
// add small delay to simulate loading
time.Sleep(time.Millisecond)
req.successCh <- &runnerRef{
llama: &mock,
}
return false
},
},
}
go s.sched.Run(t.Context())
// Create a test model
stream := false
_, digest := createBinFile(t, ggml.KV{
"general.architecture": "llama",
"llama.block_count": uint32(1),
"llama.context_length": uint32(8192),
"llama.embedding_length": uint32(4096),
"llama.attention.head_count": uint32(32),
"llama.attention.head_count_kv": uint32(8),
"tokenizer.ggml.tokens": []string{""},
"tokenizer.ggml.scores": []float32{0},
"tokenizer.ggml.token_type": []int32{0},
}, []*ggml.Tensor{
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
})
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "test-model",
Files: map[string]string{"file.gguf": digest},
Template: "{{ if .Tools }}{{ .Tools }}{{ end }}{{ range .Messages }}{{ .Role }}: {{ .Content }}\n{{ end }}",
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
tests := []struct {
name string
request api.ChatRequest
expectDebug bool
expectTemplate string
expectNumImages int
}{
{
name: "chat debug render only enabled",
request: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "system", Content: "You are a helpful assistant"},
{Role: "user", Content: "Hello"},
},
DebugRenderOnly: true,
},
expectDebug: true,
expectTemplate: "system: You are a helpful assistant\nuser: Hello\n",
},
{
name: "chat debug render only disabled",
request: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "Hello"},
},
DebugRenderOnly: false,
},
expectDebug: false,
},
{
name: "chat debug with assistant message",
request: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "Hello"},
{Role: "assistant", Content: "Hi there!"},
{Role: "user", Content: "How are you?"},
},
DebugRenderOnly: true,
},
expectDebug: true,
expectTemplate: "user: Hello\nassistant: Hi there!\nuser: How are you?\n",
},
{
name: "chat debug with images",
request: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{
Role: "user",
Content: "What's in this image?",
Images: []api.ImageData{[]byte("fake-image-data")},
},
},
DebugRenderOnly: true,
},
expectDebug: true,
expectTemplate: "user: [img-0]What's in this image?\n",
expectNumImages: 1,
},
{
name: "chat debug with tools",
request: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "Get the weather"},
},
Tools: api.Tools{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get weather information",
},
},
},
DebugRenderOnly: true,
},
expectDebug: true,
expectTemplate: "[{\"type\":\"function\",\"function\":{\"name\":\"get_weather\",\"description\":\"Get weather information\",\"parameters\":{\"type\":\"\",\"properties\":null}}}]user: Get the weather\n",
},
}
for _, tt := range tests {
// Test both with and without streaming
streamValues := []bool{false, true}
for _, stream := range streamValues {
streamSuffix := ""
if stream {
streamSuffix = " (streaming)"
}
t.Run(tt.name+streamSuffix, func(t *testing.T) {
req := tt.request
req.Stream = &stream
w := createRequest(t, s.ChatHandler, req)
if tt.expectDebug {
if w.Code != http.StatusOK {
t.Errorf("expected status %d, got %d, body: %s", http.StatusOK, w.Code, w.Body.String())
}
var response api.ChatResponse
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if response.Model != tt.request.Model {
t.Errorf("expected model %s, got %s", tt.request.Model, response.Model)
}
if tt.expectTemplate != "" && response.DebugInfo.RenderedTemplate != tt.expectTemplate {
t.Errorf("expected template %q, got %q", tt.expectTemplate, response.DebugInfo.RenderedTemplate)
}
if tt.expectNumImages > 0 && response.DebugInfo.ImageCount != tt.expectNumImages {
t.Errorf("expected image count %d, got %d", tt.expectNumImages, response.DebugInfo.ImageCount)
}
} else {
// When debug is disabled, it should attempt normal processing
if w.Code != http.StatusOK {
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
}
}
})
}
}
}

View File

@@ -0,0 +1,142 @@
package server
import (
"bytes"
"encoding/json"
"net/http"
"path/filepath"
"testing"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/types/model"
)
func TestDelete(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
var s Server
_, digest := createBinFile(t, nil, nil)
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Name: "test",
Files: map[string]string{"test.gguf": digest},
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
w = createRequest(t, s.CreateHandler, api.CreateRequest{
Name: "test2",
Files: map[string]string{"test.gguf": digest},
Template: "{{ .System }} {{ .Prompt }}",
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"),
})
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-136bf7c76bac2ec09d6617885507d37829e04b41acc47687d45e512b544e893a"),
filepath.Join(p, "blobs", "sha256-6bcdb8859d417753645538d7bbfbd7ca91a3f0c191aef5379c53c05e86b669dd"),
filepath.Join(p, "blobs", "sha256-89a2116c3a82d6a97f59f748d86ed4417214353fd178ee54df418fde32495fad"),
filepath.Join(p, "blobs", "sha256-fe7ac77b725cda2ccad03f88a880ecdfd7a33192d6cae08fce2c0ee1455991ed"),
})
w = createRequest(t, s.DeleteHandler, api.DeleteRequest{Name: "test"})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"),
})
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-136bf7c76bac2ec09d6617885507d37829e04b41acc47687d45e512b544e893a"),
filepath.Join(p, "blobs", "sha256-89a2116c3a82d6a97f59f748d86ed4417214353fd178ee54df418fde32495fad"),
filepath.Join(p, "blobs", "sha256-fe7ac77b725cda2ccad03f88a880ecdfd7a33192d6cae08fce2c0ee1455991ed"),
})
w = createRequest(t, s.DeleteHandler, api.DeleteRequest{Name: "test2"})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{})
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{})
}
func TestDeleteDuplicateLayers(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
var s Server
n := model.ParseName("test")
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(&model.ConfigV2{}); err != nil {
t.Fatal(err)
}
config, err := manifest.NewLayer(&b, "application/vnd.docker.container.image.v1+json")
if err != nil {
t.Fatal(err)
}
// create a manifest with duplicate layers
if err := manifest.WriteManifest(n, config, []manifest.Layer{config}); err != nil {
t.Fatal(err)
}
w := createRequest(t, s.DeleteHandler, api.DeleteRequest{Name: "test"})
if w.Code != http.StatusOK {
t.Errorf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{})
}
func TestDeleteCloudSourceNormalizesToLegacyName(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
var s Server
_, digest := createBinFile(t, nil, nil)
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Name: "gpt-oss:20b-cloud",
Files: map[string]string{"test.gguf": digest},
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "gpt-oss", "20b-cloud"),
})
w = createRequest(t, s.DeleteHandler, api.DeleteRequest{Name: "gpt-oss:20b:cloud"})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d (%s)", w.Code, w.Body.String())
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{})
}

View File

@@ -0,0 +1,315 @@
package server
import (
"bytes"
"encoding/json"
"net/http"
"strings"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/ml"
)
// TestGenerateWithBuiltinRenderer tests that api/generate uses built-in renderers
// when in chat-like flow (messages present, no suffix, no template)
func TestGenerateWithBuiltinRenderer(t *testing.T) {
t.Setenv("OLLAMA_CONTEXT_LENGTH", "4096")
gin.SetMode(gin.TestMode)
mock := mockRunner{
CompletionResponse: llm.CompletionResponse{
Done: true,
DoneReason: llm.DoneReasonStop,
PromptEvalCount: 1,
PromptEvalDuration: 1,
EvalCount: 1,
EvalDuration: 1,
},
}
s := Server{
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
expiredCh: make(chan *runnerRef, 1),
unloadedCh: make(chan any, 1),
loaded: make(map[string]*runnerRef),
newServerFn: newMockServer(&mock),
getGpuFn: getGpuFn,
getSystemInfoFn: getSystemInfoFn,
waitForRecovery: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
time.Sleep(time.Millisecond)
req.successCh <- &runnerRef{
llama: &mock,
}
return false
},
},
}
go s.sched.Run(t.Context())
// Create a model with a built-in renderer (qwen3-coder)
_, digest := createBinFile(t, ggml.KV{
"general.architecture": "qwen3",
"qwen3.block_count": uint32(1),
"qwen3.context_length": uint32(8192),
"qwen3.embedding_length": uint32(4096),
"qwen3.attention.head_count": uint32(32),
"qwen3.attention.head_count_kv": uint32(8),
"tokenizer.ggml.tokens": []string{""},
"tokenizer.ggml.scores": []float32{0},
"tokenizer.ggml.token_type": []int32{0},
}, []*ggml.Tensor{
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
})
// Create a model with the qwen3-coder renderer
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "test-renderer",
Files: map[string]string{"file.gguf": digest},
Renderer: "qwen3-coder",
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
mock.CompletionResponse.Content = "Hi!"
t.Run("chat-like flow uses renderer", func(t *testing.T) {
// Test that when using messages (chat-like flow), the built-in renderer is used
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-renderer",
Prompt: "Write a hello world function",
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
// The qwen3-coder renderer produces output with <|im_start|> and <|im_end|> tags
// When messages are built internally from prompt, it should use the renderer
if !strings.Contains(mock.CompletionRequest.Prompt, "<|im_start|>") {
t.Errorf("expected prompt to contain <|im_start|> from qwen3-coder renderer, got: %s", mock.CompletionRequest.Prompt)
}
if !strings.Contains(mock.CompletionRequest.Prompt, "<|im_end|>") {
t.Errorf("expected prompt to contain <|im_end|> from qwen3-coder renderer, got: %s", mock.CompletionRequest.Prompt)
}
})
t.Run("chat-like flow with system message uses renderer", func(t *testing.T) {
// Test that system messages work with the renderer
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-renderer",
Prompt: "Write a hello world function",
System: "You are a helpful coding assistant.",
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
// Should contain the system message and use renderer format
if !strings.Contains(mock.CompletionRequest.Prompt, "<|im_start|>system") {
t.Errorf("expected prompt to contain system message with renderer format, got: %s", mock.CompletionRequest.Prompt)
}
if !strings.Contains(mock.CompletionRequest.Prompt, "You are a helpful coding assistant.") {
t.Errorf("expected prompt to contain system message content, got: %s", mock.CompletionRequest.Prompt)
}
})
t.Run("custom template bypasses renderer", func(t *testing.T) {
// Test that providing a custom template uses the legacy flow
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-renderer",
Prompt: "Write a hello world function",
Template: "{{ .Prompt }}",
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
// Should NOT use the renderer format when custom template is provided
if strings.Contains(mock.CompletionRequest.Prompt, "<|im_start|>") {
t.Errorf("expected prompt to NOT use renderer when custom template provided, got: %s", mock.CompletionRequest.Prompt)
}
// Should just be the raw prompt from the template
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "Write a hello world function"); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
// Create a model with suffix support for the next test
w = createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "test-suffix-renderer",
From: "test-renderer",
Template: `{{- if .Suffix }}<PRE> {{ .Prompt }} <SUF>{{ .Suffix }} <MID>
{{- else }}{{ .Prompt }}
{{- end }}`,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
t.Run("suffix bypasses renderer", func(t *testing.T) {
// Test that providing a suffix uses the legacy flow
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-suffix-renderer",
Prompt: "def add(",
Suffix: " return c",
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
// Should NOT use the renderer format when suffix is provided
if strings.Contains(mock.CompletionRequest.Prompt, "<|im_start|>") {
t.Errorf("expected prompt to NOT use renderer when suffix provided, got: %s", mock.CompletionRequest.Prompt)
}
// Should use the suffix template format
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "<PRE> def add( <SUF> return c <MID>"); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
}
// TestGenerateWithDebugRenderOnly tests that debug_render_only works with built-in renderers
func TestGenerateWithDebugRenderOnly(t *testing.T) {
t.Setenv("OLLAMA_CONTEXT_LENGTH", "4096")
gin.SetMode(gin.TestMode)
mock := mockRunner{
CompletionResponse: llm.CompletionResponse{
Done: true,
DoneReason: llm.DoneReasonStop,
PromptEvalCount: 1,
PromptEvalDuration: 1,
EvalCount: 1,
EvalDuration: 1,
},
}
s := Server{
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
expiredCh: make(chan *runnerRef, 1),
unloadedCh: make(chan any, 1),
loaded: make(map[string]*runnerRef),
newServerFn: newMockServer(&mock),
getGpuFn: getGpuFn,
getSystemInfoFn: getSystemInfoFn,
waitForRecovery: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
time.Sleep(time.Millisecond)
req.successCh <- &runnerRef{
llama: &mock,
}
return false
},
},
}
go s.sched.Run(t.Context())
// Create a model with a built-in renderer
_, digest := createBinFile(t, ggml.KV{
"general.architecture": "qwen3",
"qwen3.block_count": uint32(1),
"qwen3.context_length": uint32(8192),
"qwen3.embedding_length": uint32(4096),
"qwen3.attention.head_count": uint32(32),
"qwen3.attention.head_count_kv": uint32(8),
"tokenizer.ggml.tokens": []string{""},
"tokenizer.ggml.scores": []float32{0},
"tokenizer.ggml.token_type": []int32{0},
}, []*ggml.Tensor{
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
})
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "test-debug-renderer",
Files: map[string]string{"file.gguf": digest},
Renderer: "qwen3-coder",
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
t.Run("debug_render_only with renderer", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-debug-renderer",
Prompt: "Write a hello world function",
System: "You are a coding assistant",
DebugRenderOnly: true,
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
var resp api.GenerateResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatal(err)
}
if resp.DebugInfo == nil {
t.Fatalf("expected debug info, got nil")
}
// Verify that the rendered template uses the built-in renderer
if !strings.Contains(resp.DebugInfo.RenderedTemplate, "<|im_start|>") {
t.Errorf("expected rendered template to use qwen3-coder renderer format, got: %s", resp.DebugInfo.RenderedTemplate)
}
if !strings.Contains(resp.DebugInfo.RenderedTemplate, "You are a coding assistant") {
t.Errorf("expected rendered template to contain system message, got: %s", resp.DebugInfo.RenderedTemplate)
}
if !strings.Contains(resp.DebugInfo.RenderedTemplate, "Write a hello world function") {
t.Errorf("expected rendered template to contain prompt, got: %s", resp.DebugInfo.RenderedTemplate)
}
})
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,683 @@
package server
// this test file is to test integration of harmony parser into routes.go (as
// opposed to harmonyparser_test.go, which tests the parser in isolation)
import (
"bytes"
"context"
"encoding/json"
"net/http"
"strings"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/ml"
)
func getTestTools() []api.Tool {
return []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get the current weather in a given location",
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"location"},
Properties: testPropsMap(map[string]api.ToolProperty{
"location": {
Type: api.PropertyType{"string"},
Description: "The city and state, e.g. San Francisco, CA",
},
}),
},
},
},
{
Type: "function",
Function: api.ToolFunction{
Name: "calculate",
Description: "Calculate a mathematical expression",
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"expression"},
Properties: testPropsMap(map[string]api.ToolProperty{
"expression": {
Type: api.PropertyType{"string"},
Description: "The mathematical expression to calculate",
},
}),
},
},
},
}
}
func createHarmonyTestModel(t *testing.T) (string, string) {
t.Helper()
return createBinFile(t, ggml.KV{
"general.architecture": "gptoss",
"llama.block_count": uint32(1),
"llama.context_length": uint32(8192),
"llama.embedding_length": uint32(4096),
"llama.attention.head_count": uint32(32),
"llama.attention.head_count_kv": uint32(8),
"tokenizer.ggml.tokens": []string{""},
"tokenizer.ggml.scores": []float32{0},
"tokenizer.ggml.token_type": []int32{0},
}, []*ggml.Tensor{
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
})
}
// TestChatHarmonyParserStreamingRealtime verifies that chunks are emitted as soon as they're available
func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
gin.SetMode(gin.TestMode)
type step struct {
input llm.CompletionResponse
wantContent string
wantThinking string
wantToolCalls []api.ToolCall
}
testCases := []struct {
name string
steps []step
only bool
}{
{
name: "content streams as it arrives",
steps: []step{
{
input: llm.CompletionResponse{Content: "<|message|>Hello", Done: false},
wantContent: "Hello",
},
{
input: llm.CompletionResponse{Content: ", world", Done: false},
wantContent: ", world",
},
{
input: llm.CompletionResponse{Content: "!<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
wantContent: "!",
},
},
},
{
name: "thinking streams separately from content",
steps: []step{
{
input: llm.CompletionResponse{Content: "<|channel|>analysis<|message|>Thinking...", Done: false},
wantThinking: "Thinking...",
},
{
input: llm.CompletionResponse{Content: "<|end|>", Done: false},
// No output expected - just closes the analysis message and resets state to normal
},
{
input: llm.CompletionResponse{Content: "<|start|>assistant<|message|>Answer", Done: false},
wantContent: "Answer", // After message end, state is reset to normal
},
{
input: llm.CompletionResponse{Content: "<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
// No output expected - just closes the assistant message
},
},
},
{
name: "partial tags buffer until complete",
steps: []step{
{
input: llm.CompletionResponse{Content: "<|chan", Done: false},
// No output - partial tag
},
{
input: llm.CompletionResponse{Content: "nel|>analysis<|mess", Done: false},
// No output - still building tags
},
{
input: llm.CompletionResponse{Content: "age|>Deep ", Done: false},
wantThinking: "Deep ",
},
{
input: llm.CompletionResponse{Content: "thought<|end|>", Done: false},
wantThinking: "thought",
},
{
input: llm.CompletionResponse{Content: "<|start|>assistant<|message|>Done<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
wantContent: "Done", // After message end, state is reset to normal
},
},
},
{
name: "simple assistant after analysis",
steps: []step{
{
input: llm.CompletionResponse{Content: "<|channel|>analysis<|message|>Think<|end|><|start|>assistant<|message|>Answer<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
wantContent: "Answer",
wantThinking: "Think",
},
},
},
{
name: "tool call parsed and returned correctly",
steps: []step{
{
input: llm.CompletionResponse{Content: "<|channel|>commentary to=functions.get_weather<|message|>{\"location\":\"San Francisco\"}<|end|><|start|>assistant<|message|>The weather is sunny<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
wantContent: "The weather is sunny",
wantToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{
"location": "San Francisco",
}),
},
},
},
},
},
},
{
name: "tool call with streaming JSON across chunks",
steps: []step{
{
input: llm.CompletionResponse{Content: "<|channel|>commentary to=functions.calculate<|message|>{\"expr", Done: false},
// No output yet - incomplete JSON
},
{
input: llm.CompletionResponse{Content: "ession\":\"2+", Done: false},
// Still no output - incomplete JSON
},
{
input: llm.CompletionResponse{Content: "2\"}", Done: true},
wantToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "calculate",
Arguments: testArgs(map[string]any{
"expression": "2+2",
}),
},
},
},
},
},
},
}
anyOnlies := false
for _, tc := range testCases {
if tc.only {
anyOnlies = true
}
}
for _, tc := range testCases {
if anyOnlies && !tc.only {
continue
}
t.Run(tc.name, func(t *testing.T) {
var chunks []api.ChatResponse
chunkIdx := 0
mockResponses := make([]llm.CompletionResponse, len(tc.steps))
for i, step := range tc.steps {
mockResponses[i] = step.input
}
mock := mockRunner{
CompletionFn: func(ctx context.Context, r llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
for _, resp := range mockResponses {
fn(resp)
// Give the handler time to process each response
time.Sleep(30 * time.Millisecond)
}
return nil
},
}
s := Server{
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
expiredCh: make(chan *runnerRef, 1),
unloadedCh: make(chan any, 1),
loaded: make(map[string]*runnerRef),
newServerFn: newMockServer(&mock),
getGpuFn: getGpuFn,
getSystemInfoFn: getSystemInfoFn,
waitForRecovery: 100 * time.Millisecond,
loadFn: func(req *LlmRequest, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
req.successCh <- &runnerRef{
llama: &mock,
}
return false
},
},
}
go s.sched.Run(t.Context())
// Create a simple test model
_, digest := createHarmonyTestModel(t)
streamFalse := false
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "harmony-test-streaming",
Files: map[string]string{"test.gguf": digest},
Template: `<|start|><|end|>{{ with .Tools }}{{ end }}{{ .Prompt }}`,
Stream: &streamFalse,
})
if w.Code != 200 {
t.Fatalf("failed to create model: %d", w.Code)
}
// Test chat endpoint with streaming
streamTrue := true
w = createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "harmony-test-streaming",
Messages: []api.Message{{Role: "user", Content: "Hello"}},
Stream: &streamTrue,
Tools: getTestTools(),
})
if w.Code != 200 {
t.Fatalf("chat request failed: %d - %s", w.Code, w.Body.String())
}
// Parse all chunks
decoder := json.NewDecoder(w.Body)
for decoder.More() {
var chunk api.ChatResponse
if err := decoder.Decode(&chunk); err != nil {
t.Fatalf("failed to decode chunk: %v", err)
}
if chunk.Message.Content != "" || chunk.Message.Thinking != "" || len(chunk.Message.ToolCalls) > 0 {
chunks = append(chunks, chunk)
}
}
// Log received chunks for debugging
if t.Failed() || len(chunks) == 0 {
t.Logf("Received %d chunks:", len(chunks))
for i, chunk := range chunks {
t.Logf(" Chunk %d: content=%q thinking=%q", i, chunk.Message.Content, chunk.Message.Thinking)
}
}
// Verify chunks match expected steps
for i, step := range tc.steps {
// Skip steps that don't expect any output
if step.wantContent == "" && step.wantThinking == "" && len(step.wantToolCalls) == 0 {
continue
}
if chunkIdx >= len(chunks) {
t.Errorf("step %d: expected chunk not received (wanted content=%q thinking=%q)",
i, step.wantContent, step.wantThinking)
continue
}
chunk := chunks[chunkIdx]
if chunk.Message.Content != step.wantContent || chunk.Message.Thinking != step.wantThinking {
t.Errorf("step %d: chunk mismatch: got (content=%q, thinking=%q), want (content=%q, thinking=%q)",
i, chunk.Message.Content, chunk.Message.Thinking, step.wantContent, step.wantThinking)
}
// Check tool calls if expected
if len(step.wantToolCalls) > 0 {
if len(chunk.Message.ToolCalls) != len(step.wantToolCalls) {
t.Errorf("step %d: tool calls count mismatch: got %d, want %d",
i, len(chunk.Message.ToolCalls), len(step.wantToolCalls))
} else {
for j, wantCall := range step.wantToolCalls {
if j >= len(chunk.Message.ToolCalls) {
break
}
gotCall := chunk.Message.ToolCalls[j]
if gotCall.Function.Name != wantCall.Function.Name {
t.Errorf("step %d, tool call %d: name mismatch: got %q, want %q",
i, j, gotCall.Function.Name, wantCall.Function.Name)
}
// Compare arguments as JSON strings for simplicity
gotArgs, _ := json.Marshal(gotCall.Function.Arguments)
wantArgs, _ := json.Marshal(wantCall.Function.Arguments)
if string(gotArgs) != string(wantArgs) {
t.Errorf("step %d, tool call %d: arguments mismatch: got %s, want %s",
i, j, string(gotArgs), string(wantArgs))
}
}
}
}
chunkIdx++
}
// Check if we have extra chunks
if chunkIdx < len(chunks) {
t.Errorf("received %d extra chunks", len(chunks)-chunkIdx)
for i := chunkIdx; i < len(chunks); i++ {
t.Logf(" extra chunk %d: content=%q thinking=%q",
i-chunkIdx, chunks[i].Message.Content, chunks[i].Message.Thinking)
}
}
})
}
}
// TestChatHarmonyParserStreamingSimple is a simpler test that just verifies basic streaming
func TestChatHarmonyParserStreamingSimple(t *testing.T) {
gin.SetMode(gin.TestMode)
mockResponses := []llm.CompletionResponse{
{Content: "<|message|>First ", Done: false},
{Content: "chunk ", Done: false},
{Content: "here<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
}
mock := mockRunner{
CompletionFn: func(ctx context.Context, r llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
t.Logf("Mock received prompt: %q", r.Prompt)
t.Logf("Mock sending %d responses", len(mockResponses))
for i, resp := range mockResponses {
t.Logf("Sending response %d: %q", i, resp.Content)
fn(resp)
}
return nil
},
}
s := Server{
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
expiredCh: make(chan *runnerRef, 1),
unloadedCh: make(chan any, 1),
loaded: make(map[string]*runnerRef),
newServerFn: newMockServer(&mock),
getGpuFn: getGpuFn,
getSystemInfoFn: getSystemInfoFn,
waitForRecovery: 100 * time.Millisecond,
loadFn: func(req *LlmRequest, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
req.successCh <- &runnerRef{
llama: &mock,
}
return false
},
},
}
go s.sched.Run(t.Context())
// Create model
_, digest := createHarmonyTestModel(t)
streamFalse := false
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "gpt-oss",
Files: map[string]string{"test.gguf": digest},
Template: `<|start|><|end|>{{ .Tools }}{{ .Prompt }}`,
Stream: &streamFalse,
})
if w.Code != 200 {
t.Fatalf("failed to create model: %d", w.Code)
}
// Test streaming
streamTrue := true
w = createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "gpt-oss",
Messages: []api.Message{{Role: "user", Content: "Hello"}},
Stream: &streamTrue,
Tools: getTestTools(),
})
if w.Code != 200 {
t.Fatalf("chat request failed: %d - %s", w.Code, w.Body.String())
}
// Parse chunks
var chunks []api.ChatResponse
decoder := json.NewDecoder(w.Body)
for decoder.More() {
var chunk api.ChatResponse
if err := decoder.Decode(&chunk); err != nil {
t.Fatalf("failed to decode chunk: %v", err)
}
chunks = append(chunks, chunk)
t.Logf("Received chunk %d: content=%q thinking=%q done=%v",
len(chunks), chunk.Message.Content, chunk.Message.Thinking, chunk.Done)
}
// Verify we got chunks
if len(chunks) == 0 {
t.Fatal("expected streaming chunks, got none")
}
// Verify content
var content strings.Builder
for _, chunk := range chunks {
content.WriteString(chunk.Message.Content)
}
expectedContent := "First chunk here"
if content.String() != expectedContent {
t.Errorf("content mismatch: got %q, want %q", content.String(), expectedContent)
}
// Verify we got multiple chunks (streaming)
contentChunks := 0
for _, chunk := range chunks {
if chunk.Message.Content != "" {
contentChunks++
}
}
if contentChunks < 2 {
t.Errorf("expected at least 2 content chunks for streaming, got %d", contentChunks)
}
}
func TestChatHarmonyParserStreaming(t *testing.T) {
gin.SetMode(gin.TestMode)
type expectedChunk struct {
afterResponse int // Which mock response this chunk should appear after
content string // Expected content in this chunk
thinking string // Expected thinking in this chunk
}
testCases := []struct {
name string
mockResponses []llm.CompletionResponse
expectedChunks []expectedChunk
wantContent string
wantThinking string
}{
{
name: "simple message without thinking",
mockResponses: []llm.CompletionResponse{
{Content: "<|start|>assistant<|message|>Hello, ", Done: false},
{Content: "how can I help?", Done: false},
{Content: "<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
},
expectedChunks: []expectedChunk{
{afterResponse: 1, content: "Hello, "},
{afterResponse: 2, content: "how can I help?"},
},
wantContent: "Hello, how can I help?",
},
{
name: "message with analysis channel for thinking",
mockResponses: []llm.CompletionResponse{
{Content: "<|channel|>analysis<|message|>", Done: false},
{Content: "Let me think ", Done: false},
{Content: "about this problem...", Done: false},
{Content: "<|end|>", Done: false},
{Content: "<|start|>assistant<|message|>", Done: false},
{Content: "The answer ", Done: false},
{Content: "is 42", Done: false},
{Content: "<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
},
expectedChunks: []expectedChunk{
{afterResponse: 2, thinking: "Let me think "},
{afterResponse: 3, thinking: "about this problem..."},
{afterResponse: 6, content: "The answer "},
{afterResponse: 7, content: "is 42"},
},
wantContent: "The answer is 42",
wantThinking: "Let me think about this problem...",
},
{
name: "streaming with partial tags across boundaries",
mockResponses: []llm.CompletionResponse{
{Content: "<|chan", Done: false},
{Content: "nel|>analy", Done: false},
{Content: "sis<|mess", Done: false},
{Content: "age|>Think", Done: false},
{Content: "ing deeply...<|end|>", Done: false},
{Content: "<|start|>assi", Done: false},
{Content: "stant<|message|>Result ", Done: false},
{Content: "computed<|e", Done: false},
{Content: "nd|>", Done: true, DoneReason: llm.DoneReasonStop},
},
expectedChunks: []expectedChunk{
{afterResponse: 4, thinking: "Think"},
{afterResponse: 5, thinking: "ing deeply..."},
{afterResponse: 7, content: "Result "},
{afterResponse: 8, content: "computed"},
},
wantContent: "Result computed",
wantThinking: "Thinking deeply...",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Channel to synchronize mock responses with chunk verification
responsesSent := make(chan int, len(tc.mockResponses))
mock := mockRunner{
CompletionFn: func(ctx context.Context, r llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
// Send mock responses one at a time, notifying when each is sent
for i, resp := range tc.mockResponses {
fn(resp)
responsesSent <- i + 1
}
close(responsesSent)
return nil
},
}
s := Server{
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
expiredCh: make(chan *runnerRef, 1),
unloadedCh: make(chan any, 1),
loaded: make(map[string]*runnerRef),
newServerFn: newMockServer(&mock),
getGpuFn: getGpuFn,
getSystemInfoFn: getSystemInfoFn,
waitForRecovery: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
req.successCh <- &runnerRef{
llama: &mock,
}
return false
},
},
}
go s.sched.Run(t.Context())
// Create a minimal model
_, digest := createHarmonyTestModel(t)
// Create model with passthrough template
stream := false
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "harmony-test",
Files: map[string]string{"file.gguf": digest},
Template: `<|start|><|end|>{{ with .Tools }}{{ end }}{{ .Prompt }}`,
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("failed to create model: %d", w.Code)
}
// Test chat endpoint with streaming
streamTrue := true
w = createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "harmony-test",
Messages: []api.Message{{Role: "user", Content: "Hello"}},
Stream: &streamTrue,
Tools: getTestTools(),
})
if w.Code != http.StatusOK {
t.Fatalf("chat request failed: %d - %s", w.Code, w.Body.String())
}
// Parse streaming response
var chunks []api.ChatResponse
var content, thinking strings.Builder
decoder := json.NewDecoder(w.Body)
for decoder.More() {
var chunk api.ChatResponse
if err := decoder.Decode(&chunk); err != nil {
t.Fatalf("failed to decode chunk: %v", err)
}
chunks = append(chunks, chunk)
// Accumulate content and thinking from each chunk
content.WriteString(chunk.Message.Content)
thinking.WriteString(chunk.Message.Thinking)
// Debug output
t.Logf("Chunk %d: content=%q thinking=%q done=%v", len(chunks), chunk.Message.Content, chunk.Message.Thinking, chunk.Done)
}
// Verify we got streaming chunks
if len(chunks) == 0 {
t.Fatal("expected streaming chunks, got none")
}
gotContent := content.String()
gotThinking := thinking.String()
if gotContent != tc.wantContent {
t.Errorf("content mismatch: got %q, want %q", gotContent, tc.wantContent)
}
if gotThinking != tc.wantThinking {
t.Errorf("thinking mismatch: got %q, want %q", gotThinking, tc.wantThinking)
}
// Verify last chunk has done=true
lastChunk := chunks[len(chunks)-1]
if !lastChunk.Done {
t.Error("expected last chunk to have done=true")
}
})
}
}

View File

@@ -0,0 +1,78 @@
package server
import (
"context"
"encoding/json"
"net/http"
"slices"
"testing"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
)
func TestList(t *testing.T) {
gin.SetMode(gin.TestMode)
t.Setenv("OLLAMA_MODELS", t.TempDir())
expectNames := []string{
"mistral:7b-instruct-q4_0",
"zephyr:7b-beta-q5_K_M",
"apple/OpenELM:latest",
"boreas:2b-code-v1.5-q6_K",
"notus:7b-v1-IQ2_S",
// TODO: host:port currently fails on windows (#4107)
// "localhost:5000/library/eurus:700b-v0.5-iq3_XXS",
"mynamespace/apeliotes:latest",
"myhost/mynamespace/lips:code",
}
s := Server{modelCaches: &modelCaches{modelList: newModelListCache()}}
s.modelCaches.modelList.Start(context.Background())
if err := s.modelCaches.modelList.Wait(context.Background()); err != nil {
t.Fatal(err)
}
for _, n := range expectNames {
_, digest := createBinFile(t, nil, nil)
createRequest(t, s.CreateHandler, api.CreateRequest{
Name: n,
Files: map[string]string{"test.gguf": digest},
})
}
w := createRequest(t, s.ListHandler, nil)
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
var resp api.ListResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatal(err)
}
if len(resp.Models) != len(expectNames) {
t.Fatalf("expected %d models, actual %d", len(expectNames), len(resp.Models))
}
actualNames := make([]string, len(resp.Models))
for i, m := range resp.Models {
actualNames[i] = m.Name
}
slices.Sort(actualNames)
slices.Sort(expectNames)
if !slices.Equal(actualNames, expectNames) {
t.Fatalf("expected slices to be equal %v", actualNames)
}
for _, m := range resp.Models {
if !slices.Contains(m.Capabilities, "completion") {
t.Fatalf("capabilities for %q = %v, want completion", m.Name, m.Capabilities)
}
}
}

View File

@@ -0,0 +1,127 @@
package server
import (
"testing"
)
func TestModelOptionsNumCtxPriority(t *testing.T) {
tests := []struct {
name string
envContextLen string // empty means not set (uses 0 sentinel)
defaultNumCtx int // VRAM-based default
modelNumCtx int // 0 means not set in model
requestNumCtx int // 0 means not set in request
expectedNumCtx int
}{
{
name: "vram default when nothing else set",
envContextLen: "",
defaultNumCtx: 32768,
modelNumCtx: 0,
requestNumCtx: 0,
expectedNumCtx: 32768,
},
{
name: "env var overrides vram default",
envContextLen: "8192",
defaultNumCtx: 32768,
modelNumCtx: 0,
requestNumCtx: 0,
expectedNumCtx: 8192,
},
{
name: "model overrides vram default",
envContextLen: "",
defaultNumCtx: 32768,
modelNumCtx: 16384,
requestNumCtx: 0,
expectedNumCtx: 16384,
},
{
name: "model overrides env var",
envContextLen: "8192",
defaultNumCtx: 32768,
modelNumCtx: 16384,
requestNumCtx: 0,
expectedNumCtx: 16384,
},
{
name: "request overrides everything",
envContextLen: "8192",
defaultNumCtx: 32768,
modelNumCtx: 16384,
requestNumCtx: 4096,
expectedNumCtx: 4096,
},
{
name: "request overrides vram default",
envContextLen: "",
defaultNumCtx: 32768,
modelNumCtx: 0,
requestNumCtx: 4096,
expectedNumCtx: 4096,
},
{
name: "request overrides model",
envContextLen: "",
defaultNumCtx: 32768,
modelNumCtx: 16384,
requestNumCtx: 4096,
expectedNumCtx: 4096,
},
{
name: "low vram tier default",
envContextLen: "",
defaultNumCtx: 4096,
modelNumCtx: 0,
requestNumCtx: 0,
expectedNumCtx: 4096,
},
{
name: "high vram tier default",
envContextLen: "",
defaultNumCtx: 262144,
modelNumCtx: 0,
requestNumCtx: 0,
expectedNumCtx: 262144,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Set or clear environment variable
if tt.envContextLen != "" {
t.Setenv("OLLAMA_CONTEXT_LENGTH", tt.envContextLen)
}
// Create server with VRAM-based default
s := &Server{
defaultNumCtx: tt.defaultNumCtx,
}
// Create model options (use float64 as FromMap expects JSON-style numbers)
var modelOpts map[string]any
if tt.modelNumCtx != 0 {
modelOpts = map[string]any{"num_ctx": float64(tt.modelNumCtx)}
}
model := &Model{
Options: modelOpts,
}
// Create request options (use float64 as FromMap expects JSON-style numbers)
var requestOpts map[string]any
if tt.requestNumCtx != 0 {
requestOpts = map[string]any{"num_ctx": float64(tt.requestNumCtx)}
}
opts, err := s.modelOptions(model, requestOpts)
if err != nil {
t.Fatalf("modelOptions failed: %v", err)
}
if opts.NumCtx != tt.expectedNumCtx {
t.Errorf("NumCtx = %d, want %d", opts.NumCtx, tt.expectedNumCtx)
}
})
}
}

View File

@@ -0,0 +1,128 @@
package server
import (
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"github.com/gin-gonic/gin"
)
func TestInferenceRequestLoggerMiddlewareWritesReplayArtifacts(t *testing.T) {
gin.SetMode(gin.TestMode)
logDir := t.TempDir()
requestLogger := &inferenceRequestLogger{dir: logDir}
const route = "/v1/chat/completions"
const requestBody = `{"model":"test-model","messages":[{"role":"user","content":"hello"}]}`
var bodySeenByHandler string
r := gin.New()
r.POST(route, requestLogger.middleware(route), func(c *gin.Context) {
body, err := io.ReadAll(c.Request.Body)
if err != nil {
t.Fatalf("failed to read body in handler: %v", err)
}
bodySeenByHandler = string(body)
c.Status(http.StatusOK)
})
req := httptest.NewRequest(http.MethodPost, route, strings.NewReader(requestBody))
req.Host = "127.0.0.1:11434"
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
if bodySeenByHandler != requestBody {
t.Fatalf("handler body mismatch:\nexpected: %s\ngot: %s", requestBody, bodySeenByHandler)
}
bodyFiles, err := filepath.Glob(filepath.Join(logDir, "*_v1_chat_completions_body.json"))
if err != nil {
t.Fatalf("failed to glob body logs: %v", err)
}
if len(bodyFiles) != 1 {
t.Fatalf("expected 1 body log, got %d (%v)", len(bodyFiles), bodyFiles)
}
curlFiles, err := filepath.Glob(filepath.Join(logDir, "*_v1_chat_completions_request.sh"))
if err != nil {
t.Fatalf("failed to glob curl logs: %v", err)
}
if len(curlFiles) != 1 {
t.Fatalf("expected 1 curl log, got %d (%v)", len(curlFiles), curlFiles)
}
bodyData, err := os.ReadFile(bodyFiles[0])
if err != nil {
t.Fatalf("failed to read body log: %v", err)
}
if string(bodyData) != requestBody {
t.Fatalf("body log mismatch:\nexpected: %s\ngot: %s", requestBody, string(bodyData))
}
curlData, err := os.ReadFile(curlFiles[0])
if err != nil {
t.Fatalf("failed to read curl log: %v", err)
}
curlString := string(curlData)
if !strings.Contains(curlString, "http://127.0.0.1:11434"+route) {
t.Fatalf("curl log does not contain expected route URL: %s", curlString)
}
bodyFileName := filepath.Base(bodyFiles[0])
if !strings.Contains(curlString, "@\"${SCRIPT_DIR}/"+bodyFileName+"\"") {
t.Fatalf("curl log does not reference sibling body file: %s", curlString)
}
}
func TestNewInferenceRequestLoggerCreatesDirectory(t *testing.T) {
requestLogger, err := newInferenceRequestLogger()
if err != nil {
t.Fatalf("expected no error creating request logger: %v", err)
}
t.Cleanup(func() {
_ = os.RemoveAll(requestLogger.dir)
})
if requestLogger == nil || requestLogger.dir == "" {
t.Fatalf("expected request logger directory to be set")
}
info, err := os.Stat(requestLogger.dir)
if err != nil {
t.Fatalf("expected directory to exist: %v", err)
}
if !info.IsDir() {
t.Fatalf("expected %q to be a directory", requestLogger.dir)
}
}
func TestSanitizeRouteForFilename(t *testing.T) {
tests := []struct {
route string
want string
}{
{route: "/api/generate", want: "api_generate"},
{route: "/v1/chat/completions", want: "v1_chat_completions"},
{route: "/v1/messages", want: "v1_messages"},
}
for _, tt := range tests {
if got := sanitizeRouteForFilename(tt.route); got != tt.want {
t.Fatalf("sanitizeRouteForFilename(%q) = %q, want %q", tt.route, got, tt.want)
}
}
}

1159
server/routes_test.go Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,340 @@
package server
import (
"bytes"
"context"
"encoding/json"
"errors"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
internalcloud "github.com/ollama/ollama/internal/cloud"
"github.com/ollama/ollama/version"
)
type webExperimentalUpstreamCapture struct {
path string
body string
header http.Header
}
func newWebExperimentalUpstream(t *testing.T, responseBody string) (*httptest.Server, *webExperimentalUpstreamCapture) {
t.Helper()
capture := &webExperimentalUpstreamCapture{}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
payload, _ := io.ReadAll(r.Body)
capture.path = r.URL.Path
capture.body = string(payload)
capture.header = r.Header.Clone()
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(responseBody))
}))
return srv, capture
}
func TestExperimentalWebEndpointsPassthrough(t *testing.T) {
gin.SetMode(gin.TestMode)
setTestHome(t, t.TempDir())
tests := []struct {
name string
localPath string
upstreamPath string
requestBody string
responseBody string
assertBody string
}{
{
name: "web_search",
localPath: "/api/experimental/web_search",
upstreamPath: "/api/web_search",
requestBody: `{"query":"what is ollama?","max_results":3}`,
responseBody: `{"results":[{"title":"Ollama","url":"https://ollama.com","content":"Cloud models are now available"}]}`,
assertBody: `"query":"what is ollama?"`,
},
{
name: "web_fetch",
localPath: "/api/experimental/web_fetch",
upstreamPath: "/api/web_fetch",
requestBody: `{"url":"https://ollama.com"}`,
responseBody: `{"title":"Ollama","content":"Cloud models are now available","links":["https://ollama.com/"]}`,
assertBody: `"url":"https://ollama.com"`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
upstream, capture := newWebExperimentalUpstream(t, tt.responseBody)
defer upstream.Close()
original := cloudProxyBaseURL
cloudProxyBaseURL = upstream.URL
t.Cleanup(func() { cloudProxyBaseURL = original })
s := &Server{}
router, err := s.GenerateRoutes(nil)
if err != nil {
t.Fatal(err)
}
local := httptest.NewServer(router)
defer local.Close()
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+tt.localPath, bytes.NewBufferString(tt.requestBody))
if err != nil {
t.Fatal(err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer should-forward")
req.Header.Set("X-Test-Header", "web-experimental")
req.Header.Set(cloudProxyClientVersionHeader, "should-be-overwritten")
resp, err := local.Client().Do(req)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected status 200, got %d (%s)", resp.StatusCode, string(body))
}
if capture.path != tt.upstreamPath {
t.Fatalf("expected upstream path %q, got %q", tt.upstreamPath, capture.path)
}
if !bytes.Contains([]byte(capture.body), []byte(tt.assertBody)) {
t.Fatalf("expected upstream body to contain %q, got %q", tt.assertBody, capture.body)
}
if got := capture.header.Get("Authorization"); got != "Bearer should-forward" {
t.Fatalf("expected forwarded Authorization header, got %q", got)
}
if got := capture.header.Get("X-Test-Header"); got != "web-experimental" {
t.Fatalf("expected forwarded X-Test-Header=web-experimental, got %q", got)
}
if got := capture.header.Get(cloudProxyClientVersionHeader); got != version.Version {
t.Fatalf("expected %s=%q, got %q", cloudProxyClientVersionHeader, version.Version, got)
}
})
}
}
func TestExperimentalWebEndpointsMissingBody(t *testing.T) {
gin.SetMode(gin.TestMode)
setTestHome(t, t.TempDir())
s := &Server{}
router, err := s.GenerateRoutes(nil)
if err != nil {
t.Fatal(err)
}
local := httptest.NewServer(router)
defer local.Close()
tests := []string{
"/api/experimental/web_search",
"/api/experimental/web_fetch",
}
for _, path := range tests {
t.Run(path, func(t *testing.T) {
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+path, nil)
if err != nil {
t.Fatal(err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := local.Client().Do(req)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusBadRequest {
t.Fatalf("expected status 400, got %d (%s)", resp.StatusCode, string(body))
}
if string(body) != `{"error":"missing request body"}` {
t.Fatalf("unexpected response body: %s", string(body))
}
})
}
}
func TestExperimentalWebEndpointsCloudDisabled(t *testing.T) {
gin.SetMode(gin.TestMode)
setTestHome(t, t.TempDir())
t.Setenv("OLLAMA_NO_CLOUD", "1")
s := &Server{}
router, err := s.GenerateRoutes(nil)
if err != nil {
t.Fatal(err)
}
local := httptest.NewServer(router)
defer local.Close()
tests := []struct {
name string
path string
request string
operation string
}{
{
name: "web_search",
path: "/api/experimental/web_search",
request: `{"query":"latest ollama release"}`,
operation: cloudErrWebSearchUnavailable,
},
{
name: "web_fetch",
path: "/api/experimental/web_fetch",
request: `{"url":"https://ollama.com"}`,
operation: cloudErrWebFetchUnavailable,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+tt.path, bytes.NewBufferString(tt.request))
if err != nil {
t.Fatal(err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := local.Client().Do(req)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusForbidden {
t.Fatalf("expected status 403, got %d (%s)", resp.StatusCode, string(body))
}
var got map[string]string
if err := json.Unmarshal(body, &got); err != nil {
t.Fatalf("expected json error body, got: %q", string(body))
}
if got["error"] != internalcloud.DisabledError(tt.operation) {
t.Fatalf("unexpected error message: %q", got["error"])
}
})
}
}
func TestExperimentalWebEndpointSigningFailureReturnsUnauthorized(t *testing.T) {
gin.SetMode(gin.TestMode)
setTestHome(t, t.TempDir())
origSignRequest := cloudProxySignRequest
origSigninURL := cloudProxySigninURL
cloudProxySignRequest = func(context.Context, *http.Request) error {
return errors.New("ssh: no key found")
}
cloudProxySigninURL = func() (string, error) {
return "https://ollama.com/signin/example", nil
}
t.Cleanup(func() {
cloudProxySignRequest = origSignRequest
cloudProxySigninURL = origSigninURL
})
s := &Server{}
router, err := s.GenerateRoutes(nil)
if err != nil {
t.Fatal(err)
}
local := httptest.NewServer(router)
defer local.Close()
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+"/api/experimental/web_search", bytes.NewBufferString(`{"query":"hello"}`))
if err != nil {
t.Fatal(err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := local.Client().Do(req)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusUnauthorized {
t.Fatalf("expected status 401, got %d (%s)", resp.StatusCode, string(body))
}
var got map[string]any
if err := json.Unmarshal(body, &got); err != nil {
t.Fatalf("expected json error body, got: %q", string(body))
}
if got["error"] != "unauthorized" {
t.Fatalf("unexpected error message: %v", got["error"])
}
if got["signin_url"] != "https://ollama.com/signin/example" {
t.Fatalf("unexpected signin_url: %v", got["signin_url"])
}
}
func TestExperimentalWebEndpointSigningFailureWithoutSigninURL(t *testing.T) {
gin.SetMode(gin.TestMode)
setTestHome(t, t.TempDir())
origSignRequest := cloudProxySignRequest
origSigninURL := cloudProxySigninURL
cloudProxySignRequest = func(context.Context, *http.Request) error {
return errors.New("ssh: no key found")
}
cloudProxySigninURL = func() (string, error) {
return "", errors.New("key missing")
}
t.Cleanup(func() {
cloudProxySignRequest = origSignRequest
cloudProxySigninURL = origSigninURL
})
s := &Server{}
router, err := s.GenerateRoutes(nil)
if err != nil {
t.Fatal(err)
}
local := httptest.NewServer(router)
defer local.Close()
req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, local.URL+"/api/experimental/web_fetch", bytes.NewBufferString(`{"url":"https://ollama.com"}`))
if err != nil {
t.Fatal(err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := local.Client().Do(req)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusUnauthorized {
t.Fatalf("expected status 401, got %d (%s)", resp.StatusCode, string(body))
}
var got map[string]any
if err := json.Unmarshal(body, &got); err != nil {
t.Fatalf("expected json error body, got: %q", string(body))
}
if got["error"] != "unauthorized" {
t.Fatalf("unexpected error message: %v", got["error"])
}
if _, ok := got["signin_url"]; ok {
t.Fatalf("did not expect signin_url when helper fails, got %v", got["signin_url"])
}
}

932
server/sched.go Normal file
View File

@@ -0,0 +1,932 @@
package server
import (
"context"
"errors"
"fmt"
"log/slog"
"reflect"
"slices"
"sort"
"strings"
"sync"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/discover"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/mlxrunner"
)
type LlmRequest struct {
ctx context.Context //nolint:containedctx
model *Model
opts api.Options
sessionDuration *api.Duration
successCh chan *runnerRef
errCh chan error
schedAttempts uint
}
type Scheduler struct {
pendingReqCh chan *LlmRequest
finishedReqCh chan *LlmRequest
expiredCh chan *runnerRef
unloadedCh chan any
// loadedMu protects loaded and activeLoading
loadedMu sync.Mutex
// activeLoading is the model that we are currently working on loading,
// including by evicting one or more other models. We can only load
// one model at a time but new requests to models that already loaded can
// happen in parallel
activeLoading llm.LlamaServer
loaded map[string]*runnerRef
loadFn func(req *LlmRequest, systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, requireFull bool) bool
newServerFn func(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, model string, f *ggml.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error)
getGpuFn func(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.DeviceInfo
getSystemInfoFn func() ml.SystemInfo
waitForRecovery time.Duration
}
// Default automatic value for number of models we allow per GPU
// Model will still need to fit in VRAM, but loading many small models
// on a large GPU can cause stalling
var defaultModelsPerGPU = 3
var ErrMaxQueue = errors.New("server busy, please try again. maximum pending requests exceeded")
func InitScheduler(ctx context.Context) *Scheduler {
maxQueue := envconfig.MaxQueue()
sched := &Scheduler{
pendingReqCh: make(chan *LlmRequest, maxQueue),
finishedReqCh: make(chan *LlmRequest, maxQueue),
expiredCh: make(chan *runnerRef, maxQueue),
unloadedCh: make(chan any, maxQueue),
loaded: make(map[string]*runnerRef),
newServerFn: llm.NewLlamaServer,
getGpuFn: discover.GPUDevices,
getSystemInfoFn: discover.GetSystemInfo,
waitForRecovery: 5 * time.Second,
}
sched.loadFn = sched.load
return sched
}
// schedulerModelKey returns the scheduler map key for a model.
// GGUF-backed models use ModelPath; safetensors/image models without a
// ModelPath use manifest digest so distinct models don't collide.
func schedulerModelKey(m *Model) string {
if m == nil {
return ""
}
if m.ModelPath != "" {
return m.ModelPath
}
if m.Digest != "" {
return "digest:" + m.Digest
}
if m.Name != "" {
return "name:" + m.Name
}
if m.ShortName != "" {
return "short:" + m.ShortName
}
return ""
}
// context must be canceled to decrement ref count and release the runner
func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, sessionDuration *api.Duration) (chan *runnerRef, chan error) {
if opts.NumCtx < 4 {
opts.NumCtx = 4
}
if m.CheckCapabilities(model.CapabilityVision) == nil {
// multimodal models require at least 2048 context
opts.NumCtx = max(opts.NumCtx, 2048)
}
req := &LlmRequest{
ctx: c,
model: m,
opts: opts,
sessionDuration: sessionDuration,
successCh: make(chan *runnerRef, 1),
errCh: make(chan error, 1),
}
key := schedulerModelKey(req.model)
s.loadedMu.Lock()
runner := s.loaded[key]
s.loadedMu.Unlock()
if runner != nil && !runner.needsReload(c, req) {
req.useLoadedRunner(runner, s.finishedReqCh)
} else {
select {
case s.pendingReqCh <- req:
default:
req.errCh <- ErrMaxQueue
}
}
return req.successCh, req.errCh
}
// Returns immediately, spawns go routines for the scheduler which will shutdown when ctx is done
func (s *Scheduler) Run(ctx context.Context) {
slog.Debug("starting llm scheduler")
go func() {
s.processPending(ctx)
}()
go func() {
s.processCompleted(ctx)
}()
}
func (s *Scheduler) processPending(ctx context.Context) {
maxRunners := envconfig.MaxRunners()
for {
select {
case <-ctx.Done():
slog.Debug("shutting down scheduler pending loop")
return
case pending := <-s.pendingReqCh:
// Block other requests until we get this pending request running
pending.schedAttempts++
if pending.ctx.Err() != nil {
slog.Debug("pending request cancelled or timed out, skipping scheduling")
continue
}
logutil.Trace("processing incoming request", "model", pending.model.ModelPath)
for {
var runnerToExpire *runnerRef
pendingKey := schedulerModelKey(pending.model)
s.loadedMu.Lock()
runner := s.loaded[pendingKey]
loadedCount := len(s.loaded)
runnersSnapshot := make([]ml.FilteredRunnerDiscovery, 0, len(s.loaded))
for _, r := range s.loaded {
runnersSnapshot = append(runnersSnapshot, r)
}
s.loadedMu.Unlock()
if runner != nil {
if runner.needsReload(ctx, pending) {
slog.Debug("reloading", "runner", runner)
runnerToExpire = runner
} else {
// Runner is usable, return it
logutil.Trace("using existing loaded runner", "model", pendingKey)
pending.useLoadedRunner(runner, s.finishedReqCh)
break
}
} else if maxRunners > 0 && loadedCount >= int(maxRunners) {
slog.Debug("max runners achieved, unloading one to make room", "runner_count", loadedCount)
runnerToExpire = s.findRunnerToUnload()
} else {
// Either no models are loaded or below envconfig.MaxRunners
// Get a refreshed GPU list
var gpus []ml.DeviceInfo
if pending.opts.NumGPU == 0 {
gpus = []ml.DeviceInfo{}
} else {
logutil.Trace("refreshing GPU list", "model", pending.model.ModelPath)
gpus = s.getGpuFn(ctx, runnersSnapshot)
}
logutil.Trace("refreshing system information", "model", pending.model.ModelPath)
systemInfo := s.getSystemInfoFn()
if maxRunners <= 0 {
// No user specified MaxRunners, so figure out what automatic setting to use for the next load attempt
if pending.opts.NumGPU == 0 {
// Need to get actual GPU list to set the correct default max models
logutil.Trace("refreshing GPU list", "model", pending.model.ModelPath)
g := s.getGpuFn(ctx, runnersSnapshot)
maxRunners = uint(defaultModelsPerGPU * max(len(g), 1))
} else {
maxRunners = uint(defaultModelsPerGPU * max(len(gpus), 1))
}
slog.Debug("updating default concurrency", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "gpu_count", len(gpus))
}
// Update free memory from currently loaded models
logutil.Trace("updating free space", "gpu_count", len(gpus), "model", pending.model.ModelPath)
s.updateFreeSpace(gpus)
if loadedCount == 0 {
// No models loaded. Load the model but prefer the best fit.
slog.Debug("loading first model", "model", pending.model.ModelPath)
s.loadFn(pending, systemInfo, gpus, false)
break
}
// More than one loaded model, so we have to see if the
// new one fits
logutil.Trace("loading additional model", "model", pending.model.ModelPath)
needEvict := s.loadFn(pending, systemInfo, gpus, true)
if !needEvict {
slog.Debug("new model fits with existing models, loading")
break
}
runnerToExpire = s.findRunnerToUnload()
}
if runnerToExpire == nil {
// While we were performing load calculations, the loaded runner(s) unloaded in parallel
// so findRunnerToUnload returned no runners. We'll try again and the loadedCount should be zero
slog.Debug("runner to expire was nil, retrying")
continue
}
// Trigger an expiration to unload once it's done
runnerToExpire.refMu.Lock()
slog.Debug("resetting model to expire immediately to make room", "runner", runnerToExpire, "refCount", runnerToExpire.refCount)
if runnerToExpire.expireTimer != nil {
runnerToExpire.expireTimer.Stop()
runnerToExpire.expireTimer = nil
}
runnerToExpire.sessionDuration = 0
if runnerToExpire.refCount <= 0 {
s.expiredCh <- runnerToExpire
}
runnerToExpire.refMu.Unlock()
// Wait for the unload to happen
slog.Debug("waiting for pending requests to complete and unload to occur", "runner", runnerToExpire)
select {
case <-ctx.Done():
slog.Debug("shutting down scheduler pending loop")
return
case <-s.unloadedCh:
slog.Debug("unload completed", "runner", runnerToExpire)
continue
}
}
case <-s.unloadedCh:
// An unload request when there are no pending request can be ignored
slog.Debug("ignoring unload event with no pending requests")
}
}
}
func (s *Scheduler) processCompleted(ctx context.Context) {
// Process completed requests, expired timers, and unloading models
for {
select {
case <-ctx.Done():
slog.Debug("shutting down scheduler completed loop")
return
case finished := <-s.finishedReqCh:
finishedKey := schedulerModelKey(finished.model)
s.loadedMu.Lock()
runner := s.loaded[finishedKey]
s.loadedMu.Unlock()
if runner == nil {
slog.Error("finished request signal received after model unloaded", "modelPath", finishedKey)
continue
}
runner.refMu.Lock()
runner.refCount--
if runner.refCount <= 0 {
if runner.sessionDuration <= 0 {
slog.Debug("runner with zero duration has gone idle, expiring to unload", "runner", runner)
if runner.expireTimer != nil {
runner.expireTimer.Stop()
runner.expireTimer = nil
}
s.expiredCh <- runner
} else if runner.expireTimer == nil {
slog.Debug("runner with non-zero duration has gone idle, adding timer", "runner", runner, "duration", runner.sessionDuration)
runner.expireTimer = time.AfterFunc(runner.sessionDuration, func() {
slog.Debug("timer expired, expiring to unload", "runner", runner)
runner.refMu.Lock()
defer runner.refMu.Unlock()
if runner.expireTimer != nil {
runner.expireTimer.Stop()
runner.expireTimer = nil
}
s.expiredCh <- runner
})
runner.expiresAt = time.Now().Add(runner.sessionDuration)
} else {
slog.Debug("runner with non-zero duration has gone idle, resetting timer", "runner", runner, "duration", runner.sessionDuration)
runner.expireTimer.Reset(runner.sessionDuration)
runner.expiresAt = time.Now().Add(runner.sessionDuration)
}
}
slog.Debug("after processing request finished event", "runner", runner, "refCount", runner.refCount)
runner.refMu.Unlock()
case runner := <-s.expiredCh:
slog.Debug("runner expired event received", "runner", runner)
runner.refMu.Lock()
if runner.refCount > 0 {
slog.Debug("expired event with positive ref count, retrying", "runner", runner, "refCount", runner.refCount)
go func(runner *runnerRef) {
// We can't unload yet, but want to as soon as the current request completes
// So queue up another expired event
time.Sleep(10 * time.Millisecond)
s.expiredCh <- runner
}(runner)
runner.refMu.Unlock()
continue
}
s.loadedMu.Lock()
slog.Debug("got lock to unload expired event", "runner", runner)
runnerToUnload := s.loaded[runner.modelKey]
if runnerToUnload == nil {
// If runnerToUnload is nil, we already processed an event and
// unloaded it. This double unload can happen if the initial
// request is canceled and we're trying to load another model
// that requires this one to be evicted, or the settings change
// and require a reload
s.loadedMu.Unlock()
runner.refMu.Unlock()
slog.Debug("duplicate expired event, ignoring", "runner", runner)
} else if runner.pid != runnerToUnload.pid {
// If the pids do not match, we likely had multiple load
// failures for the same model in quick succession due to
// request context canceled and are draining the queue of
// events. Ensure the orphaned runner is properly shut down, but
// do not delete the mismatched loaded runner, or wait for VRAM
// convergence.
slog.Debug("orphaned runner shutting down", "orphan", runner, "loaded", runnerToUnload)
runner.unload()
s.loadedMu.Unlock()
runner.refMu.Unlock()
} else {
slog.Debug("starting background wait for VRAM recovery", "runner", runner)
runnersSnapshot := make([]ml.FilteredRunnerDiscovery, 0, len(s.loaded))
for _, r := range s.loaded {
runnersSnapshot = append(runnersSnapshot, r)
}
finished := s.waitForVRAMRecovery(runner, runnersSnapshot)
runner.unload()
delete(s.loaded, runner.modelKey)
s.loadedMu.Unlock()
slog.Debug("runner terminated and removed from list, blocking for VRAM recovery", "runner", runner)
<-finished
runner.refMu.Unlock()
slog.Debug("sending an unloaded event", "runner", runner)
s.unloadedCh <- struct{}{}
}
}
}
}
// Complete the pending request and send the runner back to the requester
// Wires up a finished event after the request context is completed
// Updates session duration, and resets expiration timer
func (pending *LlmRequest) useLoadedRunner(runner *runnerRef, finished chan *LlmRequest) {
runner.refMu.Lock()
defer runner.refMu.Unlock()
runner.refCount++
if runner.expireTimer != nil {
runner.expireTimer.Stop()
runner.expireTimer = nil
}
if pending.sessionDuration != nil {
runner.sessionDuration = pending.sessionDuration.Duration
}
pending.successCh <- runner
go func() {
<-pending.ctx.Done()
slog.Debug("context for request finished", "runner", runner)
finished <- pending
}()
}
// load creates a new model based on req and loads it. If requireFull is true then the model must be loaded fully onto GPUs
// (if any). Returns whether the scheduler needs to evict a model to make this one fit.
func (s *Scheduler) load(req *LlmRequest, systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, requireFull bool) bool {
numParallel := max(int(envconfig.NumParallel()), 1)
// Embedding models should always be loaded with parallel=1
if req.model.CheckCapabilities(model.CapabilityCompletion) != nil {
numParallel = 1
}
// Some architectures are not safe with num_parallel > 1.
// ref: https://github.com/ollama/ollama/issues/4165
if slices.Contains([]string{"mllama", "qwen3vl", "qwen3vlmoe", "qwen35", "qwen35moe", "qwen3next", "lfm2", "lfm2moe", "nemotron_h", "nemotron_h_moe", "nemotron_h_omni"}, req.model.Config.ModelFamily) && numParallel != 1 {
numParallel = 1
slog.Warn("model architecture does not currently support parallel requests", "architecture", req.model.Config.ModelFamily)
}
sessionDuration := envconfig.KeepAlive()
if req.sessionDuration != nil {
sessionDuration = req.sessionDuration.Duration
}
s.loadedMu.Lock()
llama := s.activeLoading
if llama == nil {
var err error
if !req.model.IsMLX() {
f, loadErr := llm.LoadModel(req.model.ModelPath, 1024)
if loadErr != nil {
slog.Info("failed to load model metadata", "model", req.model.ModelPath, "error", loadErr)
req.errCh <- loadErr
s.loadedMu.Unlock()
return false
}
llama, err = s.newServerFn(systemInfo, gpus, req.model.ModelPath, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, numParallel)
if err != nil {
// some older models are not compatible with newer versions of llama.cpp
// show a generalized compatibility error until there is a better way to
// check for model compatibility
if errors.Is(err, ggml.ErrUnsupportedFormat) || strings.Contains(err.Error(), "failed to load model") {
err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, req.model.ShortName)
}
}
} else {
modelName := req.model.ShortName
if slices.Contains(req.model.Config.Capabilities, "image") {
llama, err = imagegen.NewServer(modelName)
} else {
llama, err = mlxrunner.NewClient(modelName)
}
}
if err != nil {
slog.Info("failed to create server", "model", req.model.ShortName, "error", err)
req.errCh <- err
s.loadedMu.Unlock()
return false
}
s.activeLoading = llama
} else {
wantPath := req.model.ModelPath
if wantPath == "" {
wantPath = req.model.ShortName
}
if s.activeLoading.ModelPath() != wantPath {
panic(fmt.Errorf("attempting to load different model after eviction (original %v new %v)", s.activeLoading.ModelPath(), wantPath))
}
}
s.loadedMu.Unlock()
systemTotalMemory := systemInfo.TotalMemory
systemFreeMemory := systemInfo.FreeMemory
systemSwapFreeMemory := systemInfo.FreeSwap
slog.Info("system memory", "total", format.HumanBytes2(systemTotalMemory), "free", format.HumanBytes2(systemFreeMemory), "free_swap", format.HumanBytes2(systemSwapFreeMemory))
for _, gpu := range gpus {
available := gpu.FreeMemory - envconfig.GpuOverhead() - gpu.MinimumMemory()
if gpu.FreeMemory < envconfig.GpuOverhead()+gpu.MinimumMemory() {
available = 0
}
slog.Info("gpu memory", "id", gpu.ID, "library", gpu.Library,
"available", format.HumanBytes2(available),
"free", format.HumanBytes2(gpu.FreeMemory),
"minimum", format.HumanBytes2(gpu.MinimumMemory()),
"overhead", format.HumanBytes2(envconfig.GpuOverhead()))
}
gpuIDs, err := llama.Load(req.ctx, systemInfo, gpus, requireFull)
if err != nil {
if errors.Is(err, llm.ErrLoadRequiredFull) {
if !requireFull {
// No other models loaded, yet we still don't fit, so report an error
slog.Info("model is too large for system memory", "requireFull", requireFull)
s.activeLoading.Close()
s.activeLoading = nil
req.errCh <- err
}
return true
}
slog.Info("Load failed", "model", req.model.ModelPath, "error", err)
s.activeLoading.Close()
s.activeLoading = nil
req.errCh <- err
return false
}
// Determine if we have discrete GPUs which we should monitor VRAM usage on during shutdown
discreteGPUs := false
iGPUScan:
for _, devid := range gpuIDs {
for _, dev := range gpus {
if dev.DeviceID == devid {
if !dev.Integrated {
discreteGPUs = true
break iGPUScan
}
}
}
}
totalSize, vramSize := llama.MemorySize()
runner := &runnerRef{
model: req.model,
modelPath: req.model.ModelPath,
modelKey: schedulerModelKey(req.model),
llama: llama,
Options: &req.opts,
sessionDuration: sessionDuration,
gpus: gpuIDs,
discreteGPUs: discreteGPUs,
isImagegen: slices.Contains(req.model.Config.Capabilities, "image"),
totalSize: totalSize,
vramSize: vramSize,
loading: true,
pid: llama.Pid(),
}
runner.numParallel = numParallel
runner.refMu.Lock() // hold lock until running or aborted
s.loadedMu.Lock()
if oldRunner, ok := s.loaded[runner.modelKey]; ok {
// Shouldn't happen, but safeguard against leaking a runner
slog.Warn("model was still loaded", "old_runner", oldRunner, "new_runner", runner)
oldRunner.refMu.Lock()
oldRunner.unload()
oldRunner.refMu.Unlock()
}
s.activeLoading = nil
s.loaded[runner.modelKey] = runner
slog.Info("loaded runners", "count", len(s.loaded))
s.loadedMu.Unlock()
go func() {
defer runner.refMu.Unlock()
if err = llama.WaitUntilRunning(req.ctx); err != nil {
slog.Error("error loading llama server", "error", err)
req.errCh <- err
slog.Debug("triggering expiration for failed load", "runner", runner)
s.expiredCh <- runner
return
}
slog.Debug("finished setting up", "runner", runner)
if runner.pid < 0 {
runner.pid = llama.Pid()
}
runner.refCount++
runner.loading = false
go func() {
<-req.ctx.Done()
slog.Debug("context for request finished")
s.finishedReqCh <- req
}()
req.successCh <- runner
}()
return false
}
func (s *Scheduler) updateFreeSpace(allGpus []ml.DeviceInfo) {
if len(allGpus) == 0 {
return
}
predMap := map[ml.DeviceID]uint64{} // Sum up the total predicted usage per GPU for all runners
s.loadedMu.Lock()
runners := make([]*runnerRef, 0, len(s.loaded))
for _, r := range s.loaded {
runners = append(runners, r)
}
s.loadedMu.Unlock()
for _, r := range runners {
r.refMu.Lock()
if r.llama != nil {
for _, gpu := range allGpus {
predMap[gpu.DeviceID] += r.llama.VRAMByGPU(gpu.DeviceID)
}
} else {
slog.Warn("unexpected nil runner reference, memory prediction may be incorrect")
}
r.refMu.Unlock()
}
// Now that we've summed up all the GPU usage predictions across all the loaded runners, update the gpu list
for i := range allGpus {
if p, ok := predMap[allGpus[i].DeviceID]; ok {
slog.Debug("gpu reported", "gpu", allGpus[i].ID, "library", allGpus[i].Library, "available", format.HumanBytes2(allGpus[i].FreeMemory))
if p > allGpus[i].TotalMemory {
// Shouldn't happen
slog.Warn("predicted usage exceeds VRAM", "gpu", allGpus[i].ID, "totalMemory", allGpus[i].TotalMemory, "predicted", p)
allGpus[i].FreeMemory = 0
} else if (allGpus[i].TotalMemory - p) < allGpus[i].FreeMemory { // predicted free is smaller than reported free, use it
// TODO maybe we should just always trust our numbers, since cuda's free memory reporting is laggy
// and we might unload models we didn't actually need to. The risk is if some other GPU intensive app is loaded
// after we start our first runner, then we'll never account for that, so picking the smallest free value seems prudent.
allGpus[i].FreeMemory = allGpus[i].TotalMemory - p
}
slog.Info("updated VRAM based on existing loaded models", "gpu", allGpus[i].ID, "library", allGpus[i].Library, "total", format.HumanBytes2(allGpus[i].TotalMemory), "available", format.HumanBytes2(allGpus[i].FreeMemory))
}
}
}
// TODO consolidate sched_types.go
type runnerRef struct {
refMu sync.Mutex
refCount uint // prevent unloading if > 0
llama llm.LlamaServer
pid int
loading bool // True only during initial load, then false forever
gpus []ml.DeviceID // Recorded at time of provisioning
discreteGPUs bool // True if all devices are discrete GPUs - used to skip VRAM recovery check for iGPUs
isImagegen bool // True if loaded via imagegen runner (vs mlxrunner)
vramSize uint64
totalSize uint64
sessionDuration time.Duration
expireTimer *time.Timer
expiresAt time.Time
model *Model
modelPath string
modelKey string
numParallel int
*api.Options
}
// The refMu must already be held when calling unload
func (runner *runnerRef) unload() {
if runner.expireTimer != nil {
runner.expireTimer.Stop()
runner.expireTimer = nil
}
if runner.llama != nil {
runner.llama.Close()
}
runner.model = nil
runner.Options = nil
runner.gpus = nil
}
func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool {
slog.Debug("evaluating already loaded", "model", schedulerModelKey(req.model))
runner.refMu.Lock()
defer runner.refMu.Unlock()
// Check if runner type (imagegen vs mlxrunner) matches what's requested.
wantImagegen := slices.Contains(req.model.Config.Capabilities, "image")
if runner.isImagegen != wantImagegen {
return true
}
timeout := 10 * time.Second
if runner.loading {
timeout = 2 * time.Minute // Initial load can take a long time for big models on slow systems...
}
if runner.Options == nil {
return true
}
// Don't reload runner if num_gpu=-1 was provided
optsExisting := runner.Options.Runner
optsNew := req.opts.Runner
if optsNew.NumGPU < 0 {
optsExisting.NumGPU = -1
optsNew.NumGPU = -1
}
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
if !reflect.DeepEqual(runner.model.AdapterPaths, req.model.AdapterPaths) || // have the adapters changed?
!reflect.DeepEqual(runner.model.ProjectorPaths, req.model.ProjectorPaths) || // have the projectors changed?
(!runner.model.IsMLX() && !reflect.DeepEqual(optsExisting, optsNew)) || // have the runner options changed?
runner.llama.Ping(ctx) != nil {
return true
}
return false
}
// Free memory reporting on GPUs can lag for a while even after the runner
// exits, so we have to keep checking until we see the available memory recover,
// otherwise subsequent model loads will get far less layers loaded or worse
// case, may completely fall back to CPU mode.
// This routine must be called before the runner unloads so it can establish
// a before and after GPU memory allocation. The returned channel
// will be notified when we're done waiting, or have timed out and should
// proceed anyway
func (s *Scheduler) waitForVRAMRecovery(runner *runnerRef, runners []ml.FilteredRunnerDiscovery) chan any {
finished := make(chan any, 1)
// CPU, Metal and iGPUs don't need checking, so no waiting required
if len(runner.gpus) == 0 || !runner.discreteGPUs ||
(len(runner.gpus) == 1 && runner.gpus[0].Library == "Metal") {
finished <- struct{}{}
slog.Debug("no need to wait for VRAM recovery", "runner", runner)
return finished
}
start := time.Now()
// Establish a baseline before we unload
gpusBefore := s.getGpuFn(context.Background(), runners)
var totalMemoryBefore, freeMemoryBefore uint64
for _, gpu := range gpusBefore {
totalMemoryBefore += gpu.TotalMemory
freeMemoryBefore += gpu.FreeMemory
}
totalMemoryNow := totalMemoryBefore
freeMemoryNow := freeMemoryBefore
go func() {
// typical convergence is 0.5-1.5s - If it takes too long to discover and converge, let the scheduler estimate VRAM usage
ctx, cancel := context.WithTimeout(context.Background(), s.waitForRecovery)
defer cancel()
ticker := time.NewTicker(250 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ticker.C:
// Query GPUs, look for free to go back up
gpusNow := s.getGpuFn(ctx, runners)
totalMemoryNow = 0
freeMemoryNow = 0
for _, gpu := range gpusNow {
totalMemoryNow += gpu.TotalMemory
freeMemoryNow += gpu.FreeMemory
}
if freeMemoryNow > freeMemoryBefore {
logutil.Trace("gpu VRAM convergence", "percent", int(float32(freeMemoryNow-freeMemoryBefore)/float32(runner.vramSize)*100))
} else {
logutil.Trace("gpu VRAM convergence", "percent", 0)
}
// If we're within ~75% of the estimated memory usage recovered, bail out
if float32(freeMemoryNow-freeMemoryBefore) > float32(runner.vramSize)*0.75 {
slog.Debug(fmt.Sprintf("gpu VRAM free memory converged after %0.2f seconds", time.Since(start).Seconds()), "free_before", format.HumanBytes2(freeMemoryBefore), "free_now", format.HumanBytes2(freeMemoryNow), "runner", runner)
finished <- struct{}{}
return
}
case <-ctx.Done():
slog.Debug("gpu VRAM usage didn't recover within timeout", "seconds", time.Since(start).Seconds(), "free_before", format.HumanBytes2(freeMemoryBefore), "free_now", format.HumanBytes2(freeMemoryNow), "runner", runner)
finished <- struct{}{}
return
}
}
}()
return finished
}
func (runner *runnerRef) LogValue() slog.Value {
if runner == nil {
return slog.StringValue("nil")
}
modelID := runner.modelPath
if modelID == "" {
modelID = runner.modelKey
}
attrs := []slog.Attr{}
if runner.model != nil {
attrs = append(attrs, slog.String("name", runner.model.Name))
}
if len(runner.gpus) > 0 {
attrs = append(attrs,
slog.Any("inference", runner.gpus),
)
}
attrs = append(attrs,
slog.String("size", format.HumanBytes2(runner.totalSize)),
slog.String("vram", format.HumanBytes2(runner.vramSize)),
slog.Int("parallel", runner.numParallel),
slog.Int("pid", runner.pid),
slog.String("model", modelID),
)
if runner.Options != nil {
attrs = append(attrs, slog.Int("num_ctx", runner.Options.NumCtx))
}
return slog.GroupValue(attrs...)
}
// Implements discover.RunnerDiscovery
func (runner *runnerRef) GetPort() int {
if runner.llama != nil {
return runner.llama.GetPort()
}
return -1
}
func (runner *runnerRef) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
if runner.llama != nil {
return runner.llama.GetDeviceInfos(ctx)
}
return nil
}
func (runner *runnerRef) GetActiveDeviceIDs() []ml.DeviceID {
return runner.gpus
}
func (runner *runnerRef) HasExited() bool {
if runner.llama != nil {
return runner.llama.HasExited()
}
return true
}
type ByDurationAndName []*runnerRef
func (a ByDurationAndName) Len() int { return len(a) }
func (a ByDurationAndName) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a ByDurationAndName) Less(i, j int) bool {
// Primary sort by session duration (uint64 to handle negatives)
d1 := uint64(a[i].sessionDuration)
d2 := uint64(a[j].sessionDuration)
if d1 != d2 {
return d1 < d2
}
// Secondary sort by model key/path lex order
n1 := a[i].modelPath
if n1 == "" {
n1 = a[i].modelKey
}
n2 := a[j].modelPath
if n2 == "" {
n2 = a[j].modelKey
}
return n1 < n2
}
// TODO - future consideration to pick runners based on size
// type BySize []*runnerRef
// func (a BySize) Len() int { return len(a) }
// func (a BySize) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
// func (a BySize) Less(i, j int) bool { return a[i].vramSize < a[j].vramSize }
// findRunnerToUnload finds a runner to unload to make room for a new model
func (s *Scheduler) findRunnerToUnload() *runnerRef {
s.loadedMu.Lock()
runnerList := make([]*runnerRef, 0, len(s.loaded))
for _, r := range s.loaded {
runnerList = append(runnerList, r)
}
s.loadedMu.Unlock()
if len(runnerList) == 0 {
slog.Debug("no loaded runner to unload")
return nil
}
// In the future we can enhance the algorithm to be smarter about picking the optimal runner to unload
// e.g., if we have multiple options, will one make room for the request?
sort.Sort(ByDurationAndName(runnerList))
// First try to find a runner that's already idle
for _, runner := range runnerList {
runner.refMu.Lock()
rc := runner.refCount
runner.refMu.Unlock()
if rc == 0 {
slog.Debug("found an idle runner to unload", "runner", runner)
return runner
}
}
// None appear idle, just wait for the one with the shortest duration
slog.Debug("no idle runners, picking the shortest duration", "runner_count", len(runnerList), "runner", runnerList[0])
return runnerList[0]
}
func (s *Scheduler) unloadAllRunners() {
s.loadedMu.Lock()
defer s.loadedMu.Unlock()
if s.activeLoading != nil {
slog.Debug("shutting down currently loading runner")
s.activeLoading.Close()
s.activeLoading = nil
}
for model, runner := range s.loaded {
if runner.llama != nil {
slog.Debug("shutting down runner", "model", model)
runner.llama.Close()
}
}
}
func (s *Scheduler) expireRunner(model *Model) {
modelKey := schedulerModelKey(model)
s.loadedMu.Lock()
runner, ok := s.loaded[modelKey]
s.loadedMu.Unlock()
if ok {
runner.refMu.Lock()
runner.expiresAt = time.Now()
if runner.expireTimer != nil {
runner.expireTimer.Stop()
runner.expireTimer = nil
}
runner.sessionDuration = 0
if runner.refCount <= 0 {
s.expiredCh <- runner
}
runner.refMu.Unlock()
}
}

1018
server/sched_test.go Normal file

File diff suppressed because it is too large Load Diff

8
server/sparse_common.go Normal file
View File

@@ -0,0 +1,8 @@
//go:build !windows
package server
import "os"
func setSparse(*os.File) {
}

17
server/sparse_windows.go Normal file
View File

@@ -0,0 +1,17 @@
package server
import (
"os"
"golang.org/x/sys/windows"
)
func setSparse(file *os.File) {
// exFat (and other FS types) don't support sparse files, so ignore errors
windows.DeviceIoControl( //nolint:errcheck
windows.Handle(file.Fd()), windows.FSCTL_SET_SPARSE,
nil, 0,
nil, 0,
nil, nil,
)
}

15
server/test_home_test.go Normal file
View File

@@ -0,0 +1,15 @@
package server
import (
"testing"
"github.com/ollama/ollama/envconfig"
)
func setTestHome(t *testing.T, home string) {
t.Helper()
t.Setenv("HOME", home)
t.Setenv("USERPROFILE", home)
t.Setenv("OLLAMA_MODELS", "")
envconfig.ReloadServerConfig()
}

405
server/upload.go Normal file
View File

@@ -0,0 +1,405 @@
package server
import (
"context"
"crypto/md5"
"errors"
"fmt"
"hash"
"io"
"log/slog"
"math"
"net/http"
"net/url"
"os"
"strconv"
"sync"
"sync/atomic"
"time"
"golang.org/x/sync/errgroup"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/types/model"
)
var blobUploadManager sync.Map
type blobUpload struct {
manifest.Layer
Total int64
Completed atomic.Int64
Parts []blobUploadPart
nextURL chan *url.URL
context.CancelFunc
file *os.File
done bool
err error
references atomic.Int32
}
const (
numUploadParts = 16
minUploadPartSize int64 = 100 * format.MegaByte
maxUploadPartSize int64 = 1000 * format.MegaByte
)
func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *registryOptions) error {
p, err := manifest.BlobsPath(b.Digest)
if err != nil {
return err
}
if b.From != "" {
values := requestURL.Query()
values.Add("mount", b.Digest)
values.Add("from", model.ParseName(b.From).DisplayNamespaceModel())
requestURL.RawQuery = values.Encode()
}
resp, err := makeRequestWithRetry(ctx, http.MethodPost, requestURL, nil, nil, opts)
if err != nil {
return err
}
defer resp.Body.Close()
location := resp.Header.Get("Docker-Upload-Location")
if location == "" {
location = resp.Header.Get("Location")
}
fi, err := os.Stat(p)
if err != nil {
return err
}
b.Total = fi.Size()
// http.StatusCreated indicates a blob has been mounted
// ref: https://distribution.github.io/distribution/spec/api/#cross-repository-blob-mount
if resp.StatusCode == http.StatusCreated {
b.Completed.Store(b.Total)
b.done = true
return nil
}
size := b.Total / numUploadParts
switch {
case size < minUploadPartSize:
size = minUploadPartSize
case size > maxUploadPartSize:
size = maxUploadPartSize
}
var offset int64
for offset < fi.Size() {
if offset+size > fi.Size() {
size = fi.Size() - offset
}
// set part.N to the current number of parts
b.Parts = append(b.Parts, blobUploadPart{N: len(b.Parts), Offset: offset, Size: size})
offset += size
}
if len(b.Parts) > 0 {
slog.Info(fmt.Sprintf("uploading %s in %d %s part(s)", b.Digest[7:19], len(b.Parts), format.HumanBytes(b.Parts[0].Size)))
}
requestURL, err = url.Parse(location)
if err != nil {
return err
}
b.nextURL = make(chan *url.URL, 1)
b.nextURL <- requestURL
return nil
}
// Run uploads blob parts to the upstream. If the upstream supports redirection, parts will be uploaded
// in parallel as defined by Prepare. Otherwise, parts will be uploaded serially. Run sets b.err on error.
func (b *blobUpload) Run(ctx context.Context, opts *registryOptions) {
defer blobUploadManager.Delete(b.Digest)
ctx, b.CancelFunc = context.WithCancel(ctx)
p, err := manifest.BlobsPath(b.Digest)
if err != nil {
b.err = err
return
}
b.file, err = os.Open(p)
if err != nil {
b.err = err
return
}
defer b.file.Close()
g, inner := errgroup.WithContext(ctx)
g.SetLimit(numUploadParts)
for i := range b.Parts {
part := &b.Parts[i]
select {
case <-inner.Done():
case requestURL := <-b.nextURL:
g.Go(func() error {
var err error
for try := range maxRetries {
err = b.uploadPart(inner, http.MethodPatch, requestURL, part, opts)
switch {
case errors.Is(err, context.Canceled):
return err
case errors.Is(err, errMaxRetriesExceeded):
return err
case err != nil:
sleep := time.Second * time.Duration(math.Pow(2, float64(try)))
slog.Info(fmt.Sprintf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep))
time.Sleep(sleep)
continue
}
return nil
}
return fmt.Errorf("%w: %w", errMaxRetriesExceeded, err)
})
}
}
if err := g.Wait(); err != nil {
b.err = err
return
}
requestURL := <-b.nextURL
// calculate md5 checksum and add it to the commit request
md5sum := md5.New()
for _, part := range b.Parts {
md5sum.Write(part.Sum(nil))
}
values := requestURL.Query()
values.Add("digest", b.Digest)
values.Add("etag", fmt.Sprintf("%x-%d", md5sum.Sum(nil), len(b.Parts)))
requestURL.RawQuery = values.Encode()
headers := make(http.Header)
headers.Set("Content-Type", "application/octet-stream")
headers.Set("Content-Length", "0")
for try := range maxRetries {
var resp *http.Response
resp, err = makeRequestWithRetry(ctx, http.MethodPut, requestURL, headers, nil, opts)
if errors.Is(err, context.Canceled) {
break
} else if err != nil {
sleep := time.Second * time.Duration(math.Pow(2, float64(try)))
slog.Info(fmt.Sprintf("%s complete upload attempt %d failed: %v, retrying in %s", b.Digest[7:19], try, err, sleep))
time.Sleep(sleep)
continue
}
defer resp.Body.Close()
break
}
b.err = err
b.done = true
}
func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *url.URL, part *blobUploadPart, opts *registryOptions) error {
headers := make(http.Header)
headers.Set("Content-Type", "application/octet-stream")
headers.Set("Content-Length", strconv.FormatInt(part.Size, 10))
if method == http.MethodPatch {
headers.Set("X-Redirect-Uploads", "1")
headers.Set("Content-Range", fmt.Sprintf("%d-%d", part.Offset, part.Offset+part.Size-1))
}
sr := io.NewSectionReader(b.file, part.Offset, part.Size)
md5sum := md5.New()
w := &progressWriter{blobUpload: b}
resp, err := makeRequest(ctx, method, requestURL, headers, io.TeeReader(sr, io.MultiWriter(w, md5sum)), opts)
if err != nil {
w.Rollback()
return err
}
defer resp.Body.Close()
location := resp.Header.Get("Docker-Upload-Location")
if location == "" {
location = resp.Header.Get("Location")
}
nextURL, err := url.Parse(location)
if err != nil {
w.Rollback()
return err
}
switch {
case resp.StatusCode == http.StatusTemporaryRedirect:
w.Rollback()
b.nextURL <- nextURL
redirectURL, err := resp.Location()
if err != nil {
return err
}
// retry uploading to the redirect URL
for try := range maxRetries {
err = b.uploadPart(ctx, http.MethodPut, redirectURL, part, &registryOptions{})
switch {
case errors.Is(err, context.Canceled):
return err
case errors.Is(err, errMaxRetriesExceeded):
return err
case err != nil:
sleep := time.Second * time.Duration(math.Pow(2, float64(try)))
slog.Info(fmt.Sprintf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep))
time.Sleep(sleep)
continue
}
return nil
}
return fmt.Errorf("%w: %w", errMaxRetriesExceeded, err)
case resp.StatusCode == http.StatusUnauthorized:
w.Rollback()
challenge := parseRegistryChallenge(resp.Header.Get("www-authenticate"))
token, err := getAuthorizationToken(ctx, challenge, requestURL.Host)
if err != nil {
return err
}
opts.Token = token
fallthrough
case resp.StatusCode >= http.StatusBadRequest:
w.Rollback()
body, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
return fmt.Errorf("http status %s: %s", resp.Status, body)
}
if method == http.MethodPatch {
b.nextURL <- nextURL
}
part.Hash = md5sum
return nil
}
func (b *blobUpload) acquire() {
b.references.Add(1)
}
func (b *blobUpload) release() {
if b.references.Add(-1) == 0 {
b.CancelFunc()
}
}
func (b *blobUpload) Wait(ctx context.Context, fn func(api.ProgressResponse)) error {
b.acquire()
defer b.release()
ticker := time.NewTicker(60 * time.Millisecond)
for {
select {
case <-ticker.C:
case <-ctx.Done():
return ctx.Err()
}
fn(api.ProgressResponse{
Status: fmt.Sprintf("pushing %s", b.Digest[7:19]),
Digest: b.Digest,
Total: b.Total,
Completed: b.Completed.Load(),
})
if b.done || b.err != nil {
return b.err
}
}
}
type blobUploadPart struct {
// N is the part number
N int
Offset int64
Size int64
hash.Hash
}
type progressWriter struct {
written int64
*blobUpload
}
func (p *progressWriter) Write(b []byte) (n int, err error) {
n = len(b)
p.written += int64(n)
p.Completed.Add(int64(n))
return n, nil
}
func (p *progressWriter) Rollback() {
p.Completed.Add(-p.written)
p.written = 0
}
func uploadBlob(ctx context.Context, n model.Name, layer manifest.Layer, opts *registryOptions, fn func(api.ProgressResponse)) error {
requestURL := n.BaseURL()
requestURL = requestURL.JoinPath("v2", n.DisplayNamespaceModel(), "blobs", layer.Digest)
resp, err := makeRequestWithRetry(ctx, http.MethodHead, requestURL, nil, nil, opts)
switch {
case errors.Is(err, os.ErrNotExist):
case err != nil:
return err
default:
defer resp.Body.Close()
fn(api.ProgressResponse{
Status: fmt.Sprintf("pushing %s", layer.Digest[7:19]),
Digest: layer.Digest,
Total: layer.Size,
Completed: layer.Size,
})
return nil
}
data, ok := blobUploadManager.LoadOrStore(layer.Digest, &blobUpload{Layer: layer})
upload := data.(*blobUpload)
if !ok {
requestURL := n.BaseURL()
requestURL = requestURL.JoinPath("v2", n.DisplayNamespaceModel(), "blobs/uploads/")
if err := upload.Prepare(ctx, requestURL, opts); err != nil {
blobUploadManager.Delete(layer.Digest)
return err
}
//nolint:contextcheck
go upload.Run(context.Background(), opts)
}
return upload.Wait(ctx, fn)
}